more import changes

This commit is contained in:
Jarod Mica 2024-12-23 01:58:31 -08:00
parent c5490bb2a2
commit 54bcce13d2
26 changed files with 226 additions and 75 deletions

View File

@ -17,7 +17,7 @@ from transformers import AutoTokenizer
version = os.environ.get('version',None)
from text import cleaned_text_to_sequence
from GPT_SoVITS.text import cleaned_text_to_sequence
# from config import exp_dir

View File

@ -884,6 +884,157 @@ class Text2SemanticDecoder(nn.Module):
return y[:, :-1], 0
return y[:, :-1], idx - 1
def infer_panel_generator(
self,
x: torch.LongTensor,
x_lens: torch.LongTensor,
prompts: torch.LongTensor,
bert_feature: torch.LongTensor,
cumulation_amount: int,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
"""
Generator method that yields generated tokens based on a specified cumulative amount.
Args:
x (torch.LongTensor): Input phoneme IDs.
x_lens (torch.LongTensor): Lengths of the input sequences.
prompts (torch.LongTensor): Initial prompt tokens.
bert_feature (torch.LongTensor): BERT features corresponding to the input.
cumulation_amount (int): Number of tokens to generate before yielding.
top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (float): If set to < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
early_stop_num (int): Early stopping after generating a certain number of tokens.
temperature (float): The value used to module the next token probabilities.
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
**kwargs: Additional keyword arguments.
Yields:
torch.LongTensor: Generated tokens since the last yield.
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device)
stop = False
# Initialize cumulative token counter
tokens_since_last_yield = 0
# Initialize last yield index
if y is not None:
prefix_len = y.shape[1]
else:
prefix_len = 0
last_yield_idx = prefix_len
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int64, device=x.device)
ref_free = True
bsz = x.shape[0]
src_len = x_len + y_len
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len), # Extend x_attn_mask to include y tokens
value=True,
)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
xy_attn_mask = xy_attn_mask.unsqueeze(0).expand(bsz * self.num_head, -1, -1)
xy_attn_mask = xy_attn_mask.view(bsz, self.num_head, src_len, src_len)
xy_attn_mask = xy_attn_mask.to(device=x.device, dtype=torch.bool)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if idx < 11: # Ensure at least 10 tokens are generated before stopping
logits = logits[:, :-1]
samples = sample(
logits,
y,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
)[0]
y = torch.concat([y, samples], dim=1)
tokens_since_last_yield += 1
if tokens_since_last_yield >= cumulation_amount:
# Yield back the generated tokens since last yield
generated_tokens = y[:, last_yield_idx:]
# print(generated_tokens)
yield generated_tokens
last_yield_idx = y.shape[1]
tokens_since_last_yield = 0
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("Using early stop num:", early_stop_num)
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("Bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# Update for next step
y_emb = self.ar_audio_embedding(y[:, -1:])
y_len += 1
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len - 1
].to(dtype=y_emb.dtype, device=y_emb.device)
# After loop ends, yield any remaining tokens
if last_yield_idx < y.shape[1]:
generated_tokens = y[:, last_yield_idx:]
yield generated_tokens
def infer_panel(
self,

View File

@ -8,10 +8,10 @@ sys.path.append(now_dir)
import re
import torch
import LangSegment
from text import chinese
from GPT_SoVITS.text import chinese
from typing import Dict, List, Tuple
from text.cleaner import clean_text
from text import cleaned_text_to_sequence
from GPT_SoVITS.text.cleaner import clean_text
from GPT_SoVITS.text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method

View File

@ -1,5 +1,5 @@
import os, sys
now_dir = os.getcwd()
sys.path.insert(0, now_dir)
from text.g2pw import G2PWPinyin
from GPT_SoVITS.text.g2pw import G2PWPinyin
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)

View File

@ -3,7 +3,7 @@
import argparse
from typing import Optional
from my_utils import load_audio
from text import cleaned_text_to_sequence
from GPT_SoVITS.text import cleaned_text_to_sequence
import torch
import torchaudio
@ -14,7 +14,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from feature_extractor import cnhubert
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from module.models_onnx import SynthesizerTrn
from GPT_SoVITS.module.models_onnx import SynthesizerTrn
from inference_webui import get_phones_and_bert

View File

@ -83,12 +83,12 @@ from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn
from GPT_SoVITS.module.models import SynthesizerTrn
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from GPT_SoVITS.text import cleaned_text_to_sequence
from GPT_SoVITS.text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from GPT_SoVITS.module.mel_processing import spectrogram_torch
from tools.my_utils import load_audio
from tools.i18n.i18n import I18nAuto, scan_language_list
@ -303,7 +303,7 @@ def get_first(text):
text = re.split(pattern, text)[0].strip()
return text
from text import chinese
from GPT_SoVITS.text import chinese
def get_phones_and_bert(text,language,version,final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")

View File

@ -9,18 +9,18 @@ import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module import modules
from module import attentions
from GPT_SoVITS.module import commons
from GPT_SoVITS.module import modules
from GPT_SoVITS.module import attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from GPT_SoVITS.module.commons import init_weights, get_padding
from GPT_SoVITS.module.mrte_model import MRTE
from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
# from GPT_SoVITS.text import symbols
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
import contextlib

View File

@ -5,17 +5,17 @@ import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module import modules
from module import attentions_onnx as attentions
from GPT_SoVITS.module import commons
from GPT_SoVITS.module import modules
from GPT_SoVITS.module import attentions_onnx as attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from GPT_SoVITS.module.commons import init_weights, get_padding
from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
# from GPT_SoVITS.text import symbols
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
from torch.cuda.amp import autocast

View File

@ -7,9 +7,9 @@ from torch.nn import functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from module import commons
from module.commons import init_weights, get_padding
from module.transforms import piecewise_rational_quadratic_transform
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.commons import init_weights, get_padding
from GPT_SoVITS.module.transforms import piecewise_rational_quadratic_transform
import torch.distributions as D

View File

@ -3,7 +3,7 @@
import torch
from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention
from GPT_SoVITS.module.attentions import MultiHeadAttention
class MRTE(nn.Module):

View File

@ -13,7 +13,7 @@ import typing as tp
import torch
from torch import nn
from module.core_vq import ResidualVectorQuantization
from GPT_SoVITS.module.core_vq import ResidualVectorQuantization
@dataclass

View File

@ -1,4 +1,4 @@
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from GPT_SoVITS.module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from GPT_SoVITS.AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch
import torchaudio
@ -8,7 +8,7 @@ from feature_extractor import cnhubert
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
from GPT_SoVITS.text import cleaned_text_to_sequence
import soundfile
from tools.my_utils import load_audio
import os

View File

@ -18,7 +18,7 @@ import sys, numpy as np, traceback, pdb
import os.path
from glob import glob
from tqdm import tqdm
from text.cleaner import clean_text
from GPT_SoVITS.text.cleaner import clean_text
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from tools.my_utils import clean_path

View File

@ -23,7 +23,7 @@ import torch.multiprocessing as mp
from glob import glob
from tqdm import tqdm
import logging, librosa, utils
from module.models import SynthesizerTrn
from GPT_SoVITS.module.models import SynthesizerTrn
from tools.my_utils import clean_path
logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G

View File

@ -18,19 +18,19 @@ logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from GPT_SoVITS.module import commons
from module.data_utils import (
from GPT_SoVITS.module.data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from module.models import (
from GPT_SoVITS.module.models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from GPT_SoVITS.module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee
torch.backends.cudnn.benchmark = False

View File

@ -1,11 +1,11 @@
import os
# if os.environ.get("version","v1")=="v1":
# from text.symbols import symbols
# from GPT_SoVITS.text.symbols import symbols
# else:
# from text.symbols2 import symbols
# from GPT_SoVITS.text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}

View File

@ -5,8 +5,8 @@ import re
import cn2an
from pyjyutping import jyutping
from text.symbols import punctuation
from text.zh_normalization.text_normlization import TextNormalizer
from GPT_SoVITS.text.symbols import punctuation
from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn")
@ -182,7 +182,7 @@ def get_jyutping(text):
def get_bert_feature(text, word2ph):
from text import chinese_bert
from GPT_SoVITS.text import chinese_bert
return chinese_bert.get_bert_feature(text, word2ph)

View File

@ -5,9 +5,9 @@ import re
import cn2an
from pypinyin import lazy_pinyin, Style
from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi
from text.zh_normalization.text_normlization import TextNormalizer
from GPT_SoVITS.text.symbols import punctuation
from GPT_SoVITS.text.tone_sandhi import ToneSandhi
from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn")

View File

@ -6,9 +6,9 @@ import cn2an
from pypinyin import lazy_pinyin, Style
from pypinyin.contrib.tone_convert import to_normal, to_finals_tone3, to_initials, to_finals
from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi
from text.zh_normalization.text_normlization import TextNormalizer
from GPT_SoVITS.text.symbols import punctuation
from GPT_SoVITS.text.tone_sandhi import ToneSandhi
from GPT_SoVITS.text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn")
@ -25,7 +25,7 @@ import jieba_fast.posseg as psg
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
if is_g2pw:
print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation
from GPT_SoVITS.text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)

View File

@ -1,14 +1,14 @@
from text import cleaned_text_to_sequence
from GPT_SoVITS.text import cleaned_text_to_sequence
import os
# if os.environ.get("version","v1")=="v1":
# from text import chinese
# from text.symbols import symbols
# from GPT_SoVITS.text import chinese
# from GPT_SoVITS.text.symbols import symbols
# else:
# from text import chinese2 as chinese
# from text.symbols2 import symbols
# from GPT_SoVITS.text import chinese2 as chinese
# from GPT_SoVITS.text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
special = [
# ("%", "zh", "SP"),

View File

@ -4,9 +4,9 @@ import re
import wordsegment
from g2p_en import G2p
from text.symbols import punctuation
from GPT_SoVITS.text.symbols import punctuation
from text.symbols2 import symbols
from GPT_SoVITS.text.symbols2 import symbols
import unicodedata
from builtins import str as unicode

View File

@ -1 +1 @@
from text.g2pw.g2pw import *
from GPT_SoVITS.text.g2pw.g2pw import *

View File

@ -31,7 +31,7 @@ except Exception as e:
pass
from text.symbols import punctuation
from GPT_SoVITS.text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"

View File

@ -5,7 +5,7 @@ from jamo import h2j, j2hcj
import ko_pron
from g2pk2 import G2p
from text.symbols2 import symbols
from GPT_SoVITS.text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'

View File

@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from text.zh_normalization.text_normlization import *
from GPT_SoVITS.text.zh_normalization.text_normlization import *

10
api.py
View File

@ -162,11 +162,11 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from feature_extractor import cnhubert
from io import BytesIO
from module.models import SynthesizerTrn
from GPT_SoVITS.module.models import SynthesizerTrn
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from module.mel_processing import spectrogram_torch
from GPT_SoVITS.text import cleaned_text_to_sequence
from GPT_SoVITS.text.cleaner import clean_text
from GPT_SoVITS.module.mel_processing import spectrogram_torch
from tools.my_utils import load_audio
import config as global_config
import logging
@ -312,7 +312,7 @@ def get_bert_inf(phones, word2ph, norm_text, language):
return bert
from text import chinese
from GPT_SoVITS.text import chinese
def get_phones_and_bert(text,language,version,final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")