mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
gpt_sovits_v3
gpt_sovits_v3
This commit is contained in:
parent
190dd6198c
commit
c8f7604ba7
21
GPT_SoVITS/BigVGAN/LICENSE
Normal file
21
GPT_SoVITS/BigVGAN/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
266
GPT_SoVITS/BigVGAN/README.md
Normal file
266
GPT_SoVITS/BigVGAN/README.md
Normal file
@ -0,0 +1,266 @@
|
||||
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
||||
|
||||
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
|
||||
|
||||
[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
|
||||
|
||||
[](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
|
||||
|
||||
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
|
||||
|
||||
## News
|
||||
- **Sep 2024 (v2.4):**
|
||||
- We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
|
||||
|
||||
- **Jul 2024 (v2.3):**
|
||||
- General refactor and code improvements for improved readability.
|
||||
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
|
||||
|
||||
- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
|
||||
|
||||
- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
|
||||
|
||||
- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
|
||||
- Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
|
||||
- Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
|
||||
- Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
|
||||
- We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
|
||||
|
||||
## Installation
|
||||
|
||||
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
|
||||
|
||||
```shell
|
||||
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||
conda activate bigvgan
|
||||
```
|
||||
|
||||
Clone the repository and install dependencies:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/NVIDIA/BigVGAN
|
||||
cd BigVGAN
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Inference Quickstart using 🤗 Hugging Face Hub
|
||||
|
||||
Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
|
||||
|
||||
```python
|
||||
device = 'cuda'
|
||||
|
||||
import torch
|
||||
import bigvgan
|
||||
import librosa
|
||||
from meldataset import get_mel_spectrogram
|
||||
|
||||
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
|
||||
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
|
||||
|
||||
# remove weight norm in the model and set to eval mode
|
||||
model.remove_weight_norm()
|
||||
model = model.eval().to(device)
|
||||
|
||||
# load wav file and compute mel spectrogram
|
||||
wav_path = '/path/to/your/audio.wav'
|
||||
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
|
||||
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
|
||||
|
||||
# compute mel spectrogram from the ground truth audio
|
||||
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
|
||||
|
||||
# generate waveform from mel
|
||||
with torch.inference_mode():
|
||||
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
|
||||
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
|
||||
|
||||
# you can convert the generated waveform to 16 bit linear PCM
|
||||
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
|
||||
```
|
||||
|
||||
## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
|
||||
|
||||
You can run a local gradio demo using below command:
|
||||
|
||||
```python
|
||||
pip install -r demo/requirements.txt
|
||||
python demo/app.py
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
|
||||
|
||||
```shell
|
||||
cd filelists/LibriTTS && \
|
||||
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
|
||||
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
|
||||
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
|
||||
ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
|
||||
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
|
||||
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
|
||||
ln -s /path/to/your/LibriTTS/test-other test-other && \
|
||||
cd ../..
|
||||
```
|
||||
|
||||
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
|
||||
|
||||
```shell
|
||||
python train.py \
|
||||
--config configs/bigvgan_v2_24khz_100band_256x.json \
|
||||
--input_wavs_dir filelists/LibriTTS \
|
||||
--input_training_file filelists/LibriTTS/train-full.txt \
|
||||
--input_validation_file filelists/LibriTTS/val-full.txt \
|
||||
--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
|
||||
--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
|
||||
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
|
||||
```
|
||||
|
||||
## Synthesis
|
||||
|
||||
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
|
||||
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
|
||||
|
||||
```shell
|
||||
python inference.py \
|
||||
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
||||
--input_wavs_dir /path/to/your/input_wav \
|
||||
--output_dir /path/to/your/output_wav
|
||||
```
|
||||
|
||||
`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
|
||||
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
|
||||
|
||||
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
|
||||
|
||||
```shell
|
||||
python inference_e2e.py \
|
||||
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
||||
--input_mels_dir /path/to/your/input_mel \
|
||||
--output_dir /path/to/your/output_wav
|
||||
```
|
||||
|
||||
## Using Custom CUDA Kernel for Synthesis
|
||||
|
||||
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
|
||||
|
||||
```python
|
||||
generator = BigVGAN(h, use_cuda_kernel=True)
|
||||
```
|
||||
|
||||
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
|
||||
|
||||
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
|
||||
|
||||
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
|
||||
|
||||
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
|
||||
|
||||
```python
|
||||
python tests/test_cuda_vs_torch_model.py \
|
||||
--checkpoint_file /path/to/your/bigvgan_generator.pt
|
||||
```
|
||||
|
||||
```shell
|
||||
loading plain Pytorch BigVGAN
|
||||
...
|
||||
loading CUDA kernel BigVGAN with auto-build
|
||||
Detected CUDA files, patching ldflags
|
||||
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
|
||||
Building extension module anti_alias_activation_cuda...
|
||||
...
|
||||
Loading extension module anti_alias_activation_cuda...
|
||||
...
|
||||
Loading '/path/to/your/bigvgan_generator.pt'
|
||||
...
|
||||
[Success] test CUDA fused vs. plain torch BigVGAN inference
|
||||
> mean_difference=0.0007238413265440613
|
||||
...
|
||||
```
|
||||
|
||||
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
|
||||
|
||||
## Pretrained Models
|
||||
|
||||
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
|
||||
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
|
||||
|
||||
| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
|
||||
|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
|
||||
| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
|
||||
| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
|
||||
| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
||||
| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
||||
|
||||
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
|
||||
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
|
||||
Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
|
||||
|
||||
You can fine-tune the models by:
|
||||
|
||||
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
|
||||
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
|
||||
|
||||
## Training Details of BigVGAN-v2
|
||||
|
||||
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
|
||||
|
||||
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
|
||||
|
||||
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
|
||||
|
||||
## Evaluation Results of BigVGAN-v2
|
||||
|
||||
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
|
||||
|
||||
| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
|
||||
|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
|
||||
| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
|
||||
| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
|
||||
| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
|
||||
| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
|
||||
|
||||
## Speed Benchmark
|
||||
|
||||
Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
|
||||
|
||||
| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
|
||||
|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
|
||||
| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
|
||||
| | | True | 3916.5 | 163.2x | 1.3 |
|
||||
| | 2048 | False | 1899.6 | 79.2x | 1.7 |
|
||||
| | | True | 5330.1 | 222.1x | 1.7 |
|
||||
| | 16384 | False | 1973.8 | 82.2x | 5.0 |
|
||||
| | | True | 5761.7 | 240.1x | 4.4 |
|
||||
| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
|
||||
| | | True | 1598.1 | 66.6x | 1.3 |
|
||||
| | 2048 | False | 929.9 | 38.7x | 1.7 |
|
||||
| | | True | 1971.3 | 82.1x | 1.6 |
|
||||
| | 16384 | False | 943.4 | 39.3x | 5.0 |
|
||||
| | | True | 2026.5 | 84.4x | 3.9 |
|
||||
| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
|
||||
| | | True | 811.3 | 33.8x | 1.3 |
|
||||
| | 2048 | False | 576.5 | 24.0x | 1.7 |
|
||||
| | | True | 1023.0 | 42.6x | 1.5 |
|
||||
| | 16384 | False | 589.4 | 24.6x | 5.0 |
|
||||
| | | True | 1068.1 | 44.5x | 3.2 |
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
|
||||
|
||||
## References
|
||||
|
||||
- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
|
||||
- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
|
||||
- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
|
||||
- [Julius](https://github.com/adefossez/julius) (for low-pass filter)
|
||||
- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
|
||||
- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
|
||||
- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
|
126
GPT_SoVITS/BigVGAN/activations.py
Normal file
126
GPT_SoVITS/BigVGAN/activations.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
"""
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# Initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # Linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
"""
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# Initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # Linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
"""
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
493
GPT_SoVITS/BigVGAN/bigvgan.py
Normal file
493
GPT_SoVITS/BigVGAN/bigvgan.py
Normal file
@ -0,0 +1,493 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
import activations
|
||||
from utils0 import init_weights, get_padding
|
||||
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from env import AttrDict
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
|
||||
|
||||
def load_hparams_from_json(path) -> AttrDict:
|
||||
with open(path) as f:
|
||||
data = f.read()
|
||||
return AttrDict(json.loads(data))
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
"""
|
||||
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
||||
|
||||
Args:
|
||||
h (AttrDict): Hyperparameters.
|
||||
channels (int): Number of convolution channels.
|
||||
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
h: AttrDict,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: tuple = (1, 3, 5),
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=get_padding(kernel_size, d),
|
||||
)
|
||||
)
|
||||
for d in dilation
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
)
|
||||
for _ in range(len(dilation))
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(
|
||||
self.convs2
|
||||
) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
Activation1d = TorchActivation1d
|
||||
|
||||
# Activation functions
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class AMPBlock2(torch.nn.Module):
|
||||
"""
|
||||
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
||||
|
||||
Args:
|
||||
h (AttrDict): Hyperparameters.
|
||||
channels (int): Number of convolution channels.
|
||||
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
h: AttrDict,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: tuple = (1, 3, 5),
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=get_padding(kernel_size, d),
|
||||
)
|
||||
)
|
||||
for d in dilation
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
Activation1d = TorchActivation1d
|
||||
|
||||
# Activation functions
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class BigVGAN(
|
||||
torch.nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# library_name="bigvgan",
|
||||
# repo_url="https://github.com/NVIDIA/BigVGAN",
|
||||
# docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
||||
# pipeline_tag="audio-to-audio",
|
||||
# license="mit",
|
||||
# tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
||||
):
|
||||
"""
|
||||
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
||||
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
||||
|
||||
Args:
|
||||
h (AttrDict): Hyperparameters.
|
||||
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
||||
|
||||
Note:
|
||||
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
||||
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
||||
"""
|
||||
|
||||
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.h["use_cuda_kernel"] = use_cuda_kernel
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
Activation1d = TorchActivation1d
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# Pre-conv
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
|
||||
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
if h.resblock == "1":
|
||||
resblock_class = AMPBlock1
|
||||
elif h.resblock == "2":
|
||||
resblock_class = AMPBlock2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
||||
)
|
||||
|
||||
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(
|
||||
resblock_class(h, ch, k, d, activation=h.activation)
|
||||
)
|
||||
|
||||
# Post-conv
|
||||
activation_post = (
|
||||
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snake"
|
||||
else (
|
||||
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snakebeta"
|
||||
else None
|
||||
)
|
||||
)
|
||||
if activation_post is None:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
|
||||
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
||||
)
|
||||
|
||||
# Weight initialization
|
||||
for i in range(len(self.ups)):
|
||||
self.ups[i].apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
# Final tanh activation. Defaults to True for backward compatibility
|
||||
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
||||
|
||||
def forward(self, x):
|
||||
# Pre-conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
# Upsampling
|
||||
for i_up in range(len(self.ups[i])):
|
||||
x = self.ups[i][i_up](x)
|
||||
# AMP blocks
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
# Post-conv
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
# Final tanh activation
|
||||
if self.use_tanh_at_final:
|
||||
x = torch.tanh(x)
|
||||
else:
|
||||
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
try:
|
||||
# print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
for l_i in l:
|
||||
remove_weight_norm(l_i)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
except ValueError:
|
||||
print("[INFO] Model already removed weight norm. Skipping!")
|
||||
pass
|
||||
|
||||
# Additional methods for huggingface_hub support
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Save weights and config.json from a Pytorch model to a local directory."""
|
||||
|
||||
model_path = save_directory / "bigvgan_generator.pt"
|
||||
torch.save({"generator": self.state_dict()}, model_path)
|
||||
|
||||
config_path = save_directory / "config.json"
|
||||
with open(config_path, "w") as config_file:
|
||||
json.dump(self.h, config_file, indent=4)
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: str,
|
||||
cache_dir: str,
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: bool,
|
||||
local_files_only: bool,
|
||||
token: Union[str, bool, None],
|
||||
map_location: str = "cpu", # Additional argument
|
||||
strict: bool = False, # Additional argument
|
||||
use_cuda_kernel: bool = False,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""Load Pytorch pretrained weights and return the loaded model."""
|
||||
|
||||
# Download and load hyperparameters (h) used by BigVGAN
|
||||
if os.path.isdir(model_id):
|
||||
# print("Loading config.json from local directory")
|
||||
config_file = os.path.join(model_id, "config.json")
|
||||
else:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="config.json",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
h = load_hparams_from_json(config_file)
|
||||
|
||||
# instantiate BigVGAN using h
|
||||
if use_cuda_kernel:
|
||||
print(
|
||||
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
)
|
||||
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||
|
||||
# Download and load pretrained generator weight
|
||||
if os.path.isdir(model_id):
|
||||
# print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
||||
else:
|
||||
# print(f"Loading weights from {model_id}")
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="bigvgan_generator.pt",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
||||
|
||||
try:
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
except RuntimeError:
|
||||
print(
|
||||
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
)
|
||||
model.remove_weight_norm()
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
|
||||
return model
|
651
GPT_SoVITS/BigVGAN/discriminators.py
Normal file
651
GPT_SoVITS/BigVGAN/discriminators.py
Normal file
@ -0,0 +1,651 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv2d
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
from torchaudio.transforms import Spectrogram, Resample
|
||||
|
||||
from env import AttrDict
|
||||
from utils import get_padding
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
h: AttrDict,
|
||||
period: List[int],
|
||||
kernel_size: int = 5,
|
||||
stride: int = 3,
|
||||
use_spectral_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
self.d_mult = h.discriminator_channel_mult
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1,
|
||||
int(32 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(128 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(128 * self.d_mult),
|
||||
int(512 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(512 * self.d_mult),
|
||||
int(1024 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(1024 * self.d_mult),
|
||||
int(1024 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
1,
|
||||
padding=(2, 0),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, 0.1)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, h: AttrDict):
|
||||
super().__init__()
|
||||
self.mpd_reshapes = h.mpd_reshapes
|
||||
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
|
||||
for rs in self.mpd_reshapes
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorR(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, resolution: List[List[int]]):
|
||||
super().__init__()
|
||||
|
||||
self.resolution = resolution
|
||||
assert (
|
||||
len(self.resolution) == 3
|
||||
), f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
self.lrelu_slope = 0.1
|
||||
|
||||
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||
print(
|
||||
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
|
||||
)
|
||||
norm_f = (
|
||||
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
)
|
||||
self.d_mult = cfg.discriminator_channel_mult
|
||||
if hasattr(cfg, "mrd_channel_mult"):
|
||||
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||
self.d_mult = cfg.mrd_channel_mult
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 9),
|
||||
stride=(1, 2),
|
||||
padding=(1, 4),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 9),
|
||||
stride=(1, 2),
|
||||
padding=(1, 4),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 9),
|
||||
stride=(1, 2),
|
||||
padding=(1, 4),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 3),
|
||||
padding=(1, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
x = F.pad(
|
||||
x,
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
x = x.squeeze(1)
|
||||
x = torch.stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=False,
|
||||
return_complex=True,
|
||||
)
|
||||
x = torch.view_as_real(x) # [B, F, TT, 2]
|
||||
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
||||
|
||||
return mag
|
||||
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert (
|
||||
len(self.resolutions) == 3
|
||||
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(x=y)
|
||||
y_d_g, fmap_g = d(x=y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
||||
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorB(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
window_length: int,
|
||||
channels: int = 32,
|
||||
hop_factor: float = 0.25,
|
||||
bands: Tuple[Tuple[float, float], ...] = (
|
||||
(0.0, 0.1),
|
||||
(0.1, 0.25),
|
||||
(0.25, 0.5),
|
||||
(0.5, 0.75),
|
||||
(0.75, 1.0),
|
||||
),
|
||||
):
|
||||
super().__init__()
|
||||
self.window_length = window_length
|
||||
self.hop_factor = hop_factor
|
||||
self.spec_fn = Spectrogram(
|
||||
n_fft=window_length,
|
||||
hop_length=int(window_length * hop_factor),
|
||||
win_length=window_length,
|
||||
power=None,
|
||||
)
|
||||
n_fft = window_length // 2 + 1
|
||||
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||
self.bands = bands
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
||||
),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
||||
)
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
# Remove DC offset
|
||||
x = x - x.mean(dim=-1, keepdims=True)
|
||||
# Peak normalize the volume of input audio
|
||||
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||
x = self.spec_fn(x)
|
||||
x = torch.view_as_real(x)
|
||||
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
x_bands = self.spectrogram(x.squeeze(1))
|
||||
fmap = []
|
||||
x = []
|
||||
|
||||
for band, stack in zip(x_bands, self.band_convs):
|
||||
for i, layer in enumerate(stack):
|
||||
band = layer(band)
|
||||
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||
if i > 0:
|
||||
fmap.append(band)
|
||||
x.append(band)
|
||||
|
||||
x = torch.cat(x, dim=-1)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
||||
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class MultiBandDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
h,
|
||||
):
|
||||
"""
|
||||
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
||||
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
||||
"""
|
||||
super().__init__()
|
||||
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y)
|
||||
y_d_g, fmap_g = d(x=y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorCQT(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.filters = cfg["cqtd_filters"]
|
||||
self.max_filters = cfg["cqtd_max_filters"]
|
||||
self.filters_scale = cfg["cqtd_filters_scale"]
|
||||
self.kernel_size = (3, 9)
|
||||
self.dilations = cfg["cqtd_dilations"]
|
||||
self.stride = (1, 2)
|
||||
|
||||
self.in_channels = cfg["cqtd_in_channels"]
|
||||
self.out_channels = cfg["cqtd_out_channels"]
|
||||
self.fs = cfg["sampling_rate"]
|
||||
self.hop_length = hop_length
|
||||
self.n_octaves = n_octaves
|
||||
self.bins_per_octave = bins_per_octave
|
||||
|
||||
# Lazy-load
|
||||
from nnAudio import features
|
||||
|
||||
self.cqt_transform = features.cqt.CQT2010v2(
|
||||
sr=self.fs * 2,
|
||||
hop_length=self.hop_length,
|
||||
n_bins=self.bins_per_octave * self.n_octaves,
|
||||
bins_per_octave=self.bins_per_octave,
|
||||
output_format="Complex",
|
||||
pad_mode="constant",
|
||||
)
|
||||
|
||||
self.conv_pres = nn.ModuleList()
|
||||
for _ in range(self.n_octaves):
|
||||
self.conv_pres.append(
|
||||
nn.Conv2d(
|
||||
self.in_channels * 2,
|
||||
self.in_channels * 2,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.get_2d_padding(self.kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
|
||||
self.convs.append(
|
||||
nn.Conv2d(
|
||||
self.in_channels * 2,
|
||||
self.filters,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.get_2d_padding(self.kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
out_chs = min(
|
||||
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
||||
)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=(dilation, 1),
|
||||
padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
|
||||
)
|
||||
)
|
||||
)
|
||||
in_chs = out_chs
|
||||
out_chs = min(
|
||||
(self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
|
||||
self.max_filters,
|
||||
)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding(
|
||||
(self.kernel_size[0], self.kernel_size[0])
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(
|
||||
out_chs,
|
||||
self.out_channels,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||
)
|
||||
)
|
||||
|
||||
self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
|
||||
self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
|
||||
|
||||
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
||||
if self.cqtd_normalize_volume:
|
||||
print(
|
||||
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
)
|
||||
|
||||
def get_2d_padding(
|
||||
self,
|
||||
kernel_size: typing.Tuple[int, int],
|
||||
dilation: typing.Tuple[int, int] = (1, 1),
|
||||
):
|
||||
return (
|
||||
((kernel_size[0] - 1) * dilation[0]) // 2,
|
||||
((kernel_size[1] - 1) * dilation[1]) // 2,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
|
||||
if self.cqtd_normalize_volume:
|
||||
# Remove DC offset
|
||||
x = x - x.mean(dim=-1, keepdims=True)
|
||||
# Peak normalize the volume of input audio
|
||||
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||
|
||||
x = self.resample(x)
|
||||
|
||||
z = self.cqt_transform(x)
|
||||
|
||||
z_amplitude = z[:, :, :, 0].unsqueeze(1)
|
||||
z_phase = z[:, :, :, 1].unsqueeze(1)
|
||||
|
||||
z = torch.cat([z_amplitude, z_phase], dim=1)
|
||||
z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
|
||||
|
||||
latent_z = []
|
||||
for i in range(self.n_octaves):
|
||||
latent_z.append(
|
||||
self.conv_pres[i](
|
||||
z[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
|
||||
]
|
||||
)
|
||||
)
|
||||
latent_z = torch.cat(latent_z, dim=-1)
|
||||
|
||||
for i, l in enumerate(self.convs):
|
||||
latent_z = l(latent_z)
|
||||
|
||||
latent_z = self.activation(latent_z)
|
||||
fmap.append(latent_z)
|
||||
|
||||
latent_z = self.conv_post(latent_z)
|
||||
|
||||
return latent_z, fmap
|
||||
|
||||
|
||||
class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
def __init__(self, cfg: AttrDict):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
# Using get with defaults
|
||||
self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
|
||||
self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
|
||||
self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
|
||||
self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
|
||||
self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
|
||||
self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
|
||||
# Multi-scale params to loop over
|
||||
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
|
||||
"cqtd_bins_per_octaves", [24, 36, 48]
|
||||
)
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorCQT(
|
||||
self.cfg,
|
||||
hop_length=self.cfg["cqtd_hop_lengths"][i],
|
||||
n_octaves=self.cfg["cqtd_n_octaves"][i],
|
||||
bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
|
||||
)
|
||||
for i in range(len(self.cfg["cqtd_hop_lengths"]))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for disc in self.discriminators:
|
||||
y_d_r, fmap_r = disc(y)
|
||||
y_d_g, fmap_g = disc(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class CombinedDiscriminator(nn.Module):
|
||||
"""
|
||||
Wrapper of chaining multiple discrimiantor architectures.
|
||||
Example: combine mbd and cqtd as a single class
|
||||
"""
|
||||
|
||||
def __init__(self, list_discriminator: List[nn.Module]):
|
||||
super().__init__()
|
||||
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for disc in self.discrimiantor:
|
||||
y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
|
||||
y_d_rs.extend(y_d_r)
|
||||
fmap_rs.extend(fmap_r)
|
||||
y_d_gs.extend(y_d_g)
|
||||
fmap_gs.extend(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
18
GPT_SoVITS/BigVGAN/env.py
Normal file
18
GPT_SoVITS/BigVGAN/env.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
89
GPT_SoVITS/BigVGAN/inference.py
Normal file
89
GPT_SoVITS/BigVGAN/inference.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import librosa
|
||||
from utils import load_checkpoint
|
||||
from meldataset import get_mel_spectrogram
|
||||
from scipy.io.wavfile import write
|
||||
from env import AttrDict
|
||||
from meldataset import MAX_WAV_VALUE
|
||||
from bigvgan import BigVGAN as Generator
|
||||
|
||||
h = None
|
||||
device = None
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def inference(a, h):
|
||||
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
|
||||
|
||||
state_dict_g = load_checkpoint(a.checkpoint_file, device)
|
||||
generator.load_state_dict(state_dict_g["generator"])
|
||||
|
||||
filelist = os.listdir(a.input_wavs_dir)
|
||||
|
||||
os.makedirs(a.output_dir, exist_ok=True)
|
||||
|
||||
generator.eval()
|
||||
generator.remove_weight_norm()
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the ground truth audio and resample if necessary
|
||||
wav, sr = librosa.load(
|
||||
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
|
||||
)
|
||||
wav = torch.FloatTensor(wav).to(device)
|
||||
# Compute mel spectrogram from the ground truth audio
|
||||
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||
|
||||
y_g_hat = generator(x)
|
||||
|
||||
audio = y_g_hat.squeeze()
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
|
||||
)
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
def main():
|
||||
print("Initializing Inference Process..")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_wavs_dir", default="test_files")
|
||||
parser.add_argument("--output_dir", default="generated_files")
|
||||
parser.add_argument("--checkpoint_file", required=True)
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||
|
||||
a = parser.parse_args()
|
||||
|
||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||
with open(config_file) as f:
|
||||
data = f.read()
|
||||
|
||||
global h
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
|
||||
torch.manual_seed(h.seed)
|
||||
global device
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
inference(a, h)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
102
GPT_SoVITS/BigVGAN/inference_e2e.py
Normal file
102
GPT_SoVITS/BigVGAN/inference_e2e.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import glob
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
from scipy.io.wavfile import write
|
||||
from env import AttrDict
|
||||
from meldataset import MAX_WAV_VALUE
|
||||
from bigvgan import BigVGAN as Generator
|
||||
|
||||
h = None
|
||||
device = None
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print(f"Loading '{filepath}'")
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + "*")
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return ""
|
||||
return sorted(cp_list)[-1]
|
||||
|
||||
|
||||
def inference(a, h):
|
||||
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
|
||||
|
||||
state_dict_g = load_checkpoint(a.checkpoint_file, device)
|
||||
generator.load_state_dict(state_dict_g["generator"])
|
||||
|
||||
filelist = os.listdir(a.input_mels_dir)
|
||||
|
||||
os.makedirs(a.output_dir, exist_ok=True)
|
||||
|
||||
generator.eval()
|
||||
generator.remove_weight_norm()
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the mel spectrogram in .npy format
|
||||
x = np.load(os.path.join(a.input_mels_dir, filname))
|
||||
x = torch.FloatTensor(x).to(device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
y_g_hat = generator(x)
|
||||
|
||||
audio = y_g_hat.squeeze()
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
|
||||
)
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
def main():
|
||||
print("Initializing Inference Process..")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_mels_dir", default="test_mel_files")
|
||||
parser.add_argument("--output_dir", default="generated_files_from_mel")
|
||||
parser.add_argument("--checkpoint_file", required=True)
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||
|
||||
a = parser.parse_args()
|
||||
|
||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||
with open(config_file) as f:
|
||||
data = f.read()
|
||||
|
||||
global h
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
|
||||
torch.manual_seed(h.seed)
|
||||
global device
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
inference(a, h)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
254
GPT_SoVITS/BigVGAN/loss.py
Normal file
254
GPT_SoVITS/BigVGAN/loss.py
Normal file
@ -0,0 +1,254 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy import signal
|
||||
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import functools
|
||||
|
||||
|
||||
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
"""Compute distance between mel spectrograms. Can be used
|
||||
in a multi-scale way.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_mels : List[int]
|
||||
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
|
||||
window_lengths : List[int], optional
|
||||
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
|
||||
loss_fn : typing.Callable, optional
|
||||
How to compare each loss, by default nn.L1Loss()
|
||||
clamp_eps : float, optional
|
||||
Clamp on the log magnitude, below, by default 1e-5
|
||||
mag_weight : float, optional
|
||||
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
|
||||
log_weight : float, optional
|
||||
Weight of log magnitude portion of loss, by default 1.0
|
||||
pow : float, optional
|
||||
Power to raise magnitude to before taking log, by default 1.0
|
||||
weight : float, optional
|
||||
Weight of this loss, by default 1.0
|
||||
match_stride : bool, optional
|
||||
Whether to match the stride of convolutional layers, by default False
|
||||
|
||||
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
||||
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate: int,
|
||||
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
||||
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
||||
loss_fn: typing.Callable = nn.L1Loss(),
|
||||
clamp_eps: float = 1e-5,
|
||||
mag_weight: float = 0.0,
|
||||
log_weight: float = 1.0,
|
||||
pow: float = 1.0,
|
||||
weight: float = 1.0,
|
||||
match_stride: bool = False,
|
||||
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
|
||||
mel_fmax: List[float] = [None, None, None, None, None, None, None],
|
||||
window_type: str = "hann",
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
STFTParams = namedtuple(
|
||||
"STFTParams",
|
||||
["window_length", "hop_length", "window_type", "match_stride"],
|
||||
)
|
||||
|
||||
self.stft_params = [
|
||||
STFTParams(
|
||||
window_length=w,
|
||||
hop_length=w // 4,
|
||||
match_stride=match_stride,
|
||||
window_type=window_type,
|
||||
)
|
||||
for w in window_lengths
|
||||
]
|
||||
self.n_mels = n_mels
|
||||
self.loss_fn = loss_fn
|
||||
self.clamp_eps = clamp_eps
|
||||
self.log_weight = log_weight
|
||||
self.mag_weight = mag_weight
|
||||
self.weight = weight
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.pow = pow
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def get_window(
|
||||
window_type,
|
||||
window_length,
|
||||
):
|
||||
return signal.get_window(window_type, window_length)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
|
||||
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
|
||||
def mel_spectrogram(
|
||||
self,
|
||||
wav,
|
||||
n_mels,
|
||||
fmin,
|
||||
fmax,
|
||||
window_length,
|
||||
hop_length,
|
||||
match_stride,
|
||||
window_type,
|
||||
):
|
||||
"""
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||
"""
|
||||
B, C, T = wav.shape
|
||||
|
||||
if match_stride:
|
||||
assert (
|
||||
hop_length == window_length // 4
|
||||
), "For match_stride, hop must equal n_fft // 4"
|
||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||
pad = (window_length - hop_length) // 2
|
||||
else:
|
||||
right_pad = 0
|
||||
pad = 0
|
||||
|
||||
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
|
||||
|
||||
window = self.get_window(window_type, window_length)
|
||||
window = torch.from_numpy(window).to(wav.device).float()
|
||||
|
||||
stft = torch.stft(
|
||||
wav.reshape(-1, T),
|
||||
n_fft=window_length,
|
||||
hop_length=hop_length,
|
||||
window=window,
|
||||
return_complex=True,
|
||||
center=True,
|
||||
)
|
||||
_, nf, nt = stft.shape
|
||||
stft = stft.reshape(B, C, nf, nt)
|
||||
if match_stride:
|
||||
"""
|
||||
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
|
||||
"""
|
||||
stft = stft[..., 2:-2]
|
||||
magnitude = torch.abs(stft)
|
||||
|
||||
nf = magnitude.shape[2]
|
||||
mel_basis = self.get_mel_filters(
|
||||
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
||||
)
|
||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||
|
||||
return mel_spectrogram
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes mel loss between an estimate and a reference
|
||||
signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Estimate signal
|
||||
y : torch.Tensor
|
||||
Reference signal
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Mel loss.
|
||||
"""
|
||||
|
||||
loss = 0.0
|
||||
for n_mels, fmin, fmax, s in zip(
|
||||
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
||||
):
|
||||
kwargs = {
|
||||
"n_mels": n_mels,
|
||||
"fmin": fmin,
|
||||
"fmax": fmax,
|
||||
"window_length": s.window_length,
|
||||
"hop_length": s.hop_length,
|
||||
"match_stride": s.match_stride,
|
||||
"window_type": s.window_type,
|
||||
}
|
||||
|
||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||
x_logmels = torch.log(
|
||||
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(
|
||||
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
|
||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# Loss functions
|
||||
def feature_loss(
|
||||
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
|
||||
|
||||
|
||||
def discriminator_loss(
|
||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(
|
||||
disc_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
396
GPT_SoVITS/BigVGAN/meldataset.py
Normal file
396
GPT_SoVITS/BigVGAN/meldataset.py
Normal file
@ -0,0 +1,396 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
import librosa
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
import pathlib
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple, Optional
|
||||
from env import AttrDict
|
||||
|
||||
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
return dynamic_range_compression_torch(magnitudes)
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
return dynamic_range_decompression_torch(magnitudes)
|
||||
|
||||
|
||||
mel_basis_cache = {}
|
||||
hann_window_cache = {}
|
||||
|
||||
|
||||
def mel_spectrogram(
|
||||
y: torch.Tensor,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
sampling_rate: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: int,
|
||||
fmax: int = None,
|
||||
center: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mel spectrogram of an input signal.
|
||||
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Input signal.
|
||||
n_fft (int): FFT size.
|
||||
num_mels (int): Number of mel bins.
|
||||
sampling_rate (int): Sampling rate of the input signal.
|
||||
hop_size (int): Hop size for STFT.
|
||||
win_size (int): Window size for STFT.
|
||||
fmin (int): Minimum frequency for mel filterbank.
|
||||
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
||||
center (bool): Whether to pad the input to center the frames. Default is False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel spectrogram.
|
||||
"""
|
||||
if torch.min(y) < -1.0:
|
||||
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
||||
if torch.max(y) > 1.0:
|
||||
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
||||
|
||||
device = y.device
|
||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||
|
||||
mel_basis = mel_basis_cache[key]
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_size) // 2
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (padding, padding), mode="reflect"
|
||||
).squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||
|
||||
mel_spec = torch.matmul(mel_basis, spec)
|
||||
mel_spec = spectral_normalize_torch(mel_spec)
|
||||
|
||||
return mel_spec
|
||||
|
||||
|
||||
def get_mel_spectrogram(wav, h):
|
||||
"""
|
||||
Generate mel spectrogram from a waveform using given hyperparameters.
|
||||
|
||||
Args:
|
||||
wav (torch.Tensor): Input waveform.
|
||||
h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel spectrogram.
|
||||
"""
|
||||
return mel_spectrogram(
|
||||
wav,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
)
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
training_files = []
|
||||
validation_files = []
|
||||
list_unseen_validation_files = []
|
||||
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(f"first training file: {training_files[0]}")
|
||||
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(f"first validation file: {validation_files[0]}")
|
||||
|
||||
for i in range(len(a.list_input_unseen_validation_file)):
|
||||
with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
|
||||
unseen_validation_files = [
|
||||
os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(
|
||||
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
|
||||
)
|
||||
list_unseen_validation_files.append(unseen_validation_files)
|
||||
|
||||
return training_files, validation_files, list_unseen_validation_files
|
||||
|
||||
|
||||
class MelDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
training_files: List[str],
|
||||
hparams: AttrDict,
|
||||
segment_size: int,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
sampling_rate: int,
|
||||
fmin: int,
|
||||
fmax: Optional[int],
|
||||
split: bool = True,
|
||||
shuffle: bool = True,
|
||||
device: str = None,
|
||||
fmax_loss: Optional[int] = None,
|
||||
fine_tuning: bool = False,
|
||||
base_mels_path: str = None,
|
||||
is_seen: bool = True,
|
||||
):
|
||||
self.audio_files = training_files
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
random.shuffle(self.audio_files)
|
||||
self.hparams = hparams
|
||||
self.is_seen = is_seen
|
||||
if self.is_seen:
|
||||
self.name = pathlib.Path(self.audio_files[0]).parts[0]
|
||||
else:
|
||||
self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
|
||||
|
||||
self.segment_size = segment_size
|
||||
self.sampling_rate = sampling_rate
|
||||
self.split = split
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.fmax_loss = fmax_loss
|
||||
self.device = device
|
||||
self.fine_tuning = fine_tuning
|
||||
self.base_mels_path = base_mels_path
|
||||
|
||||
print("[INFO] checking dataset integrity...")
|
||||
for i in tqdm(range(len(self.audio_files))):
|
||||
assert os.path.exists(
|
||||
self.audio_files[i]
|
||||
), f"{self.audio_files[i]} not found"
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
try:
|
||||
filename = self.audio_files[index]
|
||||
|
||||
# Use librosa.load that ensures loading waveform into mono with [-1, 1] float values
|
||||
# Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead
|
||||
# The on-the-fly resampling during training will be done only for the obtained random chunk
|
||||
audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True)
|
||||
|
||||
# Main logic that uses <mel, audio> pair for training BigVGAN
|
||||
if not self.fine_tuning:
|
||||
if self.split: # Training step
|
||||
# Obtain randomized audio chunk
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
# Adjust segment size to crop if the source sr is different
|
||||
target_segment_size = math.ceil(
|
||||
self.segment_size
|
||||
* (source_sampling_rate / self.sampling_rate)
|
||||
)
|
||||
else:
|
||||
target_segment_size = self.segment_size
|
||||
|
||||
# Compute upper bound index for the random chunk
|
||||
random_chunk_upper_bound = max(
|
||||
0, audio.shape[0] - target_segment_size
|
||||
)
|
||||
|
||||
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||
if audio.shape[0] >= target_segment_size:
|
||||
audio_start = random.randint(0, random_chunk_upper_bound)
|
||||
audio = audio[audio_start : audio_start + target_segment_size]
|
||||
else:
|
||||
audio = np.pad(
|
||||
audio,
|
||||
(0, target_segment_size - audio.shape[0]),
|
||||
mode="constant",
|
||||
)
|
||||
|
||||
# Resample audio chunk to self.sampling rate
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=source_sampling_rate,
|
||||
target_sr=self.sampling_rate,
|
||||
)
|
||||
if audio.shape[0] > self.segment_size:
|
||||
# trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384)
|
||||
audio = audio[: self.segment_size]
|
||||
|
||||
else: # Validation step
|
||||
# Resample full audio clip to target sampling rate
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=source_sampling_rate,
|
||||
target_sr=self.sampling_rate,
|
||||
)
|
||||
# Trim last elements to match audio length to self.hop_size * n for evaluation
|
||||
if (audio.shape[0] % self.hop_size) != 0:
|
||||
audio = audio[: -(audio.shape[0] % self.hop_size)]
|
||||
|
||||
# BigVGAN is trained using volume-normalized waveform
|
||||
audio = librosa.util.normalize(audio) * 0.95
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0) # [B(1), self.segment_size]
|
||||
|
||||
# Compute mel spectrogram corresponding to audio
|
||||
mel = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax,
|
||||
center=False,
|
||||
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
|
||||
|
||||
# Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input
|
||||
else:
|
||||
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||
assert (
|
||||
source_sampling_rate == self.sampling_rate
|
||||
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0) # [B(1), T_time]
|
||||
|
||||
# Load pre-computed mel from disk
|
||||
mel = np.load(
|
||||
os.path.join(
|
||||
self.base_mels_path,
|
||||
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
||||
)
|
||||
)
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
if len(mel.shape) < 3:
|
||||
mel = mel.unsqueeze(0) # ensure [B, C, T]
|
||||
|
||||
if self.split:
|
||||
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
||||
|
||||
if audio.size(1) >= self.segment_size:
|
||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start
|
||||
* self.hop_size : (mel_start + frames_per_seg)
|
||||
* self.hop_size,
|
||||
]
|
||||
|
||||
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||
)
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
|
||||
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||
mel_loss = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax_loss,
|
||||
center=False,
|
||||
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
|
||||
|
||||
# Shape sanity checks
|
||||
assert (
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size
|
||||
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
# If it encounters error during loading the data, skip this sample and load random other sample to the batch
|
||||
except Exception as e:
|
||||
if self.fine_tuning:
|
||||
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
|
||||
)
|
||||
return self[random.randrange(len(self))]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_files)
|
13
GPT_SoVITS/BigVGAN/requirements.txt
Normal file
13
GPT_SoVITS/BigVGAN/requirements.txt
Normal file
@ -0,0 +1,13 @@
|
||||
torch
|
||||
numpy
|
||||
librosa>=0.8.1
|
||||
scipy
|
||||
tensorboard
|
||||
soundfile
|
||||
matplotlib
|
||||
pesq
|
||||
auraloss
|
||||
tqdm
|
||||
nnAudio
|
||||
ninja
|
||||
huggingface_hub>=0.23.4
|
777
GPT_SoVITS/BigVGAN/train.py
Normal file
777
GPT_SoVITS/BigVGAN/train.py
Normal file
@ -0,0 +1,777 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from env import AttrDict, build_env
|
||||
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE
|
||||
|
||||
from bigvgan import BigVGAN
|
||||
from discriminators import (
|
||||
MultiPeriodDiscriminator,
|
||||
MultiResolutionDiscriminator,
|
||||
MultiBandDiscriminator,
|
||||
MultiScaleSubbandCQTDiscriminator,
|
||||
)
|
||||
from loss import (
|
||||
feature_loss,
|
||||
generator_loss,
|
||||
discriminator_loss,
|
||||
MultiScaleMelSpectrogramLoss,
|
||||
)
|
||||
|
||||
from utils import (
|
||||
plot_spectrogram,
|
||||
plot_spectrogram_clipped,
|
||||
scan_checkpoint,
|
||||
load_checkpoint,
|
||||
save_checkpoint,
|
||||
save_audio,
|
||||
)
|
||||
import torchaudio as ta
|
||||
from pesq import pesq
|
||||
from tqdm import tqdm
|
||||
import auraloss
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def train(rank, a, h):
|
||||
if h.num_gpus > 1:
|
||||
# initialize distributed
|
||||
init_process_group(
|
||||
backend=h.dist_config["dist_backend"],
|
||||
init_method=h.dist_config["dist_url"],
|
||||
world_size=h.dist_config["world_size"] * h.num_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
# Set seed and device
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank:d}")
|
||||
|
||||
# Define BigVGAN generator
|
||||
generator = BigVGAN(h).to(device)
|
||||
|
||||
# Define discriminators. MPD is used by default
|
||||
mpd = MultiPeriodDiscriminator(h).to(device)
|
||||
|
||||
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||
print(
|
||||
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||
mrd = MultiBandDiscriminator(h).to(device)
|
||||
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||
print(
|
||||
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||
else: # Fallback to original MRD in BigVGAN-v1
|
||||
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||
|
||||
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||
if h.get("use_multiscale_melloss", False):
|
||||
print(
|
||||
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
|
||||
)
|
||||
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||
sampling_rate=h.sampling_rate
|
||||
) # NOTE: accepts waveform as input
|
||||
else:
|
||||
fn_mel_loss_singlescale = F.l1_loss
|
||||
|
||||
# Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory
|
||||
if rank == 0:
|
||||
print(generator)
|
||||
print(mpd)
|
||||
print(mrd)
|
||||
print(f"Generator params: {sum(p.numel() for p in generator.parameters())}")
|
||||
print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}")
|
||||
print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}")
|
||||
os.makedirs(a.checkpoint_path, exist_ok=True)
|
||||
print(f"Checkpoints directory: {a.checkpoint_path}")
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||
cp_g = scan_checkpoint(
|
||||
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
|
||||
)
|
||||
cp_do = scan_checkpoint(
|
||||
a.checkpoint_path,
|
||||
prefix="do_",
|
||||
renamed_file="bigvgan_discriminator_optimizer.pt",
|
||||
)
|
||||
|
||||
# Load the latest checkpoint if exists
|
||||
steps = 0
|
||||
if cp_g is None or cp_do is None:
|
||||
state_dict_do = None
|
||||
last_epoch = -1
|
||||
else:
|
||||
state_dict_g = load_checkpoint(cp_g, device)
|
||||
state_dict_do = load_checkpoint(cp_do, device)
|
||||
generator.load_state_dict(state_dict_g["generator"])
|
||||
mpd.load_state_dict(state_dict_do["mpd"])
|
||||
mrd.load_state_dict(state_dict_do["mrd"])
|
||||
steps = state_dict_do["steps"] + 1
|
||||
last_epoch = state_dict_do["epoch"]
|
||||
|
||||
# Initialize DDP, optimizers, and schedulers
|
||||
if h.num_gpus > 1:
|
||||
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(
|
||||
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
|
||||
)
|
||||
optim_d = torch.optim.AdamW(
|
||||
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||
h.learning_rate,
|
||||
betas=[h.adam_b1, h.adam_b2],
|
||||
)
|
||||
|
||||
if state_dict_do is not None:
|
||||
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
|
||||
# Define training and validation datasets
|
||||
|
||||
"""
|
||||
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||
Example: trained on LibriTTS, validate on VCTK
|
||||
"""
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = (
|
||||
get_dataset_filelist(a)
|
||||
)
|
||||
|
||||
trainset = MelDataset(
|
||||
training_filelist,
|
||||
h,
|
||||
h.segment_size,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.sampling_rate,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
shuffle=False if h.num_gpus > 1 else True,
|
||||
fmax_loss=h.fmax_for_loss,
|
||||
device=device,
|
||||
fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir,
|
||||
is_seen=True,
|
||||
)
|
||||
|
||||
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
||||
|
||||
train_loader = DataLoader(
|
||||
trainset,
|
||||
num_workers=h.num_workers,
|
||||
shuffle=False,
|
||||
sampler=train_sampler,
|
||||
batch_size=h.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
validset = MelDataset(
|
||||
validation_filelist,
|
||||
h,
|
||||
h.segment_size,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.sampling_rate,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
False,
|
||||
False,
|
||||
fmax_loss=h.fmax_for_loss,
|
||||
device=device,
|
||||
fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir,
|
||||
is_seen=True,
|
||||
)
|
||||
validation_loader = DataLoader(
|
||||
validset,
|
||||
num_workers=1,
|
||||
shuffle=False,
|
||||
sampler=None,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
list_unseen_validset = []
|
||||
list_unseen_validation_loader = []
|
||||
for i in range(len(list_unseen_validation_filelist)):
|
||||
unseen_validset = MelDataset(
|
||||
list_unseen_validation_filelist[i],
|
||||
h,
|
||||
h.segment_size,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.sampling_rate,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
False,
|
||||
False,
|
||||
fmax_loss=h.fmax_for_loss,
|
||||
device=device,
|
||||
fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir,
|
||||
is_seen=False,
|
||||
)
|
||||
unseen_validation_loader = DataLoader(
|
||||
unseen_validset,
|
||||
num_workers=1,
|
||||
shuffle=False,
|
||||
sampler=None,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
list_unseen_validset.append(unseen_validset)
|
||||
list_unseen_validation_loader.append(unseen_validation_loader)
|
||||
|
||||
# Tensorboard logger
|
||||
sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs"))
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True)
|
||||
|
||||
"""
|
||||
Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset).
|
||||
If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors
|
||||
"""
|
||||
|
||||
def validate(rank, a, h, loader, mode="seen"):
|
||||
assert rank == 0, "validate should only run on rank=0"
|
||||
generator.eval()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
val_err_tot = 0
|
||||
val_pesq_tot = 0
|
||||
val_mrstft_tot = 0
|
||||
|
||||
# Modules for evaluation metrics
|
||||
pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda()
|
||||
loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda")
|
||||
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
os.makedirs(
|
||||
os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"),
|
||||
exist_ok=True,
|
||||
)
|
||||
os.makedirs(
|
||||
os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
print(f"step {steps} {mode} speaker validation...")
|
||||
|
||||
# Loop over validation set and compute metrics
|
||||
for j, batch in enumerate(tqdm(loader)):
|
||||
x, y, _, y_mel = batch
|
||||
y = y.to(device)
|
||||
if hasattr(generator, "module"):
|
||||
y_g_hat = generator.module(x.to(device))
|
||||
else:
|
||||
y_g_hat = generator(x.to(device))
|
||||
y_mel = y_mel.to(device, non_blocking=True)
|
||||
y_g_hat_mel = mel_spectrogram(
|
||||
y_g_hat.squeeze(1),
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
|
||||
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
|
||||
|
||||
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||
if (
|
||||
not "nonspeech" in mode
|
||||
): # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
|
||||
# Resample to 16000 for pesq
|
||||
y_16k = pesq_resampler(y)
|
||||
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
y_g_hat_int_16k = (
|
||||
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
)
|
||||
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||
|
||||
# MRSTFT calculation
|
||||
min_t = min(y.size(-1), y_g_hat.size(-1))
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
|
||||
|
||||
# Log audio and figures to Tensorboard
|
||||
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||
if steps >= 0:
|
||||
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y[0],
|
||||
os.path.join(
|
||||
a.checkpoint_path,
|
||||
"samples",
|
||||
f"gt_{mode}",
|
||||
f"{j:04d}.wav",
|
||||
),
|
||||
h.sampling_rate,
|
||||
)
|
||||
sw.add_figure(
|
||||
f"gt_{mode}/y_spec_{j}",
|
||||
plot_spectrogram(x[0]),
|
||||
steps,
|
||||
)
|
||||
|
||||
sw.add_audio(
|
||||
f"generated_{mode}/y_hat_{j}",
|
||||
y_g_hat[0],
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y_g_hat[0, 0],
|
||||
os.path.join(
|
||||
a.checkpoint_path,
|
||||
"samples",
|
||||
f"{mode}_{steps:08d}",
|
||||
f"{j:04d}.wav",
|
||||
),
|
||||
h.sampling_rate,
|
||||
)
|
||||
# Spectrogram of synthesized audio
|
||||
y_hat_spec = mel_spectrogram(
|
||||
y_g_hat.squeeze(1),
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
)
|
||||
sw.add_figure(
|
||||
f"generated_{mode}/y_hat_spec_{j}",
|
||||
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()),
|
||||
steps,
|
||||
)
|
||||
|
||||
"""
|
||||
Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization.
|
||||
"""
|
||||
spec_delta = torch.clamp(
|
||||
torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()),
|
||||
min=1e-6,
|
||||
max=1.0,
|
||||
)
|
||||
sw.add_figure(
|
||||
f"delta_dclip1_{mode}/spec_{j}",
|
||||
plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0),
|
||||
steps,
|
||||
)
|
||||
|
||||
val_err = val_err_tot / (j + 1)
|
||||
val_pesq = val_pesq_tot / (j + 1)
|
||||
val_mrstft = val_mrstft_tot / (j + 1)
|
||||
# Log evaluation metrics to Tensorboard
|
||||
sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps)
|
||||
sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps)
|
||||
sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps)
|
||||
|
||||
generator.train()
|
||||
|
||||
# If the checkpoint is loaded, start with validation loop
|
||||
if steps != 0 and rank == 0 and not a.debug:
|
||||
if not a.skip_seen:
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
validation_loader,
|
||||
mode=f"seen_{train_loader.dataset.name}",
|
||||
)
|
||||
for i in range(len(list_unseen_validation_loader)):
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
list_unseen_validation_loader[i],
|
||||
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
|
||||
)
|
||||
# Exit the script if --evaluate is set to True
|
||||
if a.evaluate:
|
||||
exit()
|
||||
|
||||
# Main training loop
|
||||
generator.train()
|
||||
mpd.train()
|
||||
mrd.train()
|
||||
for epoch in range(max(0, last_epoch), a.training_epochs):
|
||||
if rank == 0:
|
||||
start = time.time()
|
||||
print(f"Epoch: {epoch + 1}")
|
||||
|
||||
if h.num_gpus > 1:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
if rank == 0:
|
||||
start_b = time.time()
|
||||
x, y, _, y_mel = batch
|
||||
|
||||
x = x.to(device, non_blocking=True)
|
||||
y = y.to(device, non_blocking=True)
|
||||
y_mel = y_mel.to(device, non_blocking=True)
|
||||
y = y.unsqueeze(1)
|
||||
|
||||
y_g_hat = generator(x)
|
||||
y_g_hat_mel = mel_spectrogram(
|
||||
y_g_hat.squeeze(1),
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
||||
y_df_hat_r, y_df_hat_g
|
||||
)
|
||||
|
||||
# MRD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
||||
y_ds_hat_r, y_ds_hat_g
|
||||
)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
# Set clip_grad_norm value
|
||||
clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000
|
||||
|
||||
# Whether to freeze D for initial training steps
|
||||
if steps >= a.freeze_step:
|
||||
loss_disc_all.backward()
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
|
||||
mpd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
|
||||
mrd.parameters(), clip_grad_norm
|
||||
)
|
||||
optim_d.step()
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
|
||||
)
|
||||
grad_norm_mpd = 0.0
|
||||
grad_norm_mrd = 0.0
|
||||
|
||||
# Generator
|
||||
optim_g.zero_grad()
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
lambda_melloss = h.get(
|
||||
"lambda_melloss", 45.0
|
||||
) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||
loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss
|
||||
|
||||
# MPD loss
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
|
||||
# MRD loss
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat)
|
||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
|
||||
if steps >= a.freeze_step:
|
||||
loss_gen_all = (
|
||||
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
|
||||
)
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(
|
||||
generator.parameters(), clip_grad_norm
|
||||
)
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to stdout
|
||||
print(
|
||||
f"Steps: {steps:d}, "
|
||||
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||
f"Mel Error: {mel_error:4.3f}, "
|
||||
f"s/b: {time.time() - start_b:4.3f} "
|
||||
f"lr: {optim_g.param_groups[0]['lr']:4.7f} "
|
||||
f"grad_norm_g: {grad_norm_g:4.3f}"
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
if steps % a.checkpoint_interval == 0 and steps != 0:
|
||||
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{
|
||||
"generator": (
|
||||
generator.module if h.num_gpus > 1 else generator
|
||||
).state_dict()
|
||||
},
|
||||
)
|
||||
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{
|
||||
"mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
"mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(),
|
||||
"optim_g": optim_g.state_dict(),
|
||||
"optim_d": optim_d.state_dict(),
|
||||
"steps": steps,
|
||||
"epoch": epoch,
|
||||
},
|
||||
)
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to tensorboard
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||
sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
|
||||
sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps)
|
||||
sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
|
||||
sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
|
||||
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||
|
||||
# Validation
|
||||
if steps % a.validation_interval == 0:
|
||||
# Plot training input x so far used
|
||||
for i_x in range(x.shape[0]):
|
||||
sw.add_figure(
|
||||
f"training_input/x_{i_x}",
|
||||
plot_spectrogram(x[i_x].cpu()),
|
||||
steps,
|
||||
)
|
||||
sw.add_audio(
|
||||
f"training_input/y_{i_x}",
|
||||
y[i_x][0],
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
|
||||
# Seen and unseen speakers validation loops
|
||||
if not a.debug and steps != 0:
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
validation_loader,
|
||||
mode=f"seen_{train_loader.dataset.name}",
|
||||
)
|
||||
for i in range(len(list_unseen_validation_loader)):
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
list_unseen_validation_loader[i],
|
||||
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
|
||||
)
|
||||
steps += 1
|
||||
|
||||
# BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
print("Initializing Training Process..")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--group_name", default=None)
|
||||
|
||||
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||
parser.add_argument(
|
||||
"--input_training_file", default="tests/LibriTTS/train-full.txt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_wavs_dir",
|
||||
nargs="+",
|
||||
default=["tests/LibriTTS", "tests/LibriTTS"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_validation_file",
|
||||
nargs="+",
|
||||
default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"],
|
||||
)
|
||||
|
||||
parser.add_argument("--checkpoint_path", default="exp/bigvgan")
|
||||
parser.add_argument("--config", default="")
|
||||
|
||||
parser.add_argument("--training_epochs", default=100000, type=int)
|
||||
parser.add_argument("--stdout_interval", default=5, type=int)
|
||||
parser.add_argument("--checkpoint_interval", default=50000, type=int)
|
||||
parser.add_argument("--summary_interval", default=100, type=int)
|
||||
parser.add_argument("--validation_interval", default=50000, type=int)
|
||||
|
||||
parser.add_argument(
|
||||
"--freeze_step",
|
||||
default=0,
|
||||
type=int,
|
||||
help="freeze D for the first specified steps. G only uses regression loss for these steps.",
|
||||
)
|
||||
|
||||
parser.add_argument("--fine_tuning", default=False, type=bool)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="debug mode. skips validation loop throughout training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluate",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="only run evaluation from checkpoint and exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_subsample",
|
||||
default=5,
|
||||
type=int,
|
||||
help="subsampling during evaluation loop",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_seen",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="skip seen dataset. useful for test set inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_audio",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="save audio of test set inference to disk",
|
||||
)
|
||||
|
||||
a = parser.parse_args()
|
||||
|
||||
with open(a.config) as f:
|
||||
data = f.read()
|
||||
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
|
||||
build_env(a.config, "config.json", a.checkpoint_path)
|
||||
|
||||
torch.manual_seed(h.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print(f"Batch size per GPU: {h.batch_size}")
|
||||
else:
|
||||
pass
|
||||
|
||||
if h.num_gpus > 1:
|
||||
mp.spawn(
|
||||
train,
|
||||
nprocs=h.num_gpus,
|
||||
args=(
|
||||
a,
|
||||
h,
|
||||
),
|
||||
)
|
||||
else:
|
||||
train(0, a, h)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
99
GPT_SoVITS/BigVGAN/utils0.py
Normal file
99
GPT_SoVITS/BigVGAN/utils0.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import matplotlib
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
from meldataset import MAX_WAV_VALUE
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(
|
||||
spectrogram,
|
||||
aspect="auto",
|
||||
origin="lower",
|
||||
interpolation="none",
|
||||
vmin=1e-6,
|
||||
vmax=clip_max,
|
||||
)
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print(f"Loading '{filepath}'")
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def save_checkpoint(filepath, obj):
|
||||
print(f"Saving checkpoint to {filepath}")
|
||||
torch.save(obj, filepath)
|
||||
print("Complete.")
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix, renamed_file=None):
|
||||
# Fallback to original scanning logic first
|
||||
pattern = os.path.join(cp_dir, prefix + "????????")
|
||||
cp_list = glob.glob(pattern)
|
||||
|
||||
if len(cp_list) > 0:
|
||||
last_checkpoint_path = sorted(cp_list)[-1]
|
||||
print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
|
||||
return last_checkpoint_path
|
||||
|
||||
# If no pattern-based checkpoints are found, check for renamed file
|
||||
if renamed_file:
|
||||
renamed_path = os.path.join(cp_dir, renamed_file)
|
||||
if os.path.isfile(renamed_path):
|
||||
print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
|
||||
return renamed_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def save_audio(audio, path, sr):
|
||||
# wav: torch with 1d shape
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
write(path, sr, audio)
|
Loading…
x
Reference in New Issue
Block a user