Reimplement run_generator to be compatible with v3/v4 models, sync the code with the main repo.

This commit is contained in:
Jarod Mica 2025-04-25 00:35:41 -07:00
commit 344ca488d9
147 changed files with 1548 additions and 616 deletions

View File

@ -1,8 +0,0 @@
docs
logs
output
reference
SoVITS_weights
GPT_weights
TEMP
.git

View File

@ -1,42 +0,0 @@
# Base CUDA image
FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
LABEL maintainer="breakstring@hotmail.com"
LABEL version="dev-20240209"
LABEL description="Docker image for GPT-SoVITS"
# Install 3rd party apps
ENV DEBIAN_FRONTEND=noninteractive
ENV TZ=Etc/UTC
RUN apt-get update && \
apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
git lfs install && \
rm -rf /var/lib/apt/lists/*
# Copy only requirements.txt initially to leverage Docker cache
WORKDIR /workspace
COPY requirements.txt /workspace/
RUN pip install --no-cache-dir -r requirements.txt
# Define a build-time argument for image type
ARG IMAGE_TYPE=full
# Conditional logic based on the IMAGE_TYPE argument
# Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
COPY ./Docker /workspace/Docker
# elite 类型的镜像里面不包含额外的模型
RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
chmod +x /workspace/Docker/download.sh && \
/workspace/Docker/download.sh && \
python /workspace/Docker/download.py && \
python -m nltk.downloader averaged_perceptron_tagger cmudict; \
fi
# Copy the rest of the application
COPY . /workspace
EXPOSE 9871 9872 9873 9874 9880
CMD ["python", "webui.py"]

View File

@ -1,11 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
from pytorch_lightning import LightningDataModule from pytorch_lightning import LightningDataModule
from GPT_SoVITS.AR.data.bucket_sampler import DistributedBucketSampler
from GPT_SoVITS.AR.data.dataset import Text2SemanticDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
class Text2SemanticDataModule(LightningDataModule): class Text2SemanticDataModule(LightningDataModule):
def __init__( def __init__(

View File

@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, Dataset
version = os.environ.get("version", None) version = os.environ.get("version", None)
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
# from config import exp_dir # from config import exp_dir

View File

@ -9,10 +9,9 @@ from typing import Dict
import torch import torch
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from GPT_SoVITS.AR.models.t2s_model import Text2SemanticDecoder
from AR.models.t2s_model import Text2SemanticDecoder from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.lr_schedulers import WarmupCosineLRSchedule from GPT_SoVITS.AR.modules.optim import ScaledAdam
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule): class Text2SemanticLightningModule(LightningModule):

View File

@ -9,10 +9,9 @@ from typing import Dict
import torch import torch
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from GPT_SoVITS.AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.models.t2s_model_onnx import Text2SemanticDecoder from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.lr_schedulers import WarmupCosineLRSchedule from GPT_SoVITS.AR.modules.optim import ScaledAdam
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule): class Text2SemanticLightningModule(LightningModule):

View File

@ -9,7 +9,7 @@ from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm from tqdm import tqdm
from AR.models.utils import ( from GPT_SoVITS.AR.models.utils import (
dpo_loss, dpo_loss,
get_batch_logps, get_batch_logps,
make_pad_mask, make_pad_mask,
@ -18,8 +18,8 @@ from AR.models.utils import (
sample, sample,
topk_sampling, topk_sampling,
) )
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding from GPT_SoVITS.AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer from GPT_SoVITS.AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = { default_config = {
"embedding_dim": 512, "embedding_dim": 512,
@ -933,3 +933,140 @@ class Text2SemanticDecoder(nn.Module):
return self.infer_panel_naive( return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
) )
def infer_panel_generator(
self,
x: torch.LongTensor,
x_lens: torch.LongTensor,
prompts: torch.LongTensor,
bert_feature: torch.LongTensor,
cumulation_amount: int,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs,
):
"""
Generator method that yields generated tokens based on a specified cumulative amount.
Args:
x (torch.LongTensor): Input phoneme IDs.
x_lens (torch.LongTensor): Lengths of the input sequences.
prompts (torch.LongTensor): Initial prompt tokens.
bert_feature (torch.LongTensor): BERT features corresponding to the input.
cumulation_amount (int): Number of tokens to generate before yielding.
top_k (int): Top-k sampling.
top_p (int): Top-p sampling.
early_stop_num (int): Early stopping number.
temperature (float): Sampling temperature.
repetition_penalty (float): Repetition penalty.
Yields:
torch.LongTensor: Generated tokens since the last yield.
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device)
stop = False
# Initialize cumulative token counter
tokens_since_last_yield = 0
# Initialize last yield index
prefix_len = y.shape[1] if y is not None else 0
last_yield_idx = prefix_len
k_cache = None
v_cache = None
################### first step ##########################
if y is not None and y.shape[1] > 0:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int64, device=x.device)
ref_free = True
bsz = x.shape[0]
src_len = x_len + y_len
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
xy_attn_mask = xy_attn_mask.unsqueeze(0).expand(bsz * self.num_head, -1, -1)
xy_attn_mask = xy_attn_mask.view(bsz, self.num_head, src_len, src_len).to(device=x.device, dtype=torch.bool)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if idx < 11: # Ensure at least 10 tokens are generated before stopping
logits = logits[:, :-1]
samples = sample(
logits,
y,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
)[0]
y = torch.concat([y, samples], dim=1)
tokens_since_last_yield += 1
if tokens_since_last_yield >= cumulation_amount:
generated_tokens = y[:, last_yield_idx:]
yield generated_tokens
last_yield_idx = y.shape[1]
tokens_since_last_yield = 0
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("Using early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("Bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# Update for next step
y_emb = self.ar_audio_embedding(y[:, -1:])
y_len += 1
xy_pos = (
y_emb * self.ar_audio_position.x_scale
+ self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len - 1].to(
dtype=y_emb.dtype, device=y_emb.device
)
)
# After loop ends, yield any remaining tokens
if last_yield_idx < y.shape[1]:
generated_tokens = y[:, last_yield_idx:]
yield generated_tokens

View File

@ -1,12 +1,17 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
import torch import torch
from tqdm import tqdm
from GPT_SoVITS.AR.modules.embedding_onnx import SinePositionalEmbedding
from GPT_SoVITS.AR.modules.embedding_onnx import TokenEmbedding
from GPT_SoVITS.AR.modules.transformer_onnx import LayerNorm
from GPT_SoVITS.AR.modules.transformer_onnx import TransformerEncoder
from GPT_SoVITS.AR.modules.transformer_onnx import TransformerEncoderLayer
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = { default_config = {
"embedding_dim": 512, "embedding_dim": 512,

View File

@ -9,7 +9,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched from torch.nn import functional as F
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched F.multi_head_attention_forward = multi_head_attention_forward_patched
@ -152,14 +153,14 @@ class MultiheadAttention(Module):
bias=bias, bias=bias,
**factory_kwargs, **factory_kwargs,
) )
self.in_proj_weight = self.in_proj_linear.weight self.in_proj_weight = self.in_proj_lineGPT_SoVITS.AR.weight
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None) self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None) self.register_parameter("v_proj_weight", None)
if bias: if bias:
self.in_proj_bias = self.in_proj_linear.bias self.in_proj_bias = self.in_proj_lineGPT_SoVITS.AR.bias
else: else:
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)

View File

@ -8,7 +8,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched from torch.nn import functional as F
from GPT_SoVITS.AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
class MultiheadAttention(Module): class MultiheadAttention(Module):
@ -102,14 +103,14 @@ class MultiheadAttention(Module):
bias=bias, bias=bias,
**factory_kwargs, **factory_kwargs,
) )
self.in_proj_weight = self.in_proj_linear.weight self.in_proj_weight = self.in_proj_lineGPT_SoVITS.AR.weight
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None) self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None) self.register_parameter("v_proj_weight", None)
if bias: if bias:
self.in_proj_bias = self.in_proj_linear.bias self.in_proj_bias = self.in_proj_lineGPT_SoVITS.AR.bias
else: else:
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)

View File

@ -10,8 +10,8 @@ from typing import Tuple
from typing import Union from typing import Union
import torch import torch
from AR.modules.activation import MultiheadAttention from GPT_SoVITS.AR.modules.activation import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
from torch import nn from torch import nn
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F

View File

@ -10,8 +10,8 @@ from typing import Tuple
from typing import Union from typing import Union
import torch import torch
from AR.modules.activation_onnx import MultiheadAttention from GPT_SoVITS.AR.modules.activation_onnx import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
from torch import nn from torch import nn
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F

View File

@ -9,7 +9,7 @@ import regex
from gruut import sentences from gruut import sentences
from gruut.const import Sentence from gruut.const import Sentence
from gruut.const import Word from gruut.const import Word
from AR.text_processing.symbols import SYMBOL_TO_ID from GPT_SoVITS.AR.text_processing.symbols import SYMBOL_TO_ID
class GruutPhonemizer: class GruutPhonemizer:

View File

@ -18,27 +18,31 @@ from typing import List, Tuple, Union
import ffmpeg import ffmpeg
import librosa import librosa
import numpy as np import numpy as np
import random
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import traceback
import yaml import yaml
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from BigVGAN.bigvgan import BigVGAN from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
from feature_extractor.cnhubert import CNHubert from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from huggingface_hub import snapshot_download
from tools.audio_sr import AP_BWE from GPT_SoVITS.tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import load_audio from GPT_SoVITS.tools.my_utils import load_audio
from TTS_infer_pack.text_segmentation_method import splits from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
LIBRARY_NAME = "GPT_SoVITS"
spec_min = -12 spec_min = -12
@ -149,28 +153,28 @@ class NO_PROMPT_ERROR(Exception):
# configs/tts_infer.yaml # configs/tts_infer.yaml
""" """
custom: custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: pretrained_models/chinese-hubert-base
device: cpu device: cpu
is_half: false is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2 version: v2
v1: v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
is_half: false is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt t2s_weights_path: pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth vits_weights_path: pretrained_models/s2G488k.pth
version: v1 version: v1
v2: v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cpu
is_half: false is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
version: v2 version: v2
v3: v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
@ -323,8 +327,10 @@ class TTS_Config:
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
self.cnhuhbert_base_path = self.default_configs[version]["cnhuhbert_base_path"] self.cnhuhbert_base_path = self.default_configs[version]["cnhuhbert_base_path"]
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs()
repo_name="lj1995/GPT-SoVITS"
snapshot_download(repo_id=repo_name, local_dir=os.path.dirname(self.bert_base_path))
self.update_configs()
self.max_sec = None self.max_sec = None
self.hz: int = 50 self.hz: int = 50
self.semantic_frame_rate: str = "25hz" self.semantic_frame_rate: str = "25hz"
@ -1295,6 +1301,16 @@ class TTS:
finally: finally:
self.empty_cache() self.empty_cache()
def empty_cache(self):
try:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
if "cuda" in str(self.configs.device):
torch.cuda.empty_cache()
elif str(self.configs.device) == "mps":
torch.mps.empty_cache()
except:
pass
def empty_cache(self): def empty_cache(self):
try: try:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
@ -1558,3 +1574,160 @@ class TTS:
audio_fragments[i + 1] = f2_ audio_fragments[i + 1] = f2_
return torch.cat(audio_fragments, 0) return torch.cat(audio_fragments, 0)
@torch.no_grad()
def run_generator(self, inputs: dict):
"""
Streaming inference using infer_panel_generator and zero-cross splitting for v1-v4.
Yields tuples of (sampling_rate, np.ndarray audio fragment).
"""
# Initialize parameters
self.stop_flag = False
text = inputs.get("text", "")
text_lang = inputs.get("text_lang", "")
ref_audio_path = inputs.get("ref_audio_path", "")
aux_ref_audio_paths = inputs.get("aux_ref_audio_paths", [])
prompt_text = inputs.get("prompt_text", "")
prompt_lang = inputs.get("prompt_lang", "")
top_k = inputs.get("top_k", 5)
top_p = inputs.get("top_p", 1)
temperature = inputs.get("temperature", 1)
text_split_method = inputs.get("text_split_method", "cut0")
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
seed = inputs.get("seed", -1)
seed = -1 if seed in [None, ""] else seed
set_seed(seed)
repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 8)
super_sampling = inputs.get("super_sampling", False)
search_length = inputs.get("search_length", 32000 * 5)
num_zeroes = inputs.get("num_zeroes", 5)
cumulation_amount = inputs.get("cumulation_amount", 50)
# Prepare reference audio
if ref_audio_path and ref_audio_path != self.prompt_cache["ref_audio_path"]:
if not os.path.exists(ref_audio_path):
raise ValueError(f"{ref_audio_path} not exists")
self.set_ref_audio(ref_audio_path)
# Auxiliary refs
self.prompt_cache["aux_ref_audio_paths"] = aux_ref_audio_paths or []
self.prompt_cache["refer_spec"] = [self.prompt_cache["refer_spec"][0]]
for p in aux_ref_audio_paths or []:
if p and os.path.exists(p):
self.prompt_cache["refer_spec"].append(self._get_ref_spec(p))
# Prompt text handling
no_prompt = prompt_text in [None, ""]
if not no_prompt:
prompt_text = prompt_text.strip("\n")
if prompt_text and prompt_text[-1] not in splits:
prompt_text += "" if prompt_lang != "en" else "."
phones_p, bert_p, norm_p = self.text_preprocessor.segment_and_extract_feature_for_text(
prompt_text, prompt_lang, self.configs.version
)
self.prompt_cache.update({
"prompt_text": prompt_text,
"prompt_lang": prompt_lang,
"phones": phones_p,
"bert_features": bert_p,
"norm_text": norm_p,
})
# Text to semantic preprocessing
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version)
if not data:
sr = self.vocoder_configs["sr"] if self.configs.use_vocoder else self.configs.sampling_rate
yield sr, np.zeros(1, dtype=np.int16)
return
# Single-batch conversion
batches, _ = self.to_batch(
data,
prompt_data=None if no_prompt else self.prompt_cache,
batch_size=1,
threshold=batch_threshold,
split_bucket=False,
device=self.configs.device,
precision=self.precision,
)
item = batches[0]
phones = item["phones"][0]
all_ids = item["all_phones"][0]
all_lens = item["all_phones_len"][0]
all_bert = item["all_bert_features"][0]
max_len = item["max_len"]
# Prepare semantic prompt
if not no_prompt:
prompt_sem = self.prompt_cache["prompt_semantic"].unsqueeze(0).to(self.configs.device)
else:
prompt_sem = None
# Reference spectrograms
refer_spec = [s.to(dtype=self.precision, device=self.configs.device) for s in self.prompt_cache["refer_spec"]]
# Streaming via generator
from GPT_SoVITS.TTS_infer_pack.zero_crossing import find_zero_zone, find_matching_index
zc_idx1 = zc_idx2 = crossing_dir = 0
first = True
last = False
gen_list = []
for gen_tokens in self.t2s_model.model.infer_panel_generator(
all_ids.unsqueeze(0).to(self.configs.device),
all_lens.unsqueeze(0).to(self.configs.device),
prompt_sem,
all_bert.unsqueeze(0).to(self.configs.device),
cumulation_amount=cumulation_amount,
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
max_len=max_len,
repetition_penalty=repetition_penalty,
):
gen_list.append(gen_tokens)
total = sum([t.size(1) for t in gen_list])
toks = torch.cat(gen_list, dim=1)[:, :total]
eos = self.t2s_model.model.EOS
has_eos = (toks == eos).any()
if has_eos:
toks = toks.masked_fill(toks == eos, 0)
last = True
first = False
# Decode to waveform
pred = toks.unsqueeze(0)
phone_t = phones.unsqueeze(0).to(self.configs.device)
if not self.configs.use_vocoder:
w = self.vits_model.decode(pred, phone_t, refer_spec, speed=speed_factor).detach()[0,0,:]
else:
w = self.using_vocoder_synthesis(pred, phone_t, speed=speed_factor, sample_steps=sample_steps)
w = w.cpu().numpy().astype(np.float32)
mv = np.abs(w).max()
if mv > 1.0:
w /= mv
# Zero-cross splitting
start = len(w) - search_length
if start < 0:
search_length = len(w)
start = 0
center = zc_idx2
off = int(search_length // 2)
sr = self.vocoder_configs["sr"] if self.configs.use_vocoder else self.configs.sampling_rate
if first:
zc_idx1, crossing_dir = find_zero_zone(w, start, search_length, num_zeroes)
frag = w[:zc_idx1]
print(len(frag))
frag_int16 = (frag * np.iinfo(np.int16).max).astype(np.int16)
yield sr, frag_int16
first = False
zc_idx2 = zc_idx1
elif last:
zc1 = find_matching_index(w, center, off, crossing_dir)
frag = w[zc1:]
print(len(frag))
frag_int16 = (frag * np.iinfo(np.int16).max).astype(np.int16)
yield sr, frag_int16
zc_idx2 = zc_idx1
else:
zc1 = find_matching_index(w, center, off, crossing_dir)
zc_idx1, crossing_dir = find_zero_zone(w, start, search_length, num_zeroes)
frag = w[zc1:zc_idx1]
print(len(frag))
frag_int16 = (frag * np.iinfo(np.int16).max).astype(np.int16)
yield sr, frag_int16
zc_idx2 = zc_idx1
self.empty_cache()

View File

@ -9,15 +9,15 @@ sys.path.append(now_dir)
import re import re
import torch import torch
from text.LangSegmenter import LangSegmenter from GPT_SoVITS.text.LangSegmenter import LangSegmenter
from text import chinese from GPT_SoVITS.text import chinese
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from text.cleaner import clean_text from GPT_SoVITS.text.cleaner import clean_text
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from tools.i18n.i18n import I18nAuto, scan_language_list from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language

View File

@ -0,0 +1,203 @@
import numpy as np
import wave
import struct
def read_wav_file(filename):
"""
Reads a WAV file and returns the sample rate and data as a numpy array.
"""
with wave.open(filename, 'rb') as wf:
sample_rate = wf.getframerate()
n_frames = wf.getnframes()
sample_width = wf.getsampwidth()
n_channels = wf.getnchannels()
audio_data = wf.readframes(n_frames)
# Determine the format string for struct unpacking
fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * n_frames * n_channels
audio_samples = struct.unpack(fmt, audio_data)
audio_array = np.array(audio_samples, dtype=int)
# If stereo, reshape the array
if n_channels > 1:
audio_array = audio_array.reshape(-1, n_channels)
return sample_rate, audio_array, sample_width, n_channels
def write_wav_file(filename, sample_rate, data, sample_width, n_channels):
"""
Writes numpy array data to a WAV file.
"""
with wave.open(filename, 'wb') as wf:
wf.setnchannels(n_channels)
wf.setsampwidth(sample_width)
wf.setframerate(sample_rate)
# Flatten the array if it's multi-dimensional
if data.ndim > 1:
data = data.flatten()
# Pack the data into bytes
fmt = "<" + {1:'b', 2:'h', 4:'i'}[sample_width] * len(data)
byte_data = struct.pack(fmt, *data)
wf.writeframes(byte_data)
def find_zero_zone(chunk, start_index, search_length, num_zeroes=11):
zone = chunk[start_index:start_index + search_length]
print(f"Zero-crossing search zone: Start={start_index}, Length={len(zone)}")
zero_threshold = 1.0e-4
# Check for y consecutive zeros
for idx in range(len(zone), -1 + num_zeroes, -1):
index_to_start = idx-num_zeroes
abs_zone = np.abs(zone[index_to_start:idx])
if np.all(abs_zone < zero_threshold):
index_midpoint = index_to_start + int(num_zeroes // 2)
return (start_index + index_midpoint), None
print("Falling back to zero crossing due to no zero zone found. You may hear more prominent pops and clicks in the audio. Try increasing search length or cumulative tokens.")
return find_zero_crossing(chunk, start_index, search_length)
def find_zero_crossing(chunk, start_index, search_length):
# If the model is falling back on the this function, it might be a bad indicator that the search length is too low
zone = chunk[start_index:start_index + search_length]
sign_changes = np.where(np.diff(np.sign(zone)) != 0)[0]
if len(sign_changes) == 0:
raise ("No zero-crossings found in this zone. This should not be happening, debugging time.")
else:
zc_index = start_index + sign_changes[0] + 1
print(f"Zero-crossing found at index {zc_index}")
# Determine the crossing direction in chunk1
prev_value = chunk[zc_index - 1]
curr_value = chunk[zc_index]
crossing_direction = np.sign(curr_value) - np.sign(prev_value)
print(f"Crossing direction in chunk1: {np.sign(prev_value)} to {np.sign(curr_value)}")
return zc_index, crossing_direction
def find_matching_index(chunk, center_index, max_offset, crossing_direction):
"""
Finds a zero-crossing in data that matches the specified crossing direction,
starting from center_index and searching outward.
"""
if crossing_direction == None:
return center_index # if zero zone
# fall back for zero_crossing
data_length = len(chunk)
print(f"Center index in chunk2: {center_index}")
for offset in range(max_offset + 1):
# Check index bounds
idx_forward = center_index + offset
idx_backward = center_index - offset
found = False
# Check forward direction
if idx_forward < data_length - 1:
prev_sign = np.sign(chunk[idx_forward])
curr_sign = np.sign(chunk[idx_forward + 1])
direction = curr_sign - prev_sign
if direction == crossing_direction:
print(f"Matching zero-crossing found at index {idx_forward + 1} (forward)")
return idx_forward + 1
# Check backward direction
if idx_backward > 0:
prev_sign = np.sign(chunk[idx_backward - 1])
curr_sign = np.sign(chunk[idx_backward])
direction = curr_sign - prev_sign
if direction == crossing_direction:
print(f"Matching zero-crossing found at index {idx_backward} (backward)")
return idx_backward
print("No matching zero-crossings found in this zone.")
return None
# legacy, just for history. delete me sometime
def splice_chunks(chunk1, chunk2, search_length, y):
"""
Splices two audio chunks at zero-crossing points.
"""
# Define the zone to search in chunk1
start_index1 = len(chunk1) - search_length
if start_index1 < 0:
start_index1 = 0
search_length = len(chunk1)
print(f"Searching for zero-crossing in chunk1 from index {start_index1} to {len(chunk1)}")
# Find zero-crossing in chunk1
zc_index1, crossing_direction = find_zero_crossing(chunk1, start_index1, search_length, y)
if zc_index1 is None:
print("No zero-crossing found in chunk1 within the specified zone.")
return None
# Define the zone to search in chunk2 near the same index
# Since chunk2 overlaps with chunk1, we can assume that index positions correspond
# Adjusted search in chunk2
# You can adjust this value if needed
center_index = zc_index1 # Assuming alignment between chunk1 and chunk2
max_offset = search_length
# Ensure center_index is within bounds
if center_index < 0:
center_index = 0
elif center_index >= len(chunk2):
center_index = len(chunk2) - 1
print(f"Searching for matching zero-crossing in chunk2 around index {center_index} with max offset {max_offset}")
zc_index2 = find_matching_zero_crossing(chunk2, center_index, max_offset, crossing_direction)
if zc_index2 is None:
print("No matching zero-crossing found in chunk2.")
return None
print(f"Zero-crossing in chunk1 at index {zc_index1}, chunk2 at index {zc_index2}")
# Splice the chunks
new_chunk = np.concatenate((chunk1[:zc_index1], chunk2[zc_index2:]))
print(f"Spliced chunk length: {len(new_chunk)}")
return new_chunk
# legacy, just for history. delete me sometime
def process_audio_chunks(filenames, sample_rate, x, y, output_filename):
"""
Processes and splices a list of audio chunks.
"""
# Read the first chunk
sr, chunk_data, sample_width, n_channels = read_wav_file(filenames[0])
if sr != sample_rate:
print(f"Sample rate mismatch in {filenames[0]}")
return
print(f"Processing {filenames[0]}")
# Initialize the combined audio with the first chunk
combined_audio = chunk_data
# Process remaining chunks
for filename in filenames[1:]:
sr, next_chunk_data, _, _ = read_wav_file(filename)
if sr != sample_rate:
print(f"Sample rate mismatch in {filename}")
return
print(f"Processing {filename}")
# Splice the current combined audio with the next chunk
new_combined = splice_chunks(combined_audio, next_chunk_data, x, y)
if new_combined is None:
print(f"Failed to splice chunks between {filename} and previous chunk.")
return
combined_audio = new_combined
# Write the final combined audio to output file
write_wav_file(output_filename, sample_rate, combined_audio, sample_width, n_channels)
print(f"Final audio saved to {output_filename}")
# Main execution
if __name__ == "__main__":
# User-specified parameters
sample_rate = 32000 # Sample rate in Hz
x = 500 # Number of frames to search from the end of the chunk
y = 10 # Number of consecutive zeros to look for
output_filename = "combined_output.wav"
folder_with_chunks = "output_chunks"
import os
def absolute_file_paths(directory):
path = os.path.abspath(directory)
return [entry.path for entry in os.scandir(path) if entry.is_file()]
# List of input audio chunk filenames in sequential order
filenames = absolute_file_paths(folder_with_chunks)
# Process and splice the audio chunks
process_audio_chunks(filenames, sample_rate, x, y, output_filename)

View File

@ -116,8 +116,10 @@ import soundfile as sf
from fastapi import FastAPI, Response from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn import uvicorn
from importlib.resources import files
from io import BytesIO from io import BytesIO
from tools.i18n.i18n import I18nAuto from GPT_SoVITS.tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from pydantic import BaseModel from pydantic import BaseModel
@ -127,7 +129,7 @@ i18n = I18nAuto()
cut_method_names = get_cut_method_names() cut_method_names = get_cut_method_names()
parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") parser.add_argument("-c", "--tts_config", type=str, default=None, help="tts_infer路径")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
args = parser.parse_args() args = parser.parse_args()
@ -138,7 +140,7 @@ host = args.bind_addr
argv = sys.argv argv = sys.argv
if config_path in [None, ""]: if config_path in [None, ""]:
config_path = "GPT-SoVITS/configs/tts_infer.yaml" config_path = str(files("GPT_SoVITS").joinpath("configs/tts_infer.yaml"))
tts_config = TTS_Config(config_path) tts_config = TTS_Config(config_path)
print(tts_config) print(tts_config)
@ -434,7 +436,7 @@ async def tts_get_endpoint(
@APP.post("/tts") @APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request): async def tts_post_endpoint(request: TTS_Request):
req = request.dict() req = request.model_dump()
return await tts_handle(req) return await tts_handle(req)
@ -498,3 +500,6 @@ if __name__ == "__main__":
traceback.print_exc() traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM) os.kill(os.getpid(), signal.SIGTERM)
exit(0) exit(0)
if __name__ == "__main__":
main()

View File

@ -3,7 +3,7 @@ custom:
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda device: cuda
is_half: true is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v2 version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v1: v1:

View File

@ -3,11 +3,5 @@ import sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.insert(0, now_dir) sys.path.insert(0, now_dir)
from text.g2pw import G2PWPinyin from GPT_SoVITS.text.g2pw import G2PWPinyin
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
v_to_u=False,
neutral_tone_with_five=True,
)

View File

@ -12,8 +12,8 @@ from torch.nn import functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from feature_extractor import cnhubert from feature_extractor import cnhubert
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from module.models_onnx import SynthesizerTrn from GPT_SoVITS.module.models_onnx import SynthesizerTrn
from inference_webui import get_phones_and_bert from inference_webui import get_phones_and_bert

View File

@ -13,7 +13,7 @@ from transformers import (
HubertModel, HubertModel,
) )
import utils import GPT_SoVITS.utils
import torch.nn as nn import torch.nn as nn
cnhubert_base_path = None cnhubert_base_path = None

View File

@ -2,7 +2,7 @@ import argparse
import os import os
import soundfile as sf import soundfile as sf
from tools.i18n.i18n import I18nAuto from GPT_SoVITS.tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
i18n = I18nAuto() i18n = I18nAuto()

View File

@ -5,8 +5,7 @@ from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushB
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
import soundfile as sf import soundfile as sf
from tools.i18n.i18n import I18nAuto from GPT_SoVITS.tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
@ -305,7 +304,7 @@ class GPTSoVITSGUI(QMainWindow):
result = "Audio saved to " + output_wav_path result = "Audio saved to " + output_wav_path
self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000) self.status_bGPT_SoVITS.AR.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
self.output_text.append("处理结果:\n" + result) self.output_text.append("处理结果:\n" + result)

View File

@ -124,12 +124,12 @@ def set_seed(seed):
from time import time as ttime from time import time as ttime
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
from text.cleaner import clean_text from GPT_SoVITS.text.cleaner import clean_text
from tools.i18n.i18n import I18nAuto, scan_language_list from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
@ -165,8 +165,8 @@ dict_language_v2 = {
} }
dict_language = dict_language_v1 if version == "v1" else dict_language_v2 dict_language = dict_language_v1 if version == "v1" else dict_language_v2
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path, local_files_only=True)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, local_files_only=True)
if is_half == True: if is_half == True:
bert_model = bert_model.half().to(device) bert_model = bert_model.half().to(device)
else: else:
@ -406,6 +406,7 @@ def init_bigvgan():
bigvgan_model = bigvgan.BigVGAN.from_pretrained( bigvgan_model = bigvgan.BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False, use_cuda_kernel=False,
local_files_only=True
) # if True, RuntimeError: Ninja is required to load C++ extensions ) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm() bigvgan_model.remove_weight_norm()
@ -518,11 +519,8 @@ def get_first(text):
text = re.split(pattern, text)[0].strip() text = re.split(pattern, text)[0].strip()
return text return text
from GPT_SoVITS.text import chinese
from text import chinese def get_phones_and_bert(text,language,version,final=False):
def get_phones_and_bert(text, language, version, final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
formattext = text formattext = text
while " " in formattext: while " " in formattext:

View File

@ -50,10 +50,9 @@ bert_path = os.environ.get("bert_path", None)
version = model_version = os.environ.get("version", "v2") version = model_version = os.environ.get("version", "v2")
import gradio as gr import gradio as gr
from TTS_infer_pack.text_segmentation_method import get_method from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method
from GPT_SoVITS.tools.i18n.i18n import I18nAuto, scan_language_list
from tools.i18n.i18n import I18nAuto, scan_language_list
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language

View File

@ -3,8 +3,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from GPT_SoVITS.module import commons
from module.modules import LayerNorm from GPT_SoVITS.module.modules import LayerNorm
class Encoder(nn.Module): class Encoder(nn.Module):
@ -325,7 +325,7 @@ class MultiHeadAttention(nn.Module):
def _attention_bias_proximal(self, length): def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions. """Bias for self-attention to encourage attention to close positions.
Args: Args:
length: an integer scalar. length: an integer scalGPT_SoVITS.AR.
Returns: Returns:
a Tensor with shape [1, 1, length, length] a Tensor with shape [1, 1, length, length]
""" """

View File

@ -3,7 +3,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from GPT_SoVITS.module import commons
from typing import Optional from typing import Optional
@ -288,7 +288,7 @@ class MultiHeadAttention(nn.Module):
def _attention_bias_proximal(self, length): def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions. """Bias for self-attention to encourage attention to close positions.
Args: Args:
length: an integer scalar. length: an integer scalGPT_SoVITS.AR.
Returns: Returns:
a Tensor with shape [1, 1, length, length] a Tensor with shape [1, 1, length, length]
""" """

View File

@ -5,10 +5,10 @@ import torch
import torch.utils.data import torch.utils.data
from tqdm import tqdm from tqdm import tqdm
from module.mel_processing import spectrogram_torch, spec_to_mel_torch from GPT_SoVITS.module.mel_processing import spectrogram_torch, spec_to_mel_torch
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
import torch.nn.functional as F import torch.nn.functional as F
from tools.my_utils import load_audio from GPT_SoVITS.tools.my_utils import load_audio
version = os.environ.get("version", None) version = os.environ.get("version", None)

View File

@ -7,19 +7,19 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from GPT_SoVITS.module import commons
from module import modules from GPT_SoVITS.module import modules
from module import attentions from GPT_SoVITS.module import attentions
from f5_tts.model import DiT from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from GPT_SoVITS.module.commons import init_weights, get_padding
from module.mrte_model import MRTE from GPT_SoVITS.module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from GPT_SoVITS.text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from GPT_SoVITS.text import symbols2 as symbols_v2
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib import contextlib
import random import random

View File

@ -4,20 +4,20 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from module import commons from GPT_SoVITS.module import commons
from module import modules from GPT_SoVITS.module import modules
from module import attentions_onnx as attentions from GPT_SoVITS.module import attentions_onnx as attentions
from f5_tts.model import DiT from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from GPT_SoVITS.module.commons import init_weights, get_padding
from module.quantize import ResidualVectorQuantizer from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from GPT_SoVITS.text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from GPT_SoVITS.text import symbols2 as symbols_v2
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):

View File

@ -7,9 +7,9 @@ from torch.nn import functional as F
from torch.nn import Conv1d from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm from torch.nn.utils import weight_norm, remove_weight_norm
from module import commons from GPT_SoVITS.module import commons
from module.commons import init_weights, get_padding from GPT_SoVITS.module.commons import init_weights, get_padding
from module.transforms import piecewise_rational_quadratic_transform from GPT_SoVITS.module.transforms import piecewise_rational_quadratic_transform
import torch.distributions as D import torch.distributions as D

View File

@ -3,7 +3,7 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention from GPT_SoVITS.module.attentions import MultiHeadAttention
class MRTE(nn.Module): class MRTE(nn.Module):

View File

@ -12,7 +12,7 @@ import typing as tp
import torch import torch
from torch import nn from torch import nn
from module.core_vq import ResidualVectorQuantization from GPT_SoVITS.module.core_vq import ResidualVectorQuantization
@dataclass @dataclass

View File

@ -1,18 +1,22 @@
import torch import torch
import torchaudio import torchaudio
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule from GPT_SoVITS.AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert from GPT_SoVITS.feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 from GPT_SoVITS.module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn from torch import nn
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model() ssl_model = cnhubert.get_model()
from GPT_SoVITS.text import cleaned_text_to_sequence
import soundfile
from GPT_SoVITS.tools.my_utils import load_audio
import os
import json import json
import os import os
import soundfile import soundfile
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):

View File

@ -17,9 +17,9 @@ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
version = os.environ.get("version", None) version = os.environ.get("version", None)
import traceback import traceback
import os.path import os.path
from text.cleaner import clean_text from GPT_SoVITS.text.cleaner import clean_text
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.my_utils import clean_path from GPT_SoVITS.tools.my_utils import clean_path
# inp_text=sys.argv[1] # inp_text=sys.argv[1]
# inp_wav_dir=sys.argv[2] # inp_wav_dir=sys.argv[2]

View File

@ -25,7 +25,7 @@ import librosa
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from tools.my_utils import load_audio, clean_path from GPT_SoVITS.tools.my_utils import load_audio, clean_path
# from config import cnhubert_base_path # from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path # cnhubert.cnhubert_base_path=cnhubert_base_path

View File

@ -38,10 +38,10 @@ import logging
import utils import utils
if version != "v3": if version != "v3":
from module.models import SynthesizerTrn from GPT_SoVITS.module.models import SynthesizerTrn
else: else:
from module.models import SynthesizerTrnV3 as SynthesizerTrn from GPT_SoVITS.module.models import SynthesizerTrnV3 as SynthesizerTrn
from tools.my_utils import clean_path from GPT_SoVITS.tools.my_utils import clean_path
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G # from config import pretrained_s2G

View File

@ -4,7 +4,7 @@ from time import time as ttime
import shutil import shutil
import os import os
import torch import torch
from tools.i18n.i18n import I18nAuto from GPT_SoVITS.tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()

View File

@ -9,9 +9,9 @@ import platform
from pathlib import Path from pathlib import Path
import torch import torch
from AR.data.data_module import Text2SemanticDataModule from GPT_SoVITS.AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config from GPT_SoVITS.AR.utils.io import load_yaml_config
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
@ -20,10 +20,12 @@ from pytorch_lightning.strategies import DDPStrategy
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
from GPT_SoVITS.AR.utils import get_newest_ckpt
from collections import OrderedDict from collections import OrderedDict
from AR.utils import get_newest_ckpt from GPT_SoVITS.AR.utils import get_newest_ckpt
from process_ckpt import my_save from GPT_SoVITS.process_ckpt import my_save
class my_model_ckpt(ModelCheckpoint): class my_model_ckpt(ModelCheckpoint):

View File

@ -24,19 +24,19 @@ logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO)
from random import randint from random import randint
from module import commons from GPT_SoVITS.module import commons
from module.data_utils import ( from GPT_SoVITS.module.data_utils import (
DistributedBucketSampler, DistributedBucketSampler,
TextAudioSpeakerCollate, TextAudioSpeakerCollate,
TextAudioSpeakerLoader, TextAudioSpeakerLoader,
) )
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss from GPT_SoVITS.module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from module.models import ( from GPT_SoVITS.module.models import (
MultiPeriodDiscriminator, MultiPeriodDiscriminator,
SynthesizerTrn, SynthesizerTrn,
) )
from process_ckpt import savee from GPT_SoVITS.process_ckpt import savee
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
@ -71,7 +71,7 @@ def main():
def run(rank, n_gpus, hps): def run(rank, n_gpus, hps):
global global_step global global_step
if rank == 0: if rank == 0:
logger = utils.get_logger(hps.data.exp_dir) logger = GPT_SoVITS.utils.get_logger(hps.data.exp_dir)
logger.info(hps) logger.info(hps)
# utils.check_git_hash(hps.s2_ckpt_dir) # utils.check_git_hash(hps.s2_ckpt_dir)
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
@ -204,7 +204,7 @@ def run(rank, n_gpus, hps):
net_d = net_d.to(device) net_d = net_d.to(device)
try: # 如果能加载自动resume try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = GPT_SoVITS.utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"), utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
net_d, net_d,
optim_d, optim_d,
@ -212,7 +212,7 @@ def run(rank, n_gpus, hps):
if rank == 0: if rank == 0:
logger.info("loaded D") logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = GPT_SoVITS.utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"), utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
net_g, net_g,
optim_g, optim_g,
@ -479,30 +479,30 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
image_dict = None image_dict = None
try: ###Some people installed the wrong version of matplotlib. try: ###Some people installed the wrong version of matplotlib.
image_dict = { image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy( "slice/mel_org": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy(), y_mel[0].data.cpu().numpy(),
), ),
"slice/mel_gen": utils.plot_spectrogram_to_numpy( "slice/mel_gen": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy(), y_hat_mel[0].data.cpu().numpy(),
), ),
"all/mel": utils.plot_spectrogram_to_numpy( "all/mel": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy(), mel[0].data.cpu().numpy(),
), ),
"all/stats_ssl": utils.plot_spectrogram_to_numpy( "all/stats_ssl": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy(), stats_ssl[0].data.cpu().numpy(),
), ),
} }
except: except:
pass pass
if image_dict: if image_dict:
utils.summarize( GPT_SoVITS.utils.summarize(
writer=writer, writer=writer,
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,
scalars=scalar_dict, scalars=scalar_dict,
) )
else: else:
utils.summarize( GPT_SoVITS.utils.summarize(
writer=writer, writer=writer,
global_step=global_step, global_step=global_step,
scalars=scalar_dict, scalars=scalar_dict,
@ -510,7 +510,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0: if hps.train.if_save_latest == 0:
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_g, net_g,
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
@ -520,7 +520,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
"G_{}.pth".format(global_step), "G_{}.pth".format(global_step),
), ),
) )
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
@ -531,7 +531,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
), ),
) )
else: else:
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_g, net_g,
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
@ -541,7 +541,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
"G_{}.pth".format(233333333333), "G_{}.pth".format(233333333333),
), ),
) )
utils.save_checkpoint( GPT_SoVITS.utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
@ -644,7 +644,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
) )
image_dict.update( image_dict.update(
{ {
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy( f"gen/mel_{batch_idx}_{test}": GPT_SoVITS.utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy(), y_hat_mel[0].cpu().numpy(),
), ),
} }
@ -656,7 +656,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
) )
image_dict.update( image_dict.update(
{ {
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()), f"gt/mel_{batch_idx}": GPT_SoVITS.utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
}, },
) )
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
@ -666,7 +666,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
# f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :] # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
# }) # })
utils.summarize( GPT_SoVITS.utils.summarize(
writer=writer_eval, writer=writer_eval,
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,

View File

@ -1,11 +1,11 @@
import os import os
# if os.environ.get("version","v1")=="v1": # if os.environ.get("version","v1")=="v1":
# from text.symbols import symbols # from GPT_SoVITS.text.symbols import symbols
# else: # else:
# from text.symbols2 import symbols # from GPT_SoVITS.text.symbols2 import symbols
from text import symbols as symbols_v1 from GPT_SoVITS.text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from GPT_SoVITS.text import symbols2 as symbols_v2
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)} _symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)} _symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}

View File

@ -4,8 +4,8 @@ import re
import cn2an import cn2an
import ToJyutping import ToJyutping
from text.symbols import punctuation from GPT_SoVITS.text.symbols import punctuation
from text.zh_normalization.text_normlization import TextNormalizer from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn") normalizer = lambda x: cn2an.transform(x, "an2cn")
@ -195,7 +195,7 @@ def get_jyutping(text):
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
from text import chinese_bert from GPT_SoVITS.text import chinese_bert
return chinese_bert.get_bert_feature(text, word2ph) return chinese_bert.get_bert_feature(text, word2ph)

View File

@ -4,9 +4,9 @@ import re
import cn2an import cn2an
from pypinyin import lazy_pinyin, Style from pypinyin import lazy_pinyin, Style
from text.symbols import punctuation from GPT_SoVITS.text.symbols import punctuation
from text.tone_sandhi import ToneSandhi from GPT_SoVITS.text.tone_sandhi import ToneSandhi
from text.zh_normalization.text_normlization import TextNormalizer from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn") normalizer = lambda x: cn2an.transform(x, "an2cn")

View File

@ -5,9 +5,9 @@ import cn2an
from pypinyin import lazy_pinyin, Style from pypinyin import lazy_pinyin, Style
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
from text.symbols import punctuation from GPT_SoVITS.text.symbols import punctuation
from text.tone_sandhi import ToneSandhi from GPT_SoVITS.text.tone_sandhi import ToneSandhi
from text.zh_normalization.text_normlization import TextNormalizer from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn") normalizer = lambda x: cn2an.transform(x, "an2cn")
@ -28,7 +28,7 @@ import jieba_fast.posseg as psg
is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False
if is_g2pw: if is_g2pw:
# print("当前使用g2pw进行拼音推理") # print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation from GPT_SoVITS.text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path) parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin( g2pw = G2PWPinyin(

View File

@ -1,14 +1,14 @@
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
import os import os
# if os.environ.get("version","v1")=="v1": # if os.environ.get("version","v1")=="v1":
# from text import chinese # from GPT_SoVITS.text import chinese
# from text.symbols import symbols # from GPT_SoVITS.text.symbols import symbols
# else: # else:
# from text import chinese2 as chinese # from GPT_SoVITS.text import chinese2 as chinese
# from text.symbols2 import symbols # from GPT_SoVITS.text.symbols2 import symbols
from text import symbols as symbols_v1 from GPT_SoVITS.text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from GPT_SoVITS.text import symbols2 as symbols_v2
special = [ special = [
# ("%", "zh", "SP"), # ("%", "zh", "SP"),
@ -34,7 +34,7 @@ def clean_text(text, language, version=None):
for special_s, special_l, target_symbol in special: for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l: if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol, version) return clean_special(text, language, special_s, target_symbol, version)
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) language_module = __import__("GPT_SoVITS.text." + language_module_map[language], fromlist=[language_module_map[language]])
if hasattr(language_module, "text_normalize"): if hasattr(language_module, "text_normalize"):
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
else: else:
@ -69,7 +69,7 @@ def clean_special(text, language, special_s, target_symbol, version=None):
特殊静音段sp符号处理 特殊静音段sp符号处理
""" """
text = text.replace(special_s, ",") text = text.replace(special_s, ",")
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) language_module = __import__("GPT_SoVITS.text."+language_module_map[language],fromlist=[language_module_map[language]])
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
new_ph = [] new_ph = []

View File

@ -4,9 +4,9 @@ import re
import wordsegment import wordsegment
from g2p_en import G2p from g2p_en import G2p
from text.symbols import punctuation from GPT_SoVITS.text.symbols import punctuation
from text.symbols2 import symbols from GPT_SoVITS.text.symbols2 import symbols
from builtins import str as unicode from builtins import str as unicode
from text.en_normalization.expend import normalize from text.en_normalization.expend import normalize

View File

@ -1 +1 @@
from text.g2pw.g2pw import * from GPT_SoVITS.text.g2pw.g2pw import *

View File

@ -77,8 +77,7 @@ except Exception:
pass pass
from text.symbols import punctuation from GPT_SoVITS.text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks: # Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile( _japanese_characters = re.compile(
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"

View File

@ -56,7 +56,7 @@ if os.name == "nt":
G2p = win_G2p G2p = win_G2p
from text.symbols2 import symbols from GPT_SoVITS.text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals. # This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = ( _korean_classifiers = (

View File

@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text.zh_normalization.text_normlization import * from GPT_SoVITS.text.zh_normalization.text_normlization import *

View File

@ -9,7 +9,7 @@ import torch
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from tqdm import tqdm from tqdm import tqdm
from tools.asr.config import check_fw_local_models from GPT_SoVITS.tools.asr.config import check_fw_local_models
# fmt: off # fmt: off
language_code_list = [ language_code_list = [
@ -72,8 +72,13 @@ def execute_asr(input_folder, output_folder, model_size, language, precision):
if info.language == "zh": if info.language == "zh":
print("检测为中文文本, 转 FunASR 处理") print("检测为中文文本, 转 FunASR 处理")
<<<<<<< HEAD:tools/asr/fasterwhisper_asr.py
if "only_asr" not in globals(): if "only_asr" not in globals():
from tools.asr.funasr_asr import only_asr # 如果用英文就不需要导入下载模型 from tools.asr.funasr_asr import only_asr # 如果用英文就不需要导入下载模型
=======
if("only_asr" not in globals()):
from GPT_SoVITS.tools.asr.funasr_asr import only_asr #如果用英文就不需要导入下载模型
>>>>>>> main:GPT_SoVITS/tools/asr/fasterwhisper_asr.py
text = only_asr(file_path, language=info.language.lower()) text = only_asr(file_path, language=info.language.lower())
if text == "": if text == "":

View File

@ -9,8 +9,8 @@ import torch
import torchaudio.functional as aF import torchaudio.functional as aF
# from attrdict import AttrDict####will be bug in py3.10 # from attrdict import AttrDict####will be bug in py3.10
from datasets1.dataset import amp_pha_stft, amp_pha_istft from GPT_SoVITS.tools.AP_BWE_main.datasets1.dataset import amp_pha_stft, amp_pha_istft
from models.model import APNet_BWE_Model from GPT_SoVITS.tools.AP_BWE_main.models.model import APNet_BWE_Model
class AP_BWE: class AP_BWE:

View File

@ -101,6 +101,7 @@
"实际输入的目标文本(每句):": "Texto alvo realmente inserido (por frase):", "实际输入的目标文本(每句):": "Texto alvo realmente inserido (por frase):",
"实际输入的目标文本:": "Texto alvo realmente inserido:", "实际输入的目标文本:": "Texto alvo realmente inserido:",
"导出文件格式": "Formato de arquivo de exportação", "导出文件格式": "Formato de arquivo de exportação",
<<<<<<< HEAD:tools/i18n/locale/pt_BR.json
"已关闭": " Fechado", "已关闭": " Fechado",
"已完成": " Concluído", "已完成": " Concluído",
"已开启": " Ativado", "已开启": " Ativado",
@ -110,6 +111,21 @@
"开启": "Ativar ", "开启": "Ativar ",
"开启无参考文本模式。不填参考文本亦相当于开启。": "Ativar o modo sem texto de referência. Não preencher o texto de referência também equivale a ativar.", "开启无参考文本模式。不填参考文本亦相当于开启。": "Ativar o modo sem texto de referência. Não preencher o texto de referência também equivale a ativar.",
"微调训练": "Treinamento de ajuste fino", "微调训练": "Treinamento de ajuste fino",
=======
"开启GPT训练": "Ativar treinamento GPT",
"开启SSL提取": "Ativar extração SSL",
"开启SoVITS训练": "Ativar treinamento SoVITS",
"开启TTS推理WebUI": "Abrir TTS Inference WebUI",
"开启UVR5-WebUI": "Abrir UVR5-WebUI",
"开启一键三连": "Ativar um clique",
"开启打标WebUI": "Abrir Labeling WebUI",
"开启文本获取": "Ativar obtenção de texto",
"开启无参考文本模式。不填参考文本亦相当于开启。": "Ativar o modo sem texto de referência. Não preencher o texto de referência também equivale a ativGPT_SoVITS.AR.",
"开启离线批量ASR": "Ativar ASR offline em lote",
"开启语义token提取": "Ativar extração de token semântico",
"开启语音切割": "Ativar corte de voz",
"开启语音降噪": "Ativar redução de ruído de voz",
>>>>>>> main:GPT_SoVITS/tools/i18n/locale/pt_BR.json
"怎么切": "Como cortar", "怎么切": "Como cortar",
"总训练轮数total_epoch": "Total de epoch de treinamento", "总训练轮数total_epoch": "Total de epoch de treinamento",
"总训练轮数total_epoch不建议太高": "Total de epoch de treinamento, não é recomendável um valor muito alto", "总训练轮数total_epoch不建议太高": "Total de epoch de treinamento, não é recomendável um valor muito alto",

View File

@ -3,7 +3,7 @@ import traceback
import ffmpeg import ffmpeg
import numpy as np import numpy as np
import gradio as gr import gradio as gr
from tools.i18n.i18n import I18nAuto from GPT_SoVITS.tools.i18n.i18n import I18nAuto
import pandas as pd import pandas as pd
i18n = I18nAuto(language=os.environ.get("language", "Auto")) i18n = I18nAuto(language=os.environ.get("language", "Auto"))

View File

@ -1,6 +1,7 @@
# This code is modified from https://github.com/ZFTurbo/ # This code is modified from https://github.com/ZFTurbo/
import os import os
import warnings import warnings
import subprocess
import librosa import librosa
import numpy as np import numpy as np
@ -160,7 +161,7 @@ class Roformer_Loader:
batch_data.append(part) batch_data.append(part)
batch_locations.append((i, length)) batch_locations.append((i, length))
i += step i += step
progress_bar.update(1) progress_bGPT_SoVITS.AR.update(1)
if len(batch_data) >= batch_size or (i >= mix.shape[1]): if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0) arr = torch.stack(batch_data, dim=0)
@ -189,7 +190,7 @@ class Roformer_Loader:
# Remove pad # Remove pad
estimated_sources = estimated_sources[..., border:-border] estimated_sources = estimated_sources[..., border:-border]
progress_bar.close() progress_bGPT_SoVITS.AR.close()
if self.config["training"]["target_instrument"] is None: if self.config["training"]["target_instrument"] is None:
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)} return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
@ -253,7 +254,10 @@ class Roformer_Loader:
sf.write(path, data, sr) sf.write(path, data, sr)
else: else:
sf.write(path, data, sr) sf.write(path, data, sr)
os.system('ffmpeg -i "{}" -vn "{}" -q:a 2 -y'.format(path, path[:-3] + format)) subprocess.run(
["ffmpeg", "-i", path, "-vn", path[:-3] + format, "-q:a", "2", "-y"],
check=True,
)
try: try:
os.remove(path) os.remove(path)
except: except:

Some files were not shown because too many files have changed in this diff Show More