mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-18 15:59:51 +08:00
Reimplement run_generator to be compatible with v3/v4 models, sync the code with the main repo.
This commit is contained in:
commit
344ca488d9
@ -1,8 +0,0 @@
|
|||||||
docs
|
|
||||||
logs
|
|
||||||
output
|
|
||||||
reference
|
|
||||||
SoVITS_weights
|
|
||||||
GPT_weights
|
|
||||||
TEMP
|
|
||||||
.git
|
|
42
Dockerfile
42
Dockerfile
@ -1,42 +0,0 @@
|
|||||||
# Base CUDA image
|
|
||||||
FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
|
|
||||||
|
|
||||||
LABEL maintainer="breakstring@hotmail.com"
|
|
||||||
LABEL version="dev-20240209"
|
|
||||||
LABEL description="Docker image for GPT-SoVITS"
|
|
||||||
|
|
||||||
|
|
||||||
# Install 3rd party apps
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
ENV TZ=Etc/UTC
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
|
|
||||||
git lfs install && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy only requirements.txt initially to leverage Docker cache
|
|
||||||
WORKDIR /workspace
|
|
||||||
COPY requirements.txt /workspace/
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Define a build-time argument for image type
|
|
||||||
ARG IMAGE_TYPE=full
|
|
||||||
|
|
||||||
# Conditional logic based on the IMAGE_TYPE argument
|
|
||||||
# Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
|
|
||||||
COPY ./Docker /workspace/Docker
|
|
||||||
# elite 类型的镜像里面不包含额外的模型
|
|
||||||
RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
|
|
||||||
chmod +x /workspace/Docker/download.sh && \
|
|
||||||
/workspace/Docker/download.sh && \
|
|
||||||
python /workspace/Docker/download.py && \
|
|
||||||
python -m nltk.downloader averaged_perceptron_tagger cmudict; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# Copy the rest of the application
|
|
||||||
COPY . /workspace
|
|
||||||
|
|
||||||
EXPOSE 9871 9872 9873 9874 9880
|
|
||||||
|
|
||||||
CMD ["python", "webui.py"]
|
|
@ -1,11 +1,10 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
from pytorch_lightning import LightningDataModule
|
from pytorch_lightning import LightningDataModule
|
||||||
|
from GPT_SoVITS.AR.data.bucket_sampler import DistributedBucketSampler
|
||||||
|
from GPT_SoVITS.AR.data.dataset import Text2SemanticDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
|
||||||
from AR.data.dataset import Text2SemanticDataset
|
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticDataModule(LightningDataModule):
|
class Text2SemanticDataModule(LightningDataModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset
|
|||||||
|
|
||||||
version = os.environ.get("version", None)
|
version = os.environ.get("version", None)
|
||||||
|
|
||||||
from text import cleaned_text_to_sequence
|
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||||
|
|
||||||
# from config import exp_dir
|
# from config import exp_dir
|
||||||
|
|
||||||
|
@ -9,10 +9,9 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import LightningModule
|
from pytorch_lightning import LightningModule
|
||||||
|
from GPT_SoVITS.AR.models.t2s_model import Text2SemanticDecoder
|
||||||
from AR.models.t2s_model import Text2SemanticDecoder
|
from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
from GPT_SoVITS.AR.modules.optim import ScaledAdam
|
||||||
from AR.modules.optim import ScaledAdam
|
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticLightningModule(LightningModule):
|
class Text2SemanticLightningModule(LightningModule):
|
||||||
|
@ -9,10 +9,9 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import LightningModule
|
from pytorch_lightning import LightningModule
|
||||||
|
from GPT_SoVITS.AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||||
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
from GPT_SoVITS.AR.modules.optim import ScaledAdam
|
||||||
from AR.modules.optim import ScaledAdam
|
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticLightningModule(LightningModule):
|
class Text2SemanticLightningModule(LightningModule):
|
||||||
|
@ -9,7 +9,7 @@ from torch.nn import functional as F
|
|||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from AR.models.utils import (
|
from GPT_SoVITS.AR.models.utils import (
|
||||||
dpo_loss,
|
dpo_loss,
|
||||||
get_batch_logps,
|
get_batch_logps,
|
||||||
make_pad_mask,
|
make_pad_mask,
|
||||||
@ -18,8 +18,8 @@ from AR.models.utils import (
|
|||||||
sample,
|
sample,
|
||||||
topk_sampling,
|
topk_sampling,
|
||||||
)
|
)
|
||||||
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
from GPT_SoVITS.AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||||
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
from GPT_SoVITS.AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
default_config = {
|
default_config = {
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
@ -933,3 +933,140 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
return self.infer_panel_naive(
|
return self.infer_panel_naive(
|
||||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def infer_panel_generator(
|
||||||
|
self,
|
||||||
|
x: torch.LongTensor,
|
||||||
|
x_lens: torch.LongTensor,
|
||||||
|
prompts: torch.LongTensor,
|
||||||
|
bert_feature: torch.LongTensor,
|
||||||
|
cumulation_amount: int,
|
||||||
|
top_k: int = -100,
|
||||||
|
top_p: int = 100,
|
||||||
|
early_stop_num: int = -1,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
repetition_penalty: float = 1.35,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generator method that yields generated tokens based on a specified cumulative amount.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.LongTensor): Input phoneme IDs.
|
||||||
|
x_lens (torch.LongTensor): Lengths of the input sequences.
|
||||||
|
prompts (torch.LongTensor): Initial prompt tokens.
|
||||||
|
bert_feature (torch.LongTensor): BERT features corresponding to the input.
|
||||||
|
cumulation_amount (int): Number of tokens to generate before yielding.
|
||||||
|
top_k (int): Top-k sampling.
|
||||||
|
top_p (int): Top-p sampling.
|
||||||
|
early_stop_num (int): Early stopping number.
|
||||||
|
temperature (float): Sampling temperature.
|
||||||
|
repetition_penalty (float): Repetition penalty.
|
||||||
|
Yields:
|
||||||
|
torch.LongTensor: Generated tokens since the last yield.
|
||||||
|
"""
|
||||||
|
x = self.ar_text_embedding(x)
|
||||||
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
|
x = self.ar_text_position(x)
|
||||||
|
|
||||||
|
# AR Decoder
|
||||||
|
y = prompts
|
||||||
|
|
||||||
|
x_len = x.shape[1]
|
||||||
|
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device)
|
||||||
|
stop = False
|
||||||
|
|
||||||
|
# Initialize cumulative token counter
|
||||||
|
tokens_since_last_yield = 0
|
||||||
|
# Initialize last yield index
|
||||||
|
prefix_len = y.shape[1] if y is not None else 0
|
||||||
|
last_yield_idx = prefix_len
|
||||||
|
|
||||||
|
k_cache = None
|
||||||
|
v_cache = None
|
||||||
|
|
||||||
|
################### first step ##########################
|
||||||
|
if y is not None and y.shape[1] > 0:
|
||||||
|
y_emb = self.ar_audio_embedding(y)
|
||||||
|
y_len = y_emb.shape[1]
|
||||||
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
ref_free = False
|
||||||
|
else:
|
||||||
|
y_emb = None
|
||||||
|
y_len = 0
|
||||||
|
xy_pos = x
|
||||||
|
y = torch.zeros(x.shape[0], 0, dtype=torch.int64, device=x.device)
|
||||||
|
ref_free = True
|
||||||
|
|
||||||
|
bsz = x.shape[0]
|
||||||
|
src_len = x_len + y_len
|
||||||
|
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||||
|
y_attn_mask = F.pad(
|
||||||
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
||||||
|
(x_len, 0),
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
|
xy_attn_mask = xy_attn_mask.unsqueeze(0).expand(bsz * self.num_head, -1, -1)
|
||||||
|
xy_attn_mask = xy_attn_mask.view(bsz, self.num_head, src_len, src_len).to(device=x.device, dtype=torch.bool)
|
||||||
|
|
||||||
|
for idx in tqdm(range(1500)):
|
||||||
|
if xy_attn_mask is not None:
|
||||||
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||||
|
else:
|
||||||
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||||
|
|
||||||
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
|
||||||
|
if idx == 0:
|
||||||
|
xy_attn_mask = None
|
||||||
|
if idx < 11: # Ensure at least 10 tokens are generated before stopping
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
|
samples = sample(
|
||||||
|
logits,
|
||||||
|
y,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
temperature=temperature,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
tokens_since_last_yield += 1
|
||||||
|
|
||||||
|
if tokens_since_last_yield >= cumulation_amount:
|
||||||
|
generated_tokens = y[:, last_yield_idx:]
|
||||||
|
yield generated_tokens
|
||||||
|
last_yield_idx = y.shape[1]
|
||||||
|
tokens_since_last_yield = 0
|
||||||
|
|
||||||
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
|
print("Using early stop num:", early_stop_num)
|
||||||
|
stop = True
|
||||||
|
|
||||||
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
|
stop = True
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
if y.shape[1] == 0:
|
||||||
|
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
|
print("Bad zero prediction")
|
||||||
|
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update for next step
|
||||||
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
|
y_len += 1
|
||||||
|
xy_pos = (
|
||||||
|
y_emb * self.ar_audio_position.x_scale
|
||||||
|
+ self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len - 1].to(
|
||||||
|
dtype=y_emb.dtype, device=y_emb.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# After loop ends, yield any remaining tokens
|
||||||
|
if last_yield_idx < y.shape[1]:
|
||||||
|
generated_tokens = y[:, last_yield_idx:]
|
||||||
|
yield generated_tokens
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from GPT_SoVITS.AR.modules.embedding_onnx import SinePositionalEmbedding
|
||||||
|
from GPT_SoVITS.AR.modules.embedding_onnx import TokenEmbedding
|
||||||
|
from GPT_SoVITS.AR.modules.transformer_onnx import LayerNorm
|
||||||
|
from GPT_SoVITS.AR.modules.transformer_onnx import TransformerEncoder
|
||||||
|
from GPT_SoVITS.AR.modules.transformer_onnx import TransformerEncoderLayer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
|
||||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
|
||||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
|
||||||
|
|
||||||
default_config = {
|
default_config = {
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
|
@ -9,7 +9,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
|||||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
from torch.nn import functional as F
|
||||||
|
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||||
|
|
||||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||||
|
|
||||||
@ -152,14 +153,14 @@ class MultiheadAttention(Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.in_proj_weight = self.in_proj_linear.weight
|
self.in_proj_weight = self.in_proj_lineGPT_SoVITS.AR.weight
|
||||||
|
|
||||||
self.register_parameter("q_proj_weight", None)
|
self.register_parameter("q_proj_weight", None)
|
||||||
self.register_parameter("k_proj_weight", None)
|
self.register_parameter("k_proj_weight", None)
|
||||||
self.register_parameter("v_proj_weight", None)
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.in_proj_bias = self.in_proj_linear.bias
|
self.in_proj_bias = self.in_proj_lineGPT_SoVITS.AR.bias
|
||||||
else:
|
else:
|
||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
|||||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
from torch.nn import functional as F
|
||||||
|
from GPT_SoVITS.AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(Module):
|
class MultiheadAttention(Module):
|
||||||
@ -102,14 +103,14 @@ class MultiheadAttention(Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.in_proj_weight = self.in_proj_linear.weight
|
self.in_proj_weight = self.in_proj_lineGPT_SoVITS.AR.weight
|
||||||
|
|
||||||
self.register_parameter("q_proj_weight", None)
|
self.register_parameter("q_proj_weight", None)
|
||||||
self.register_parameter("k_proj_weight", None)
|
self.register_parameter("k_proj_weight", None)
|
||||||
self.register_parameter("v_proj_weight", None)
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.in_proj_bias = self.in_proj_linear.bias
|
self.in_proj_bias = self.in_proj_lineGPT_SoVITS.AR.bias
|
||||||
else:
|
else:
|
||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ from typing import Tuple
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from AR.modules.activation import MultiheadAttention
|
from GPT_SoVITS.AR.modules.activation import MultiheadAttention
|
||||||
from AR.modules.scaling import BalancedDoubleSwish
|
from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
@ -10,8 +10,8 @@ from typing import Tuple
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from AR.modules.activation_onnx import MultiheadAttention
|
from GPT_SoVITS.AR.modules.activation_onnx import MultiheadAttention
|
||||||
from AR.modules.scaling import BalancedDoubleSwish
|
from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
@ -9,7 +9,7 @@ import regex
|
|||||||
from gruut import sentences
|
from gruut import sentences
|
||||||
from gruut.const import Sentence
|
from gruut.const import Sentence
|
||||||
from gruut.const import Word
|
from gruut.const import Word
|
||||||
from AR.text_processing.symbols import SYMBOL_TO_ID
|
from GPT_SoVITS.AR.text_processing.symbols import SYMBOL_TO_ID
|
||||||
|
|
||||||
|
|
||||||
class GruutPhonemizer:
|
class GruutPhonemizer:
|
||||||
|
@ -18,27 +18,31 @@ from typing import List, Tuple, Union
|
|||||||
import ffmpeg
|
import ffmpeg
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import random
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import traceback
|
||||||
import yaml
|
import yaml
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from BigVGAN.bigvgan import BigVGAN
|
from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
|
||||||
from feature_extractor.cnhubert import CNHubert
|
from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
|
||||||
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||||||
from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
|
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from tools.audio_sr import AP_BWE
|
from GPT_SoVITS.tools.audio_sr import AP_BWE
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
from tools.my_utils import load_audio
|
from GPT_SoVITS.tools.my_utils import load_audio
|
||||||
from TTS_infer_pack.text_segmentation_method import splits
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits
|
||||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
i18n = I18nAuto(language=language)
|
i18n = I18nAuto(language=language)
|
||||||
|
LIBRARY_NAME = "GPT_SoVITS"
|
||||||
|
|
||||||
|
|
||||||
spec_min = -12
|
spec_min = -12
|
||||||
@ -149,28 +153,28 @@ class NO_PROMPT_ERROR(Exception):
|
|||||||
# configs/tts_infer.yaml
|
# configs/tts_infer.yaml
|
||||||
"""
|
"""
|
||||||
custom:
|
custom:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: pretrained_models/chinese-hubert-base
|
||||||
device: cpu
|
device: cpu
|
||||||
is_half: false
|
is_half: false
|
||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
t2s_weights_path: pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
vits_weights_path: pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||||
version: v2
|
version: v2
|
||||||
v1:
|
v1:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
device: cpu
|
device: cpu
|
||||||
is_half: false
|
is_half: false
|
||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
t2s_weights_path: pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
vits_weights_path: pretrained_models/s2G488k.pth
|
||||||
version: v1
|
version: v1
|
||||||
v2:
|
v2:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
device: cpu
|
device: cpu
|
||||||
is_half: false
|
is_half: false
|
||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
t2s_weights_path: pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
vits_weights_path: pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||||
version: v2
|
version: v2
|
||||||
v3:
|
v3:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
@ -323,8 +327,10 @@ class TTS_Config:
|
|||||||
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||||
self.cnhuhbert_base_path = self.default_configs[version]["cnhuhbert_base_path"]
|
self.cnhuhbert_base_path = self.default_configs[version]["cnhuhbert_base_path"]
|
||||||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||||
|
|
||||||
|
repo_name="lj1995/GPT-SoVITS"
|
||||||
|
snapshot_download(repo_id=repo_name, local_dir=os.path.dirname(self.bert_base_path))
|
||||||
self.update_configs()
|
self.update_configs()
|
||||||
|
|
||||||
self.max_sec = None
|
self.max_sec = None
|
||||||
self.hz: int = 50
|
self.hz: int = 50
|
||||||
self.semantic_frame_rate: str = "25hz"
|
self.semantic_frame_rate: str = "25hz"
|
||||||
@ -1294,7 +1300,17 @@ class TTS:
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
self.empty_cache()
|
self.empty_cache()
|
||||||
|
|
||||||
|
def empty_cache(self):
|
||||||
|
try:
|
||||||
|
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
|
||||||
|
if "cuda" in str(self.configs.device):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
elif str(self.configs.device) == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
def empty_cache(self):
|
def empty_cache(self):
|
||||||
try:
|
try:
|
||||||
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
|
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
|
||||||
@ -1558,3 +1574,160 @@ class TTS:
|
|||||||
audio_fragments[i + 1] = f2_
|
audio_fragments[i + 1] = f2_
|
||||||
|
|
||||||
return torch.cat(audio_fragments, 0)
|
return torch.cat(audio_fragments, 0)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_generator(self, inputs: dict):
|
||||||
|
"""
|
||||||
|
Streaming inference using infer_panel_generator and zero-cross splitting for v1-v4.
|
||||||
|
Yields tuples of (sampling_rate, np.ndarray audio fragment).
|
||||||
|
"""
|
||||||
|
# Initialize parameters
|
||||||
|
self.stop_flag = False
|
||||||
|
text = inputs.get("text", "")
|
||||||
|
text_lang = inputs.get("text_lang", "")
|
||||||
|
ref_audio_path = inputs.get("ref_audio_path", "")
|
||||||
|
aux_ref_audio_paths = inputs.get("aux_ref_audio_paths", [])
|
||||||
|
prompt_text = inputs.get("prompt_text", "")
|
||||||
|
prompt_lang = inputs.get("prompt_lang", "")
|
||||||
|
top_k = inputs.get("top_k", 5)
|
||||||
|
top_p = inputs.get("top_p", 1)
|
||||||
|
temperature = inputs.get("temperature", 1)
|
||||||
|
text_split_method = inputs.get("text_split_method", "cut0")
|
||||||
|
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||||
|
speed_factor = inputs.get("speed_factor", 1.0)
|
||||||
|
seed = inputs.get("seed", -1)
|
||||||
|
seed = -1 if seed in [None, ""] else seed
|
||||||
|
set_seed(seed)
|
||||||
|
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||||
|
sample_steps = inputs.get("sample_steps", 8)
|
||||||
|
super_sampling = inputs.get("super_sampling", False)
|
||||||
|
search_length = inputs.get("search_length", 32000 * 5)
|
||||||
|
num_zeroes = inputs.get("num_zeroes", 5)
|
||||||
|
cumulation_amount = inputs.get("cumulation_amount", 50)
|
||||||
|
# Prepare reference audio
|
||||||
|
if ref_audio_path and ref_audio_path != self.prompt_cache["ref_audio_path"]:
|
||||||
|
if not os.path.exists(ref_audio_path):
|
||||||
|
raise ValueError(f"{ref_audio_path} not exists")
|
||||||
|
self.set_ref_audio(ref_audio_path)
|
||||||
|
# Auxiliary refs
|
||||||
|
self.prompt_cache["aux_ref_audio_paths"] = aux_ref_audio_paths or []
|
||||||
|
self.prompt_cache["refer_spec"] = [self.prompt_cache["refer_spec"][0]]
|
||||||
|
for p in aux_ref_audio_paths or []:
|
||||||
|
if p and os.path.exists(p):
|
||||||
|
self.prompt_cache["refer_spec"].append(self._get_ref_spec(p))
|
||||||
|
# Prompt text handling
|
||||||
|
no_prompt = prompt_text in [None, ""]
|
||||||
|
if not no_prompt:
|
||||||
|
prompt_text = prompt_text.strip("\n")
|
||||||
|
if prompt_text and prompt_text[-1] not in splits:
|
||||||
|
prompt_text += "。" if prompt_lang != "en" else "."
|
||||||
|
phones_p, bert_p, norm_p = self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||||
|
prompt_text, prompt_lang, self.configs.version
|
||||||
|
)
|
||||||
|
self.prompt_cache.update({
|
||||||
|
"prompt_text": prompt_text,
|
||||||
|
"prompt_lang": prompt_lang,
|
||||||
|
"phones": phones_p,
|
||||||
|
"bert_features": bert_p,
|
||||||
|
"norm_text": norm_p,
|
||||||
|
})
|
||||||
|
# Text to semantic preprocessing
|
||||||
|
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
|
||||||
|
if not data:
|
||||||
|
sr = self.vocoder_configs["sr"] if self.configs.use_vocoder else self.configs.sampling_rate
|
||||||
|
yield sr, np.zeros(1, dtype=np.int16)
|
||||||
|
return
|
||||||
|
# Single-batch conversion
|
||||||
|
batches, _ = self.to_batch(
|
||||||
|
data,
|
||||||
|
prompt_data=None if no_prompt else self.prompt_cache,
|
||||||
|
batch_size=1,
|
||||||
|
threshold=batch_threshold,
|
||||||
|
split_bucket=False,
|
||||||
|
device=self.configs.device,
|
||||||
|
precision=self.precision,
|
||||||
|
)
|
||||||
|
item = batches[0]
|
||||||
|
phones = item["phones"][0]
|
||||||
|
all_ids = item["all_phones"][0]
|
||||||
|
all_lens = item["all_phones_len"][0]
|
||||||
|
all_bert = item["all_bert_features"][0]
|
||||||
|
max_len = item["max_len"]
|
||||||
|
# Prepare semantic prompt
|
||||||
|
if not no_prompt:
|
||||||
|
prompt_sem = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
|
||||||
|
else:
|
||||||
|
prompt_sem = None
|
||||||
|
# Reference spectrograms
|
||||||
|
refer_spec = [s.to(dtype=self.precision, device=self.configs.device) for s in self.prompt_cache["refer_spec"]]
|
||||||
|
# Streaming via generator
|
||||||
|
from GPT_SoVITS.TTS_infer_pack.zero_crossing import find_zero_zone, find_matching_index
|
||||||
|
zc_idx1 = zc_idx2 = crossing_dir = 0
|
||||||
|
first = True
|
||||||
|
last = False
|
||||||
|
gen_list = []
|
||||||
|
for gen_tokens in self.t2s_model.model.infer_panel_generator(
|
||||||
|
all_ids.unsqueeze(0).to(self.configs.device),
|
||||||
|
all_lens.unsqueeze(0).to(self.configs.device),
|
||||||
|
prompt_sem,
|
||||||
|
all_bert.unsqueeze(0).to(self.configs.device),
|
||||||
|
cumulation_amount=cumulation_amount,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||||
|
max_len=max_len,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
):
|
||||||
|
gen_list.append(gen_tokens)
|
||||||
|
total = sum([t.size(1) for t in gen_list])
|
||||||
|
toks = torch.cat(gen_list, dim=1)[:, :total]
|
||||||
|
eos = self.t2s_model.model.EOS
|
||||||
|
has_eos = (toks == eos).any()
|
||||||
|
if has_eos:
|
||||||
|
toks = toks.masked_fill(toks == eos, 0)
|
||||||
|
last = True
|
||||||
|
first = False
|
||||||
|
# Decode to waveform
|
||||||
|
pred = toks.unsqueeze(0)
|
||||||
|
phone_t = phones.unsqueeze(0).to(self.configs.device)
|
||||||
|
if not self.configs.use_vocoder:
|
||||||
|
w = self.vits_model.decode(pred, phone_t, refer_spec, speed=speed_factor).detach()[0,0,:]
|
||||||
|
else:
|
||||||
|
w = self.using_vocoder_synthesis(pred, phone_t, speed=speed_factor, sample_steps=sample_steps)
|
||||||
|
w = w.cpu().numpy().astype(np.float32)
|
||||||
|
mv = np.abs(w).max()
|
||||||
|
if mv > 1.0:
|
||||||
|
w /= mv
|
||||||
|
# Zero-cross splitting
|
||||||
|
start = len(w) - search_length
|
||||||
|
if start < 0:
|
||||||
|
search_length = len(w)
|
||||||
|
start = 0
|
||||||
|
center = zc_idx2
|
||||||
|
off = int(search_length // 2)
|
||||||
|
sr = self.vocoder_configs["sr"] if self.configs.use_vocoder else self.configs.sampling_rate
|
||||||
|
if first:
|
||||||
|
zc_idx1, crossing_dir = find_zero_zone(w, start, search_length, num_zeroes)
|
||||||
|
frag = w[:zc_idx1]
|
||||||
|
print(len(frag))
|
||||||
|
frag_int16 = (frag * np.iinfo(np.int16).max).astype(np.int16)
|
||||||
|
yield sr, frag_int16
|
||||||
|
first = False
|
||||||
|
zc_idx2 = zc_idx1
|
||||||
|
elif last:
|
||||||
|
zc1 = find_matching_index(w, center, off, crossing_dir)
|
||||||
|
frag = w[zc1:]
|
||||||
|
print(len(frag))
|
||||||
|
frag_int16 = (frag * np.iinfo(np.int16).max).astype(np.int16)
|
||||||
|
yield sr, frag_int16
|
||||||
|
zc_idx2 = zc_idx1
|
||||||
|
else:
|
||||||
|
zc1 = find_matching_index(w, center, off, crossing_dir)
|
||||||
|
zc_idx1, crossing_dir = find_zero_zone(w, start, search_length, num_zeroes)
|
||||||
|
frag = w[zc1:zc_idx1]
|
||||||
|
print(len(frag))
|
||||||
|
frag_int16 = (frag * np.iinfo(np.int16).max).astype(np.int16)
|
||||||
|
yield sr, frag_int16
|
||||||
|
zc_idx2 = zc_idx1
|
||||||
|
self.empty_cache()
|
||||||
|
@ -9,15 +9,15 @@ sys.path.append(now_dir)
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from text.LangSegmenter import LangSegmenter
|
from GPT_SoVITS.text.LangSegmenter import LangSegmenter
|
||||||
from text import chinese
|
from GPT_SoVITS.text import chinese
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
from text.cleaner import clean_text
|
from GPT_SoVITS.text.cleaner import clean_text
|
||||||
from text import cleaned_text_to_sequence
|
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
|
203
GPT_SoVITS/TTS_infer_pack/zero_crossing.py
Normal file
203
GPT_SoVITS/TTS_infer_pack/zero_crossing.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import numpy as np
|
||||||
|
import wave
|
||||||
|
import struct
|
||||||
|
|
||||||
|
def read_wav_file(filename):
|
||||||
|
"""
|
||||||
|
Reads a WAV file and returns the sample rate and data as a numpy array.
|
||||||
|
"""
|
||||||
|
with wave.open(filename, 'rb') as wf:
|
||||||
|
sample_rate = wf.getframerate()
|
||||||
|
n_frames = wf.getnframes()
|
||||||
|
sample_width = wf.getsampwidth()
|
||||||
|
n_channels = wf.getnchannels()
|
||||||
|
|
||||||
|
audio_data = wf.readframes(n_frames)
|
||||||
|
# Determine the format string for struct unpacking
|
||||||
|
fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * n_frames * n_channels
|
||||||
|
audio_samples = struct.unpack(fmt, audio_data)
|
||||||
|
audio_array = np.array(audio_samples, dtype=int)
|
||||||
|
|
||||||
|
# If stereo, reshape the array
|
||||||
|
if n_channels > 1:
|
||||||
|
audio_array = audio_array.reshape(-1, n_channels)
|
||||||
|
return sample_rate, audio_array, sample_width, n_channels
|
||||||
|
|
||||||
|
def write_wav_file(filename, sample_rate, data, sample_width, n_channels):
|
||||||
|
"""
|
||||||
|
Writes numpy array data to a WAV file.
|
||||||
|
"""
|
||||||
|
with wave.open(filename, 'wb') as wf:
|
||||||
|
wf.setnchannels(n_channels)
|
||||||
|
wf.setsampwidth(sample_width)
|
||||||
|
wf.setframerate(sample_rate)
|
||||||
|
# Flatten the array if it's multi-dimensional
|
||||||
|
if data.ndim > 1:
|
||||||
|
data = data.flatten()
|
||||||
|
# Pack the data into bytes
|
||||||
|
fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * len(data)
|
||||||
|
byte_data = struct.pack(fmt, *data)
|
||||||
|
wf.writeframes(byte_data)
|
||||||
|
|
||||||
|
def find_zero_zone(chunk, start_index, search_length, num_zeroes=11):
|
||||||
|
zone = chunk[start_index:start_index + search_length]
|
||||||
|
print(f"Zero-crossing search zone: Start={start_index}, Length={len(zone)}")
|
||||||
|
|
||||||
|
zero_threshold = 1.0e-4
|
||||||
|
# Check for y consecutive zeros
|
||||||
|
for idx in range(len(zone), -1 + num_zeroes, -1):
|
||||||
|
index_to_start = idx-num_zeroes
|
||||||
|
abs_zone = np.abs(zone[index_to_start:idx])
|
||||||
|
if np.all(abs_zone < zero_threshold):
|
||||||
|
index_midpoint = index_to_start + int(num_zeroes // 2)
|
||||||
|
return (start_index + index_midpoint), None
|
||||||
|
|
||||||
|
print("Falling back to zero crossing due to no zero zone found. You may hear more prominent pops and clicks in the audio. Try increasing search length or cumulative tokens.")
|
||||||
|
return find_zero_crossing(chunk, start_index, search_length)
|
||||||
|
|
||||||
|
def find_zero_crossing(chunk, start_index, search_length):
|
||||||
|
# If the model is falling back on the this function, it might be a bad indicator that the search length is too low
|
||||||
|
|
||||||
|
zone = chunk[start_index:start_index + search_length]
|
||||||
|
sign_changes = np.where(np.diff(np.sign(zone)) != 0)[0]
|
||||||
|
|
||||||
|
if len(sign_changes) == 0:
|
||||||
|
raise ("No zero-crossings found in this zone. This should not be happening, debugging time.")
|
||||||
|
else:
|
||||||
|
zc_index = start_index + sign_changes[0] + 1
|
||||||
|
print(f"Zero-crossing found at index {zc_index}")
|
||||||
|
# Determine the crossing direction in chunk1
|
||||||
|
prev_value = chunk[zc_index - 1]
|
||||||
|
curr_value = chunk[zc_index]
|
||||||
|
crossing_direction = np.sign(curr_value) - np.sign(prev_value)
|
||||||
|
print(f"Crossing direction in chunk1: {np.sign(prev_value)} to {np.sign(curr_value)}")
|
||||||
|
return zc_index, crossing_direction
|
||||||
|
|
||||||
|
def find_matching_index(chunk, center_index, max_offset, crossing_direction):
|
||||||
|
"""
|
||||||
|
Finds a zero-crossing in data that matches the specified crossing direction,
|
||||||
|
starting from center_index and searching outward.
|
||||||
|
"""
|
||||||
|
if crossing_direction == None:
|
||||||
|
return center_index # if zero zone
|
||||||
|
|
||||||
|
# fall back for zero_crossing
|
||||||
|
data_length = len(chunk)
|
||||||
|
print(f"Center index in chunk2: {center_index}")
|
||||||
|
for offset in range(max_offset + 1):
|
||||||
|
# Check index bounds
|
||||||
|
idx_forward = center_index + offset
|
||||||
|
idx_backward = center_index - offset
|
||||||
|
found = False
|
||||||
|
|
||||||
|
# Check forward direction
|
||||||
|
if idx_forward < data_length - 1:
|
||||||
|
prev_sign = np.sign(chunk[idx_forward])
|
||||||
|
curr_sign = np.sign(chunk[idx_forward + 1])
|
||||||
|
direction = curr_sign - prev_sign
|
||||||
|
if direction == crossing_direction:
|
||||||
|
print(f"Matching zero-crossing found at index {idx_forward + 1} (forward)")
|
||||||
|
return idx_forward + 1
|
||||||
|
|
||||||
|
# Check backward direction
|
||||||
|
if idx_backward > 0:
|
||||||
|
prev_sign = np.sign(chunk[idx_backward - 1])
|
||||||
|
curr_sign = np.sign(chunk[idx_backward])
|
||||||
|
direction = curr_sign - prev_sign
|
||||||
|
if direction == crossing_direction:
|
||||||
|
print(f"Matching zero-crossing found at index {idx_backward} (backward)")
|
||||||
|
return idx_backward
|
||||||
|
|
||||||
|
print("No matching zero-crossings found in this zone.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# legacy, just for history. delete me sometime
|
||||||
|
def splice_chunks(chunk1, chunk2, search_length, y):
|
||||||
|
"""
|
||||||
|
Splices two audio chunks at zero-crossing points.
|
||||||
|
"""
|
||||||
|
# Define the zone to search in chunk1
|
||||||
|
start_index1 = len(chunk1) - search_length
|
||||||
|
if start_index1 < 0:
|
||||||
|
start_index1 = 0
|
||||||
|
search_length = len(chunk1)
|
||||||
|
print(f"Searching for zero-crossing in chunk1 from index {start_index1} to {len(chunk1)}")
|
||||||
|
# Find zero-crossing in chunk1
|
||||||
|
zc_index1, crossing_direction = find_zero_crossing(chunk1, start_index1, search_length, y)
|
||||||
|
if zc_index1 is None:
|
||||||
|
print("No zero-crossing found in chunk1 within the specified zone.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Define the zone to search in chunk2 near the same index
|
||||||
|
# Since chunk2 overlaps with chunk1, we can assume that index positions correspond
|
||||||
|
# Adjusted search in chunk2
|
||||||
|
# You can adjust this value if needed
|
||||||
|
center_index = zc_index1 # Assuming alignment between chunk1 and chunk2
|
||||||
|
max_offset = search_length
|
||||||
|
|
||||||
|
# Ensure center_index is within bounds
|
||||||
|
if center_index < 0:
|
||||||
|
center_index = 0
|
||||||
|
elif center_index >= len(chunk2):
|
||||||
|
center_index = len(chunk2) - 1
|
||||||
|
|
||||||
|
print(f"Searching for matching zero-crossing in chunk2 around index {center_index} with max offset {max_offset}")
|
||||||
|
|
||||||
|
zc_index2 = find_matching_zero_crossing(chunk2, center_index, max_offset, crossing_direction)
|
||||||
|
|
||||||
|
if zc_index2 is None:
|
||||||
|
print("No matching zero-crossing found in chunk2.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Zero-crossing in chunk1 at index {zc_index1}, chunk2 at index {zc_index2}")
|
||||||
|
# Splice the chunks
|
||||||
|
new_chunk = np.concatenate((chunk1[:zc_index1], chunk2[zc_index2:]))
|
||||||
|
print(f"Spliced chunk length: {len(new_chunk)}")
|
||||||
|
return new_chunk
|
||||||
|
|
||||||
|
# legacy, just for history. delete me sometime
|
||||||
|
def process_audio_chunks(filenames, sample_rate, x, y, output_filename):
|
||||||
|
"""
|
||||||
|
Processes and splices a list of audio chunks.
|
||||||
|
"""
|
||||||
|
# Read the first chunk
|
||||||
|
sr, chunk_data, sample_width, n_channels = read_wav_file(filenames[0])
|
||||||
|
if sr != sample_rate:
|
||||||
|
print(f"Sample rate mismatch in {filenames[0]}")
|
||||||
|
return
|
||||||
|
print(f"Processing {filenames[0]}")
|
||||||
|
# Initialize the combined audio with the first chunk
|
||||||
|
combined_audio = chunk_data
|
||||||
|
# Process remaining chunks
|
||||||
|
for filename in filenames[1:]:
|
||||||
|
sr, next_chunk_data, _, _ = read_wav_file(filename)
|
||||||
|
if sr != sample_rate:
|
||||||
|
print(f"Sample rate mismatch in {filename}")
|
||||||
|
return
|
||||||
|
print(f"Processing {filename}")
|
||||||
|
# Splice the current combined audio with the next chunk
|
||||||
|
new_combined = splice_chunks(combined_audio, next_chunk_data, x, y)
|
||||||
|
if new_combined is None:
|
||||||
|
print(f"Failed to splice chunks between {filename} and previous chunk.")
|
||||||
|
return
|
||||||
|
combined_audio = new_combined
|
||||||
|
# Write the final combined audio to output file
|
||||||
|
write_wav_file(output_filename, sample_rate, combined_audio, sample_width, n_channels)
|
||||||
|
print(f"Final audio saved to {output_filename}")
|
||||||
|
|
||||||
|
# Main execution
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# User-specified parameters
|
||||||
|
sample_rate = 32000 # Sample rate in Hz
|
||||||
|
x = 500 # Number of frames to search from the end of the chunk
|
||||||
|
y = 10 # Number of consecutive zeros to look for
|
||||||
|
output_filename = "combined_output.wav"
|
||||||
|
folder_with_chunks = "output_chunks"
|
||||||
|
import os
|
||||||
|
def absolute_file_paths(directory):
|
||||||
|
path = os.path.abspath(directory)
|
||||||
|
return [entry.path for entry in os.scandir(path) if entry.is_file()]
|
||||||
|
# List of input audio chunk filenames in sequential order
|
||||||
|
filenames = absolute_file_paths(folder_with_chunks)
|
||||||
|
# Process and splice the audio chunks
|
||||||
|
process_audio_chunks(filenames, sample_rate, x, y, output_filename)
|
@ -116,8 +116,10 @@ import soundfile as sf
|
|||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Response
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from importlib.resources import files
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from tools.i18n.i18n import I18nAuto
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto
|
||||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -127,7 +129,7 @@ i18n = I18nAuto()
|
|||||||
cut_method_names = get_cut_method_names()
|
cut_method_names = get_cut_method_names()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
parser.add_argument("-c", "--tts_config", type=str, default=None, help="tts_infer路径")
|
||||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||||
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -138,7 +140,7 @@ host = args.bind_addr
|
|||||||
argv = sys.argv
|
argv = sys.argv
|
||||||
|
|
||||||
if config_path in [None, ""]:
|
if config_path in [None, ""]:
|
||||||
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
|
config_path = str(files("GPT_SoVITS").joinpath("configs/tts_infer.yaml"))
|
||||||
|
|
||||||
tts_config = TTS_Config(config_path)
|
tts_config = TTS_Config(config_path)
|
||||||
print(tts_config)
|
print(tts_config)
|
||||||
@ -434,7 +436,7 @@ async def tts_get_endpoint(
|
|||||||
|
|
||||||
@APP.post("/tts")
|
@APP.post("/tts")
|
||||||
async def tts_post_endpoint(request: TTS_Request):
|
async def tts_post_endpoint(request: TTS_Request):
|
||||||
req = request.dict()
|
req = request.model_dump()
|
||||||
return await tts_handle(req)
|
return await tts_handle(req)
|
||||||
|
|
||||||
|
|
||||||
@ -498,3 +500,6 @@ if __name__ == "__main__":
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
os.kill(os.getpid(), signal.SIGTERM)
|
os.kill(os.getpid(), signal.SIGTERM)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -3,7 +3,7 @@ custom:
|
|||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
device: cuda
|
device: cuda
|
||||||
is_half: true
|
is_half: true
|
||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||||
version: v2
|
version: v2
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||||
v1:
|
v1:
|
||||||
|
@ -3,11 +3,5 @@ import sys
|
|||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.insert(0, now_dir)
|
sys.path.insert(0, now_dir)
|
||||||
from text.g2pw import G2PWPinyin
|
from GPT_SoVITS.text.g2pw import G2PWPinyin
|
||||||
|
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
||||||
g2pw = G2PWPinyin(
|
|
||||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
|
||||||
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
|
||||||
v_to_u=False,
|
|
||||||
neutral_tone_with_five=True,
|
|
||||||
)
|
|
||||||
|
@ -12,8 +12,8 @@ from torch.nn import functional as F
|
|||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
|
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from module.models_onnx import SynthesizerTrn
|
from GPT_SoVITS.module.models_onnx import SynthesizerTrn
|
||||||
|
|
||||||
from inference_webui import get_phones_and_bert
|
from inference_webui import get_phones_and_bert
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from transformers import (
|
|||||||
HubertModel,
|
HubertModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
import utils
|
import GPT_SoVITS.utils
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
cnhubert_base_path = None
|
cnhubert_base_path = None
|
||||||
|
@ -2,7 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto
|
||||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
@ -5,8 +5,7 @@ from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushB
|
|||||||
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
|
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||||
@ -305,7 +304,7 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
|
|
||||||
result = "Audio saved to " + output_wav_path
|
result = "Audio saved to " + output_wav_path
|
||||||
|
|
||||||
self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
|
self.status_bGPT_SoVITS.AR.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
|
||||||
self.output_text.append("处理结果:\n" + result)
|
self.output_text.append("处理结果:\n" + result)
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,12 +124,12 @@ def set_seed(seed):
|
|||||||
|
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
|
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from text import cleaned_text_to_sequence
|
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||||
from text.cleaner import clean_text
|
from GPT_SoVITS.text.cleaner import clean_text
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
@ -165,8 +165,8 @@ dict_language_v2 = {
|
|||||||
}
|
}
|
||||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
tokenizer = AutoTokenizer.from_pretrained(bert_path, local_files_only=True)
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, local_files_only=True)
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
bert_model = bert_model.half().to(device)
|
bert_model = bert_model.half().to(device)
|
||||||
else:
|
else:
|
||||||
@ -406,6 +406,7 @@ def init_bigvgan():
|
|||||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||||
use_cuda_kernel=False,
|
use_cuda_kernel=False,
|
||||||
|
local_files_only=True
|
||||||
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||||
# remove weight norm in the model and set to eval mode
|
# remove weight norm in the model and set to eval mode
|
||||||
bigvgan_model.remove_weight_norm()
|
bigvgan_model.remove_weight_norm()
|
||||||
@ -518,11 +519,8 @@ def get_first(text):
|
|||||||
text = re.split(pattern, text)[0].strip()
|
text = re.split(pattern, text)[0].strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
from GPT_SoVITS.text import chinese
|
||||||
from text import chinese
|
def get_phones_and_bert(text,language,version,final=False):
|
||||||
|
|
||||||
|
|
||||||
def get_phones_and_bert(text, language, version, final=False):
|
|
||||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||||
formattext = text
|
formattext = text
|
||||||
while " " in formattext:
|
while " " in formattext:
|
||||||
|
@ -50,10 +50,9 @@ bert_path = os.environ.get("bert_path", None)
|
|||||||
version = model_version = os.environ.get("version", "v2")
|
version = model_version = os.environ.get("version", "v2")
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from TTS_infer_pack.text_segmentation_method import get_method
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||||
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method
|
||||||
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
|
@ -3,8 +3,8 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from module import commons
|
from GPT_SoVITS.module import commons
|
||||||
from module.modules import LayerNorm
|
from GPT_SoVITS.module.modules import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
@ -325,7 +325,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
def _attention_bias_proximal(self, length):
|
def _attention_bias_proximal(self, length):
|
||||||
"""Bias for self-attention to encourage attention to close positions.
|
"""Bias for self-attention to encourage attention to close positions.
|
||||||
Args:
|
Args:
|
||||||
length: an integer scalar.
|
length: an integer scalGPT_SoVITS.AR.
|
||||||
Returns:
|
Returns:
|
||||||
a Tensor with shape [1, 1, length, length]
|
a Tensor with shape [1, 1, length, length]
|
||||||
"""
|
"""
|
||||||
|
@ -3,7 +3,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from module import commons
|
from GPT_SoVITS.module import commons
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
def _attention_bias_proximal(self, length):
|
def _attention_bias_proximal(self, length):
|
||||||
"""Bias for self-attention to encourage attention to close positions.
|
"""Bias for self-attention to encourage attention to close positions.
|
||||||
Args:
|
Args:
|
||||||
length: an integer scalar.
|
length: an integer scalGPT_SoVITS.AR.
|
||||||
Returns:
|
Returns:
|
||||||
a Tensor with shape [1, 1, length, length]
|
a Tensor with shape [1, 1, length, length]
|
||||||
"""
|
"""
|
||||||
|
@ -5,10 +5,10 @@ import torch
|
|||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
from GPT_SoVITS.module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
||||||
from text import cleaned_text_to_sequence
|
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from tools.my_utils import load_audio
|
from GPT_SoVITS.tools.my_utils import load_audio
|
||||||
|
|
||||||
version = os.environ.get("version", None)
|
version = os.environ.get("version", None)
|
||||||
|
|
||||||
|
@ -7,19 +7,19 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from module import commons
|
from GPT_SoVITS.module import commons
|
||||||
from module import modules
|
from GPT_SoVITS.module import modules
|
||||||
from module import attentions
|
from GPT_SoVITS.module import attentions
|
||||||
from f5_tts.model import DiT
|
from f5_tts.model import DiT
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
from module.commons import init_weights, get_padding
|
from GPT_SoVITS.module.commons import init_weights, get_padding
|
||||||
from module.mrte_model import MRTE
|
from GPT_SoVITS.module.mrte_model import MRTE
|
||||||
from module.quantize import ResidualVectorQuantizer
|
from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
|
||||||
|
|
||||||
# from text import symbols
|
# from text import symbols
|
||||||
from text import symbols as symbols_v1
|
from GPT_SoVITS.text import symbols as symbols_v1
|
||||||
from text import symbols2 as symbols_v2
|
from GPT_SoVITS.text import symbols2 as symbols_v2
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
import contextlib
|
import contextlib
|
||||||
import random
|
import random
|
||||||
|
@ -4,20 +4,20 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from module import commons
|
from GPT_SoVITS.module import commons
|
||||||
from module import modules
|
from GPT_SoVITS.module import modules
|
||||||
from module import attentions_onnx as attentions
|
from GPT_SoVITS.module import attentions_onnx as attentions
|
||||||
|
|
||||||
from f5_tts.model import DiT
|
from f5_tts.model import DiT
|
||||||
|
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
from module.commons import init_weights, get_padding
|
from GPT_SoVITS.module.commons import init_weights, get_padding
|
||||||
from module.quantize import ResidualVectorQuantizer
|
from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
|
||||||
|
|
||||||
# from text import symbols
|
# from text import symbols
|
||||||
from text import symbols as symbols_v1
|
from GPT_SoVITS.text import symbols as symbols_v1
|
||||||
from text import symbols2 as symbols_v2
|
from GPT_SoVITS.text import symbols2 as symbols_v2
|
||||||
|
|
||||||
|
|
||||||
class StochasticDurationPredictor(nn.Module):
|
class StochasticDurationPredictor(nn.Module):
|
||||||
|
@ -7,9 +7,9 @@ from torch.nn import functional as F
|
|||||||
from torch.nn import Conv1d
|
from torch.nn import Conv1d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||||
|
|
||||||
from module import commons
|
from GPT_SoVITS.module import commons
|
||||||
from module.commons import init_weights, get_padding
|
from GPT_SoVITS.module.commons import init_weights, get_padding
|
||||||
from module.transforms import piecewise_rational_quadratic_transform
|
from GPT_SoVITS.module.transforms import piecewise_rational_quadratic_transform
|
||||||
import torch.distributions as D
|
import torch.distributions as D
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
from module.attentions import MultiHeadAttention
|
from GPT_SoVITS.module.attentions import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
class MRTE(nn.Module):
|
class MRTE(nn.Module):
|
||||||
|
@ -12,7 +12,7 @@ import typing as tp
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from module.core_vq import ResidualVectorQuantization
|
from GPT_SoVITS.module.core_vq import ResidualVectorQuantization
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
from GPT_SoVITS.AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||||
from feature_extractor import cnhubert
|
from GPT_SoVITS.feature_extractor import cnhubert
|
||||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
from GPT_SoVITS.module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
ssl_model = cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
|
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||||
|
import soundfile
|
||||||
|
from GPT_SoVITS.tools.my_utils import load_audio
|
||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import soundfile
|
import soundfile
|
||||||
from text import cleaned_text_to_sequence
|
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
|
@ -17,9 +17,9 @@ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
|||||||
version = os.environ.get("version", None)
|
version = os.environ.get("version", None)
|
||||||
import traceback
|
import traceback
|
||||||
import os.path
|
import os.path
|
||||||
from text.cleaner import clean_text
|
from GPT_SoVITS.text.cleaner import clean_text
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
from tools.my_utils import clean_path
|
from GPT_SoVITS.tools.my_utils import clean_path
|
||||||
|
|
||||||
# inp_text=sys.argv[1]
|
# inp_text=sys.argv[1]
|
||||||
# inp_wav_dir=sys.argv[2]
|
# inp_wav_dir=sys.argv[2]
|
||||||
|
@ -25,7 +25,7 @@ import librosa
|
|||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
from tools.my_utils import load_audio, clean_path
|
from GPT_SoVITS.tools.my_utils import load_audio, clean_path
|
||||||
|
|
||||||
# from config import cnhubert_base_path
|
# from config import cnhubert_base_path
|
||||||
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||||
|
@ -38,10 +38,10 @@ import logging
|
|||||||
import utils
|
import utils
|
||||||
|
|
||||||
if version != "v3":
|
if version != "v3":
|
||||||
from module.models import SynthesizerTrn
|
from GPT_SoVITS.module.models import SynthesizerTrn
|
||||||
else:
|
else:
|
||||||
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
from GPT_SoVITS.module.models import SynthesizerTrnV3 as SynthesizerTrn
|
||||||
from tools.my_utils import clean_path
|
from GPT_SoVITS.tools.my_utils import clean_path
|
||||||
|
|
||||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
# from config import pretrained_s2G
|
# from config import pretrained_s2G
|
||||||
|
@ -4,7 +4,7 @@ from time import time as ttime
|
|||||||
import shutil
|
import shutil
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from tools.i18n.i18n import I18nAuto
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ import platform
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from AR.data.data_module import Text2SemanticDataModule
|
from GPT_SoVITS.AR.data.data_module import Text2SemanticDataModule
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from AR.utils.io import load_yaml_config
|
from GPT_SoVITS.AR.utils.io import load_yaml_config
|
||||||
from pytorch_lightning import Trainer, seed_everything
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||||
@ -20,10 +20,12 @@ from pytorch_lightning.strategies import DDPStrategy
|
|||||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
from GPT_SoVITS.AR.utils import get_newest_ckpt
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from AR.utils import get_newest_ckpt
|
from GPT_SoVITS.AR.utils import get_newest_ckpt
|
||||||
from process_ckpt import my_save
|
from GPT_SoVITS.process_ckpt import my_save
|
||||||
|
|
||||||
|
|
||||||
class my_model_ckpt(ModelCheckpoint):
|
class my_model_ckpt(ModelCheckpoint):
|
||||||
|
@ -24,19 +24,19 @@ logging.getLogger("h5py").setLevel(logging.INFO)
|
|||||||
logging.getLogger("numba").setLevel(logging.INFO)
|
logging.getLogger("numba").setLevel(logging.INFO)
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
from module import commons
|
from GPT_SoVITS.module import commons
|
||||||
from module.data_utils import (
|
from GPT_SoVITS.module.data_utils import (
|
||||||
DistributedBucketSampler,
|
DistributedBucketSampler,
|
||||||
TextAudioSpeakerCollate,
|
TextAudioSpeakerCollate,
|
||||||
TextAudioSpeakerLoader,
|
TextAudioSpeakerLoader,
|
||||||
)
|
)
|
||||||
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
from GPT_SoVITS.module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
||||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||||
from module.models import (
|
from GPT_SoVITS.module.models import (
|
||||||
MultiPeriodDiscriminator,
|
MultiPeriodDiscriminator,
|
||||||
SynthesizerTrn,
|
SynthesizerTrn,
|
||||||
)
|
)
|
||||||
from process_ckpt import savee
|
from GPT_SoVITS.process_ckpt import savee
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = False
|
torch.backends.cudnn.deterministic = False
|
||||||
@ -71,7 +71,7 @@ def main():
|
|||||||
def run(rank, n_gpus, hps):
|
def run(rank, n_gpus, hps):
|
||||||
global global_step
|
global global_step
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger = utils.get_logger(hps.data.exp_dir)
|
logger = GPT_SoVITS.utils.get_logger(hps.data.exp_dir)
|
||||||
logger.info(hps)
|
logger.info(hps)
|
||||||
# utils.check_git_hash(hps.s2_ckpt_dir)
|
# utils.check_git_hash(hps.s2_ckpt_dir)
|
||||||
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||||
@ -204,7 +204,7 @@ def run(rank, n_gpus, hps):
|
|||||||
net_d = net_d.to(device)
|
net_d = net_d.to(device)
|
||||||
|
|
||||||
try: # 如果能加载自动resume
|
try: # 如果能加载自动resume
|
||||||
_, _, _, epoch_str = utils.load_checkpoint(
|
_, _, _, epoch_str = GPT_SoVITS.utils.load_checkpoint(
|
||||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
|
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
|
||||||
net_d,
|
net_d,
|
||||||
optim_d,
|
optim_d,
|
||||||
@ -212,7 +212,7 @@ def run(rank, n_gpus, hps):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded D")
|
logger.info("loaded D")
|
||||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||||
_, _, _, epoch_str = utils.load_checkpoint(
|
_, _, _, epoch_str = GPT_SoVITS.utils.load_checkpoint(
|
||||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
|
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
|
||||||
net_g,
|
net_g,
|
||||||
optim_g,
|
optim_g,
|
||||||
@ -479,30 +479,30 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
image_dict = None
|
image_dict = None
|
||||||
try: ###Some people installed the wrong version of matplotlib.
|
try: ###Some people installed the wrong version of matplotlib.
|
||||||
image_dict = {
|
image_dict = {
|
||||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
"slice/mel_org": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
|
||||||
y_mel[0].data.cpu().numpy(),
|
y_mel[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
"slice/mel_gen": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
|
||||||
y_hat_mel[0].data.cpu().numpy(),
|
y_hat_mel[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
"all/mel": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
|
||||||
mel[0].data.cpu().numpy(),
|
mel[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
"all/stats_ssl": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
|
||||||
stats_ssl[0].data.cpu().numpy(),
|
stats_ssl[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
if image_dict:
|
if image_dict:
|
||||||
utils.summarize(
|
GPT_SoVITS.utils.summarize(
|
||||||
writer=writer,
|
writer=writer,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
images=image_dict,
|
images=image_dict,
|
||||||
scalars=scalar_dict,
|
scalars=scalar_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
utils.summarize(
|
GPT_SoVITS.utils.summarize(
|
||||||
writer=writer,
|
writer=writer,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
scalars=scalar_dict,
|
scalars=scalar_dict,
|
||||||
@ -510,7 +510,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||||
if hps.train.if_save_latest == 0:
|
if hps.train.if_save_latest == 0:
|
||||||
utils.save_checkpoint(
|
GPT_SoVITS.utils.save_checkpoint(
|
||||||
net_g,
|
net_g,
|
||||||
optim_g,
|
optim_g,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
@ -520,7 +520,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
"G_{}.pth".format(global_step),
|
"G_{}.pth".format(global_step),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
GPT_SoVITS.utils.save_checkpoint(
|
||||||
net_d,
|
net_d,
|
||||||
optim_d,
|
optim_d,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
@ -531,7 +531,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
utils.save_checkpoint(
|
GPT_SoVITS.utils.save_checkpoint(
|
||||||
net_g,
|
net_g,
|
||||||
optim_g,
|
optim_g,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
@ -541,7 +541,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
"G_{}.pth".format(233333333333),
|
"G_{}.pth".format(233333333333),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
GPT_SoVITS.utils.save_checkpoint(
|
||||||
net_d,
|
net_d,
|
||||||
optim_d,
|
optim_d,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
@ -644,7 +644,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
)
|
)
|
||||||
image_dict.update(
|
image_dict.update(
|
||||||
{
|
{
|
||||||
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
f"gen/mel_{batch_idx}_{test}": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
|
||||||
y_hat_mel[0].cpu().numpy(),
|
y_hat_mel[0].cpu().numpy(),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@ -656,7 +656,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
)
|
)
|
||||||
image_dict.update(
|
image_dict.update(
|
||||||
{
|
{
|
||||||
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
|
f"gt/mel_{batch_idx}": GPT_SoVITS.utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
||||||
@ -666,7 +666,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
# f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
|
# f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
|
||||||
# })
|
# })
|
||||||
|
|
||||||
utils.summarize(
|
GPT_SoVITS.utils.summarize(
|
||||||
writer=writer_eval,
|
writer=writer_eval,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
images=image_dict,
|
images=image_dict,
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
# if os.environ.get("version","v1")=="v1":
|
# if os.environ.get("version","v1")=="v1":
|
||||||
# from text.symbols import symbols
|
# from GPT_SoVITS.text.symbols import symbols
|
||||||
# else:
|
# else:
|
||||||
# from text.symbols2 import symbols
|
# from GPT_SoVITS.text.symbols2 import symbols
|
||||||
|
|
||||||
from text import symbols as symbols_v1
|
from GPT_SoVITS.text import symbols as symbols_v1
|
||||||
from text import symbols2 as symbols_v2
|
from GPT_SoVITS.text import symbols2 as symbols_v2
|
||||||
|
|
||||||
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
||||||
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
||||||
|
@ -4,8 +4,8 @@ import re
|
|||||||
import cn2an
|
import cn2an
|
||||||
import ToJyutping
|
import ToJyutping
|
||||||
|
|
||||||
from text.symbols import punctuation
|
from GPT_SoVITS.text.symbols import punctuation
|
||||||
from text.zh_normalization.text_normlization import TextNormalizer
|
from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
|
||||||
|
|
||||||
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ def get_jyutping(text):
|
|||||||
|
|
||||||
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
from text import chinese_bert
|
from GPT_SoVITS.text import chinese_bert
|
||||||
|
|
||||||
return chinese_bert.get_bert_feature(text, word2ph)
|
return chinese_bert.get_bert_feature(text, word2ph)
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@ import re
|
|||||||
import cn2an
|
import cn2an
|
||||||
from pypinyin import lazy_pinyin, Style
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
|
||||||
from text.symbols import punctuation
|
from GPT_SoVITS.text.symbols import punctuation
|
||||||
from text.tone_sandhi import ToneSandhi
|
from GPT_SoVITS.text.tone_sandhi import ToneSandhi
|
||||||
from text.zh_normalization.text_normlization import TextNormalizer
|
from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
|
||||||
|
|
||||||
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
||||||
|
|
||||||
|
@ -5,9 +5,9 @@ import cn2an
|
|||||||
from pypinyin import lazy_pinyin, Style
|
from pypinyin import lazy_pinyin, Style
|
||||||
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
||||||
|
|
||||||
from text.symbols import punctuation
|
from GPT_SoVITS.text.symbols import punctuation
|
||||||
from text.tone_sandhi import ToneSandhi
|
from GPT_SoVITS.text.tone_sandhi import ToneSandhi
|
||||||
from text.zh_normalization.text_normlization import TextNormalizer
|
from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
|
||||||
|
|
||||||
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ import jieba_fast.posseg as psg
|
|||||||
is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False
|
is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False
|
||||||
if is_g2pw:
|
if is_g2pw:
|
||||||
# print("当前使用g2pw进行拼音推理")
|
# print("当前使用g2pw进行拼音推理")
|
||||||
from text.g2pw import G2PWPinyin, correct_pronunciation
|
from GPT_SoVITS.text.g2pw import G2PWPinyin, correct_pronunciation
|
||||||
|
|
||||||
parent_directory = os.path.dirname(current_file_path)
|
parent_directory = os.path.dirname(current_file_path)
|
||||||
g2pw = G2PWPinyin(
|
g2pw = G2PWPinyin(
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
from text import cleaned_text_to_sequence
|
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||||
import os
|
import os
|
||||||
# if os.environ.get("version","v1")=="v1":
|
# if os.environ.get("version","v1")=="v1":
|
||||||
# from text import chinese
|
# from GPT_SoVITS.text import chinese
|
||||||
# from text.symbols import symbols
|
# from GPT_SoVITS.text.symbols import symbols
|
||||||
# else:
|
# else:
|
||||||
# from text import chinese2 as chinese
|
# from GPT_SoVITS.text import chinese2 as chinese
|
||||||
# from text.symbols2 import symbols
|
# from GPT_SoVITS.text.symbols2 import symbols
|
||||||
|
|
||||||
from text import symbols as symbols_v1
|
from GPT_SoVITS.text import symbols as symbols_v1
|
||||||
from text import symbols2 as symbols_v2
|
from GPT_SoVITS.text import symbols2 as symbols_v2
|
||||||
|
|
||||||
special = [
|
special = [
|
||||||
# ("%", "zh", "SP"),
|
# ("%", "zh", "SP"),
|
||||||
@ -34,7 +34,7 @@ def clean_text(text, language, version=None):
|
|||||||
for special_s, special_l, target_symbol in special:
|
for special_s, special_l, target_symbol in special:
|
||||||
if special_s in text and language == special_l:
|
if special_s in text and language == special_l:
|
||||||
return clean_special(text, language, special_s, target_symbol, version)
|
return clean_special(text, language, special_s, target_symbol, version)
|
||||||
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
|
language_module = __import__("GPT_SoVITS.text." + language_module_map[language], fromlist=[language_module_map[language]])
|
||||||
if hasattr(language_module, "text_normalize"):
|
if hasattr(language_module, "text_normalize"):
|
||||||
norm_text = language_module.text_normalize(text)
|
norm_text = language_module.text_normalize(text)
|
||||||
else:
|
else:
|
||||||
@ -69,7 +69,7 @@ def clean_special(text, language, special_s, target_symbol, version=None):
|
|||||||
特殊静音段sp符号处理
|
特殊静音段sp符号处理
|
||||||
"""
|
"""
|
||||||
text = text.replace(special_s, ",")
|
text = text.replace(special_s, ",")
|
||||||
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
|
language_module = __import__("GPT_SoVITS.text."+language_module_map[language],fromlist=[language_module_map[language]])
|
||||||
norm_text = language_module.text_normalize(text)
|
norm_text = language_module.text_normalize(text)
|
||||||
phones = language_module.g2p(norm_text)
|
phones = language_module.g2p(norm_text)
|
||||||
new_ph = []
|
new_ph = []
|
||||||
|
@ -4,9 +4,9 @@ import re
|
|||||||
import wordsegment
|
import wordsegment
|
||||||
from g2p_en import G2p
|
from g2p_en import G2p
|
||||||
|
|
||||||
from text.symbols import punctuation
|
from GPT_SoVITS.text.symbols import punctuation
|
||||||
|
|
||||||
from text.symbols2 import symbols
|
from GPT_SoVITS.text.symbols2 import symbols
|
||||||
|
|
||||||
from builtins import str as unicode
|
from builtins import str as unicode
|
||||||
from text.en_normalization.expend import normalize
|
from text.en_normalization.expend import normalize
|
||||||
|
@ -1 +1 @@
|
|||||||
from text.g2pw.g2pw import *
|
from GPT_SoVITS.text.g2pw.g2pw import *
|
||||||
|
@ -77,8 +77,7 @@ except Exception:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
from text.symbols import punctuation
|
from GPT_SoVITS.text.symbols import punctuation
|
||||||
|
|
||||||
# Regular expression matching Japanese without punctuation marks:
|
# Regular expression matching Japanese without punctuation marks:
|
||||||
_japanese_characters = re.compile(
|
_japanese_characters = re.compile(
|
||||||
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
||||||
|
@ -56,7 +56,7 @@ if os.name == "nt":
|
|||||||
G2p = win_G2p
|
G2p = win_G2p
|
||||||
|
|
||||||
|
|
||||||
from text.symbols2 import symbols
|
from GPT_SoVITS.text.symbols2 import symbols
|
||||||
|
|
||||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||||
_korean_classifiers = (
|
_korean_classifiers = (
|
||||||
|
@ -11,4 +11,4 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from text.zh_normalization.text_normlization import *
|
from GPT_SoVITS.text.zh_normalization.text_normlization import *
|
||||||
|
@ -9,7 +9,7 @@ import torch
|
|||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from tools.asr.config import check_fw_local_models
|
from GPT_SoVITS.tools.asr.config import check_fw_local_models
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
language_code_list = [
|
language_code_list = [
|
||||||
@ -72,8 +72,13 @@ def execute_asr(input_folder, output_folder, model_size, language, precision):
|
|||||||
|
|
||||||
if info.language == "zh":
|
if info.language == "zh":
|
||||||
print("检测为中文文本, 转 FunASR 处理")
|
print("检测为中文文本, 转 FunASR 处理")
|
||||||
|
<<<<<<< HEAD:tools/asr/fasterwhisper_asr.py
|
||||||
if "only_asr" not in globals():
|
if "only_asr" not in globals():
|
||||||
from tools.asr.funasr_asr import only_asr # 如果用英文就不需要导入下载模型
|
from tools.asr.funasr_asr import only_asr # 如果用英文就不需要导入下载模型
|
||||||
|
=======
|
||||||
|
if("only_asr" not in globals()):
|
||||||
|
from GPT_SoVITS.tools.asr.funasr_asr import only_asr #如果用英文就不需要导入下载模型
|
||||||
|
>>>>>>> main:GPT_SoVITS/tools/asr/fasterwhisper_asr.py
|
||||||
text = only_asr(file_path, language=info.language.lower())
|
text = only_asr(file_path, language=info.language.lower())
|
||||||
|
|
||||||
if text == "":
|
if text == "":
|
@ -9,8 +9,8 @@ import torch
|
|||||||
import torchaudio.functional as aF
|
import torchaudio.functional as aF
|
||||||
# from attrdict import AttrDict####will be bug in py3.10
|
# from attrdict import AttrDict####will be bug in py3.10
|
||||||
|
|
||||||
from datasets1.dataset import amp_pha_stft, amp_pha_istft
|
from GPT_SoVITS.tools.AP_BWE_main.datasets1.dataset import amp_pha_stft, amp_pha_istft
|
||||||
from models.model import APNet_BWE_Model
|
from GPT_SoVITS.tools.AP_BWE_main.models.model import APNet_BWE_Model
|
||||||
|
|
||||||
|
|
||||||
class AP_BWE:
|
class AP_BWE:
|
@ -101,6 +101,7 @@
|
|||||||
"实际输入的目标文本(每句):": "Texto alvo realmente inserido (por frase):",
|
"实际输入的目标文本(每句):": "Texto alvo realmente inserido (por frase):",
|
||||||
"实际输入的目标文本:": "Texto alvo realmente inserido:",
|
"实际输入的目标文本:": "Texto alvo realmente inserido:",
|
||||||
"导出文件格式": "Formato de arquivo de exportação",
|
"导出文件格式": "Formato de arquivo de exportação",
|
||||||
|
<<<<<<< HEAD:tools/i18n/locale/pt_BR.json
|
||||||
"已关闭": " Fechado",
|
"已关闭": " Fechado",
|
||||||
"已完成": " Concluído",
|
"已完成": " Concluído",
|
||||||
"已开启": " Ativado",
|
"已开启": " Ativado",
|
||||||
@ -110,6 +111,21 @@
|
|||||||
"开启": "Ativar ",
|
"开启": "Ativar ",
|
||||||
"开启无参考文本模式。不填参考文本亦相当于开启。": "Ativar o modo sem texto de referência. Não preencher o texto de referência também equivale a ativar.",
|
"开启无参考文本模式。不填参考文本亦相当于开启。": "Ativar o modo sem texto de referência. Não preencher o texto de referência também equivale a ativar.",
|
||||||
"微调训练": "Treinamento de ajuste fino",
|
"微调训练": "Treinamento de ajuste fino",
|
||||||
|
=======
|
||||||
|
"开启GPT训练": "Ativar treinamento GPT",
|
||||||
|
"开启SSL提取": "Ativar extração SSL",
|
||||||
|
"开启SoVITS训练": "Ativar treinamento SoVITS",
|
||||||
|
"开启TTS推理WebUI": "Abrir TTS Inference WebUI",
|
||||||
|
"开启UVR5-WebUI": "Abrir UVR5-WebUI",
|
||||||
|
"开启一键三连": "Ativar um clique",
|
||||||
|
"开启打标WebUI": "Abrir Labeling WebUI",
|
||||||
|
"开启文本获取": "Ativar obtenção de texto",
|
||||||
|
"开启无参考文本模式。不填参考文本亦相当于开启。": "Ativar o modo sem texto de referência. Não preencher o texto de referência também equivale a ativGPT_SoVITS.AR.",
|
||||||
|
"开启离线批量ASR": "Ativar ASR offline em lote",
|
||||||
|
"开启语义token提取": "Ativar extração de token semântico",
|
||||||
|
"开启语音切割": "Ativar corte de voz",
|
||||||
|
"开启语音降噪": "Ativar redução de ruído de voz",
|
||||||
|
>>>>>>> main:GPT_SoVITS/tools/i18n/locale/pt_BR.json
|
||||||
"怎么切": "Como cortar",
|
"怎么切": "Como cortar",
|
||||||
"总训练轮数total_epoch": "Total de epoch de treinamento",
|
"总训练轮数total_epoch": "Total de epoch de treinamento",
|
||||||
"总训练轮数total_epoch,不建议太高": "Total de epoch de treinamento, não é recomendável um valor muito alto",
|
"总训练轮数total_epoch,不建议太高": "Total de epoch de treinamento, não é recomendável um valor muito alto",
|
@ -3,7 +3,7 @@ import traceback
|
|||||||
import ffmpeg
|
import ffmpeg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from tools.i18n.i18n import I18nAuto
|
from GPT_SoVITS.tools.i18n.i18n import I18nAuto
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
|
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
|
@ -1,230 +1,230 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
# This function is obtained from librosa.
|
# This function is obtained from librosa.
|
||||||
def get_rms(
|
def get_rms(
|
||||||
y,
|
y,
|
||||||
frame_length=2048,
|
frame_length=2048,
|
||||||
hop_length=512,
|
hop_length=512,
|
||||||
pad_mode="constant",
|
pad_mode="constant",
|
||||||
):
|
):
|
||||||
padding = (int(frame_length // 2), int(frame_length // 2))
|
padding = (int(frame_length // 2), int(frame_length // 2))
|
||||||
y = np.pad(y, padding, mode=pad_mode)
|
y = np.pad(y, padding, mode=pad_mode)
|
||||||
|
|
||||||
axis = -1
|
axis = -1
|
||||||
# put our new within-frame axis at the end for now
|
# put our new within-frame axis at the end for now
|
||||||
out_strides = y.strides + tuple([y.strides[axis]])
|
out_strides = y.strides + tuple([y.strides[axis]])
|
||||||
# Reduce the shape on the framing axis
|
# Reduce the shape on the framing axis
|
||||||
x_shape_trimmed = list(y.shape)
|
x_shape_trimmed = list(y.shape)
|
||||||
x_shape_trimmed[axis] -= frame_length - 1
|
x_shape_trimmed[axis] -= frame_length - 1
|
||||||
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
|
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
|
||||||
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
|
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
|
||||||
if axis < 0:
|
if axis < 0:
|
||||||
target_axis = axis - 1
|
target_axis = axis - 1
|
||||||
else:
|
else:
|
||||||
target_axis = axis + 1
|
target_axis = axis + 1
|
||||||
xw = np.moveaxis(xw, -1, target_axis)
|
xw = np.moveaxis(xw, -1, target_axis)
|
||||||
# Downsample along the target axis
|
# Downsample along the target axis
|
||||||
slices = [slice(None)] * xw.ndim
|
slices = [slice(None)] * xw.ndim
|
||||||
slices[axis] = slice(0, None, hop_length)
|
slices[axis] = slice(0, None, hop_length)
|
||||||
x = xw[tuple(slices)]
|
x = xw[tuple(slices)]
|
||||||
|
|
||||||
# Calculate power
|
# Calculate power
|
||||||
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
|
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
|
||||||
|
|
||||||
return np.sqrt(power)
|
return np.sqrt(power)
|
||||||
|
|
||||||
|
|
||||||
class Slicer:
|
class Slicer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sr: int,
|
sr: int,
|
||||||
threshold: float = -40.0,
|
threshold: float = -40.0,
|
||||||
min_length: int = 5000,
|
min_length: int = 5000,
|
||||||
min_interval: int = 300,
|
min_interval: int = 300,
|
||||||
hop_size: int = 20,
|
hop_size: int = 20,
|
||||||
max_sil_kept: int = 5000,
|
max_sil_kept: int = 5000,
|
||||||
):
|
):
|
||||||
if not min_length >= min_interval >= hop_size:
|
if not min_length >= min_interval >= hop_size:
|
||||||
raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
|
raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
|
||||||
if not max_sil_kept >= hop_size:
|
if not max_sil_kept >= hop_size:
|
||||||
raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
|
raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
|
||||||
min_interval = sr * min_interval / 1000
|
min_interval = sr * min_interval / 1000
|
||||||
self.threshold = 10 ** (threshold / 20.0)
|
self.threshold = 10 ** (threshold / 20.0)
|
||||||
self.hop_size = round(sr * hop_size / 1000)
|
self.hop_size = round(sr * hop_size / 1000)
|
||||||
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||||
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||||
self.min_interval = round(min_interval / self.hop_size)
|
self.min_interval = round(min_interval / self.hop_size)
|
||||||
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||||
|
|
||||||
def _apply_slice(self, waveform, begin, end):
|
def _apply_slice(self, waveform, begin, end):
|
||||||
if len(waveform.shape) > 1:
|
if len(waveform.shape) > 1:
|
||||||
return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
|
return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
|
||||||
else:
|
else:
|
||||||
return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
|
return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
|
||||||
|
|
||||||
# @timeit
|
# @timeit
|
||||||
def slice(self, waveform):
|
def slice(self, waveform):
|
||||||
if len(waveform.shape) > 1:
|
if len(waveform.shape) > 1:
|
||||||
samples = waveform.mean(axis=0)
|
samples = waveform.mean(axis=0)
|
||||||
else:
|
else:
|
||||||
samples = waveform
|
samples = waveform
|
||||||
if samples.shape[0] <= self.min_length:
|
if samples.shape[0] <= self.min_length:
|
||||||
return [waveform]
|
return [waveform]
|
||||||
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||||
sil_tags = []
|
sil_tags = []
|
||||||
silence_start = None
|
silence_start = None
|
||||||
clip_start = 0
|
clip_start = 0
|
||||||
for i, rms in enumerate(rms_list):
|
for i, rms in enumerate(rms_list):
|
||||||
# Keep looping while frame is silent.
|
# Keep looping while frame is silent.
|
||||||
if rms < self.threshold:
|
if rms < self.threshold:
|
||||||
# Record start of silent frames.
|
# Record start of silent frames.
|
||||||
if silence_start is None:
|
if silence_start is None:
|
||||||
silence_start = i
|
silence_start = i
|
||||||
continue
|
continue
|
||||||
# Keep looping while frame is not silent and silence start has not been recorded.
|
# Keep looping while frame is not silent and silence start has not been recorded.
|
||||||
if silence_start is None:
|
if silence_start is None:
|
||||||
continue
|
continue
|
||||||
# Clear recorded silence start if interval is not enough or clip is too short
|
# Clear recorded silence start if interval is not enough or clip is too short
|
||||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||||
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
||||||
if not is_leading_silence and not need_slice_middle:
|
if not is_leading_silence and not need_slice_middle:
|
||||||
silence_start = None
|
silence_start = None
|
||||||
continue
|
continue
|
||||||
# Need slicing. Record the range of silent frames to be removed.
|
# Need slicing. Record the range of silent frames to be removed.
|
||||||
if i - silence_start <= self.max_sil_kept:
|
if i - silence_start <= self.max_sil_kept:
|
||||||
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
||||||
if silence_start == 0:
|
if silence_start == 0:
|
||||||
sil_tags.append((0, pos))
|
sil_tags.append((0, pos))
|
||||||
else:
|
else:
|
||||||
sil_tags.append((pos, pos))
|
sil_tags.append((pos, pos))
|
||||||
clip_start = pos
|
clip_start = pos
|
||||||
elif i - silence_start <= self.max_sil_kept * 2:
|
elif i - silence_start <= self.max_sil_kept * 2:
|
||||||
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
||||||
pos += i - self.max_sil_kept
|
pos += i - self.max_sil_kept
|
||||||
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||||
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
||||||
if silence_start == 0:
|
if silence_start == 0:
|
||||||
sil_tags.append((0, pos_r))
|
sil_tags.append((0, pos_r))
|
||||||
clip_start = pos_r
|
clip_start = pos_r
|
||||||
else:
|
else:
|
||||||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||||
clip_start = max(pos_r, pos)
|
clip_start = max(pos_r, pos)
|
||||||
else:
|
else:
|
||||||
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||||
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
|
||||||
if silence_start == 0:
|
if silence_start == 0:
|
||||||
sil_tags.append((0, pos_r))
|
sil_tags.append((0, pos_r))
|
||||||
else:
|
else:
|
||||||
sil_tags.append((pos_l, pos_r))
|
sil_tags.append((pos_l, pos_r))
|
||||||
clip_start = pos_r
|
clip_start = pos_r
|
||||||
silence_start = None
|
silence_start = None
|
||||||
# Deal with trailing silence.
|
# Deal with trailing silence.
|
||||||
total_frames = rms_list.shape[0]
|
total_frames = rms_list.shape[0]
|
||||||
if silence_start is not None and total_frames - silence_start >= self.min_interval:
|
if silence_start is not None and total_frames - silence_start >= self.min_interval:
|
||||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||||
sil_tags.append((pos, total_frames + 1))
|
sil_tags.append((pos, total_frames + 1))
|
||||||
# Apply and return slices.
|
# Apply and return slices.
|
||||||
####音频+起始时间+终止时间
|
####音频+起始时间+终止时间
|
||||||
if len(sil_tags) == 0:
|
if len(sil_tags) == 0:
|
||||||
return [[waveform, 0, int(total_frames * self.hop_size)]]
|
return [[waveform, 0, int(total_frames * self.hop_size)]]
|
||||||
else:
|
else:
|
||||||
chunks = []
|
chunks = []
|
||||||
if sil_tags[0][0] > 0:
|
if sil_tags[0][0] > 0:
|
||||||
chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
|
chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
|
||||||
for i in range(len(sil_tags) - 1):
|
for i in range(len(sil_tags) - 1):
|
||||||
chunks.append(
|
chunks.append(
|
||||||
[
|
[
|
||||||
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
|
self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
|
||||||
int(sil_tags[i][1] * self.hop_size),
|
int(sil_tags[i][1] * self.hop_size),
|
||||||
int(sil_tags[i + 1][0] * self.hop_size),
|
int(sil_tags[i + 1][0] * self.hop_size),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if sil_tags[-1][1] < total_frames:
|
if sil_tags[-1][1] < total_frames:
|
||||||
chunks.append(
|
chunks.append(
|
||||||
[
|
[
|
||||||
self._apply_slice(waveform, sil_tags[-1][1], total_frames),
|
self._apply_slice(waveform, sil_tags[-1][1], total_frames),
|
||||||
int(sil_tags[-1][1] * self.hop_size),
|
int(sil_tags[-1][1] * self.hop_size),
|
||||||
int(total_frames * self.hop_size),
|
int(total_frames * self.hop_size),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import os.path
|
import os.path
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import soundfile
|
import soundfile
|
||||||
|
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("audio", type=str, help="The audio to be sliced")
|
parser.add_argument("audio", type=str, help="The audio to be sliced")
|
||||||
parser.add_argument("--out", type=str, help="Output directory of the sliced audio clips")
|
parser.add_argument("--out", type=str, help="Output directory of the sliced audio clips")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db_thresh",
|
"--db_thresh",
|
||||||
type=float,
|
type=float,
|
||||||
required=False,
|
required=False,
|
||||||
default=-40,
|
default=-40,
|
||||||
help="The dB threshold for silence detection",
|
help="The dB threshold for silence detection",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min_length",
|
"--min_length",
|
||||||
type=int,
|
type=int,
|
||||||
required=False,
|
required=False,
|
||||||
default=5000,
|
default=5000,
|
||||||
help="The minimum milliseconds required for each sliced audio clip",
|
help="The minimum milliseconds required for each sliced audio clip",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min_interval",
|
"--min_interval",
|
||||||
type=int,
|
type=int,
|
||||||
required=False,
|
required=False,
|
||||||
default=300,
|
default=300,
|
||||||
help="The minimum milliseconds for a silence part to be sliced",
|
help="The minimum milliseconds for a silence part to be sliced",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hop_size",
|
"--hop_size",
|
||||||
type=int,
|
type=int,
|
||||||
required=False,
|
required=False,
|
||||||
default=10,
|
default=10,
|
||||||
help="Frame length in milliseconds",
|
help="Frame length in milliseconds",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_sil_kept",
|
"--max_sil_kept",
|
||||||
type=int,
|
type=int,
|
||||||
required=False,
|
required=False,
|
||||||
default=500,
|
default=500,
|
||||||
help="The maximum silence length kept around the sliced clip, presented in milliseconds",
|
help="The maximum silence length kept around the sliced clip, presented in milliseconds",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
out = args.out
|
out = args.out
|
||||||
if out is None:
|
if out is None:
|
||||||
out = os.path.dirname(os.path.abspath(args.audio))
|
out = os.path.dirname(os.path.abspath(args.audio))
|
||||||
audio, sr = librosa.load(args.audio, sr=None, mono=False)
|
audio, sr = librosa.load(args.audio, sr=None, mono=False)
|
||||||
slicer = Slicer(
|
slicer = Slicer(
|
||||||
sr=sr,
|
sr=sr,
|
||||||
threshold=args.db_thresh,
|
threshold=args.db_thresh,
|
||||||
min_length=args.min_length,
|
min_length=args.min_length,
|
||||||
min_interval=args.min_interval,
|
min_interval=args.min_interval,
|
||||||
hop_size=args.hop_size,
|
hop_size=args.hop_size,
|
||||||
max_sil_kept=args.max_sil_kept,
|
max_sil_kept=args.max_sil_kept,
|
||||||
)
|
)
|
||||||
chunks = slicer.slice(audio)
|
chunks = slicer.slice(audio)
|
||||||
if not os.path.exists(out):
|
if not os.path.exists(out):
|
||||||
os.makedirs(out)
|
os.makedirs(out)
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
if len(chunk.shape) > 1:
|
if len(chunk.shape) > 1:
|
||||||
chunk = chunk.T
|
chunk = chunk.T
|
||||||
soundfile.write(
|
soundfile.write(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
out,
|
out,
|
||||||
"%s_%d.wav" % (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
|
"%s_%d.wav" % (os.path.basename(args.audio).rsplit(".", maxsplit=1)[0], i),
|
||||||
),
|
),
|
||||||
chunk,
|
chunk,
|
||||||
sr,
|
sr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
@ -1,6 +1,7 @@
|
|||||||
# This code is modified from https://github.com/ZFTurbo/
|
# This code is modified from https://github.com/ZFTurbo/
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
import subprocess
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -160,7 +161,7 @@ class Roformer_Loader:
|
|||||||
batch_data.append(part)
|
batch_data.append(part)
|
||||||
batch_locations.append((i, length))
|
batch_locations.append((i, length))
|
||||||
i += step
|
i += step
|
||||||
progress_bar.update(1)
|
progress_bGPT_SoVITS.AR.update(1)
|
||||||
|
|
||||||
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
||||||
arr = torch.stack(batch_data, dim=0)
|
arr = torch.stack(batch_data, dim=0)
|
||||||
@ -189,7 +190,7 @@ class Roformer_Loader:
|
|||||||
# Remove pad
|
# Remove pad
|
||||||
estimated_sources = estimated_sources[..., border:-border]
|
estimated_sources = estimated_sources[..., border:-border]
|
||||||
|
|
||||||
progress_bar.close()
|
progress_bGPT_SoVITS.AR.close()
|
||||||
|
|
||||||
if self.config["training"]["target_instrument"] is None:
|
if self.config["training"]["target_instrument"] is None:
|
||||||
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
|
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
|
||||||
@ -253,7 +254,10 @@ class Roformer_Loader:
|
|||||||
sf.write(path, data, sr)
|
sf.write(path, data, sr)
|
||||||
else:
|
else:
|
||||||
sf.write(path, data, sr)
|
sf.write(path, data, sr)
|
||||||
os.system('ffmpeg -i "{}" -vn "{}" -q:a 2 -y'.format(path, path[:-3] + format))
|
subprocess.run(
|
||||||
|
["ffmpeg", "-i", path, "-vn", path[:-3] + format, "-q:a", "2", "-y"],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
except:
|
except:
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user