mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
support v2 model.
This commit is contained in:
parent
7c43b41e6d
commit
574f667c71
@ -17,6 +17,9 @@ from transformers import AutoTokenizer
|
||||
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
version = os.environ.get('version', None)
|
||||
|
||||
|
||||
# from config import exp_dir
|
||||
|
||||
|
||||
@ -33,7 +36,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
padding = (
|
||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
)
|
||||
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
||||
padded_sequences.append(padded_seq)
|
||||
@ -45,16 +48,16 @@ class Text2SemanticDataset(Dataset):
|
||||
"""dataset class for text tokens to semantic model training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phoneme_path: str,
|
||||
semantic_path: str,
|
||||
max_sample: int = None,
|
||||
max_sec: int = 100,
|
||||
pad_val: int = 1024,
|
||||
# min value of phoneme/sec
|
||||
min_ps_ratio: int = 3,
|
||||
# max value of phoneme/sec
|
||||
max_ps_ratio: int = 25,
|
||||
self,
|
||||
phoneme_path: str,
|
||||
semantic_path: str,
|
||||
max_sample: int = None,
|
||||
max_sec: int = 100,
|
||||
pad_val: int = 1024,
|
||||
# min value of phoneme/sec
|
||||
min_ps_ratio: int = 3,
|
||||
# max value of phoneme/sec
|
||||
max_ps_ratio: int = 25,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -125,7 +128,7 @@ class Text2SemanticDataset(Dataset):
|
||||
for i in range(semantic_data_len):
|
||||
# 先依次遍历
|
||||
# get str
|
||||
item_name = self.semantic_data.iloc[i,0]
|
||||
item_name = self.semantic_data.iloc[i, 0]
|
||||
# print(self.phoneme_data)
|
||||
try:
|
||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||
@ -135,13 +138,13 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
|
||||
semantic_str = self.semantic_data.iloc[i,1]
|
||||
semantic_str = self.semantic_data.iloc[i, 1]
|
||||
# get token list
|
||||
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
||||
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
||||
# 过滤掉太长的样本
|
||||
if (
|
||||
len(semantic_ids) > self.max_sec * self.hz
|
||||
len(semantic_ids) > self.max_sec * self.hz
|
||||
): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
|
||||
num_deleted_bigger += 1
|
||||
continue
|
||||
@ -149,7 +152,7 @@ class Text2SemanticDataset(Dataset):
|
||||
phoneme = phoneme.split(" ")
|
||||
|
||||
try:
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme)
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
# print(f"{item_name} not in self.phoneme_data !")
|
||||
@ -157,7 +160,7 @@ class Text2SemanticDataset(Dataset):
|
||||
continue
|
||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||
if (
|
||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
||||
): ###########2:改为恒定限制为semantic/2.5就行
|
||||
num_deleted_ps += 1
|
||||
continue
|
||||
@ -168,7 +171,7 @@ class Text2SemanticDataset(Dataset):
|
||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||
|
||||
if (
|
||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
||||
): ##########4#3~25#每秒多少个phone
|
||||
num_deleted_ps += 1
|
||||
# print(item_name)
|
||||
|
@ -1,10 +1,12 @@
|
||||
from copy import deepcopy
|
||||
import math
|
||||
import os, sys, gc
|
||||
import os, sys
|
||||
import random
|
||||
import traceback
|
||||
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import ffmpeg
|
||||
@ -26,6 +28,7 @@ from my_utils import load_audio
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
# configs/tts_infer.yaml
|
||||
@ -49,7 +52,8 @@ custom:
|
||||
|
||||
"""
|
||||
|
||||
def set_seed(seed:int):
|
||||
|
||||
def set_seed(seed: int):
|
||||
seed = int(seed)
|
||||
seed = seed if seed != -1 else random.randrange(1 << 32)
|
||||
print(f"Set seed to {seed}")
|
||||
@ -71,40 +75,41 @@ def set_seed(seed:int):
|
||||
pass
|
||||
return seed
|
||||
|
||||
|
||||
class TTS_Config:
|
||||
default_configs={
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
}
|
||||
configs:dict = None
|
||||
def __init__(self, configs: Union[dict, str]=None):
|
||||
default_configs = {
|
||||
"device": "cpu",
|
||||
"is_half": False,
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
}
|
||||
configs: dict = None
|
||||
|
||||
def __init__(self, configs: Union[dict, str] = None):
|
||||
|
||||
# 设置默认配置文件路径
|
||||
configs_base_path:str = "GPT_SoVITS/configs/"
|
||||
configs_base_path: str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
self.configs_path: str = os.path.join(configs_base_path, "tts_infer.yaml")
|
||||
|
||||
if configs in ["", None]:
|
||||
if not os.path.exists(self.configs_path):
|
||||
self.save_configs()
|
||||
print(f"Create default config file at {self.configs_path}")
|
||||
configs:dict = {"default": deepcopy(self.default_configs)}
|
||||
configs: dict = {"default": deepcopy(self.default_configs)}
|
||||
|
||||
if isinstance(configs, str):
|
||||
self.configs_path = configs
|
||||
configs:dict = self._load_configs(self.configs_path)
|
||||
configs: dict = self._load_configs(self.configs_path)
|
||||
|
||||
assert isinstance(configs, dict)
|
||||
default_configs:dict = configs.get("default", None)
|
||||
default_configs: dict = configs.get("default", None)
|
||||
if default_configs is not None:
|
||||
self.default_configs = default_configs
|
||||
|
||||
self.configs:dict = configs.get("custom", deepcopy(self.default_configs))
|
||||
|
||||
self.configs: dict = configs.get("custom", deepcopy(self.default_configs))
|
||||
|
||||
self.device = self.configs.get("device", torch.device("cpu"))
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
@ -113,7 +118,6 @@ class TTS_Config:
|
||||
self.bert_base_path = self.configs.get("bert_base_path", None)
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
|
||||
|
||||
|
||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||
self.t2s_weights_path = self.default_configs['t2s_weights_path']
|
||||
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||
@ -128,29 +132,27 @@ class TTS_Config:
|
||||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||
self.update_configs()
|
||||
|
||||
|
||||
self.max_sec = None
|
||||
self.hz:int = 50
|
||||
self.semantic_frame_rate:str = "25hz"
|
||||
self.segment_size:int = 20480
|
||||
self.filter_length:int = 2048
|
||||
self.sampling_rate:int = 32000
|
||||
self.hop_length:int = 640
|
||||
self.win_length:int = 2048
|
||||
self.n_speakers:int = 300
|
||||
self.hz: int = 50
|
||||
self.semantic_frame_rate: str = "25hz"
|
||||
self.segment_size: int = 20480
|
||||
self.filter_length: int = 2048
|
||||
self.sampling_rate: int = 32000
|
||||
self.hop_length: int = 640
|
||||
self.win_length: int = 2048
|
||||
self.n_speakers: int = 300
|
||||
|
||||
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
self.languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
||||
|
||||
|
||||
def _load_configs(self, configs_path: str)->dict:
|
||||
def _load_configs(self, configs_path: str) -> dict:
|
||||
with open(configs_path, 'r') as f:
|
||||
configs = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
return configs
|
||||
|
||||
def save_configs(self, configs_path:str=None)->None:
|
||||
configs={
|
||||
"default":self.default_configs,
|
||||
def save_configs(self, configs_path: str = None) -> None:
|
||||
configs = {
|
||||
"default": self.default_configs,
|
||||
}
|
||||
if self.configs is not None:
|
||||
configs["custom"] = self.update_configs()
|
||||
@ -162,11 +164,11 @@ class TTS_Config:
|
||||
|
||||
def update_configs(self):
|
||||
self.config = {
|
||||
"device" : str(self.device),
|
||||
"is_half" : self.is_half,
|
||||
"t2s_weights_path" : self.t2s_weights_path,
|
||||
"vits_weights_path" : self.vits_weights_path,
|
||||
"bert_base_path" : self.bert_base_path,
|
||||
"device": str(self.device),
|
||||
"is_half": self.is_half,
|
||||
"t2s_weights_path": self.t2s_weights_path,
|
||||
"vits_weights_path": self.vits_weights_path,
|
||||
"bert_base_path": self.bert_base_path,
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||
}
|
||||
return self.config
|
||||
@ -194,63 +196,58 @@ class TTS:
|
||||
if isinstance(configs, TTS_Config):
|
||||
self.configs = configs
|
||||
else:
|
||||
self.configs:TTS_Config = TTS_Config(configs)
|
||||
self.configs: TTS_Config = TTS_Config(configs)
|
||||
|
||||
self.t2s_model:Text2SemanticLightningModule = None
|
||||
self.vits_model:SynthesizerTrn = None
|
||||
self.bert_tokenizer:AutoTokenizer = None
|
||||
self.bert_model:AutoModelForMaskedLM = None
|
||||
self.cnhuhbert_model:CNHubert = None
|
||||
self.t2s_model: Text2SemanticLightningModule = None
|
||||
self.vits_model: SynthesizerTrn = None
|
||||
self.bert_tokenizer: AutoTokenizer = None
|
||||
self.bert_model: AutoModelForMaskedLM = None
|
||||
self.cnhuhbert_model: CNHubert = None
|
||||
self.version = "v1"
|
||||
|
||||
self._init_models()
|
||||
|
||||
self.text_preprocessor:TextPreprocessor = \
|
||||
TextPreprocessor(self.bert_model,
|
||||
self.bert_tokenizer,
|
||||
self.configs.device)
|
||||
self.text_preprocessor: TextPreprocessor = \
|
||||
TextPreprocessor(self.bert_model,
|
||||
self.bert_tokenizer,
|
||||
self.configs.device, version=self.version)
|
||||
|
||||
|
||||
self.prompt_cache:dict = {
|
||||
"ref_audio_path" : None,
|
||||
self.prompt_cache: dict = {
|
||||
"ref_audio_path": None,
|
||||
"prompt_semantic": None,
|
||||
"refer_spec" : None,
|
||||
"prompt_text" : None,
|
||||
"prompt_lang" : None,
|
||||
"phones" : None,
|
||||
"bert_features" : None,
|
||||
"norm_text" : None,
|
||||
"refer_spec": None,
|
||||
"prompt_text": None,
|
||||
"prompt_lang": None,
|
||||
"phones": None,
|
||||
"bert_features": None,
|
||||
"norm_text": None,
|
||||
}
|
||||
|
||||
self.stop_flag: bool = False
|
||||
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||
|
||||
self.stop_flag:bool = False
|
||||
self.precision:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||
|
||||
def _init_models(self,):
|
||||
def _init_models(self, ):
|
||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
self.init_vits_weights(self.configs.vits_weights_path)
|
||||
self.init_bert_weights(self.configs.bert_base_path)
|
||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||
# self.enable_half_precision(self.configs.is_half)
|
||||
|
||||
|
||||
|
||||
def init_cnhuhbert_weights(self, base_path: str):
|
||||
print(f"Loading CNHuBERT weights from {base_path}")
|
||||
self.cnhuhbert_model = CNHubert(base_path)
|
||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.eval()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
|
||||
|
||||
|
||||
def init_bert_weights(self, base_path: str):
|
||||
print(f"Loading BERT weights from {base_path}")
|
||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||
self.bert_model=self.bert_model.eval()
|
||||
self.bert_model = self.bert_model.eval()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.bert_model = self.bert_model.half()
|
||||
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
@ -266,6 +263,12 @@ class TTS:
|
||||
self.configs.win_length = hps["data"]["win_length"]
|
||||
self.configs.n_speakers = hps["data"]["n_speakers"]
|
||||
self.configs.semantic_frame_rate = "25hz"
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
hps['model']['version'] = "v1"
|
||||
else:
|
||||
hps['model']['version'] = "v2"
|
||||
self.version = "v2"
|
||||
logger.debug(self.version)
|
||||
kwargs = hps["model"]
|
||||
vits_model = SynthesizerTrn(
|
||||
self.configs.filter_length // 2 + 1,
|
||||
@ -281,10 +284,9 @@ class TTS:
|
||||
vits_model = vits_model.eval()
|
||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
self.vits_model = vits_model
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.vits_model = self.vits_model.half()
|
||||
|
||||
|
||||
def init_t2s_weights(self, weights_path: str):
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
@ -298,10 +300,10 @@ class TTS:
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model = t2s_model.eval()
|
||||
self.t2s_model = t2s_model
|
||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||
def enable_half_precision(self, enable: bool = True):
|
||||
'''
|
||||
To enable half precision for the TTS model.
|
||||
Args:
|
||||
@ -314,15 +316,14 @@ class TTS:
|
||||
|
||||
self.configs.is_half = enable
|
||||
self.precision = torch.float16 if enable else torch.float32
|
||||
if save:
|
||||
self.configs.save_configs()
|
||||
self.configs.save_configs()
|
||||
if enable:
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model =self.t2s_model.half()
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
if self.vits_model is not None:
|
||||
self.vits_model = self.vits_model.half()
|
||||
if self.bert_model is not None:
|
||||
self.bert_model =self.bert_model.half()
|
||||
self.bert_model = self.bert_model.half()
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
else:
|
||||
@ -335,15 +336,14 @@ class TTS:
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||
|
||||
def set_device(self, device: torch.device, save: bool = True):
|
||||
def set_device(self, device: torch.device):
|
||||
'''
|
||||
To set the device for all models.
|
||||
Args:
|
||||
device: torch.device, the device to use for all models.
|
||||
'''
|
||||
self.configs.device = device
|
||||
if save:
|
||||
self.configs.save_configs()
|
||||
self.configs.save_configs()
|
||||
if self.t2s_model is not None:
|
||||
self.t2s_model = self.t2s_model.to(device)
|
||||
if self.vits_model is not None:
|
||||
@ -353,7 +353,7 @@ class TTS:
|
||||
if self.cnhuhbert_model is not None:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
||||
|
||||
def set_ref_audio(self, ref_audio_path:str):
|
||||
def set_ref_audio(self, ref_audio_path: str):
|
||||
'''
|
||||
To set the reference audio for the TTS model,
|
||||
including the prompt_semantic and refer_spepc.
|
||||
@ -362,10 +362,6 @@ class TTS:
|
||||
'''
|
||||
self._set_prompt_semantic(ref_audio_path)
|
||||
self._set_ref_spec(ref_audio_path)
|
||||
self._set_ref_audio_path(ref_audio_path)
|
||||
|
||||
def _set_ref_audio_path(self, ref_audio_path):
|
||||
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||||
|
||||
def _set_ref_spec(self, ref_audio_path):
|
||||
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
|
||||
@ -386,8 +382,7 @@ class TTS:
|
||||
# self.refer_spec = spec
|
||||
self.prompt_cache["refer_spec"] = spec
|
||||
|
||||
|
||||
def _set_prompt_semantic(self, ref_wav_path:str):
|
||||
def _set_prompt_semantic(self, ref_wav_path: str):
|
||||
zero_wav = np.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
dtype=np.float16 if self.configs.is_half else np.float32,
|
||||
@ -415,12 +410,12 @@ class TTS:
|
||||
prompt_semantic = codes[0, 0].to(self.configs.device)
|
||||
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
||||
|
||||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
|
||||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
|
||||
seq = sequences[0]
|
||||
ndim = seq.dim()
|
||||
if axis < 0:
|
||||
axis += ndim
|
||||
dtype:torch.dtype = seq.dtype
|
||||
dtype: torch.dtype = seq.dtype
|
||||
pad_value = torch.tensor(pad_value, dtype=dtype)
|
||||
seq_lengths = [seq.shape[axis] for seq in sequences]
|
||||
if max_length is None:
|
||||
@ -436,15 +431,15 @@ class TTS:
|
||||
batch = torch.stack(padded_sequences)
|
||||
return batch
|
||||
|
||||
def to_batch(self, data:list,
|
||||
prompt_data:dict=None,
|
||||
batch_size:int=5,
|
||||
threshold:float=0.75,
|
||||
split_bucket:bool=True,
|
||||
device:torch.device=torch.device("cpu"),
|
||||
precision:torch.dtype=torch.float32,
|
||||
def to_batch(self, data: list,
|
||||
prompt_data: dict = None,
|
||||
batch_size: int = 5,
|
||||
threshold: float = 0.75,
|
||||
split_bucket: bool = True,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float32,
|
||||
):
|
||||
_data:list = []
|
||||
_data: list = []
|
||||
index_and_len_list = []
|
||||
for idx, item in enumerate(data):
|
||||
norm_text_len = len(item["norm_text"])
|
||||
@ -457,29 +452,28 @@ class TTS:
|
||||
|
||||
batch_index_list_len = 0
|
||||
pos = 0
|
||||
while pos <index_and_len_list.shape[0]:
|
||||
while pos < index_and_len_list.shape[0]:
|
||||
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
|
||||
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
||||
pos_end = min(pos + batch_size, index_and_len_list.shape[0])
|
||||
while pos < pos_end:
|
||||
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
|
||||
if (score>=threshold) or (pos_end-pos==1):
|
||||
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
||||
batch = index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||
score = batch[(pos_end - pos) // 2] / (batch.mean() + 1e-8)
|
||||
if (score >= threshold) or (pos_end - pos == 1):
|
||||
batch_index = index_and_len_list[pos:pos_end, 0].tolist()
|
||||
batch_index_list_len += len(batch_index)
|
||||
batch_index_list.append(batch_index)
|
||||
pos = pos_end
|
||||
break
|
||||
pos_end=pos_end-1
|
||||
pos_end = pos_end - 1
|
||||
|
||||
assert batch_index_list_len == len(data)
|
||||
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
if i%batch_size == 0:
|
||||
if i % batch_size == 0:
|
||||
batch_index_list.append([])
|
||||
batch_index_list[-1].append(i)
|
||||
|
||||
|
||||
for batch_idx, index_list in enumerate(batch_index_list):
|
||||
item_list = [data[idx] for idx in index_list]
|
||||
phones_list = []
|
||||
@ -493,14 +487,14 @@ class TTS:
|
||||
phones_max_len = 0
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
.to(dtype=precision, device=device)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) \
|
||||
.to(dtype=precision, device=device)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]).to(device)
|
||||
phones = torch.LongTensor(item["phones"]).to(device)
|
||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||
else:
|
||||
all_bert_features = item["bert_features"]\
|
||||
.to(dtype=precision, device=device)
|
||||
all_bert_features = item["bert_features"] \
|
||||
.to(dtype=precision, device=device)
|
||||
phones = torch.LongTensor(item["phones"]).to(device)
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
@ -519,7 +513,6 @@ class TTS:
|
||||
all_phones_batch = all_phones_list
|
||||
all_bert_features_batch = all_bert_features_list
|
||||
|
||||
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||
@ -551,7 +544,7 @@ class TTS:
|
||||
|
||||
return _data, batch_index_list
|
||||
|
||||
def recovery_order(self, data:list, batch_index_list:list)->list:
|
||||
def recovery_order(self, data: list, batch_index_list: list) -> list:
|
||||
'''
|
||||
Recovery the order of the audio according to the batch_index_list.
|
||||
|
||||
@ -563,20 +556,20 @@ class TTS:
|
||||
list (List[np.ndarray]): the data in the original order.
|
||||
'''
|
||||
length = len(sum(batch_index_list, []))
|
||||
_data = [None]*length
|
||||
_data = [None] * length
|
||||
for i, index_list in enumerate(batch_index_list):
|
||||
for j, index in enumerate(index_list):
|
||||
_data[index] = data[i][j]
|
||||
return _data
|
||||
|
||||
def stop(self,):
|
||||
def stop(self, ):
|
||||
'''
|
||||
Stop the inference process.
|
||||
'''
|
||||
self.stop_flag = True
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, inputs:dict):
|
||||
def run(self, inputs: dict):
|
||||
"""
|
||||
Text to speech inference.
|
||||
|
||||
@ -606,16 +599,16 @@ class TTS:
|
||||
Tuple[int, np.ndarray]: sampling rate and audio data.
|
||||
"""
|
||||
########## variables initialization ###########
|
||||
self.stop_flag:bool = False
|
||||
text:str = inputs.get("text", "")
|
||||
text_lang:str = inputs.get("text_lang", "")
|
||||
ref_audio_path:str = inputs.get("ref_audio_path", "")
|
||||
prompt_text:str = inputs.get("prompt_text", "")
|
||||
prompt_lang:str = inputs.get("prompt_lang", "")
|
||||
top_k:int = inputs.get("top_k", 5)
|
||||
top_p:float = inputs.get("top_p", 1)
|
||||
temperature:float = inputs.get("temperature", 1)
|
||||
text_split_method:str = inputs.get("text_split_method", "cut0")
|
||||
self.stop_flag: bool = False
|
||||
text: str = inputs.get("text", "")
|
||||
text_lang: str = inputs.get("text_lang", "")
|
||||
ref_audio_path: str = inputs.get("ref_audio_path", "")
|
||||
prompt_text: str = inputs.get("prompt_text", "")
|
||||
prompt_lang: str = inputs.get("prompt_lang", "")
|
||||
top_k: int = inputs.get("top_k", 5)
|
||||
top_p: float = inputs.get("top_p", 1)
|
||||
temperature: float = inputs.get("temperature", 1)
|
||||
text_split_method: str = inputs.get("text_split_method", "cut0")
|
||||
batch_size = inputs.get("batch_size", 1)
|
||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||
speed_factor = inputs.get("speed_factor", 1.0)
|
||||
@ -644,7 +637,7 @@ class TTS:
|
||||
if split_bucket:
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
|
||||
if fragment_interval<0.01:
|
||||
if fragment_interval < 0.01:
|
||||
fragment_interval = 0.01
|
||||
print(i18n("分段间隔过小,已自动设置为0.01"))
|
||||
|
||||
@ -657,14 +650,15 @@ class TTS:
|
||||
assert prompt_lang in self.configs.languages
|
||||
|
||||
if ref_audio_path in [None, ""] and \
|
||||
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)):
|
||||
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
|
||||
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)):
|
||||
raise ValueError(
|
||||
"ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
|
||||
|
||||
###### setting reference audio and prompt text preprocessing ########
|
||||
t0 = ttime()
|
||||
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
|
||||
self.set_ref_audio(ref_audio_path)
|
||||
|
||||
self.text_preprocessor.version = self.version
|
||||
if not no_prompt_text:
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "."
|
||||
@ -674,37 +668,37 @@ class TTS:
|
||||
self.prompt_cache["prompt_lang"] = prompt_lang
|
||||
phones, bert_features, norm_text = \
|
||||
self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||
prompt_text,
|
||||
prompt_lang)
|
||||
prompt_text,
|
||||
prompt_lang)
|
||||
self.prompt_cache["phones"] = phones
|
||||
self.prompt_cache["bert_features"] = bert_features
|
||||
self.prompt_cache["norm_text"] = norm_text
|
||||
|
||||
###### text preprocessing ########
|
||||
t1 = ttime()
|
||||
data:list = None
|
||||
data: list = None
|
||||
if not return_fragment:
|
||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
||||
if len(data) == 0:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
dtype=np.int16)
|
||||
return
|
||||
|
||||
batch_index_list:list = None
|
||||
batch_index_list: list = None
|
||||
data, batch_index_list = self.to_batch(data,
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold,
|
||||
split_bucket=split_bucket,
|
||||
device=self.configs.device,
|
||||
precision=self.precision
|
||||
)
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold,
|
||||
split_bucket=split_bucket,
|
||||
device=self.configs.device,
|
||||
precision=self.precision
|
||||
)
|
||||
else:
|
||||
print(i18n("############ 切分文本 ############"))
|
||||
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
|
||||
data = []
|
||||
for i in range(len(texts)):
|
||||
if i%batch_size == 0:
|
||||
if i % batch_size == 0:
|
||||
data.append([])
|
||||
data[-1].append(texts[i])
|
||||
|
||||
@ -712,10 +706,11 @@ class TTS:
|
||||
batch_data = []
|
||||
print(i18n("############ 提取文本Bert特征 ############"))
|
||||
for text in tqdm(batch_texts):
|
||||
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang)
|
||||
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text,
|
||||
text_lang)
|
||||
if phones is None:
|
||||
continue
|
||||
res={
|
||||
res = {
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
@ -724,16 +719,15 @@ class TTS:
|
||||
if len(batch_data) == 0:
|
||||
return None
|
||||
batch, _ = self.to_batch(batch_data,
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold,
|
||||
split_bucket=False,
|
||||
device=self.configs.device,
|
||||
precision=self.precision
|
||||
)
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold,
|
||||
split_bucket=False,
|
||||
device=self.configs.device,
|
||||
precision=self.precision
|
||||
)
|
||||
return batch[0]
|
||||
|
||||
|
||||
t2 = ttime()
|
||||
try:
|
||||
print("############ 推理 ############")
|
||||
@ -748,21 +742,21 @@ class TTS:
|
||||
if item is None:
|
||||
continue
|
||||
|
||||
batch_phones:List[torch.LongTensor] = item["phones"]
|
||||
batch_phones: List[torch.LongTensor] = item["phones"]
|
||||
# batch_phones:torch.LongTensor = item["phones"]
|
||||
batch_phones_len:torch.LongTensor = item["phones_len"]
|
||||
all_phoneme_ids:torch.LongTensor = item["all_phones"]
|
||||
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
||||
all_bert_features:torch.LongTensor = item["all_bert_features"]
|
||||
norm_text:str = item["norm_text"]
|
||||
batch_phones_len: torch.LongTensor = item["phones_len"]
|
||||
all_phoneme_ids: torch.LongTensor = item["all_phones"]
|
||||
all_phoneme_lens: torch.LongTensor = item["all_phones_len"]
|
||||
all_bert_features: torch.LongTensor = item["all_bert_features"]
|
||||
norm_text: str = item["norm_text"]
|
||||
max_len = item["max_len"]
|
||||
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||||
if no_prompt_text :
|
||||
if no_prompt_text:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(
|
||||
self.configs.device)
|
||||
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
@ -780,8 +774,8 @@ class TTS:
|
||||
t4 = ttime()
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\
|
||||
.to(dtype=self.precision, device=self.configs.device)
|
||||
refer_audio_spec: torch.Tensor = self.prompt_cache["refer_spec"] \
|
||||
.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
batch_audio_fragment = []
|
||||
|
||||
@ -804,15 +798,18 @@ class TTS:
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
|
||||
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))]
|
||||
audio_frag_idx = [pred_semantic_list[i].shape[0] * 2 * upsample_rate for i in
|
||||
range(0, len(pred_semantic_list))]
|
||||
audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
|
||||
_batch_audio_fragment = (self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec
|
||||
).detach()[0, 0, :])
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec
|
||||
).detach()[0, 0, :])
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
|
||||
batch_audio_fragment = [_batch_audio_fragment[audio_frag_end_idx[i - 1]:audio_frag_end_idx[i]] for i in
|
||||
range(1, len(audio_frag_end_idx))]
|
||||
|
||||
# ## vits串行推理
|
||||
# for i, idx in enumerate(idx_list):
|
||||
@ -830,35 +827,35 @@ class TTS:
|
||||
if return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||
yield self.audio_postprocess([batch_audio_fragment],
|
||||
self.configs.sampling_rate,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
fragment_interval
|
||||
)
|
||||
self.configs.sampling_rate,
|
||||
None,
|
||||
speed_factor,
|
||||
False,
|
||||
fragment_interval
|
||||
)
|
||||
else:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
dtype=np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
yield self.audio_postprocess(audio,
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket,
|
||||
fragment_interval
|
||||
)
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket,
|
||||
fragment_interval
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# 必须返回一个空音频, 否则会导致显存不释放。
|
||||
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||
dtype=np.int16)
|
||||
dtype=np.int16)
|
||||
# 重置模型, 否则会导致显存释放不完全。
|
||||
del self.t2s_model
|
||||
del self.vits_model
|
||||
@ -872,7 +869,6 @@ class TTS:
|
||||
|
||||
def empty_cache(self):
|
||||
try:
|
||||
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
|
||||
if "cuda" in str(self.configs.device):
|
||||
torch.cuda.empty_cache()
|
||||
elif str(self.configs.device) == "mps":
|
||||
@ -881,34 +877,32 @@ class TTS:
|
||||
pass
|
||||
|
||||
def audio_postprocess(self,
|
||||
audio:List[torch.Tensor],
|
||||
sr:int,
|
||||
batch_index_list:list=None,
|
||||
speed_factor:float=1.0,
|
||||
split_bucket:bool=True,
|
||||
fragment_interval:float=0.3
|
||||
)->Tuple[int, np.ndarray]:
|
||||
audio: List[torch.Tensor],
|
||||
sr: int,
|
||||
batch_index_list: list = None,
|
||||
speed_factor: float = 1.0,
|
||||
split_bucket: bool = True,
|
||||
fragment_interval: float = 0.3
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval),
|
||||
dtype=self.precision,
|
||||
device=self.configs.device
|
||||
)
|
||||
int(self.configs.sampling_rate * fragment_interval),
|
||||
dtype=self.precision,
|
||||
device=self.configs.device
|
||||
)
|
||||
|
||||
for i, batch in enumerate(audio):
|
||||
for j, audio_fragment in enumerate(batch):
|
||||
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
|
||||
if max_audio>1: audio_fragment/=max_audio
|
||||
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
||||
if max_audio > 1: audio_fragment /= max_audio
|
||||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||
audio[i][j] = audio_fragment.cpu().numpy()
|
||||
|
||||
|
||||
if split_bucket:
|
||||
audio = self.recovery_order(audio, batch_index_list)
|
||||
else:
|
||||
# audio = [item for batch in audio for item in batch]
|
||||
audio = sum(audio, [])
|
||||
|
||||
|
||||
audio = np.concatenate(audio, 0)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
@ -921,9 +915,7 @@ class TTS:
|
||||
return sr, audio
|
||||
|
||||
|
||||
|
||||
|
||||
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
||||
# 将 NumPy 数组转换为原始 PCM 流
|
||||
raw_audio = input_audio.astype(np.int16).tobytes()
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
|
||||
from loguru import logger
|
||||
import os, sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
@ -18,14 +19,16 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-', " "])
|
||||
|
||||
def get_first(text:str) -> str:
|
||||
|
||||
def get_first(text: str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
|
||||
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
@ -43,28 +46,29 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
def __init__(self, bert_model: AutoModelForMaskedLM,
|
||||
tokenizer: AutoTokenizer, device: torch.device, version: str = "v2"):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.version = version
|
||||
logger.debug(self.version)
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]:
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str) -> List[Dict]:
|
||||
print(i18n("############ 切分文本 ############"))
|
||||
text = self.replace_consecutive_punctuation(text) # 变量命名应该是写错了
|
||||
text = self.replace_consecutive_punctuation(text) # 变量命名应该是写错了
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
print(i18n("############ 提取文本Bert特征 ############"))
|
||||
for text in tqdm(texts):
|
||||
if not re.sub(r"\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
|
||||
if phones is None:
|
||||
continue
|
||||
res={
|
||||
res = {
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
@ -72,7 +76,7 @@ class TextPreprocessor:
|
||||
result.append(res)
|
||||
return result
|
||||
|
||||
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
|
||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||
text = text.strip("\n")
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
@ -90,13 +94,9 @@ class TextPreprocessor:
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
|
||||
@ -110,7 +110,7 @@ class TextPreprocessor:
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
|
||||
def segment_and_extract_feature_for_text(self, texts: list, language: str) -> Tuple[list, torch.Tensor, str]:
|
||||
textlist, langlist = self.seg_text(texts, language)
|
||||
if len(textlist) == 0:
|
||||
return None, None, None
|
||||
@ -118,13 +118,12 @@ class TextPreprocessor:
|
||||
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
|
||||
return phones, bert_features, norm_text
|
||||
|
||||
def seg_text(self, text: str, language: str) -> Tuple[list, list]:
|
||||
|
||||
def seg_text(self, text:str, language:str)->Tuple[list, list]:
|
||||
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language in ["auto", "zh", "ja"]:
|
||||
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||
LangSegment.setfilters(["zh", "ja", "en", "ko"])
|
||||
for tmp in LangSegment.getTexts(text):
|
||||
if tmp["text"] == "":
|
||||
continue
|
||||
@ -134,7 +133,7 @@ class TextPreprocessor:
|
||||
langlist.append("en")
|
||||
else:
|
||||
# 因无法区别中日文汉字,以用户输入为准
|
||||
langlist.append(language if language!="auto" else tmp["lang"])
|
||||
langlist.append(language if language != "auto" else tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
LangSegment.setfilters(["en"])
|
||||
@ -145,14 +144,14 @@ class TextPreprocessor:
|
||||
textlist.append(formattext)
|
||||
langlist.append("en")
|
||||
|
||||
elif language in ["all_zh","all_ja"]:
|
||||
elif language in ["all_zh", "all_ja"]:
|
||||
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
language = language.replace("all_","")
|
||||
language = language.replace("all_", "")
|
||||
if text == "":
|
||||
return [],[]
|
||||
return [], []
|
||||
textlist.append(formattext)
|
||||
langlist.append(language)
|
||||
|
||||
@ -161,8 +160,7 @@ class TextPreprocessor:
|
||||
|
||||
return textlist, langlist
|
||||
|
||||
|
||||
def extract_bert_feature(self, textlist:list, langlist:list):
|
||||
def extract_bert_feature(self, textlist: list, langlist: list):
|
||||
phones_list = []
|
||||
bert_feature_list = []
|
||||
norm_text_list = []
|
||||
@ -179,8 +177,7 @@ class TextPreprocessor:
|
||||
norm_text = ''.join(norm_text_list)
|
||||
return phones_list, bert_feature, norm_text
|
||||
|
||||
|
||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
@ -195,13 +192,13 @@ class TextPreprocessor:
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text:str, language:str):
|
||||
phones, word2ph, norm_text = clean_text(text, language)
|
||||
phones = cleaned_text_to_sequence(phones)
|
||||
def clean_text_inf(self, text: str, language: str):
|
||||
phones, word2ph, norm_text = clean_text(text, language, version=self.version)
|
||||
phones = cleaned_text_to_sequence(phones, self.version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
|
||||
language=language.replace("all_","")
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
else:
|
||||
@ -212,23 +209,19 @@ class TextPreprocessor:
|
||||
|
||||
return feature
|
||||
|
||||
def process_text(self,texts):
|
||||
_text=[]
|
||||
if all(text in [None, " ", "\n",""] for text in texts):
|
||||
def process_text(self, texts):
|
||||
_text = []
|
||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||
raise ValueError(i18n("请输入有效文本"))
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
if text in [None, " ", ""]:
|
||||
pass
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(self,text):
|
||||
def replace_consecutive_punctuation(self, text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
3
GPT_SoVITS/text/.gitignore
vendored
Normal file
3
GPT_SoVITS/text/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
G2PWModel
|
||||
__pycache__
|
||||
*.zip
|
@ -1,15 +1,27 @@
|
||||
from text.symbols import *
|
||||
import os
|
||||
# if os.environ.get("version","v1")=="v1":
|
||||
# from text.symbols import symbols
|
||||
# else:
|
||||
# from text.symbols2 import symbols
|
||||
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
||||
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
||||
|
||||
def cleaned_text_to_sequence(cleaned_text):
|
||||
def cleaned_text_to_sequence(cleaned_text, version=None):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version == "v1":
|
||||
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
|
||||
else:
|
||||
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
|
||||
|
||||
return phones
|
||||
|
||||
|
209
GPT_SoVITS/text/cantonese.py
Normal file
209
GPT_SoVITS/text/cantonese.py
Normal file
@ -0,0 +1,209 @@
|
||||
# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py
|
||||
|
||||
import sys
|
||||
import re
|
||||
import cn2an
|
||||
|
||||
from pyjyutping import jyutping
|
||||
from text.symbols import punctuation
|
||||
from text.zh_normalization.text_normlization import TextNormalizer
|
||||
|
||||
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
||||
|
||||
INITIALS = [
|
||||
"aa",
|
||||
"aai",
|
||||
"aak",
|
||||
"aap",
|
||||
"aat",
|
||||
"aau",
|
||||
"ai",
|
||||
"au",
|
||||
"ap",
|
||||
"at",
|
||||
"ak",
|
||||
"a",
|
||||
"p",
|
||||
"b",
|
||||
"e",
|
||||
"ts",
|
||||
"t",
|
||||
"dz",
|
||||
"d",
|
||||
"kw",
|
||||
"k",
|
||||
"gw",
|
||||
"g",
|
||||
"f",
|
||||
"h",
|
||||
"l",
|
||||
"m",
|
||||
"ng",
|
||||
"n",
|
||||
"s",
|
||||
"y",
|
||||
"w",
|
||||
"c",
|
||||
"z",
|
||||
"j",
|
||||
"ong",
|
||||
"on",
|
||||
"ou",
|
||||
"oi",
|
||||
"ok",
|
||||
"o",
|
||||
"uk",
|
||||
"ung",
|
||||
]
|
||||
INITIALS += ["sp", "spl", "spn", "sil"]
|
||||
|
||||
|
||||
rep_map = {
|
||||
":": ",",
|
||||
";": ",",
|
||||
",": ",",
|
||||
"。": ".",
|
||||
"!": "!",
|
||||
"?": "?",
|
||||
"\n": ".",
|
||||
"·": ",",
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
"$": ".",
|
||||
"“": "'",
|
||||
"”": "'",
|
||||
'"': "'",
|
||||
"‘": "'",
|
||||
"’": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"《": "'",
|
||||
"》": "'",
|
||||
"【": "'",
|
||||
"】": "'",
|
||||
"[": "'",
|
||||
"]": "'",
|
||||
"—": "-",
|
||||
"~": "-",
|
||||
"~": "-",
|
||||
"「": "'",
|
||||
"」": "'",
|
||||
}
|
||||
|
||||
|
||||
def replace_punctuation(text):
|
||||
# text = text.replace("嗯", "恩").replace("呣", "母")
|
||||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def text_normalize(text):
|
||||
tx = TextNormalizer()
|
||||
sentences = tx.normalize(text)
|
||||
dest_text = ""
|
||||
for sentence in sentences:
|
||||
dest_text += replace_punctuation(sentence)
|
||||
return dest_text
|
||||
|
||||
|
||||
punctuation_set=set(punctuation)
|
||||
def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||
initials_finals = []
|
||||
tones = []
|
||||
word2ph = []
|
||||
|
||||
for syllable in jyuping_syllables:
|
||||
if syllable in punctuation:
|
||||
initials_finals.append(syllable)
|
||||
tones.append(0)
|
||||
word2ph.append(1) # Add 1 for punctuation
|
||||
elif syllable == "_":
|
||||
initials_finals.append(syllable)
|
||||
tones.append(0)
|
||||
word2ph.append(1) # Add 1 for underscore
|
||||
else:
|
||||
try:
|
||||
tone = int(syllable[-1])
|
||||
syllable_without_tone = syllable[:-1]
|
||||
except ValueError:
|
||||
tone = 0
|
||||
syllable_without_tone = syllable
|
||||
|
||||
for initial in INITIALS:
|
||||
if syllable_without_tone.startswith(initial):
|
||||
if syllable_without_tone.startswith("nga"):
|
||||
initials_finals.extend(
|
||||
[
|
||||
syllable_without_tone[:2],
|
||||
syllable_without_tone[2:] or syllable_without_tone[-1],
|
||||
]
|
||||
)
|
||||
# tones.extend([tone, tone])
|
||||
tones.extend([-1, tone])
|
||||
word2ph.append(2)
|
||||
else:
|
||||
final = syllable_without_tone[len(initial) :] or initial[-1]
|
||||
initials_finals.extend([initial, final])
|
||||
# tones.extend([tone, tone])
|
||||
tones.extend([-1, tone])
|
||||
word2ph.append(2)
|
||||
break
|
||||
assert len(initials_finals) == len(tones)
|
||||
|
||||
###魔改为辅音+带音调的元音
|
||||
phones=[]
|
||||
for a,b in zip(initials_finals,tones):
|
||||
if(b not in [-1,0]):###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||
todo="%s%s"%(a,b)
|
||||
else:todo=a
|
||||
if(todo not in punctuation_set):todo="Y%s"%todo
|
||||
phones.append(todo)
|
||||
|
||||
# return initials_finals, tones, word2ph
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
def get_jyutping(text):
|
||||
jp = jyutping.convert(text)
|
||||
# print(1111111,jp)
|
||||
for symbol in punctuation:
|
||||
jp = jp.replace(symbol, " " + symbol + " ")
|
||||
jp_array = jp.split()
|
||||
return jp_array
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
from text import chinese_bert
|
||||
|
||||
return chinese_bert.get_bert_feature(text, word2ph)
|
||||
|
||||
|
||||
def g2p(text):
|
||||
# word2ph = []
|
||||
jyuping = get_jyutping(text)
|
||||
# print(jyuping)
|
||||
# phones, tones, word2ph = jyuping_to_initials_finals_tones(jyuping)
|
||||
phones, word2ph = jyuping_to_initials_finals_tones(jyuping)
|
||||
# phones = ["_"] + phones + ["_"]
|
||||
# tones = [0] + tones + [0]
|
||||
# word2ph = [1] + word2ph + [1]
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
|
||||
text = "佢個鋤頭太短啦。"
|
||||
text = text_normalize(text)
|
||||
# phones, tones, word2ph = g2p(text)
|
||||
phones, word2ph = g2p(text)
|
||||
# print(phones, tones, word2ph)
|
||||
print(phones, word2ph)
|
@ -54,6 +54,26 @@ def replace_punctuation(text):
|
||||
return replaced_text
|
||||
|
||||
|
||||
def replace_punctuation_with_en(text):
|
||||
text = text.replace("嗯", "恩").replace("呣", "母")
|
||||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
return result
|
||||
|
||||
|
||||
def g2p(text):
|
||||
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
|
||||
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
|
||||
@ -158,6 +178,23 @@ def text_normalize(text):
|
||||
dest_text = ""
|
||||
for sentence in sentences:
|
||||
dest_text += replace_punctuation(sentence)
|
||||
|
||||
# 避免重复标点引起的参考泄露
|
||||
dest_text = replace_consecutive_punctuation(dest_text)
|
||||
return dest_text
|
||||
|
||||
|
||||
# 不排除英文的文本格式化
|
||||
def mix_text_normalize(text):
|
||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||
tx = TextNormalizer()
|
||||
sentences = tx.normalize(text)
|
||||
dest_text = ""
|
||||
for sentence in sentences:
|
||||
dest_text += replace_punctuation_with_en(sentence)
|
||||
|
||||
# 避免重复标点引起的参考泄露
|
||||
dest_text = replace_consecutive_punctuation(dest_text)
|
||||
return dest_text
|
||||
|
||||
|
||||
|
308
GPT_SoVITS/text/chinese2.py
Normal file
308
GPT_SoVITS/text/chinese2.py
Normal file
@ -0,0 +1,308 @@
|
||||
import os
|
||||
import pdb
|
||||
import re
|
||||
|
||||
import cn2an
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
from pypinyin.contrib.tone_convert import to_normal, to_finals_tone3, to_initials, to_finals
|
||||
|
||||
from text.symbols import punctuation
|
||||
from text.tone_sandhi import ToneSandhi
|
||||
from text.zh_normalization.text_normlization import TextNormalizer
|
||||
|
||||
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
||||
|
||||
current_file_path = os.path.dirname(__file__)
|
||||
pinyin_to_symbol_map = {
|
||||
line.split("\t")[0]: line.strip().split("\t")[1]
|
||||
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
||||
}
|
||||
|
||||
import jieba_fast.posseg as psg
|
||||
|
||||
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
|
||||
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
|
||||
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
|
||||
if is_g2pw:
|
||||
print("当前使用g2pw进行拼音推理")
|
||||
from text.g2pw import G2PWPinyin, correct_pronunciation
|
||||
parent_directory = os.path.dirname(current_file_path)
|
||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
||||
|
||||
rep_map = {
|
||||
":": ",",
|
||||
";": ",",
|
||||
",": ",",
|
||||
"。": ".",
|
||||
"!": "!",
|
||||
"?": "?",
|
||||
"\n": ".",
|
||||
"·": ",",
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
"$": ".",
|
||||
"/": ",",
|
||||
"—": "-",
|
||||
"~": "…",
|
||||
"~":"…",
|
||||
}
|
||||
|
||||
tone_modifier = ToneSandhi()
|
||||
|
||||
|
||||
def replace_punctuation(text):
|
||||
text = text.replace("嗯", "恩").replace("呣", "母")
|
||||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def g2p(text):
|
||||
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
|
||||
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
|
||||
phones, word2ph = _g2p(sentences)
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
def _get_initials_finals(word):
|
||||
initials = []
|
||||
finals = []
|
||||
|
||||
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||
orig_finals = lazy_pinyin(
|
||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
||||
)
|
||||
|
||||
for c, v in zip(orig_initials, orig_finals):
|
||||
initials.append(c)
|
||||
finals.append(v)
|
||||
return initials, finals
|
||||
|
||||
|
||||
must_erhua = {
|
||||
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
|
||||
}
|
||||
not_erhua = {
|
||||
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
|
||||
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
|
||||
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
|
||||
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
|
||||
"狗儿", "少儿"
|
||||
}
|
||||
def _merge_erhua(initials: list[str],
|
||||
finals: list[str],
|
||||
word: str,
|
||||
pos: str) -> list[list[str]]:
|
||||
"""
|
||||
Do erhub.
|
||||
"""
|
||||
# fix er1
|
||||
for i, phn in enumerate(finals):
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn == 'er1':
|
||||
finals[i] = 'er2'
|
||||
|
||||
# 发音
|
||||
if word not in must_erhua and (word in not_erhua or
|
||||
pos in {"a", "j", "nr"}):
|
||||
return initials, finals
|
||||
|
||||
# "……" 等情况直接返回
|
||||
if len(finals) != len(word):
|
||||
return initials, finals
|
||||
|
||||
assert len(finals) == len(word)
|
||||
|
||||
# 与前一个字发同音
|
||||
new_initials = []
|
||||
new_finals = []
|
||||
for i, phn in enumerate(finals):
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn in {
|
||||
"er2", "er5"
|
||||
} and word[-2:] not in not_erhua and new_finals:
|
||||
phn = "er" + new_finals[-1][-1]
|
||||
|
||||
new_initials.append(initials[i])
|
||||
new_finals.append(phn)
|
||||
|
||||
return new_initials, new_finals
|
||||
|
||||
|
||||
def _g2p(segments):
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
for seg in segments:
|
||||
pinyins = []
|
||||
# Replace all English words in the sentence
|
||||
seg = re.sub("[a-zA-Z]+", "", seg)
|
||||
seg_cut = psg.lcut(seg)
|
||||
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
||||
initials = []
|
||||
finals = []
|
||||
|
||||
if not is_g2pw:
|
||||
for word, pos in seg_cut:
|
||||
if pos == "eng":
|
||||
continue
|
||||
sub_initials, sub_finals = _get_initials_finals(word)
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
# 儿化
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果",initials,finals)
|
||||
else:
|
||||
# g2pw采用整句推理
|
||||
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
|
||||
pre_word_length = 0
|
||||
for word, pos in seg_cut:
|
||||
sub_initials = []
|
||||
sub_finals = []
|
||||
now_word_length = pre_word_length + len(word)
|
||||
|
||||
if pos == 'eng':
|
||||
pre_word_length = now_word_length
|
||||
continue
|
||||
|
||||
word_pinyins = pinyins[pre_word_length:now_word_length]
|
||||
|
||||
# 多音字消歧
|
||||
word_pinyins = correct_pronunciation(word,word_pinyins)
|
||||
|
||||
for pinyin in word_pinyins:
|
||||
if pinyin[0].isalpha():
|
||||
sub_initials.append(to_initials(pinyin))
|
||||
sub_finals.append(to_finals_tone3(pinyin,neutral_tone_with_five=True))
|
||||
else:
|
||||
sub_initials.append(pinyin)
|
||||
sub_finals.append(pinyin)
|
||||
|
||||
pre_word_length = now_word_length
|
||||
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
# 儿化
|
||||
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
# print("g2pw结果",initials,finals)
|
||||
|
||||
for c, v in zip(initials, finals):
|
||||
raw_pinyin = c + v
|
||||
# NOTE: post process for pypinyin outputs
|
||||
# we discriminate i, ii and iii
|
||||
if c == v:
|
||||
assert c in punctuation
|
||||
phone = [c]
|
||||
word2ph.append(1)
|
||||
else:
|
||||
v_without_tone = v[:-1]
|
||||
tone = v[-1]
|
||||
|
||||
pinyin = c + v_without_tone
|
||||
assert tone in "12345"
|
||||
|
||||
if c:
|
||||
# 多音节
|
||||
v_rep_map = {
|
||||
"uei": "ui",
|
||||
"iou": "iu",
|
||||
"uen": "un",
|
||||
}
|
||||
if v_without_tone in v_rep_map.keys():
|
||||
pinyin = c + v_rep_map[v_without_tone]
|
||||
else:
|
||||
# 单音节
|
||||
pinyin_rep_map = {
|
||||
"ing": "ying",
|
||||
"i": "yi",
|
||||
"in": "yin",
|
||||
"u": "wu",
|
||||
}
|
||||
if pinyin in pinyin_rep_map.keys():
|
||||
pinyin = pinyin_rep_map[pinyin]
|
||||
else:
|
||||
single_rep_map = {
|
||||
"v": "yu",
|
||||
"e": "e",
|
||||
"i": "y",
|
||||
"u": "w",
|
||||
}
|
||||
if pinyin[0] in single_rep_map.keys():
|
||||
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
|
||||
|
||||
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
|
||||
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
|
||||
new_v = new_v + tone
|
||||
phone = [new_c, new_v]
|
||||
word2ph.append(len(phone))
|
||||
|
||||
phones_list += phone
|
||||
return phones_list, word2ph
|
||||
|
||||
|
||||
def replace_punctuation_with_en(text):
|
||||
text = text.replace("嗯", "恩").replace("呣", "母")
|
||||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
|
||||
return replaced_text
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
return result
|
||||
|
||||
def text_normalize(text):
|
||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||
tx = TextNormalizer()
|
||||
sentences = tx.normalize(text)
|
||||
dest_text = ""
|
||||
for sentence in sentences:
|
||||
dest_text += replace_punctuation(sentence)
|
||||
|
||||
# 避免重复标点引起的参考泄露
|
||||
dest_text = replace_consecutive_punctuation(dest_text)
|
||||
return dest_text
|
||||
|
||||
# 不排除英文的文本格式化
|
||||
def mix_text_normalize(text):
|
||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||
tx = TextNormalizer()
|
||||
sentences = tx.normalize(text)
|
||||
dest_text = ""
|
||||
for sentence in sentences:
|
||||
dest_text += replace_punctuation_with_en(sentence)
|
||||
|
||||
# 避免重复标点引起的参考泄露
|
||||
dest_text = replace_consecutive_punctuation(dest_text)
|
||||
return dest_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
|
||||
text = "呣呣呣~就是…大人的鼹鼠党吧?"
|
||||
text = "你好"
|
||||
text = text_normalize(text)
|
||||
print(g2p(text))
|
||||
|
||||
|
||||
# # 示例用法
|
||||
# text = "这是一个示例文本:,你好!这是一个测试..."
|
||||
# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
|
@ -1,6 +1,15 @@
|
||||
from text import chinese, japanese, cleaned_text_to_sequence, symbols, english
|
||||
from text import cleaned_text_to_sequence
|
||||
import os
|
||||
# if os.environ.get("version","v1")=="v1":
|
||||
# from text import chinese
|
||||
# from text.symbols import symbols
|
||||
# else:
|
||||
# from text import chinese2 as chinese
|
||||
# from text.symbols2 import symbols
|
||||
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
|
||||
language_module_map = {"zh": chinese, "ja": japanese, "en": english}
|
||||
special = [
|
||||
# ("%", "zh", "SP"),
|
||||
("¥", "zh", "SP2"),
|
||||
@ -9,34 +18,58 @@ special = [
|
||||
]
|
||||
|
||||
|
||||
def clean_text(text, language):
|
||||
def clean_text(text, language, version=None):
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
|
||||
|
||||
if(language not in language_module_map):
|
||||
language="en"
|
||||
text=" "
|
||||
for special_s, special_l, target_symbol in special:
|
||||
if special_s in text and language == special_l:
|
||||
return clean_special(text, language, special_s, target_symbol)
|
||||
language_module = language_module_map[language]
|
||||
norm_text = language_module.text_normalize(text)
|
||||
if language == "zh":
|
||||
return clean_special(text, language, special_s, target_symbol, version)
|
||||
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
|
||||
if hasattr(language_module,"text_normalize"):
|
||||
norm_text = language_module.text_normalize(text)
|
||||
else:
|
||||
norm_text=text
|
||||
if language == "zh" or language=="yue":##########
|
||||
phones, word2ph = language_module.g2p(norm_text)
|
||||
assert len(phones) == sum(word2ph)
|
||||
assert len(norm_text) == len(word2ph)
|
||||
elif language == "en":
|
||||
phones = language_module.g2p(norm_text)
|
||||
if len(phones) < 4:
|
||||
phones = [','] * (4 - len(phones)) + phones
|
||||
word2ph = None
|
||||
else:
|
||||
phones = language_module.g2p(norm_text)
|
||||
word2ph = None
|
||||
|
||||
for ph in phones:
|
||||
assert ph in symbols
|
||||
phones = ['UNK' if ph not in symbols else ph for ph in phones]
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def clean_special(text, language, special_s, target_symbol):
|
||||
def clean_special(text, language, special_s, target_symbol, version=None):
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
|
||||
|
||||
"""
|
||||
特殊静音段sp符号处理
|
||||
"""
|
||||
text = text.replace(special_s, ",")
|
||||
language_module = language_module_map[language]
|
||||
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
|
||||
norm_text = language_module.text_normalize(text)
|
||||
phones = language_module.g2p(norm_text)
|
||||
new_ph = []
|
||||
@ -49,9 +82,11 @@ def clean_special(text, language, special_s, target_symbol):
|
||||
return new_ph, phones[1], norm_text
|
||||
|
||||
|
||||
def text_to_sequence(text, language):
|
||||
def text_to_sequence(text, language, version=None):
|
||||
version = os.environ.get('version',version)
|
||||
if version is None:version='v2'
|
||||
phones = clean_text(text)
|
||||
return cleaned_text_to_sequence(phones)
|
||||
return cleaned_text_to_sequence(phones, version)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
BIN
GPT_SoVITS/text/cmudict_cache.pickle
Normal file
BIN
GPT_SoVITS/text/cmudict_cache.pickle
Normal file
Binary file not shown.
@ -1,2 +1,3 @@
|
||||
CHATGPT CH AE1 T JH IY1 P IY1 T IY1
|
||||
JSON JH EY1 S AH0 N
|
||||
CONDA K AA1 N D AH0
|
@ -4,9 +4,9 @@ import re
|
||||
import wordsegment
|
||||
from g2p_en import G2p
|
||||
|
||||
from string import punctuation
|
||||
from text.symbols import punctuation
|
||||
|
||||
from text import symbols
|
||||
from text.symbols2 import symbols
|
||||
|
||||
import unicodedata
|
||||
from builtins import str as unicode
|
||||
@ -110,6 +110,13 @@ def replace_phs(phs):
|
||||
return phs_new
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
return result
|
||||
|
||||
|
||||
def read_dict():
|
||||
g2p_dict = {}
|
||||
start_line = 49
|
||||
@ -234,6 +241,9 @@ def text_normalize(text):
|
||||
text = re.sub(r"(?i)i\.e\.", "that is", text)
|
||||
text = re.sub(r"(?i)e\.g\.", "for example", text)
|
||||
|
||||
# 避免重复标点引起的参考泄露
|
||||
text = replace_consecutive_punctuation(text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
1
GPT_SoVITS/text/g2pw/__init__.py
Normal file
1
GPT_SoVITS/text/g2pw/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from text.g2pw.g2pw import *
|
166
GPT_SoVITS/text/g2pw/dataset.py
Normal file
166
GPT_SoVITS/text/g2pw/dataset.py
Normal file
@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils import tokenize_and_map
|
||||
|
||||
ANCHOR_CHAR = '▁'
|
||||
|
||||
|
||||
def prepare_onnx_input(tokenizer,
|
||||
labels: List[str],
|
||||
char2phonemes: Dict[str, List[int]],
|
||||
chars: List[str],
|
||||
texts: List[str],
|
||||
query_ids: List[int],
|
||||
use_mask: bool=False,
|
||||
window_size: int=None,
|
||||
max_len: int=512) -> Dict[str, np.array]:
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||
window_size=window_size, texts=texts, query_ids=query_ids)
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
attention_masks = []
|
||||
phoneme_masks = []
|
||||
char_ids = []
|
||||
position_ids = []
|
||||
|
||||
for idx in range(len(texts)):
|
||||
text = (truncated_texts if window_size else texts)[idx].lower()
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(
|
||||
tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(
|
||||
max_len=max_len,
|
||||
text=text,
|
||||
query_id=query_id,
|
||||
tokens=tokens,
|
||||
text2token=text2token,
|
||||
token2text=token2text)
|
||||
|
||||
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
||||
|
||||
input_id = list(
|
||||
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
|
||||
if use_mask else [1] * len(labels)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[
|
||||
query_id] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
attention_masks.append(attention_mask)
|
||||
phoneme_masks.append(phoneme_mask)
|
||||
char_ids.append(char_id)
|
||||
position_ids.append(position_id)
|
||||
|
||||
outputs = {
|
||||
'input_ids': np.array(input_ids).astype(np.int64),
|
||||
'token_type_ids': np.array(token_type_ids).astype(np.int64),
|
||||
'attention_masks': np.array(attention_masks).astype(np.int64),
|
||||
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
|
||||
'char_ids': np.array(char_ids).astype(np.int64),
|
||||
'position_ids': np.array(position_ids).astype(np.int64),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
def _truncate_texts(window_size: int, texts: List[str],
|
||||
query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||
truncated_texts = []
|
||||
truncated_query_ids = []
|
||||
for text, query_id in zip(texts, query_ids):
|
||||
start = max(0, query_id - window_size // 2)
|
||||
end = min(len(text), query_id + window_size // 2)
|
||||
truncated_text = text[start:end]
|
||||
truncated_texts.append(truncated_text)
|
||||
|
||||
truncated_query_id = query_id - start
|
||||
truncated_query_ids.append(truncated_query_id)
|
||||
return truncated_texts, truncated_query_ids
|
||||
|
||||
|
||||
def _truncate(max_len: int,
|
||||
text: str,
|
||||
query_id: int,
|
||||
tokens: List[str],
|
||||
text2token: List[int],
|
||||
token2text: List[Tuple[int]]):
|
||||
truncate_len = max_len - 2
|
||||
if len(tokens) <= truncate_len:
|
||||
return (text, query_id, tokens, text2token, token2text)
|
||||
|
||||
token_position = text2token[query_id]
|
||||
|
||||
token_start = token_position - truncate_len // 2
|
||||
token_end = token_start + truncate_len
|
||||
font_exceed_dist = -token_start
|
||||
back_exceed_dist = token_end - len(tokens)
|
||||
if font_exceed_dist > 0:
|
||||
token_start += font_exceed_dist
|
||||
token_end += font_exceed_dist
|
||||
elif back_exceed_dist > 0:
|
||||
token_start -= back_exceed_dist
|
||||
token_end -= back_exceed_dist
|
||||
|
||||
start = token2text[token_start][0]
|
||||
end = token2text[token_end - 1][1]
|
||||
|
||||
return (text[start:end], query_id - start, tokens[token_start:token_end], [
|
||||
i - token_start if i is not None else None
|
||||
for i in text2token[start:end]
|
||||
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
|
||||
|
||||
|
||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
if char not in char2phonemes:
|
||||
char2phonemes[char] = []
|
||||
char2phonemes[char].append(labels.index(phoneme))
|
||||
return labels, char2phonemes
|
||||
|
||||
|
||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(
|
||||
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
if char not in char2phonemes:
|
||||
char2phonemes[char] = []
|
||||
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
|
||||
return labels, char2phonemes
|
154
GPT_SoVITS/text/g2pw/g2pw.py
Normal file
154
GPT_SoVITS/text/g2pw/g2pw.py
Normal file
@ -0,0 +1,154 @@
|
||||
# This code is modified from https://github.com/mozillazg/pypinyin-g2pW
|
||||
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from pypinyin.constants import RE_HANS
|
||||
from pypinyin.core import Pinyin, Style
|
||||
from pypinyin.seg.simpleseg import simple_seg
|
||||
from pypinyin.converter import UltimateConverter
|
||||
from pypinyin.contrib.tone_convert import to_tone
|
||||
from .onnx_api import G2PWOnnxConverter
|
||||
|
||||
current_file_path = os.path.dirname(__file__)
|
||||
CACHE_PATH = os.path.join(current_file_path, "polyphonic.pickle")
|
||||
PP_DICT_PATH = os.path.join(current_file_path, "polyphonic.rep")
|
||||
PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
|
||||
|
||||
|
||||
class G2PWPinyin(Pinyin):
|
||||
def __init__(self, model_dir='G2PWModel/', model_source=None,
|
||||
enable_non_tradional_chinese=True,
|
||||
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style='pinyin',
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
self._converter = Converter(
|
||||
self._g2pw, v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi,
|
||||
)
|
||||
|
||||
def get_seg(self, **kwargs):
|
||||
return simple_seg
|
||||
|
||||
|
||||
class Converter(UltimateConverter):
|
||||
def __init__(self, g2pw_instance, v_to_u=False,
|
||||
neutral_tone_with_five=False,
|
||||
tone_sandhi=False, **kwargs):
|
||||
super(Converter, self).__init__(
|
||||
v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi, **kwargs)
|
||||
|
||||
self._g2pw = g2pw_instance
|
||||
|
||||
def convert(self, words, style, heteronym, errors, strict, **kwargs):
|
||||
pys = []
|
||||
if RE_HANS.match(words):
|
||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
|
||||
errors=errors, strict=strict)
|
||||
post_data = self.post_pinyin(words, heteronym, pys)
|
||||
if post_data is not None:
|
||||
pys = post_data
|
||||
|
||||
pys = self.convert_styles(
|
||||
pys, words, style, heteronym, errors, strict)
|
||||
|
||||
else:
|
||||
py = self.handle_nopinyin(words, style=style, errors=errors,
|
||||
heteronym=heteronym, strict=strict)
|
||||
if py:
|
||||
pys.extend(py)
|
||||
|
||||
return _remove_dup_and_empty(pys)
|
||||
|
||||
def _to_pinyin(self, han, style, heteronym, errors, strict, **kwargs):
|
||||
pinyins = []
|
||||
|
||||
g2pw_pinyin = self._g2pw(han)
|
||||
|
||||
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
return super(Converter, self).convert(
|
||||
han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
|
||||
for i, item in enumerate(g2pw_pinyin[0]):
|
||||
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
py = super(Converter, self).convert(
|
||||
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
pinyins.extend(py)
|
||||
else:
|
||||
pinyins.append([to_tone(item)])
|
||||
|
||||
return pinyins
|
||||
|
||||
|
||||
def _remove_dup_items(lst, remove_empty=False):
|
||||
new_lst = []
|
||||
for item in lst:
|
||||
if remove_empty and not item:
|
||||
continue
|
||||
if item not in new_lst:
|
||||
new_lst.append(item)
|
||||
return new_lst
|
||||
|
||||
|
||||
def _remove_dup_and_empty(lst_list):
|
||||
new_lst_list = []
|
||||
for lst in lst_list:
|
||||
lst = _remove_dup_items(lst, remove_empty=True)
|
||||
if lst:
|
||||
new_lst_list.append(lst)
|
||||
else:
|
||||
new_lst_list.append([''])
|
||||
|
||||
return new_lst_list
|
||||
|
||||
|
||||
def cache_dict(polyphonic_dict, file_path):
|
||||
with open(file_path, "wb") as pickle_file:
|
||||
pickle.dump(polyphonic_dict, pickle_file)
|
||||
|
||||
|
||||
def get_dict():
|
||||
if os.path.exists(CACHE_PATH):
|
||||
with open(CACHE_PATH, "rb") as pickle_file:
|
||||
polyphonic_dict = pickle.load(pickle_file)
|
||||
else:
|
||||
polyphonic_dict = read_dict()
|
||||
cache_dict(polyphonic_dict, CACHE_PATH)
|
||||
|
||||
return polyphonic_dict
|
||||
|
||||
|
||||
def read_dict():
|
||||
polyphonic_dict = {}
|
||||
with open(PP_DICT_PATH) as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
with open(PP_FIX_DICT_PATH) as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
return polyphonic_dict
|
||||
|
||||
|
||||
def correct_pronunciation(word,word_pinyins):
|
||||
if word in pp_dict:
|
||||
word_pinyins = pp_dict[word]
|
||||
|
||||
return word_pinyins
|
||||
|
||||
|
||||
pp_dict = get_dict()
|
241
GPT_SoVITS/text/g2pw/onnx_api.py
Normal file
241
GPT_SoVITS/text/g2pw/onnx_api.py
Normal file
@ -0,0 +1,241 @@
|
||||
# This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw
|
||||
# This code is modified from https://github.com/GitYCC/g2pW
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
import json
|
||||
import os
|
||||
import zipfile,requests
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
from opencc import OpenCC
|
||||
from transformers import AutoTokenizer
|
||||
from pypinyin import pinyin
|
||||
from pypinyin import Style
|
||||
|
||||
from .dataset import get_char_phoneme_labels
|
||||
from .dataset import get_phoneme_labels
|
||||
from .dataset import prepare_onnx_input
|
||||
from .utils import load_config
|
||||
from ..zh_normalization.char_convert import tranditional_to_simplified
|
||||
|
||||
model_version = '1.1'
|
||||
|
||||
|
||||
def predict(session, onnx_input: Dict[str, Any],
|
||||
labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||
all_preds = []
|
||||
all_confidences = []
|
||||
probs = session.run([], {
|
||||
"input_ids": onnx_input['input_ids'],
|
||||
"token_type_ids": onnx_input['token_type_ids'],
|
||||
"attention_mask": onnx_input['attention_masks'],
|
||||
"phoneme_mask": onnx_input['phoneme_masks'],
|
||||
"char_ids": onnx_input['char_ids'],
|
||||
"position_ids": onnx_input['position_ids']
|
||||
})[0]
|
||||
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
max_probs = []
|
||||
for index, arr in zip(preds, probs.tolist()):
|
||||
max_probs.append(arr[index])
|
||||
all_preds += [labels[pred] for pred in preds]
|
||||
all_confidences += max_probs
|
||||
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str='G2PWModel/'):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip")
|
||||
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory,"G2PWModel")
|
||||
print("Downloading g2pw model...")
|
||||
modelscope_url = "https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
with requests.get(modelscope_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(zip_dir, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
print("Extracting g2pw model...")
|
||||
with zipfile.ZipFile(zip_dir, "r") as zip_ref:
|
||||
zip_ref.extractall(parent_directory)
|
||||
|
||||
os.rename(extract_dir, extract_dir_new)
|
||||
|
||||
return model_dir
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
def __init__(self,
|
||||
model_dir: str='G2PWModel/',
|
||||
style: str='bopomofo',
|
||||
model_source: str=None,
|
||||
enable_non_tradional_chinese: bool=False):
|
||||
uncompress_path = download_and_decompress(model_dir)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, 'g2pW.onnx'),
|
||||
sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
# sess_options=sess_options)
|
||||
self.config = load_config(
|
||||
config_path=os.path.join(uncompress_path, 'config.py'),
|
||||
use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
|
||||
polyphonic_chars_path = os.path.join(uncompress_path,
|
||||
'POLYPHONIC_CHARS.txt')
|
||||
monophonic_chars_path = os.path.join(uncompress_path,
|
||||
'MONOPHONIC_CHARS.txt')
|
||||
self.polyphonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(polyphonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
]
|
||||
self.non_polyphonic = {
|
||||
'一', '不', '和', '咋', '嗲', '剖', '差', '攢', '倒', '難', '奔', '勁', '拗',
|
||||
'肖', '瘙', '誒', '泊', '听', '噢'
|
||||
}
|
||||
self.non_monophonic = {'似', '攢'}
|
||||
self.monophonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(monophonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
]
|
||||
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars
|
||||
) if self.config.use_char_phoneme else get_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
|
||||
self.polyphonic_chars_new = set(self.chars)
|
||||
for char in self.non_polyphonic:
|
||||
if char in self.polyphonic_chars_new:
|
||||
self.polyphonic_chars_new.remove(char)
|
||||
|
||||
self.monophonic_chars_dict = {
|
||||
char: phoneme
|
||||
for char, phoneme in self.monophonic_chars
|
||||
}
|
||||
for char in self.non_monophonic:
|
||||
if char in self.monophonic_chars_dict:
|
||||
self.monophonic_chars_dict.pop(char)
|
||||
|
||||
self.pos_tags = [
|
||||
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
||||
]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path,
|
||||
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
'bopomofo': lambda x: x,
|
||||
'pinyin': self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC('s2tw')
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||
tone = bopomofo[-1]
|
||||
assert tone in '12345'
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
else:
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
|
||||
def __call__(self, sentences: List[str]) -> List[List[str]]:
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
if self.enable_opencc:
|
||||
translated_sentences = []
|
||||
for sent in sentences:
|
||||
translated_sent = self.cc.convert(sent)
|
||||
assert len(translated_sent) == len(sent)
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
||||
sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
|
||||
onnx_input = prepare_onnx_input(
|
||||
tokenizer=self.tokenizer,
|
||||
labels=self.labels,
|
||||
char2phonemes=self.char2phonemes,
|
||||
chars=self.chars,
|
||||
texts=texts,
|
||||
query_ids=query_ids,
|
||||
use_mask=self.config.use_mask,
|
||||
window_size=None)
|
||||
|
||||
preds, confidences = predict(
|
||||
session=self.session_g2pW,
|
||||
onnx_input=onnx_input,
|
||||
labels=self.labels)
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(' ')[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
results[sent_id][query_id] = self.style_convert_func(pred)
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||
sent_s = tranditional_to_simplified(sent)
|
||||
pypinyin_result = pinyin(
|
||||
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
for i, char in enumerate(sent):
|
||||
if char in self.polyphonic_chars_new:
|
||||
texts.append(sent)
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
elif char in self.monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(
|
||||
self.monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
else:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
|
||||
partial_results.append(partial_result)
|
||||
return texts, query_ids, sent_ids, partial_results
|
45024
GPT_SoVITS/text/g2pw/polyphonic-fix.rep
Normal file
45024
GPT_SoVITS/text/g2pw/polyphonic-fix.rep
Normal file
File diff suppressed because it is too large
Load Diff
BIN
GPT_SoVITS/text/g2pw/polyphonic.pickle
Normal file
BIN
GPT_SoVITS/text/g2pw/polyphonic.pickle
Normal file
Binary file not shown.
53
GPT_SoVITS/text/g2pw/polyphonic.rep
Normal file
53
GPT_SoVITS/text/g2pw/polyphonic.rep
Normal file
@ -0,0 +1,53 @@
|
||||
湖泊: ['hu2','po1']
|
||||
地壳: ['di4','qiao4']
|
||||
柏树: ['bai3','shu4']
|
||||
曝光: ['bao4','guang1']
|
||||
弹力: ['tan2','li4']
|
||||
字帖: ['zi4','tie4']
|
||||
口吃: ['kou3','chi1']
|
||||
包扎: ['bao1','za1']
|
||||
哪吒: ['ne2','zha1']
|
||||
说服: ['shuo1','fu2']
|
||||
识字: ['shi2','zi4']
|
||||
骨头: ['gu3','tou5']
|
||||
对称: ['dui4','chen4']
|
||||
口供: ['kou3','gong4']
|
||||
抹布: ['ma1','bu4']
|
||||
露背: ['lu4','bei4']
|
||||
圈养: ['juan4', 'yang3']
|
||||
眼眶: ['yan3', 'kuang4']
|
||||
品行: ['pin3','xing2']
|
||||
颤抖: ['chan4','dou3']
|
||||
差不多: ['cha4','bu5','duo1']
|
||||
鸭绿江: ['ya1','lu4','jiang1']
|
||||
撒切尔: ['sa4','qie4','er3']
|
||||
比比皆是: ['bi3','bi3','jie1','shi4']
|
||||
身无长物: ['shen1','wu2','chang2','wu4']
|
||||
手里: ['shou2','li3']
|
||||
关卡: ['guan1','qia3']
|
||||
怀揣: ['huai2','chuai1']
|
||||
挑剔: ['tiao1','ti4']
|
||||
供称: ['gong4','cheng1']
|
||||
作坊: ['zuo1', 'fang5']
|
||||
中医: ['zhong1','yi1']
|
||||
嚷嚷: ['rang1','rang5']
|
||||
商厦: ['shang1','sha4']
|
||||
大厦: ['da4','sha4']
|
||||
刹车: ['sha1','che1']
|
||||
嘚瑟: ['de4','se5']
|
||||
朝鲜: ['chao2','xian3']
|
||||
阿房宫: ['e1','pang2','gong1']
|
||||
阿胶: ['e1','jiao1']
|
||||
咖喱: ['ga1','li5']
|
||||
时分: ['shi2','fen1']
|
||||
蚌埠: ['beng4','bu4']
|
||||
驯服: ['xun4','fu2']
|
||||
幸免于难: ['xing4','mian3','yu2','nan4']
|
||||
恶行: ['e4','xing2']
|
||||
唉: ['ai4']
|
||||
扎实: ['zha1','shi2']
|
||||
干将: ['gan4','jiang4']
|
||||
陈威行: ['chen2', 'wei1', 'hang2']
|
||||
郭晟: ['guo1', 'sheng4']
|
||||
中标: ['zhong4', 'biao1']
|
||||
抗住: ['kang2', 'zhu4']
|
145
GPT_SoVITS/text/g2pw/utils.py
Normal file
145
GPT_SoVITS/text/g2pw/utils.py
Normal file
@ -0,0 +1,145 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def wordize_and_map(text: str):
|
||||
words = []
|
||||
index_map_from_text_to_word = []
|
||||
index_map_from_word_to_text = []
|
||||
while len(text) > 0:
|
||||
match_space = re.match(r'^ +', text)
|
||||
if match_space:
|
||||
space_str = match_space.group(0)
|
||||
index_map_from_text_to_word += [None] * len(space_str)
|
||||
text = text[len(space_str):]
|
||||
continue
|
||||
|
||||
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
||||
if match_en:
|
||||
en_word = match_en.group(0)
|
||||
|
||||
word_start_pos = len(index_map_from_text_to_word)
|
||||
word_end_pos = word_start_pos + len(en_word)
|
||||
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||||
|
||||
index_map_from_text_to_word += [len(words)] * len(en_word)
|
||||
|
||||
words.append(en_word)
|
||||
text = text[len(en_word):]
|
||||
else:
|
||||
word_start_pos = len(index_map_from_text_to_word)
|
||||
word_end_pos = word_start_pos + 1
|
||||
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||||
|
||||
index_map_from_text_to_word += [len(words)]
|
||||
|
||||
words.append(text[0])
|
||||
text = text[1:]
|
||||
return words, index_map_from_text_to_word, index_map_from_word_to_text
|
||||
|
||||
|
||||
def tokenize_and_map(tokenizer, text: str):
|
||||
words, text2word, word2text = wordize_and_map(text=text)
|
||||
|
||||
tokens = []
|
||||
index_map_from_token_to_text = []
|
||||
for word, (word_start, word_end) in zip(words, word2text):
|
||||
word_tokens = tokenizer.tokenize(word)
|
||||
|
||||
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
||||
index_map_from_token_to_text.append((word_start, word_end))
|
||||
tokens.append('[UNK]')
|
||||
else:
|
||||
current_word_start = word_start
|
||||
for word_token in word_tokens:
|
||||
word_token_len = len(re.sub(r'^##', '', word_token))
|
||||
index_map_from_token_to_text.append(
|
||||
(current_word_start, current_word_start + word_token_len))
|
||||
current_word_start = current_word_start + word_token_len
|
||||
tokens.append(word_token)
|
||||
|
||||
index_map_from_text_to_token = text2word
|
||||
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
|
||||
for token_pos in range(token_start, token_end):
|
||||
index_map_from_text_to_token[token_pos] = i
|
||||
|
||||
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
|
||||
|
||||
|
||||
def _load_config(config_path: os.PathLike):
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
||||
config = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config)
|
||||
return config
|
||||
|
||||
|
||||
default_config_dict = {
|
||||
'manual_seed': 1313,
|
||||
'model_source': 'bert-base-chinese',
|
||||
'window_size': 32,
|
||||
'num_workers': 2,
|
||||
'use_mask': True,
|
||||
'use_char_phoneme': False,
|
||||
'use_conditional': True,
|
||||
'param_conditional': {
|
||||
'affect_location': 'softmax',
|
||||
'bias': True,
|
||||
'char-linear': True,
|
||||
'pos-linear': False,
|
||||
'char+pos-second': True,
|
||||
'char+pos-second_lowrank': False,
|
||||
'lowrank_size': 0,
|
||||
'char+pos-second_fm': False,
|
||||
'fm_size': 0,
|
||||
'fix_mode': None,
|
||||
'count_json': 'train.count.json'
|
||||
},
|
||||
'lr': 5e-5,
|
||||
'val_interval': 200,
|
||||
'num_iter': 10000,
|
||||
'use_focal': False,
|
||||
'param_focal': {
|
||||
'alpha': 0.0,
|
||||
'gamma': 0.7
|
||||
},
|
||||
'use_pos': True,
|
||||
'param_pos ': {
|
||||
'weight': 0.1,
|
||||
'pos_joint_training': True,
|
||||
'train_pos_path': 'train.pos',
|
||||
'valid_pos_path': 'dev.pos',
|
||||
'test_pos_path': 'test.pos'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_config(config_path: os.PathLike, use_default: bool=False):
|
||||
config = _load_config(config_path)
|
||||
if use_default:
|
||||
for attr, val in default_config_dict.items():
|
||||
if not hasattr(config, attr):
|
||||
setattr(config, attr, val)
|
||||
elif isinstance(val, dict):
|
||||
d = getattr(config, attr)
|
||||
for dict_k, dict_v in val.items():
|
||||
if dict_k not in d:
|
||||
d[dict_k] = dict_v
|
||||
return config
|
@ -4,8 +4,7 @@ import sys
|
||||
|
||||
import pyopenjtalk
|
||||
|
||||
|
||||
from text import symbols
|
||||
from text.symbols import punctuation
|
||||
# Regular expression matching Japanese without punctuation marks:
|
||||
_japanese_characters = re.compile(
|
||||
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
||||
@ -56,15 +55,23 @@ def post_replace_ph(ph):
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
}
|
||||
|
||||
if ph in rep_map.keys():
|
||||
ph = rep_map[ph]
|
||||
if ph in symbols:
|
||||
return ph
|
||||
if ph not in symbols:
|
||||
ph = "UNK"
|
||||
# if ph in symbols:
|
||||
# return ph
|
||||
# if ph not in symbols:
|
||||
# ph = "UNK"
|
||||
return ph
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
return result
|
||||
|
||||
|
||||
def symbols_to_japanese(text):
|
||||
for regex, replacement in _symbols_to_japanese:
|
||||
text = re.sub(regex, replacement, text)
|
||||
@ -94,6 +101,9 @@ def preprocess_jap(text, with_prosody=False):
|
||||
|
||||
def text_normalize(text):
|
||||
# todo: jap text normalize
|
||||
|
||||
# 避免重复标点引起的参考泄露
|
||||
text = replace_consecutive_punctuation(text)
|
||||
return text
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
@ -179,7 +189,7 @@ def _numeric_feature_by_regex(regex, s):
|
||||
return -50
|
||||
return int(match.group(1))
|
||||
|
||||
def g2p(norm_text, with_prosody=False):
|
||||
def g2p(norm_text, with_prosody=True):
|
||||
phones = preprocess_jap(norm_text, with_prosody)
|
||||
phones = [post_replace_ph(i) for i in phones]
|
||||
# todo: implement tones and word2ph
|
||||
|
265
GPT_SoVITS/text/korean.py
Normal file
265
GPT_SoVITS/text/korean.py
Normal file
@ -0,0 +1,265 @@
|
||||
# reference: https://github.com/ORI-Muchim/MB-iSTFT-VITS-Korean/blob/main/text/korean.py
|
||||
|
||||
import re
|
||||
from jamo import h2j, j2hcj
|
||||
import ko_pron
|
||||
from g2pk2 import G2p
|
||||
|
||||
from text.symbols2 import symbols
|
||||
|
||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
||||
|
||||
# List of (hangul, hangul divided) pairs:
|
||||
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
||||
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
|
||||
# ('ㄵ', 'ㄴㅈ'),
|
||||
# ('ㄶ', 'ㄴㅎ'),
|
||||
# ('ㄺ', 'ㄹㄱ'),
|
||||
# ('ㄻ', 'ㄹㅁ'),
|
||||
# ('ㄼ', 'ㄹㅂ'),
|
||||
# ('ㄽ', 'ㄹㅅ'),
|
||||
# ('ㄾ', 'ㄹㅌ'),
|
||||
# ('ㄿ', 'ㄹㅍ'),
|
||||
# ('ㅀ', 'ㄹㅎ'),
|
||||
# ('ㅄ', 'ㅂㅅ'),
|
||||
('ㅘ', 'ㅗㅏ'),
|
||||
('ㅙ', 'ㅗㅐ'),
|
||||
('ㅚ', 'ㅗㅣ'),
|
||||
('ㅝ', 'ㅜㅓ'),
|
||||
('ㅞ', 'ㅜㅔ'),
|
||||
('ㅟ', 'ㅜㅣ'),
|
||||
('ㅢ', 'ㅡㅣ'),
|
||||
('ㅑ', 'ㅣㅏ'),
|
||||
('ㅒ', 'ㅣㅐ'),
|
||||
('ㅕ', 'ㅣㅓ'),
|
||||
('ㅖ', 'ㅣㅔ'),
|
||||
('ㅛ', 'ㅣㅗ'),
|
||||
('ㅠ', 'ㅣㅜ')
|
||||
]]
|
||||
|
||||
# List of (Latin alphabet, hangul) pairs:
|
||||
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('a', '에이'),
|
||||
('b', '비'),
|
||||
('c', '시'),
|
||||
('d', '디'),
|
||||
('e', '이'),
|
||||
('f', '에프'),
|
||||
('g', '지'),
|
||||
('h', '에이치'),
|
||||
('i', '아이'),
|
||||
('j', '제이'),
|
||||
('k', '케이'),
|
||||
('l', '엘'),
|
||||
('m', '엠'),
|
||||
('n', '엔'),
|
||||
('o', '오'),
|
||||
('p', '피'),
|
||||
('q', '큐'),
|
||||
('r', '아르'),
|
||||
('s', '에스'),
|
||||
('t', '티'),
|
||||
('u', '유'),
|
||||
('v', '브이'),
|
||||
('w', '더블유'),
|
||||
('x', '엑스'),
|
||||
('y', '와이'),
|
||||
('z', '제트')
|
||||
]]
|
||||
|
||||
# List of (ipa, lazy ipa) pairs:
|
||||
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('t͡ɕ','ʧ'),
|
||||
('d͡ʑ','ʥ'),
|
||||
('ɲ','n^'),
|
||||
('ɕ','ʃ'),
|
||||
('ʷ','w'),
|
||||
('ɭ','l`'),
|
||||
('ʎ','ɾ'),
|
||||
('ɣ','ŋ'),
|
||||
('ɰ','ɯ'),
|
||||
('ʝ','j'),
|
||||
('ʌ','ə'),
|
||||
('ɡ','g'),
|
||||
('\u031a','#'),
|
||||
('\u0348','='),
|
||||
('\u031e',''),
|
||||
('\u0320',''),
|
||||
('\u0339','')
|
||||
]]
|
||||
|
||||
|
||||
def fix_g2pk2_error(text):
|
||||
new_text = ""
|
||||
i = 0
|
||||
while i < len(text) - 4:
|
||||
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ':
|
||||
new_text += text[i:i+3] + ' ' + 'ㄴ'
|
||||
i += 5
|
||||
else:
|
||||
new_text += text[i]
|
||||
i += 1
|
||||
|
||||
new_text += text[i:]
|
||||
return new_text
|
||||
|
||||
|
||||
def latin_to_hangul(text):
|
||||
for regex, replacement in _latin_to_hangul:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def divide_hangul(text):
|
||||
text = j2hcj(h2j(text))
|
||||
for regex, replacement in _hangul_divided:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def hangul_number(num, sino=True):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
num = re.sub(',', '', num)
|
||||
|
||||
if num == '0':
|
||||
return '영'
|
||||
if not sino and num == '20':
|
||||
return '스무'
|
||||
|
||||
digits = '123456789'
|
||||
names = '일이삼사오육칠팔구'
|
||||
digit2name = {d: n for d, n in zip(digits, names)}
|
||||
|
||||
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
|
||||
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
|
||||
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
||||
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
||||
|
||||
spelledout = []
|
||||
for i, digit in enumerate(num):
|
||||
i = len(num) - i - 1
|
||||
if sino:
|
||||
if i == 0:
|
||||
name = digit2name.get(digit, '')
|
||||
elif i == 1:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
else:
|
||||
if i == 0:
|
||||
name = digit2mod.get(digit, '')
|
||||
elif i == 1:
|
||||
name = digit2dec.get(digit, '')
|
||||
if digit == '0':
|
||||
if i % 4 == 0:
|
||||
last_three = spelledout[-min(3, len(spelledout)):]
|
||||
if ''.join(last_three) == '':
|
||||
spelledout.append('')
|
||||
continue
|
||||
else:
|
||||
spelledout.append('')
|
||||
continue
|
||||
if i == 2:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
elif i == 3:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
elif i == 4:
|
||||
name = digit2name.get(digit, '') + '만'
|
||||
name = name.replace('일만', '만')
|
||||
elif i == 5:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
elif i == 6:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
elif i == 7:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
elif i == 8:
|
||||
name = digit2name.get(digit, '') + '억'
|
||||
elif i == 9:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
elif i == 10:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
elif i == 11:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
elif i == 12:
|
||||
name = digit2name.get(digit, '') + '조'
|
||||
elif i == 13:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
elif i == 14:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
elif i == 15:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
spelledout.append(name)
|
||||
return ''.join(elem for elem in spelledout)
|
||||
|
||||
|
||||
def number_to_hangul(text):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
|
||||
for token in tokens:
|
||||
num, classifier = token
|
||||
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
||||
spelledout = hangul_number(num, sino=False)
|
||||
else:
|
||||
spelledout = hangul_number(num, sino=True)
|
||||
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
|
||||
# digit by digit for remaining digits
|
||||
digits = '0123456789'
|
||||
names = '영일이삼사오육칠팔구'
|
||||
for d, n in zip(digits, names):
|
||||
text = text.replace(d, n)
|
||||
return text
|
||||
|
||||
|
||||
def korean_to_lazy_ipa(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
|
||||
for regex, replacement in _ipa_to_lazy_ipa:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
_g2p=G2p()
|
||||
def korean_to_ipa(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text = _g2p(text)
|
||||
text = fix_g2pk2_error(text)
|
||||
text = korean_to_lazy_ipa(text)
|
||||
return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
|
||||
|
||||
def post_replace_ph(ph):
|
||||
rep_map = {
|
||||
":": ",",
|
||||
";": ",",
|
||||
",": ",",
|
||||
"。": ".",
|
||||
"!": "!",
|
||||
"?": "?",
|
||||
"\n": ".",
|
||||
"·": ",",
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
" ": "空",
|
||||
}
|
||||
if ph in rep_map.keys():
|
||||
ph = rep_map[ph]
|
||||
if ph in symbols:
|
||||
return ph
|
||||
if ph not in symbols:
|
||||
ph = "停"
|
||||
return ph
|
||||
|
||||
def g2p(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = _g2p(text)
|
||||
text = divide_hangul(text)
|
||||
text = fix_g2pk2_error(text)
|
||||
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
|
||||
# text = "".join([post_replace_ph(i) for i in text])
|
||||
text = [post_replace_ph(i) for i in text]
|
||||
return text
|
419
GPT_SoVITS/text/symbols2.py
Normal file
419
GPT_SoVITS/text/symbols2.py
Normal file
@ -0,0 +1,419 @@
|
||||
import os
|
||||
|
||||
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
||||
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
||||
punctuation.append("-")
|
||||
pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
|
||||
# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
|
||||
pad = "_"
|
||||
|
||||
c = [
|
||||
"AA",
|
||||
"EE",
|
||||
"OO",
|
||||
"b",
|
||||
"c",
|
||||
"ch",
|
||||
"d",
|
||||
"f",
|
||||
"g",
|
||||
"h",
|
||||
"j",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"p",
|
||||
"q",
|
||||
"r",
|
||||
"s",
|
||||
"sh",
|
||||
"t",
|
||||
"w",
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"zh",
|
||||
]
|
||||
v = [
|
||||
"E1",
|
||||
"En1",
|
||||
"a1",
|
||||
"ai1",
|
||||
"an1",
|
||||
"ang1",
|
||||
"ao1",
|
||||
"e1",
|
||||
"ei1",
|
||||
"en1",
|
||||
"eng1",
|
||||
"er1",
|
||||
"i1",
|
||||
"i01",
|
||||
"ia1",
|
||||
"ian1",
|
||||
"iang1",
|
||||
"iao1",
|
||||
"ie1",
|
||||
"in1",
|
||||
"ing1",
|
||||
"iong1",
|
||||
"ir1",
|
||||
"iu1",
|
||||
"o1",
|
||||
"ong1",
|
||||
"ou1",
|
||||
"u1",
|
||||
"ua1",
|
||||
"uai1",
|
||||
"uan1",
|
||||
"uang1",
|
||||
"ui1",
|
||||
"un1",
|
||||
"uo1",
|
||||
"v1",
|
||||
"van1",
|
||||
"ve1",
|
||||
"vn1",
|
||||
"E2",
|
||||
"En2",
|
||||
"a2",
|
||||
"ai2",
|
||||
"an2",
|
||||
"ang2",
|
||||
"ao2",
|
||||
"e2",
|
||||
"ei2",
|
||||
"en2",
|
||||
"eng2",
|
||||
"er2",
|
||||
"i2",
|
||||
"i02",
|
||||
"ia2",
|
||||
"ian2",
|
||||
"iang2",
|
||||
"iao2",
|
||||
"ie2",
|
||||
"in2",
|
||||
"ing2",
|
||||
"iong2",
|
||||
"ir2",
|
||||
"iu2",
|
||||
"o2",
|
||||
"ong2",
|
||||
"ou2",
|
||||
"u2",
|
||||
"ua2",
|
||||
"uai2",
|
||||
"uan2",
|
||||
"uang2",
|
||||
"ui2",
|
||||
"un2",
|
||||
"uo2",
|
||||
"v2",
|
||||
"van2",
|
||||
"ve2",
|
||||
"vn2",
|
||||
"E3",
|
||||
"En3",
|
||||
"a3",
|
||||
"ai3",
|
||||
"an3",
|
||||
"ang3",
|
||||
"ao3",
|
||||
"e3",
|
||||
"ei3",
|
||||
"en3",
|
||||
"eng3",
|
||||
"er3",
|
||||
"i3",
|
||||
"i03",
|
||||
"ia3",
|
||||
"ian3",
|
||||
"iang3",
|
||||
"iao3",
|
||||
"ie3",
|
||||
"in3",
|
||||
"ing3",
|
||||
"iong3",
|
||||
"ir3",
|
||||
"iu3",
|
||||
"o3",
|
||||
"ong3",
|
||||
"ou3",
|
||||
"u3",
|
||||
"ua3",
|
||||
"uai3",
|
||||
"uan3",
|
||||
"uang3",
|
||||
"ui3",
|
||||
"un3",
|
||||
"uo3",
|
||||
"v3",
|
||||
"van3",
|
||||
"ve3",
|
||||
"vn3",
|
||||
"E4",
|
||||
"En4",
|
||||
"a4",
|
||||
"ai4",
|
||||
"an4",
|
||||
"ang4",
|
||||
"ao4",
|
||||
"e4",
|
||||
"ei4",
|
||||
"en4",
|
||||
"eng4",
|
||||
"er4",
|
||||
"i4",
|
||||
"i04",
|
||||
"ia4",
|
||||
"ian4",
|
||||
"iang4",
|
||||
"iao4",
|
||||
"ie4",
|
||||
"in4",
|
||||
"ing4",
|
||||
"iong4",
|
||||
"ir4",
|
||||
"iu4",
|
||||
"o4",
|
||||
"ong4",
|
||||
"ou4",
|
||||
"u4",
|
||||
"ua4",
|
||||
"uai4",
|
||||
"uan4",
|
||||
"uang4",
|
||||
"ui4",
|
||||
"un4",
|
||||
"uo4",
|
||||
"v4",
|
||||
"van4",
|
||||
"ve4",
|
||||
"vn4",
|
||||
"E5",
|
||||
"En5",
|
||||
"a5",
|
||||
"ai5",
|
||||
"an5",
|
||||
"ang5",
|
||||
"ao5",
|
||||
"e5",
|
||||
"ei5",
|
||||
"en5",
|
||||
"eng5",
|
||||
"er5",
|
||||
"i5",
|
||||
"i05",
|
||||
"ia5",
|
||||
"ian5",
|
||||
"iang5",
|
||||
"iao5",
|
||||
"ie5",
|
||||
"in5",
|
||||
"ing5",
|
||||
"iong5",
|
||||
"ir5",
|
||||
"iu5",
|
||||
"o5",
|
||||
"ong5",
|
||||
"ou5",
|
||||
"u5",
|
||||
"ua5",
|
||||
"uai5",
|
||||
"uan5",
|
||||
"uang5",
|
||||
"ui5",
|
||||
"un5",
|
||||
"uo5",
|
||||
"v5",
|
||||
"van5",
|
||||
"ve5",
|
||||
"vn5",
|
||||
]
|
||||
|
||||
v_without_tone = [
|
||||
"E",
|
||||
"En",
|
||||
"a",
|
||||
"ai",
|
||||
"an",
|
||||
"ang",
|
||||
"ao",
|
||||
"e",
|
||||
"ei",
|
||||
"en",
|
||||
"eng",
|
||||
"er",
|
||||
"i",
|
||||
"i0",
|
||||
"ia",
|
||||
"ian",
|
||||
"iang",
|
||||
"iao",
|
||||
"ie",
|
||||
"in",
|
||||
"ing",
|
||||
"iong",
|
||||
"ir",
|
||||
"iu",
|
||||
"o",
|
||||
"ong",
|
||||
"ou",
|
||||
"u",
|
||||
"ua",
|
||||
"uai",
|
||||
"uan",
|
||||
"uang",
|
||||
"ui",
|
||||
"un",
|
||||
"uo",
|
||||
"v",
|
||||
"van",
|
||||
"ve",
|
||||
"vn",
|
||||
]
|
||||
|
||||
# japanese
|
||||
ja_symbols = [
|
||||
"I",
|
||||
"N",
|
||||
"U",
|
||||
"a",
|
||||
"b",
|
||||
"by",
|
||||
"ch",
|
||||
"cl",
|
||||
"d",
|
||||
"dy",
|
||||
"e",
|
||||
"f",
|
||||
"g",
|
||||
"gy",
|
||||
"h",
|
||||
"hy",
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"ky",
|
||||
"m",
|
||||
"my",
|
||||
"n",
|
||||
"ny",
|
||||
"o",
|
||||
"p",
|
||||
"py",
|
||||
"r",
|
||||
"ry",
|
||||
"s",
|
||||
"sh",
|
||||
"t",
|
||||
"ts",
|
||||
"u",
|
||||
"v",
|
||||
"w",
|
||||
"y",
|
||||
"z",
|
||||
###楼下2个留到后面加
|
||||
# "[", #上升调型
|
||||
# "]", #下降调型
|
||||
# "$", #结束符
|
||||
# "^", #开始符
|
||||
]
|
||||
|
||||
arpa = {
|
||||
"AH0",
|
||||
"S",
|
||||
"AH1",
|
||||
"EY2",
|
||||
"AE2",
|
||||
"EH0",
|
||||
"OW2",
|
||||
"UH0",
|
||||
"NG",
|
||||
"B",
|
||||
"G",
|
||||
"AY0",
|
||||
"M",
|
||||
"AA0",
|
||||
"F",
|
||||
"AO0",
|
||||
"ER2",
|
||||
"UH1",
|
||||
"IY1",
|
||||
"AH2",
|
||||
"DH",
|
||||
"IY0",
|
||||
"EY1",
|
||||
"IH0",
|
||||
"K",
|
||||
"N",
|
||||
"W",
|
||||
"IY2",
|
||||
"T",
|
||||
"AA1",
|
||||
"ER1",
|
||||
"EH2",
|
||||
"OY0",
|
||||
"UH2",
|
||||
"UW1",
|
||||
"Z",
|
||||
"AW2",
|
||||
"AW1",
|
||||
"V",
|
||||
"UW2",
|
||||
"AA2",
|
||||
"ER",
|
||||
"AW0",
|
||||
"UW0",
|
||||
"R",
|
||||
"OW1",
|
||||
"EH1",
|
||||
"ZH",
|
||||
"AE0",
|
||||
"IH2",
|
||||
"IH",
|
||||
"Y",
|
||||
"JH",
|
||||
"P",
|
||||
"AY1",
|
||||
"EY0",
|
||||
"OY2",
|
||||
"TH",
|
||||
"HH",
|
||||
"D",
|
||||
"ER0",
|
||||
"CH",
|
||||
"AO1",
|
||||
"AE1",
|
||||
"AO2",
|
||||
"OY1",
|
||||
"AY2",
|
||||
"IH1",
|
||||
"OW0",
|
||||
"L",
|
||||
"SH",
|
||||
}
|
||||
|
||||
ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停'
|
||||
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
||||
|
||||
yue_symbols={'Yeot3', 'Yip1', 'Yyu3', 'Yeng4', 'Yut5', 'Yaan5', 'Ym5', 'Yaan6', 'Yang1', 'Yun4', 'Yon2', 'Yui5', 'Yun2', 'Yat3', 'Ye', 'Yeot1', 'Yoeng5', 'Yoek2', 'Yam2', 'Yeon6', 'Yu6', 'Yiu3', 'Yaang6', 'Yp5', 'Yai4', 'Yoek4', 'Yit6', 'Yam5', 'Yoeng6', 'Yg1', 'Yk3', 'Yoe4', 'Yam3', 'Yc', 'Yyu4', 'Yyut1', 'Yiu4', 'Ying3', 'Yip3', 'Yaap3', 'Yau3', 'Yan4', 'Yau1', 'Yap4', 'Yk6', 'Yok3', 'Yai1', 'Yeot6', 'Yan2', 'Yoek6', 'Yt1', 'Yoi1', 'Yit5', 'Yn4', 'Yaau3', 'Yau4', 'Yuk6', 'Ys', 'Yuk', 'Yin6', 'Yung6', 'Ya', 'You', 'Yaai5', 'Yau5', 'Yoi3', 'Yaak3', 'Yaat3', 'Ying2', 'Yok5', 'Yeng2', 'Yyut3', 'Yam1', 'Yip5', 'You1', 'Yam6', 'Yaa5', 'Yi6', 'Yek4', 'Yyu2', 'Yuk5', 'Yaam1', 'Yang2', 'Yai', 'Yiu6', 'Yin4', 'Yok4', 'Yot3', 'Yui2', 'Yeoi5', 'Yyun6', 'Yyu5', 'Yoi5', 'Yeot2', 'Yim4', 'Yeoi2', 'Yaan1', 'Yang6', 'Yong1', 'Yaang4', 'Yung5', 'Yeon1', 'Yin2', 'Ya3', 'Yaang3', 'Yg', 'Yk2', 'Yaau5', 'Yut1', 'Yt5', 'Yip4', 'Yung4', 'Yj', 'Yong3', 'Ya1', 'Yg6', 'Yaau6', 'Yit3', 'Yun3', 'Ying1', 'Yn2', 'Yg4', 'Yl', 'Yp3', 'Yn3', 'Yak1', 'Yang5', 'Yoe6', 'You2', 'Yap2', 'Yak2', 'Yt3', 'Yot5', 'Yim2', 'Yi1', 'Yn6', 'Yaat5', 'Yaam3', 'Yoek5', 'Ye3', 'Yeon4', 'Yaa2', 'Yu3', 'Yim6', 'Ym', 'Yoe3', 'Yaai2', 'Ym2', 'Ya6', 'Yeng6', 'Yik4', 'Yot4', 'Yaai4', 'Yyun3', 'Yu1', 'Yoeng1', 'Yaap2', 'Yuk3', 'Yoek3', 'Yeng5', 'Yeoi1', 'Yiu2', 'Yok1', 'Yo1', 'Yoek1', 'Yoeng2', 'Yeon5', 'Yiu1', 'Yoeng4', 'Yuk2', 'Yat4', 'Yg5', 'Yut4', 'Yan6', 'Yin3', 'Yaa6', 'Yap1', 'Yg2', 'Yoe5', 'Yt4', 'Ya5', 'Yo4', 'Yyu1', 'Yak3', 'Yeon2', 'Yong4', 'Ym1', 'Ye2', 'Yaang5', 'Yoi2', 'Yeng3', 'Yn', 'Yyut4', 'Yau', 'Yaak2', 'Yaan4', 'Yek2', 'Yin1', 'Yi5', 'Yoe2', 'Yei5', 'Yaat6', 'Yak5', 'Yp6', 'Yok6', 'Yei2', 'Yaap1', 'Yyut5', 'Yi4', 'Yim1', 'Yk5', 'Ye4', 'Yok2', 'Yaam6', 'Yat2', 'Yon6', 'Yei3', 'Yyu6', 'Yeot5', 'Yk4', 'Yai6', 'Yd', 'Yg3', 'Yei6', 'Yau2', 'Yok', 'Yau6', 'Yung3', 'Yim5', 'Yut6', 'Yit1', 'Yon3', 'Yat1', 'Yaam2', 'Yyut2', 'Yui6', 'Yt2', 'Yek6', 'Yt', 'Ye6', 'Yang3', 'Ying6', 'Yaau1', 'Yeon3', 'Yng', 'Yh', 'Yang4', 'Ying5', 'Yaap6', 'Yoeng3', 'Yyun4', 'You3', 'Yan5', 'Yat5', 'Yot1', 'Yun1', 'Yi3', 'Yaa1', 'Yaap4', 'You6', 'Yaang2', 'Yaap5', 'Yaa3', 'Yaak6', 'Yeng1', 'Yaak1', 'Yo5', 'Yoi4', 'Yam4', 'Yik1', 'Ye1', 'Yai5', 'Yung1', 'Yp2', 'Yui4', 'Yaak4', 'Yung2', 'Yak4', 'Yaat4', 'Yeoi4', 'Yut2', 'Yin5', 'Yaau4', 'Yap6', 'Yb', 'Yaam4', 'Yw', 'Yut3', 'Yong2', 'Yt6', 'Yaai6', 'Yap5', 'Yik5', 'Yun6', 'Yaam5', 'Yun5', 'Yik3', 'Ya2', 'Yyut6', 'Yon4', 'Yk1', 'Yit4', 'Yak6', 'Yaan2', 'Yuk1', 'Yai2', 'Yik2', 'Yaat2', 'Yo3', 'Ykw', 'Yn5', 'Yaa', 'Ye5', 'Yu4', 'Yei1', 'Yai3', 'Yyun5', 'Yip2', 'Yaau2', 'Yiu5', 'Ym4', 'Yeoi6', 'Yk', 'Ym6', 'Yoe1', 'Yeoi3', 'Yon', 'Yuk4', 'Yaai3', 'Yaa4', 'Yot6', 'Yaang1', 'Yei4', 'Yek1', 'Yo', 'Yp', 'Yo6', 'Yp4', 'Yan3', 'Yoi', 'Yap3', 'Yek3', 'Yim3', 'Yz', 'Yot2', 'Yoi6', 'Yit2', 'Yu5', 'Yaan3', 'Yan1', 'Yon5', 'Yp1', 'Yong5', 'Ygw', 'Yak', 'Yat6', 'Ying4', 'Yu2', 'Yf', 'Ya4', 'Yon1', 'You4', 'Yik6', 'Yui1', 'Yaat1', 'Yeot4', 'Yi2', 'Yaai1', 'Yek5', 'Ym3', 'Yong6', 'You5', 'Yyun1', 'Yn1', 'Yo2', 'Yip6', 'Yui3', 'Yaak5', 'Yyun2'}
|
||||
|
||||
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
|
||||
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
|
||||
symbols = sorted(set(symbols))
|
||||
# print(len(symbols))
|
||||
symbols+=["[","]"]##日文新增上升下降调型
|
||||
symbols+=sorted(list(ko_symbols))
|
||||
symbols+=sorted(list(yue_symbols))##新加的yue统一摆在后头#已查过开头加Y后没有重复,韩文显然不会重复
|
||||
# print(len(symbols))
|
||||
if __name__ == "__main__":
|
||||
print(len(symbols))
|
||||
'''
|
||||
粤语:
|
||||
732-353=379
|
||||
韩文+粤语:
|
||||
732-322=410
|
||||
'''
|
@ -681,7 +681,6 @@ class ToneSandhi:
|
||||
and seg[i - 1][0] == "一"
|
||||
and seg[i - 2][0] == word
|
||||
and pos == "v"
|
||||
and seg[i - 2][1] == "v"
|
||||
):
|
||||
continue
|
||||
else:
|
||||
|
@ -28,7 +28,7 @@ UNITS = OrderedDict({
|
||||
8: '亿',
|
||||
})
|
||||
|
||||
COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
||||
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
||||
|
||||
# 分数表达式
|
||||
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
||||
@ -107,8 +107,11 @@ def replace_default_num(match):
|
||||
|
||||
|
||||
# 加减乘除
|
||||
# RE_ASMD = re.compile(
|
||||
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
||||
RE_ASMD = re.compile(
|
||||
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
||||
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
|
||||
|
||||
asmd_map = {
|
||||
'+': '加',
|
||||
'-': '减',
|
||||
@ -117,7 +120,6 @@ asmd_map = {
|
||||
'=': '等于'
|
||||
}
|
||||
|
||||
|
||||
def replace_asmd(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
@ -129,6 +131,39 @@ def replace_asmd(match) -> str:
|
||||
return result
|
||||
|
||||
|
||||
# 次方专项
|
||||
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
|
||||
|
||||
power_map = {
|
||||
'⁰': '0',
|
||||
'¹': '1',
|
||||
'²': '2',
|
||||
'³': '3',
|
||||
'⁴': '4',
|
||||
'⁵': '5',
|
||||
'⁶': '6',
|
||||
'⁷': '7',
|
||||
'⁸': '8',
|
||||
'⁹': '9',
|
||||
'ˣ': 'x',
|
||||
'ʸ': 'y',
|
||||
'ⁿ': 'n'
|
||||
}
|
||||
|
||||
def replace_power(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
match (re.Match)
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
power_num = ""
|
||||
for m in match.group(0):
|
||||
power_num += power_map[m]
|
||||
result = "的" + power_num + "次方"
|
||||
return result
|
||||
|
||||
|
||||
# 数字表达式
|
||||
# 纯小数
|
||||
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
||||
|
@ -35,6 +35,7 @@ from .num import RE_POSITIVE_QUANTIFIERS
|
||||
from .num import RE_RANGE
|
||||
from .num import RE_TO_RANGE
|
||||
from .num import RE_ASMD
|
||||
from .num import RE_POWER
|
||||
from .num import replace_default_num
|
||||
from .num import replace_frac
|
||||
from .num import replace_negative_num
|
||||
@ -44,6 +45,7 @@ from .num import replace_positive_quantifier
|
||||
from .num import replace_range
|
||||
from .num import replace_to_range
|
||||
from .num import replace_asmd
|
||||
from .num import replace_power
|
||||
from .phonecode import RE_MOBILE_PHONE
|
||||
from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
|
||||
from .phonecode import RE_TELEPHONE
|
||||
@ -114,6 +116,12 @@ class TextNormalizer():
|
||||
sentence = sentence.replace('χ', '器')
|
||||
sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛')
|
||||
sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽')
|
||||
# 兜底数学运算,顺便兼容懒人用语
|
||||
sentence = sentence.replace('+', '加')
|
||||
sentence = sentence.replace('-', '减')
|
||||
sentence = sentence.replace('×', '乘')
|
||||
sentence = sentence.replace('÷', '除')
|
||||
sentence = sentence.replace('=', '等')
|
||||
# re filter special characters, have one more character "-" than line 68
|
||||
sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|\\]', '', sentence)
|
||||
return sentence
|
||||
@ -136,6 +144,12 @@ class TextNormalizer():
|
||||
sentence = RE_TO_RANGE.sub(replace_to_range, sentence)
|
||||
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
|
||||
sentence = replace_measure(sentence)
|
||||
|
||||
# 处理数学运算
|
||||
while RE_ASMD.search(sentence):
|
||||
sentence = RE_ASMD.sub(replace_asmd, sentence)
|
||||
sentence = RE_POWER.sub(replace_power, sentence)
|
||||
|
||||
sentence = RE_FRAC.sub(replace_frac, sentence)
|
||||
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
||||
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
|
||||
@ -145,10 +159,6 @@ class TextNormalizer():
|
||||
|
||||
sentence = RE_RANGE.sub(replace_range, sentence)
|
||||
|
||||
# 处理加减乘除
|
||||
while RE_ASMD.search(sentence):
|
||||
sentence = RE_ASMD.sub(replace_asmd, sentence)
|
||||
|
||||
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
||||
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
||||
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
|
||||
|
@ -1,10 +1,12 @@
|
||||
import json
|
||||
import locale
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
i18n_dir = pathlib.Path(os.path.dirname(__file__)).as_posix().replace("tools/","")
|
||||
|
||||
def load_language_list(language):
|
||||
with open(f"./i18n/locale/{language}.json", "r", encoding="utf-8") as f:
|
||||
with open(f"{i18n_dir}/locale/{language}.json", "r", encoding="utf-8") as f:
|
||||
language_list = json.load(f)
|
||||
return language_list
|
||||
|
||||
@ -15,7 +17,7 @@ class I18nAuto:
|
||||
language = locale.getdefaultlocale()[
|
||||
0
|
||||
] # getlocale can't identify the system's language ((None, None))
|
||||
if not os.path.exists(f"./i18n/locale/{language}.json"):
|
||||
if not os.path.exists(f"{i18n_dir}/locale/{language}.json"):
|
||||
language = "en_US"
|
||||
self.language = language
|
||||
self.language_map = load_language_list(language)
|
||||
|
Loading…
x
Reference in New Issue
Block a user