support v2 model.

This commit is contained in:
CyberWon 2024-08-08 22:07:31 +08:00
parent 7c43b41e6d
commit 574f667c71
27 changed files with 47513 additions and 386 deletions

View File

@ -17,6 +17,9 @@ from transformers import AutoTokenizer
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
version = os.environ.get('version', None)
# from config import exp_dir # from config import exp_dir
@ -125,7 +128,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len): for i in range(semantic_data_len):
# 先依次遍历 # 先依次遍历
# get str # get str
item_name = self.semantic_data.iloc[i,0] item_name = self.semantic_data.iloc[i, 0]
# print(self.phoneme_data) # print(self.phoneme_data)
try: try:
phoneme, word2ph, text = self.phoneme_data[item_name] phoneme, word2ph, text = self.phoneme_data[item_name]
@ -135,7 +138,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1 num_not_in += 1
continue continue
semantic_str = self.semantic_data.iloc[i,1] semantic_str = self.semantic_data.iloc[i, 1]
# get token list # get token list
semantic_ids = [int(idx) for idx in semantic_str.split(" ")] semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
@ -149,7 +152,7 @@ class Text2SemanticDataset(Dataset):
phoneme = phoneme.split(" ") phoneme = phoneme.split(" ")
try: try:
phoneme_ids = cleaned_text_to_sequence(phoneme) phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except: except:
traceback.print_exc() traceback.print_exc()
# print(f"{item_name} not in self.phoneme_data !") # print(f"{item_name} not in self.phoneme_data !")

View File

@ -1,10 +1,12 @@
from copy import deepcopy from copy import deepcopy
import math import math
import os, sys, gc import os, sys
import random import random
import traceback import traceback
from tqdm import tqdm from tqdm import tqdm
from loguru import logger
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
import ffmpeg import ffmpeg
@ -26,6 +28,7 @@ from my_utils import load_audio
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
i18n = I18nAuto() i18n = I18nAuto()
# configs/tts_infer.yaml # configs/tts_infer.yaml
@ -49,7 +52,8 @@ custom:
""" """
def set_seed(seed:int):
def set_seed(seed: int):
seed = int(seed) seed = int(seed)
seed = seed if seed != -1 else random.randrange(1 << 32) seed = seed if seed != -1 else random.randrange(1 << 32)
print(f"Set seed to {seed}") print(f"Set seed to {seed}")
@ -71,8 +75,9 @@ def set_seed(seed:int):
pass pass
return seed return seed
class TTS_Config: class TTS_Config:
default_configs={ default_configs = {
"device": "cpu", "device": "cpu",
"is_half": False, "is_half": False,
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
@ -80,31 +85,31 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
} }
configs:dict = None configs: dict = None
def __init__(self, configs: Union[dict, str]=None):
def __init__(self, configs: Union[dict, str] = None):
# 设置默认配置文件路径 # 设置默认配置文件路径
configs_base_path:str = "GPT_SoVITS/configs/" configs_base_path: str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True) os.makedirs(configs_base_path, exist_ok=True)
self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") self.configs_path: str = os.path.join(configs_base_path, "tts_infer.yaml")
if configs in ["", None]: if configs in ["", None]:
if not os.path.exists(self.configs_path): if not os.path.exists(self.configs_path):
self.save_configs() self.save_configs()
print(f"Create default config file at {self.configs_path}") print(f"Create default config file at {self.configs_path}")
configs:dict = {"default": deepcopy(self.default_configs)} configs: dict = {"default": deepcopy(self.default_configs)}
if isinstance(configs, str): if isinstance(configs, str):
self.configs_path = configs self.configs_path = configs
configs:dict = self._load_configs(self.configs_path) configs: dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict) assert isinstance(configs, dict)
default_configs:dict = configs.get("default", None) default_configs: dict = configs.get("default", None)
if default_configs is not None: if default_configs is not None:
self.default_configs = default_configs self.default_configs = default_configs
self.configs:dict = configs.get("custom", deepcopy(self.default_configs)) self.configs: dict = configs.get("custom", deepcopy(self.default_configs))
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
@ -113,7 +118,6 @@ class TTS_Config:
self.bert_base_path = self.configs.get("bert_base_path", None) self.bert_base_path = self.configs.get("bert_base_path", None)
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None)
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
self.t2s_weights_path = self.default_configs['t2s_weights_path'] self.t2s_weights_path = self.default_configs['t2s_weights_path']
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
@ -128,29 +132,27 @@ class TTS_Config:
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
self.update_configs() self.update_configs()
self.max_sec = None self.max_sec = None
self.hz:int = 50 self.hz: int = 50
self.semantic_frame_rate:str = "25hz" self.semantic_frame_rate: str = "25hz"
self.segment_size:int = 20480 self.segment_size: int = 20480
self.filter_length:int = 2048 self.filter_length: int = 2048
self.sampling_rate:int = 32000 self.sampling_rate: int = 32000
self.hop_length:int = 640 self.hop_length: int = 640
self.win_length:int = 2048 self.win_length: int = 2048
self.n_speakers:int = 300 self.n_speakers: int = 300
self.languages:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] self.languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
def _load_configs(self, configs_path: str) -> dict:
def _load_configs(self, configs_path: str)->dict:
with open(configs_path, 'r') as f: with open(configs_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader) configs = yaml.load(f, Loader=yaml.FullLoader)
return configs return configs
def save_configs(self, configs_path:str=None)->None: def save_configs(self, configs_path: str = None) -> None:
configs={ configs = {
"default":self.default_configs, "default": self.default_configs,
} }
if self.configs is not None: if self.configs is not None:
configs["custom"] = self.update_configs() configs["custom"] = self.update_configs()
@ -162,11 +164,11 @@ class TTS_Config:
def update_configs(self): def update_configs(self):
self.config = { self.config = {
"device" : str(self.device), "device": str(self.device),
"is_half" : self.is_half, "is_half": self.is_half,
"t2s_weights_path" : self.t2s_weights_path, "t2s_weights_path": self.t2s_weights_path,
"vits_weights_path" : self.vits_weights_path, "vits_weights_path": self.vits_weights_path,
"bert_base_path" : self.bert_base_path, "bert_base_path": self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path, "cnhuhbert_base_path": self.cnhuhbert_base_path,
} }
return self.config return self.config
@ -194,63 +196,58 @@ class TTS:
if isinstance(configs, TTS_Config): if isinstance(configs, TTS_Config):
self.configs = configs self.configs = configs
else: else:
self.configs:TTS_Config = TTS_Config(configs) self.configs: TTS_Config = TTS_Config(configs)
self.t2s_model:Text2SemanticLightningModule = None self.t2s_model: Text2SemanticLightningModule = None
self.vits_model:SynthesizerTrn = None self.vits_model: SynthesizerTrn = None
self.bert_tokenizer:AutoTokenizer = None self.bert_tokenizer: AutoTokenizer = None
self.bert_model:AutoModelForMaskedLM = None self.bert_model: AutoModelForMaskedLM = None
self.cnhuhbert_model:CNHubert = None self.cnhuhbert_model: CNHubert = None
self.version = "v1"
self._init_models() self._init_models()
self.text_preprocessor:TextPreprocessor = \ self.text_preprocessor: TextPreprocessor = \
TextPreprocessor(self.bert_model, TextPreprocessor(self.bert_model,
self.bert_tokenizer, self.bert_tokenizer,
self.configs.device) self.configs.device, version=self.version)
self.prompt_cache: dict = {
self.prompt_cache:dict = { "ref_audio_path": None,
"ref_audio_path" : None,
"prompt_semantic": None, "prompt_semantic": None,
"refer_spec" : None, "refer_spec": None,
"prompt_text" : None, "prompt_text": None,
"prompt_lang" : None, "prompt_lang": None,
"phones" : None, "phones": None,
"bert_features" : None, "bert_features": None,
"norm_text" : None, "norm_text": None,
} }
self.stop_flag: bool = False
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
self.stop_flag:bool = False def _init_models(self, ):
self.precision:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
def _init_models(self,):
self.init_t2s_weights(self.configs.t2s_weights_path) self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path) self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path) self.init_bert_weights(self.configs.bert_base_path)
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path) self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
# self.enable_half_precision(self.configs.is_half) # self.enable_half_precision(self.configs.is_half)
def init_cnhuhbert_weights(self, base_path: str): def init_cnhuhbert_weights(self, base_path: str):
print(f"Loading CNHuBERT weights from {base_path}") print(f"Loading CNHuBERT weights from {base_path}")
self.cnhuhbert_model = CNHubert(base_path) self.cnhuhbert_model = CNHubert(base_path)
self.cnhuhbert_model=self.cnhuhbert_model.eval() self.cnhuhbert_model = self.cnhuhbert_model.eval()
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device) self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.half()
def init_bert_weights(self, base_path: str): def init_bert_weights(self, base_path: str):
print(f"Loading BERT weights from {base_path}") print(f"Loading BERT weights from {base_path}")
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path) self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path) self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
self.bert_model=self.bert_model.eval() self.bert_model = self.bert_model.eval()
self.bert_model = self.bert_model.to(self.configs.device) self.bert_model = self.bert_model.to(self.configs.device)
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.bert_model = self.bert_model.half() self.bert_model = self.bert_model.half()
def init_vits_weights(self, weights_path: str): def init_vits_weights(self, weights_path: str):
@ -266,6 +263,12 @@ class TTS:
self.configs.win_length = hps["data"]["win_length"] self.configs.win_length = hps["data"]["win_length"]
self.configs.n_speakers = hps["data"]["n_speakers"] self.configs.n_speakers = hps["data"]["n_speakers"]
self.configs.semantic_frame_rate = "25hz" self.configs.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps['model']['version'] = "v1"
else:
hps['model']['version'] = "v2"
self.version = "v2"
logger.debug(self.version)
kwargs = hps["model"] kwargs = hps["model"]
vits_model = SynthesizerTrn( vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1, self.configs.filter_length // 2 + 1,
@ -281,10 +284,9 @@ class TTS:
vits_model = vits_model.eval() vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False) vits_model.load_state_dict(dict_s2["weight"], strict=False)
self.vits_model = vits_model self.vits_model = vits_model
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str): def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}") print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path self.configs.t2s_weights_path = weights_path
@ -298,10 +300,10 @@ class TTS:
t2s_model = t2s_model.to(self.configs.device) t2s_model = t2s_model.to(self.configs.device)
t2s_model = t2s_model.eval() t2s_model = t2s_model.eval()
self.t2s_model = t2s_model self.t2s_model = t2s_model
if self.configs.is_half and str(self.configs.device)!="cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half() self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True, save: bool = True): def enable_half_precision(self, enable: bool = True):
''' '''
To enable half precision for the TTS model. To enable half precision for the TTS model.
Args: Args:
@ -314,15 +316,14 @@ class TTS:
self.configs.is_half = enable self.configs.is_half = enable
self.precision = torch.float16 if enable else torch.float32 self.precision = torch.float16 if enable else torch.float32
if save:
self.configs.save_configs() self.configs.save_configs()
if enable: if enable:
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model =self.t2s_model.half() self.t2s_model = self.t2s_model.half()
if self.vits_model is not None: if self.vits_model is not None:
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
if self.bert_model is not None: if self.bert_model is not None:
self.bert_model =self.bert_model.half() self.bert_model = self.bert_model.half()
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.half() self.cnhuhbert_model = self.cnhuhbert_model.half()
else: else:
@ -335,14 +336,13 @@ class TTS:
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float() self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device, save: bool = True): def set_device(self, device: torch.device):
''' '''
To set the device for all models. To set the device for all models.
Args: Args:
device: torch.device, the device to use for all models. device: torch.device, the device to use for all models.
''' '''
self.configs.device = device self.configs.device = device
if save:
self.configs.save_configs() self.configs.save_configs()
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.to(device) self.t2s_model = self.t2s_model.to(device)
@ -353,7 +353,7 @@ class TTS:
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device) self.cnhuhbert_model = self.cnhuhbert_model.to(device)
def set_ref_audio(self, ref_audio_path:str): def set_ref_audio(self, ref_audio_path: str):
''' '''
To set the reference audio for the TTS model, To set the reference audio for the TTS model,
including the prompt_semantic and refer_spepc. including the prompt_semantic and refer_spepc.
@ -362,10 +362,6 @@ class TTS:
''' '''
self._set_prompt_semantic(ref_audio_path) self._set_prompt_semantic(ref_audio_path)
self._set_ref_spec(ref_audio_path) self._set_ref_spec(ref_audio_path)
self._set_ref_audio_path(ref_audio_path)
def _set_ref_audio_path(self, ref_audio_path):
self.prompt_cache["ref_audio_path"] = ref_audio_path
def _set_ref_spec(self, ref_audio_path): def _set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
@ -386,8 +382,7 @@ class TTS:
# self.refer_spec = spec # self.refer_spec = spec
self.prompt_cache["refer_spec"] = spec self.prompt_cache["refer_spec"] = spec
def _set_prompt_semantic(self, ref_wav_path: str):
def _set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros( zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3), int(self.configs.sampling_rate * 0.3),
dtype=np.float16 if self.configs.is_half else np.float32, dtype=np.float16 if self.configs.is_half else np.float32,
@ -415,12 +410,12 @@ class TTS:
prompt_semantic = codes[0, 0].to(self.configs.device) prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None): def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
seq = sequences[0] seq = sequences[0]
ndim = seq.dim() ndim = seq.dim()
if axis < 0: if axis < 0:
axis += ndim axis += ndim
dtype:torch.dtype = seq.dtype dtype: torch.dtype = seq.dtype
pad_value = torch.tensor(pad_value, dtype=dtype) pad_value = torch.tensor(pad_value, dtype=dtype)
seq_lengths = [seq.shape[axis] for seq in sequences] seq_lengths = [seq.shape[axis] for seq in sequences]
if max_length is None: if max_length is None:
@ -436,15 +431,15 @@ class TTS:
batch = torch.stack(padded_sequences) batch = torch.stack(padded_sequences)
return batch return batch
def to_batch(self, data:list, def to_batch(self, data: list,
prompt_data:dict=None, prompt_data: dict = None,
batch_size:int=5, batch_size: int = 5,
threshold:float=0.75, threshold: float = 0.75,
split_bucket:bool=True, split_bucket: bool = True,
device:torch.device=torch.device("cpu"), device: torch.device = torch.device("cpu"),
precision:torch.dtype=torch.float32, precision: torch.dtype = torch.float32,
): ):
_data:list = [] _data: list = []
index_and_len_list = [] index_and_len_list = []
for idx, item in enumerate(data): for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"]) norm_text_len = len(item["norm_text"])
@ -457,29 +452,28 @@ class TTS:
batch_index_list_len = 0 batch_index_list_len = 0
pos = 0 pos = 0
while pos <index_and_len_list.shape[0]: while pos < index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))]) # batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0]) pos_end = min(pos + batch_size, index_and_len_list.shape[0])
while pos < pos_end: while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32) batch = index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8) score = batch[(pos_end - pos) // 2] / (batch.mean() + 1e-8)
if (score>=threshold) or (pos_end-pos==1): if (score >= threshold) or (pos_end - pos == 1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist() batch_index = index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index) batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index) batch_index_list.append(batch_index)
pos = pos_end pos = pos_end
break break
pos_end=pos_end-1 pos_end = pos_end - 1
assert batch_index_list_len == len(data) assert batch_index_list_len == len(data)
else: else:
for i in range(len(data)): for i in range(len(data)):
if i%batch_size == 0: if i % batch_size == 0:
batch_index_list.append([]) batch_index_list.append([])
batch_index_list[-1].append(i) batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list): for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list] item_list = [data[idx] for idx in index_list]
phones_list = [] phones_list = []
@ -493,13 +487,13 @@ class TTS:
phones_max_len = 0 phones_max_len = 0
for item in item_list: for item in item_list:
if prompt_data is not None: if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\ all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) \
.to(dtype=precision, device=device) .to(dtype=precision, device=device)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device) all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device) phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"] # norm_text = prompt_data["norm_text"]+item["norm_text"]
else: else:
all_bert_features = item["bert_features"]\ all_bert_features = item["bert_features"] \
.to(dtype=precision, device=device) .to(dtype=precision, device=device)
phones = torch.LongTensor(item["phones"]).to(device) phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones all_phones = phones
@ -519,7 +513,6 @@ class TTS:
all_phones_batch = all_phones_list all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list all_bert_features_batch = all_bert_features_list
max_len = max(bert_max_len, phones_max_len) max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad。padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略 #### 直接对phones和bert_features进行pad。padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
@ -551,7 +544,7 @@ class TTS:
return _data, batch_index_list return _data, batch_index_list
def recovery_order(self, data:list, batch_index_list:list)->list: def recovery_order(self, data: list, batch_index_list: list) -> list:
''' '''
Recovery the order of the audio according to the batch_index_list. Recovery the order of the audio according to the batch_index_list.
@ -563,20 +556,20 @@ class TTS:
list (List[np.ndarray]): the data in the original order. list (List[np.ndarray]): the data in the original order.
''' '''
length = len(sum(batch_index_list, [])) length = len(sum(batch_index_list, []))
_data = [None]*length _data = [None] * length
for i, index_list in enumerate(batch_index_list): for i, index_list in enumerate(batch_index_list):
for j, index in enumerate(index_list): for j, index in enumerate(index_list):
_data[index] = data[i][j] _data[index] = data[i][j]
return _data return _data
def stop(self,): def stop(self, ):
''' '''
Stop the inference process. Stop the inference process.
''' '''
self.stop_flag = True self.stop_flag = True
@torch.no_grad() @torch.no_grad()
def run(self, inputs:dict): def run(self, inputs: dict):
""" """
Text to speech inference. Text to speech inference.
@ -606,16 +599,16 @@ class TTS:
Tuple[int, np.ndarray]: sampling rate and audio data. Tuple[int, np.ndarray]: sampling rate and audio data.
""" """
########## variables initialization ########### ########## variables initialization ###########
self.stop_flag:bool = False self.stop_flag: bool = False
text:str = inputs.get("text", "") text: str = inputs.get("text", "")
text_lang:str = inputs.get("text_lang", "") text_lang: str = inputs.get("text_lang", "")
ref_audio_path:str = inputs.get("ref_audio_path", "") ref_audio_path: str = inputs.get("ref_audio_path", "")
prompt_text:str = inputs.get("prompt_text", "") prompt_text: str = inputs.get("prompt_text", "")
prompt_lang:str = inputs.get("prompt_lang", "") prompt_lang: str = inputs.get("prompt_lang", "")
top_k:int = inputs.get("top_k", 5) top_k: int = inputs.get("top_k", 5)
top_p:float = inputs.get("top_p", 1) top_p: float = inputs.get("top_p", 1)
temperature:float = inputs.get("temperature", 1) temperature: float = inputs.get("temperature", 1)
text_split_method:str = inputs.get("text_split_method", "cut0") text_split_method: str = inputs.get("text_split_method", "cut0")
batch_size = inputs.get("batch_size", 1) batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75) batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0) speed_factor = inputs.get("speed_factor", 1.0)
@ -644,7 +637,7 @@ class TTS:
if split_bucket: if split_bucket:
print(i18n("分桶处理模式已开启")) print(i18n("分桶处理模式已开启"))
if fragment_interval<0.01: if fragment_interval < 0.01:
fragment_interval = 0.01 fragment_interval = 0.01
print(i18n("分段间隔过小已自动设置为0.01")) print(i18n("分段间隔过小已自动设置为0.01"))
@ -658,13 +651,14 @@ class TTS:
if ref_audio_path in [None, ""] and \ if ref_audio_path in [None, ""] and \
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)): ((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] is None)):
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()") raise ValueError(
"ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
###### setting reference audio and prompt text preprocessing ######## ###### setting reference audio and prompt text preprocessing ########
t0 = ttime() t0 = ttime()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
self.set_ref_audio(ref_audio_path) self.set_ref_audio(ref_audio_path)
self.text_preprocessor.version = self.version
if not no_prompt_text: if not no_prompt_text:
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "." if (prompt_text[-1] not in splits): prompt_text += "" if prompt_lang != "en" else "."
@ -682,7 +676,7 @@ class TTS:
###### text preprocessing ######## ###### text preprocessing ########
t1 = ttime() t1 = ttime()
data:list = None data: list = None
if not return_fragment: if not return_fragment:
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method) data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
if len(data) == 0: if len(data) == 0:
@ -690,7 +684,7 @@ class TTS:
dtype=np.int16) dtype=np.int16)
return return
batch_index_list:list = None batch_index_list: list = None
data, batch_index_list = self.to_batch(data, data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None, prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size, batch_size=batch_size,
@ -704,7 +698,7 @@ class TTS:
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method) texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
data = [] data = []
for i in range(len(texts)): for i in range(len(texts)):
if i%batch_size == 0: if i % batch_size == 0:
data.append([]) data.append([])
data[-1].append(texts[i]) data[-1].append(texts[i])
@ -712,10 +706,11 @@ class TTS:
batch_data = [] batch_data = []
print(i18n("############ 提取文本Bert特征 ############")) print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(batch_texts): for text in tqdm(batch_texts):
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang) phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text,
text_lang)
if phones is None: if phones is None:
continue continue
res={ res = {
"phones": phones, "phones": phones,
"bert_features": bert_features, "bert_features": bert_features,
"norm_text": norm_text, "norm_text": norm_text,
@ -733,7 +728,6 @@ class TTS:
) )
return batch[0] return batch[0]
t2 = ttime() t2 = ttime()
try: try:
print("############ 推理 ############") print("############ 推理 ############")
@ -748,21 +742,21 @@ class TTS:
if item is None: if item is None:
continue continue
batch_phones:List[torch.LongTensor] = item["phones"] batch_phones: List[torch.LongTensor] = item["phones"]
# batch_phones:torch.LongTensor = item["phones"] # batch_phones:torch.LongTensor = item["phones"]
batch_phones_len:torch.LongTensor = item["phones_len"] batch_phones_len: torch.LongTensor = item["phones_len"]
all_phoneme_ids:torch.LongTensor = item["all_phones"] all_phoneme_ids: torch.LongTensor = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"] all_phoneme_lens: torch.LongTensor = item["all_phones_len"]
all_bert_features:torch.LongTensor = item["all_bert_features"] all_bert_features: torch.LongTensor = item["all_bert_features"]
norm_text:str = item["norm_text"] norm_text: str = item["norm_text"]
max_len = item["max_len"] max_len = item["max_len"]
print(i18n("前端处理后的文本(每句):"), norm_text) print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text : if no_prompt_text:
prompt = None prompt = None
else: else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(
self.configs.device)
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
all_phoneme_ids, all_phoneme_ids,
@ -780,7 +774,7 @@ class TTS:
t4 = ttime() t4 = ttime()
t_34 += t4 - t3 t_34 += t4 - t3
refer_audio_spec:torch.Tensor = self.prompt_cache["refer_spec"]\ refer_audio_spec: torch.Tensor = self.prompt_cache["refer_spec"] \
.to(dtype=self.precision, device=self.configs.device) .to(dtype=self.precision, device=self.configs.device)
batch_audio_fragment = [] batch_audio_fragment = []
@ -804,15 +798,18 @@ class TTS:
# ## vits并行推理 method 2 # ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates) upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))] audio_frag_idx = [pred_semantic_list[i].shape[0] * 2 * upsample_rate for i in
audio_frag_end_idx = [ sum(audio_frag_idx[:i+1]) for i in range(0, len(audio_frag_idx))] range(0, len(pred_semantic_list))]
audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode( _batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec all_pred_semantic, _batch_phones, refer_audio_spec
).detach()[0, 0, :]) ).detach()[0, 0, :])
audio_frag_end_idx.insert(0, 0) audio_frag_end_idx.insert(0, 0)
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))] batch_audio_fragment = [_batch_audio_fragment[audio_frag_end_idx[i - 1]:audio_frag_end_idx[i]] for i in
range(1, len(audio_frag_end_idx))]
# ## vits串行推理 # ## vits串行推理
# for i, idx in enumerate(idx_list): # for i, idx in enumerate(idx_list):
@ -872,7 +869,6 @@ class TTS:
def empty_cache(self): def empty_cache(self):
try: try:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
if "cuda" in str(self.configs.device): if "cuda" in str(self.configs.device):
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif str(self.configs.device) == "mps": elif str(self.configs.device) == "mps":
@ -881,13 +877,13 @@ class TTS:
pass pass
def audio_postprocess(self, def audio_postprocess(self,
audio:List[torch.Tensor], audio: List[torch.Tensor],
sr:int, sr: int,
batch_index_list:list=None, batch_index_list: list = None,
speed_factor:float=1.0, speed_factor: float = 1.0,
split_bucket:bool=True, split_bucket: bool = True,
fragment_interval:float=0.3 fragment_interval: float = 0.3
)->Tuple[int, np.ndarray]: ) -> Tuple[int, np.ndarray]:
zero_wav = torch.zeros( zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), int(self.configs.sampling_rate * fragment_interval),
dtype=self.precision, dtype=self.precision,
@ -896,19 +892,17 @@ class TTS:
for i, batch in enumerate(audio): for i, batch in enumerate(audio):
for j, audio_fragment in enumerate(batch): for j, audio_fragment in enumerate(batch):
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音 max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio if max_audio > 1: audio_fragment /= max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy() audio[i][j] = audio_fragment.cpu().numpy()
if split_bucket: if split_bucket:
audio = self.recovery_order(audio, batch_index_list) audio = self.recovery_order(audio, batch_index_list)
else: else:
# audio = [item for batch in audio for item in batch] # audio = [item for batch in audio for item in batch]
audio = sum(audio, []) audio = sum(audio, [])
audio = np.concatenate(audio, 0) audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16) audio = (audio * 32768).astype(np.int16)
@ -921,9 +915,7 @@ class TTS:
return sr, audio return sr, audio
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流 # 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes() raw_audio = input_audio.astype(np.int16).tobytes()

View File

@ -1,7 +1,8 @@
from loguru import logger
import os, sys import os, sys
from tqdm import tqdm from tqdm import tqdm
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -18,14 +19,16 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
punctuation = set(['!', '?', '', ',', '.', '-'," "]) punctuation = set(['!', '?', '', ',', '.', '-', " "])
def get_first(text:str) -> str:
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip() text = re.split(pattern, text)[0].strip()
return text return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2: if (len(texts)) < 2:
return texts return texts
result = [] result = []
@ -43,28 +46,29 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
return result return result
class TextPreprocessor: class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM, def __init__(self, bert_model: AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device): tokenizer: AutoTokenizer, device: torch.device, version: str = "v2"):
self.bert_model = bert_model self.bert_model = bert_model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = device self.device = device
self.version = version
logger.debug(self.version)
def preprocess(self, text:str, lang:str, text_split_method:str)->List[Dict]: def preprocess(self, text: str, lang: str, text_split_method: str) -> List[Dict]:
print(i18n("############ 切分文本 ############")) print(i18n("############ 切分文本 ############"))
text = self.replace_consecutive_punctuation(text) # 变量命名应该是写错了 text = self.replace_consecutive_punctuation(text) # 变量命名应该是写错了
texts = self.pre_seg_text(text, lang, text_split_method) texts = self.pre_seg_text(text, lang, text_split_method)
result = [] result = []
print(i18n("############ 提取文本Bert特征 ############")) print(i18n("############ 提取文本Bert特征 ############"))
for text in tqdm(texts): for text in tqdm(texts):
if not re.sub(r"\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang) phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
if phones is None: if phones is None:
continue continue
res={ res = {
"phones": phones, "phones": phones,
"bert_features": bert_features, "bert_features": bert_features,
"norm_text": norm_text, "norm_text": norm_text,
@ -72,7 +76,7 @@ class TextPreprocessor:
result.append(res) result.append(res)
return result return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str): def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n") text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): if (text[0] not in splits and len(get_first(text)) < 4):
text = "" + text if lang != "en" else "." + text text = "" + text if lang != "en" else "." + text
@ -90,14 +94,10 @@ class TextPreprocessor:
_texts = merge_short_text_in_array(_texts, 5) _texts = merge_short_text_in_array(_texts, 5)
texts = [] texts = []
for text in _texts: for text in _texts:
# 解决输入目标文本的空行导致报错的问题 # 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0): if (len(text.strip()) == 0):
continue continue
if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if (text[-1] not in splits): text += "" if lang != "en" else "." if (text[-1] not in splits): text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题 # 解决句子过长导致Bert报错的问题
@ -110,7 +110,7 @@ class TextPreprocessor:
print(texts) print(texts)
return texts return texts
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]: def segment_and_extract_feature_for_text(self, texts: list, language: str) -> Tuple[list, torch.Tensor, str]:
textlist, langlist = self.seg_text(texts, language) textlist, langlist = self.seg_text(texts, language)
if len(textlist) == 0: if len(textlist) == 0:
return None, None, None return None, None, None
@ -118,13 +118,12 @@ class TextPreprocessor:
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist) phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
return phones, bert_features, norm_text return phones, bert_features, norm_text
def seg_text(self, text: str, language: str) -> Tuple[list, list]:
def seg_text(self, text:str, language:str)->Tuple[list, list]: textlist = []
langlist = []
textlist=[]
langlist=[]
if language in ["auto", "zh", "ja"]: if language in ["auto", "zh", "ja"]:
LangSegment.setfilters(["zh","ja","en","ko"]) LangSegment.setfilters(["zh", "ja", "en", "ko"])
for tmp in LangSegment.getTexts(text): for tmp in LangSegment.getTexts(text):
if tmp["text"] == "": if tmp["text"] == "":
continue continue
@ -134,7 +133,7 @@ class TextPreprocessor:
langlist.append("en") langlist.append("en")
else: else:
# 因无法区别中日文汉字,以用户输入为准 # 因无法区别中日文汉字,以用户输入为准
langlist.append(language if language!="auto" else tmp["lang"]) langlist.append(language if language != "auto" else tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
elif language == "en": elif language == "en":
LangSegment.setfilters(["en"]) LangSegment.setfilters(["en"])
@ -145,14 +144,14 @@ class TextPreprocessor:
textlist.append(formattext) textlist.append(formattext)
langlist.append("en") langlist.append("en")
elif language in ["all_zh","all_ja"]: elif language in ["all_zh", "all_ja"]:
formattext = text formattext = text
while " " in formattext: while " " in formattext:
formattext = formattext.replace(" ", " ") formattext = formattext.replace(" ", " ")
language = language.replace("all_","") language = language.replace("all_", "")
if text == "": if text == "":
return [],[] return [], []
textlist.append(formattext) textlist.append(formattext)
langlist.append(language) langlist.append(language)
@ -161,8 +160,7 @@ class TextPreprocessor:
return textlist, langlist return textlist, langlist
def extract_bert_feature(self, textlist: list, langlist: list):
def extract_bert_feature(self, textlist:list, langlist:list):
phones_list = [] phones_list = []
bert_feature_list = [] bert_feature_list = []
norm_text_list = [] norm_text_list = []
@ -179,8 +177,7 @@ class TextPreprocessor:
norm_text = ''.join(norm_text_list) norm_text = ''.join(norm_text_list)
return phones_list, bert_feature, norm_text return phones_list, bert_feature, norm_text
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
with torch.no_grad(): with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt") inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs: for i in inputs:
@ -195,13 +192,13 @@ class TextPreprocessor:
phone_level_feature = torch.cat(phone_level_feature, dim=0) phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T return phone_level_feature.T
def clean_text_inf(self, text:str, language:str): def clean_text_inf(self, text: str, language: str):
phones, word2ph, norm_text = clean_text(text, language) phones, word2ph, norm_text = clean_text(text, language, version=self.version)
phones = cleaned_text_to_sequence(phones) phones = cleaned_text_to_sequence(phones, self.version)
return phones, word2ph, norm_text return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str): def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
language=language.replace("all_","") language = language.replace("all_", "")
if language == "zh": if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device) feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else: else:
@ -212,9 +209,9 @@ class TextPreprocessor:
return feature return feature
def process_text(self,texts): def process_text(self, texts):
_text=[] _text = []
if all(text in [None, " ", "\n",""] for text in texts): if all(text in [None, " ", "\n", ""] for text in texts):
raise ValueError(i18n("请输入有效文本")) raise ValueError(i18n("请输入有效文本"))
for text in texts: for text in texts:
if text in [None, " ", ""]: if text in [None, " ", ""]:
@ -223,12 +220,8 @@ class TextPreprocessor:
_text.append(text) _text.append(text)
return _text return _text
def replace_consecutive_punctuation(self, text):
def replace_consecutive_punctuation(self,text):
punctuations = ''.join(re.escape(p) for p in punctuation) punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+' pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text) result = re.sub(pattern, r'\1', text)
return result return result

3
GPT_SoVITS/text/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
G2PWModel
__pycache__
*.zip

View File

@ -1,15 +1,27 @@
from text.symbols import * import os
# if os.environ.get("version","v1")=="v1":
# from text.symbols import symbols
# else:
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
_symbol_to_id = {s: i for i, s in enumerate(symbols)} _symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
def cleaned_text_to_sequence(cleaned_text): def cleaned_text_to_sequence(cleaned_text, version=None):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
text: string to convert to a sequence text: string to convert to a sequence
Returns: Returns:
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' '''
phones = [_symbol_to_id[symbol] for symbol in cleaned_text] if version is None:version=os.environ.get('version', 'v2')
if version == "v1":
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
else:
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
return phones return phones

View File

@ -0,0 +1,209 @@
# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py
import sys
import re
import cn2an
from pyjyutping import jyutping
from text.symbols import punctuation
from text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn")
INITIALS = [
"aa",
"aai",
"aak",
"aap",
"aat",
"aau",
"ai",
"au",
"ap",
"at",
"ak",
"a",
"p",
"b",
"e",
"ts",
"t",
"dz",
"d",
"kw",
"k",
"gw",
"g",
"f",
"h",
"l",
"m",
"ng",
"n",
"s",
"y",
"w",
"c",
"z",
"j",
"ong",
"on",
"ou",
"oi",
"ok",
"o",
"uk",
"ung",
]
INITIALS += ["sp", "spl", "spn", "sil"]
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
"": ",",
"...": "",
"$": ".",
"": "'",
"": "'",
'"': "'",
"": "'",
"": "'",
"": "'",
"": "'",
"(": "'",
")": "'",
"": "'",
"": "'",
"": "'",
"": "'",
"[": "'",
"]": "'",
"": "-",
"": "-",
"~": "-",
"": "'",
"": "'",
}
def replace_punctuation(text):
# text = text.replace("嗯", "恩").replace("呣", "母")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def text_normalize(text):
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation(sentence)
return dest_text
punctuation_set=set(punctuation)
def jyuping_to_initials_finals_tones(jyuping_syllables):
initials_finals = []
tones = []
word2ph = []
for syllable in jyuping_syllables:
if syllable in punctuation:
initials_finals.append(syllable)
tones.append(0)
word2ph.append(1) # Add 1 for punctuation
elif syllable == "_":
initials_finals.append(syllable)
tones.append(0)
word2ph.append(1) # Add 1 for underscore
else:
try:
tone = int(syllable[-1])
syllable_without_tone = syllable[:-1]
except ValueError:
tone = 0
syllable_without_tone = syllable
for initial in INITIALS:
if syllable_without_tone.startswith(initial):
if syllable_without_tone.startswith("nga"):
initials_finals.extend(
[
syllable_without_tone[:2],
syllable_without_tone[2:] or syllable_without_tone[-1],
]
)
# tones.extend([tone, tone])
tones.extend([-1, tone])
word2ph.append(2)
else:
final = syllable_without_tone[len(initial) :] or initial[-1]
initials_finals.extend([initial, final])
# tones.extend([tone, tone])
tones.extend([-1, tone])
word2ph.append(2)
break
assert len(initials_finals) == len(tones)
###魔改为辅音+带音调的元音
phones=[]
for a,b in zip(initials_finals,tones):
if(b not in [-1,0]):###防止粤语和普通话重合开头加Y如果是标点不加。
todo="%s%s"%(a,b)
else:todo=a
if(todo not in punctuation_set):todo="Y%s"%todo
phones.append(todo)
# return initials_finals, tones, word2ph
return phones, word2ph
def get_jyutping(text):
jp = jyutping.convert(text)
# print(1111111,jp)
for symbol in punctuation:
jp = jp.replace(symbol, " " + symbol + " ")
jp_array = jp.split()
return jp_array
def get_bert_feature(text, word2ph):
from text import chinese_bert
return chinese_bert.get_bert_feature(text, word2ph)
def g2p(text):
# word2ph = []
jyuping = get_jyutping(text)
# print(jyuping)
# phones, tones, word2ph = jyuping_to_initials_finals_tones(jyuping)
phones, word2ph = jyuping_to_initials_finals_tones(jyuping)
# phones = ["_"] + phones + ["_"]
# tones = [0] + tones + [0]
# word2ph = [1] + word2ph + [1]
return phones, word2ph
if __name__ == "__main__":
# text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
text = "佢個鋤頭太短啦。"
text = text_normalize(text)
# phones, tones, word2ph = g2p(text)
phones, word2ph = g2p(text)
# print(phones, tones, word2ph)
print(phones, word2ph)

View File

@ -54,6 +54,26 @@ def replace_punctuation(text):
return replaced_text return replaced_text
def replace_punctuation_with_en(text):
text = text.replace("", "").replace("", "")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def g2p(text): def g2p(text):
pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip() != ""] sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
@ -158,6 +178,23 @@ def text_normalize(text):
dest_text = "" dest_text = ""
for sentence in sentences: for sentence in sentences:
dest_text += replace_punctuation(sentence) dest_text += replace_punctuation(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
# 不排除英文的文本格式化
def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation_with_en(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text return dest_text

308
GPT_SoVITS/text/chinese2.py Normal file
View File

@ -0,0 +1,308 @@
import os
import pdb
import re
import cn2an
from pypinyin import lazy_pinyin, Style
from pypinyin.contrib.tone_convert import to_normal, to_finals_tone3, to_initials, to_finals
from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi
from text.zh_normalization.text_normlization import TextNormalizer
normalizer = lambda x: cn2an.transform(x, "an2cn")
current_file_path = os.path.dirname(__file__)
pinyin_to_symbol_map = {
line.split("\t")[0]: line.strip().split("\t")[1]
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba_fast.posseg as psg
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
if is_g2pw:
print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
"": ",",
"...": "",
"$": ".",
"/": ",",
"": "-",
"~": "",
"":"",
}
tone_modifier = ToneSandhi()
def replace_punctuation(text):
text = text.replace("", "").replace("", "")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def g2p(text):
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
phones, word2ph = _g2p(sentences)
return phones, word2ph
def _get_initials_finals(word):
initials = []
finals = []
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals):
initials.append(c)
finals.append(v)
return initials, finals
must_erhua = {
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
}
not_erhua = {
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
"狗儿", "少儿"
}
def _merge_erhua(initials: list[str],
finals: list[str],
word: str,
pos: str) -> list[list[str]]:
"""
Do erhub.
"""
# fix er1
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn == 'er1':
finals[i] = 'er2'
# 发音
if word not in must_erhua and (word in not_erhua or
pos in {"a", "j", "nr"}):
return initials, finals
# "……" 等情况直接返回
if len(finals) != len(word):
return initials, finals
assert len(finals) == len(word)
# 与前一个字发同音
new_initials = []
new_finals = []
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn in {
"er2", "er5"
} and word[-2:] not in not_erhua and new_finals:
phn = "er" + new_finals[-1][-1]
new_initials.append(initials[i])
new_finals.append(phn)
return new_initials, new_finals
def _g2p(segments):
phones_list = []
word2ph = []
for seg in segments:
pinyins = []
# Replace all English words in the sentence
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
initials = []
finals = []
if not is_g2pw:
for word, pos in seg_cut:
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
# 儿化
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])
print("pypinyin结果",initials,finals)
else:
# g2pw采用整句推理
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
pre_word_length = 0
for word, pos in seg_cut:
sub_initials = []
sub_finals = []
now_word_length = pre_word_length + len(word)
if pos == 'eng':
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
# 多音字消歧
word_pinyins = correct_pronunciation(word,word_pinyins)
for pinyin in word_pinyins:
if pinyin[0].isalpha():
sub_initials.append(to_initials(pinyin))
sub_finals.append(to_finals_tone3(pinyin,neutral_tone_with_five=True))
else:
sub_initials.append(pinyin)
sub_finals.append(pinyin)
pre_word_length = now_word_length
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
# 儿化
sub_initials, sub_finals = _merge_erhua(sub_initials, sub_finals, word, pos)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
# print("g2pw结果",initials,finals)
for c, v in zip(initials, finals):
raw_pinyin = c + v
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c == v:
assert c in punctuation
phone = [c]
word2ph.append(1)
else:
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c + v_without_tone
assert tone in "12345"
if c:
# 多音节
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
# 单音节
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
phones_list += phone
return phones_list, word2ph
def replace_punctuation_with_en(text):
text = text.replace("", "").replace("", "")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
# 不排除英文的文本格式化
def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation_with_en(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?"
text = "你好"
text = text_normalize(text)
print(g2p(text))
# # 示例用法
# text = "这是一个示例文本:,你好!这是一个测试..."
# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试

View File

@ -1,6 +1,15 @@
from text import chinese, japanese, cleaned_text_to_sequence, symbols, english from text import cleaned_text_to_sequence
import os
# if os.environ.get("version","v1")=="v1":
# from text import chinese
# from text.symbols import symbols
# else:
# from text import chinese2 as chinese
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
language_module_map = {"zh": chinese, "ja": japanese, "en": english}
special = [ special = [
# ("%", "zh", "SP"), # ("%", "zh", "SP"),
("", "zh", "SP2"), ("", "zh", "SP2"),
@ -9,34 +18,58 @@ special = [
] ]
def clean_text(text, language): def clean_text(text, language, version=None):
if version is None:version=os.environ.get('version', 'v2')
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
if(language not in language_module_map): if(language not in language_module_map):
language="en" language="en"
text=" " text=" "
for special_s, special_l, target_symbol in special: for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l: if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol) return clean_special(text, language, special_s, target_symbol, version)
language_module = language_module_map[language] language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
if hasattr(language_module,"text_normalize"):
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
if language == "zh": else:
norm_text=text
if language == "zh" or language=="yue":##########
phones, word2ph = language_module.g2p(norm_text) phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph) assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph) assert len(norm_text) == len(word2ph)
elif language == "en":
phones = language_module.g2p(norm_text)
if len(phones) < 4:
phones = [','] * (4 - len(phones)) + phones
word2ph = None
else: else:
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
word2ph = None word2ph = None
for ph in phones: for ph in phones:
assert ph in symbols phones = ['UNK' if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text return phones, word2ph, norm_text
def clean_special(text, language, special_s, target_symbol): def clean_special(text, language, special_s, target_symbol, version=None):
if version is None:version=os.environ.get('version', 'v2')
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
""" """
特殊静音段sp符号处理 特殊静音段sp符号处理
""" """
text = text.replace(special_s, ",") text = text.replace(special_s, ",")
language_module = language_module_map[language] language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
new_ph = [] new_ph = []
@ -49,9 +82,11 @@ def clean_special(text, language, special_s, target_symbol):
return new_ph, phones[1], norm_text return new_ph, phones[1], norm_text
def text_to_sequence(text, language): def text_to_sequence(text, language, version=None):
version = os.environ.get('version',version)
if version is None:version='v2'
phones = clean_text(text) phones = clean_text(text)
return cleaned_text_to_sequence(phones) return cleaned_text_to_sequence(phones, version)
if __name__ == "__main__": if __name__ == "__main__":

Binary file not shown.

View File

@ -1,2 +1,3 @@
CHATGPT CH AE1 T JH IY1 P IY1 T IY1 CHATGPT CH AE1 T JH IY1 P IY1 T IY1
JSON JH EY1 S AH0 N JSON JH EY1 S AH0 N
CONDA K AA1 N D AH0

View File

@ -4,9 +4,9 @@ import re
import wordsegment import wordsegment
from g2p_en import G2p from g2p_en import G2p
from string import punctuation from text.symbols import punctuation
from text import symbols from text.symbols2 import symbols
import unicodedata import unicodedata
from builtins import str as unicode from builtins import str as unicode
@ -110,6 +110,13 @@ def replace_phs(phs):
return phs_new return phs_new
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def read_dict(): def read_dict():
g2p_dict = {} g2p_dict = {}
start_line = 49 start_line = 49
@ -234,6 +241,9 @@ def text_normalize(text):
text = re.sub(r"(?i)i\.e\.", "that is", text) text = re.sub(r"(?i)i\.e\.", "that is", text)
text = re.sub(r"(?i)e\.g\.", "for example", text) text = re.sub(r"(?i)e\.g\.", "for example", text)
# 避免重复标点引起的参考泄露
text = replace_consecutive_punctuation(text)
return text return text

View File

@ -0,0 +1 @@
from text.g2pw.g2pw import *

View File

@ -0,0 +1,166 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
from .utils import tokenize_and_map
ANCHOR_CHAR = ''
def prepare_onnx_input(tokenizer,
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool=False,
window_size: int=None,
max_len: int=512) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
phoneme_masks = []
char_ids = []
position_ids = []
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(
tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
input_id = list(
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
char_id = chars.index(query_char)
position_id = text2token[
query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
attention_masks.append(attention_mask)
phoneme_masks.append(phoneme_mask)
char_ids.append(char_id)
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids).astype(np.int64),
'token_type_ids': np.array(token_type_ids).astype(np.int64),
'attention_masks': np.array(attention_masks).astype(np.int64),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids).astype(np.int64),
'position_ids': np.array(position_ids).astype(np.int64),
}
return outputs
def _truncate_texts(window_size: int, texts: List[str],
query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
start = max(0, query_id - window_size // 2)
end = min(len(text), query_id + window_size // 2)
truncated_text = text[start:end]
truncated_texts.append(truncated_text)
truncated_query_id = query_id - start
truncated_query_ids.append(truncated_query_id)
return truncated_texts, truncated_query_ids
def _truncate(max_len: int,
text: str,
query_id: int,
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
token_position = text2token[query_id]
token_start = token_position - truncate_len // 2
token_end = token_start + truncate_len
font_exceed_dist = -token_start
back_exceed_dist = token_end - len(tokens)
if font_exceed_dist > 0:
token_start += font_exceed_dist
token_end += font_exceed_dist
elif back_exceed_dist > 0:
token_start -= back_exceed_dist
token_end -= back_exceed_dist
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (text[start:end], query_id - start, tokens[token_start:token_end], [
i - token_start if i is not None else None
for i in text2token[start:end]
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
def get_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(phoneme))
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
return labels, char2phonemes

View File

@ -0,0 +1,154 @@
# This code is modified from https://github.com/mozillazg/pypinyin-g2pW
import pickle
import os
from pypinyin.constants import RE_HANS
from pypinyin.core import Pinyin, Style
from pypinyin.seg.simpleseg import simple_seg
from pypinyin.converter import UltimateConverter
from pypinyin.contrib.tone_convert import to_tone
from .onnx_api import G2PWOnnxConverter
current_file_path = os.path.dirname(__file__)
CACHE_PATH = os.path.join(current_file_path, "polyphonic.pickle")
PP_DICT_PATH = os.path.join(current_file_path, "polyphonic.rep")
PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
class G2PWPinyin(Pinyin):
def __init__(self, model_dir='G2PWModel/', model_source=None,
enable_non_tradional_chinese=True,
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
self._g2pw = G2PWOnnxConverter(
model_dir=model_dir,
style='pinyin',
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
self._converter = Converter(
self._g2pw, v_to_u=v_to_u,
neutral_tone_with_five=neutral_tone_with_five,
tone_sandhi=tone_sandhi,
)
def get_seg(self, **kwargs):
return simple_seg
class Converter(UltimateConverter):
def __init__(self, g2pw_instance, v_to_u=False,
neutral_tone_with_five=False,
tone_sandhi=False, **kwargs):
super(Converter, self).__init__(
v_to_u=v_to_u,
neutral_tone_with_five=neutral_tone_with_five,
tone_sandhi=tone_sandhi, **kwargs)
self._g2pw = g2pw_instance
def convert(self, words, style, heteronym, errors, strict, **kwargs):
pys = []
if RE_HANS.match(words):
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
errors=errors, strict=strict)
post_data = self.post_pinyin(words, heteronym, pys)
if post_data is not None:
pys = post_data
pys = self.convert_styles(
pys, words, style, heteronym, errors, strict)
else:
py = self.handle_nopinyin(words, style=style, errors=errors,
heteronym=heteronym, strict=strict)
if py:
pys.extend(py)
return _remove_dup_and_empty(pys)
def _to_pinyin(self, han, style, heteronym, errors, strict, **kwargs):
pinyins = []
g2pw_pinyin = self._g2pw(han)
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
return super(Converter, self).convert(
han, Style.TONE, heteronym, errors, strict, **kwargs)
for i, item in enumerate(g2pw_pinyin[0]):
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
py = super(Converter, self).convert(
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
pinyins.extend(py)
else:
pinyins.append([to_tone(item)])
return pinyins
def _remove_dup_items(lst, remove_empty=False):
new_lst = []
for item in lst:
if remove_empty and not item:
continue
if item not in new_lst:
new_lst.append(item)
return new_lst
def _remove_dup_and_empty(lst_list):
new_lst_list = []
for lst in lst_list:
lst = _remove_dup_items(lst, remove_empty=True)
if lst:
new_lst_list.append(lst)
else:
new_lst_list.append([''])
return new_lst_list
def cache_dict(polyphonic_dict, file_path):
with open(file_path, "wb") as pickle_file:
pickle.dump(polyphonic_dict, pickle_file)
def get_dict():
if os.path.exists(CACHE_PATH):
with open(CACHE_PATH, "rb") as pickle_file:
polyphonic_dict = pickle.load(pickle_file)
else:
polyphonic_dict = read_dict()
cache_dict(polyphonic_dict, CACHE_PATH)
return polyphonic_dict
def read_dict():
polyphonic_dict = {}
with open(PP_DICT_PATH) as f:
line = f.readline()
while line:
key, value_str = line.split(':')
value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value
line = f.readline()
with open(PP_FIX_DICT_PATH) as f:
line = f.readline()
while line:
key, value_str = line.split(':')
value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value
line = f.readline()
return polyphonic_dict
def correct_pronunciation(word,word_pinyins):
if word in pp_dict:
word_pinyins = pp_dict[word]
return word_pinyins
pp_dict = get_dict()

View File

@ -0,0 +1,241 @@
# This code is modified from https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw
# This code is modified from https://github.com/GitYCC/g2pW
import warnings
warnings.filterwarnings("ignore")
import json
import os
import zipfile,requests
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
import onnxruntime
onnxruntime.set_default_logger_severity(3)
from opencc import OpenCC
from transformers import AutoTokenizer
from pypinyin import pinyin
from pypinyin import Style
from .dataset import get_char_phoneme_labels
from .dataset import get_phoneme_labels
from .dataset import prepare_onnx_input
from .utils import load_config
from ..zh_normalization.char_convert import tranditional_to_simplified
model_version = '1.1'
def predict(session, onnx_input: Dict[str, Any],
labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = []
all_confidences = []
probs = session.run([], {
"input_ids": onnx_input['input_ids'],
"token_type_ids": onnx_input['token_type_ids'],
"attention_mask": onnx_input['attention_masks'],
"phoneme_mask": onnx_input['phoneme_masks'],
"char_ids": onnx_input['char_ids'],
"position_ids": onnx_input['position_ids']
})[0]
preds = np.argmax(probs, axis=1).tolist()
max_probs = []
for index, arr in zip(preds, probs.tolist()):
max_probs.append(arr[index])
all_preds += [labels[pred] for pred in preds]
all_confidences += max_probs
return all_preds, all_confidences
def download_and_decompress(model_dir: str='G2PWModel/'):
if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir)
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip")
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory,"G2PWModel")
print("Downloading g2pw model...")
modelscope_url = "https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
with requests.get(modelscope_url, stream=True) as r:
r.raise_for_status()
with open(zip_dir, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Extracting g2pw model...")
with zipfile.ZipFile(zip_dir, "r") as zip_ref:
zip_ref.extractall(parent_directory)
os.rename(extract_dir, extract_dir_new)
return model_dir
class G2PWOnnxConverter:
def __init__(self,
model_dir: str='G2PWModel/',
style: str='bopomofo',
model_source: str=None,
enable_non_tradional_chinese: bool=False):
uncompress_path = download_and_decompress(model_dir)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, 'g2pW.onnx'),
sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# sess_options=sess_options)
self.config = load_config(
config_path=os.path.join(uncompress_path, 'config.py'),
use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
polyphonic_chars_path = os.path.join(uncompress_path,
'POLYPHONIC_CHARS.txt')
monophonic_chars_path = os.path.join(uncompress_path,
'MONOPHONIC_CHARS.txt')
self.polyphonic_chars = [
line.split('\t')
for line in open(polyphonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
]
self.non_polyphonic = {
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', ''
}
self.non_monophonic = {'', ''}
self.monophonic_chars = [
line.split('\t')
for line in open(monophonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
polyphonic_chars=self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
polyphonic_chars=self.polyphonic_chars)
self.chars = sorted(list(self.char2phonemes.keys()))
self.polyphonic_chars_new = set(self.chars)
for char in self.non_polyphonic:
if char in self.polyphonic_chars_new:
self.polyphonic_chars_new.remove(char)
self.monophonic_chars_dict = {
char: phoneme
for char, phoneme in self.monophonic_chars
}
for char in self.non_monophonic:
if char in self.monophonic_chars_dict:
self.monophonic_chars_dict.pop(char)
self.pos_tags = [
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
]
with open(
os.path.join(uncompress_path,
'bopomofo_to_pinyin_wo_tune_dict.json'),
'r',
encoding='utf-8') as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin,
}[style]
with open(
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
'r',
encoding='utf-8') as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC('s2tw')
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
assert tone in '12345'
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
else:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
def __call__(self, sentences: List[str]) -> List[List[str]]:
if isinstance(sentences, str):
sentences = [sentences]
if self.enable_opencc:
translated_sentences = []
for sent in sentences:
translated_sent = self.cc.convert(sent)
assert len(translated_sent) == len(sent)
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
onnx_input = prepare_onnx_input(
tokenizer=self.tokenizer,
labels=self.labels,
char2phonemes=self.char2phonemes,
chars=self.chars,
texts=texts,
query_ids=query_ids,
use_mask=self.config.use_mask,
window_size=None)
preds, confidences = predict(
session=self.session_g2pW,
onnx_input=onnx_input,
labels=self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
results[sent_id][query_id] = self.style_convert_func(pred)
return results
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent)
pypinyin_result = pinyin(
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in self.polyphonic_chars_new:
texts.append(sent)
query_ids.append(i)
sent_ids.append(sent_id)
elif char in self.monophonic_chars_dict:
partial_result[i] = self.style_convert_func(
self.monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
else:
partial_result[i] = pypinyin_result[i][0]
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -0,0 +1,53 @@
湖泊: ['hu2','po1']
地壳: ['di4','qiao4']
柏树: ['bai3','shu4']
曝光: ['bao4','guang1']
弹力: ['tan2','li4']
字帖: ['zi4','tie4']
口吃: ['kou3','chi1']
包扎: ['bao1','za1']
哪吒: ['ne2','zha1']
说服: ['shuo1','fu2']
识字: ['shi2','zi4']
骨头: ['gu3','tou5']
对称: ['dui4','chen4']
口供: ['kou3','gong4']
抹布: ['ma1','bu4']
露背: ['lu4','bei4']
圈养: ['juan4', 'yang3']
眼眶: ['yan3', 'kuang4']
品行: ['pin3','xing2']
颤抖: ['chan4','dou3']
差不多: ['cha4','bu5','duo1']
鸭绿江: ['ya1','lu4','jiang1']
撒切尔: ['sa4','qie4','er3']
比比皆是: ['bi3','bi3','jie1','shi4']
身无长物: ['shen1','wu2','chang2','wu4']
手里: ['shou2','li3']
关卡: ['guan1','qia3']
怀揣: ['huai2','chuai1']
挑剔: ['tiao1','ti4']
供称: ['gong4','cheng1']
作坊: ['zuo1', 'fang5']
中医: ['zhong1','yi1']
嚷嚷: ['rang1','rang5']
商厦: ['shang1','sha4']
大厦: ['da4','sha4']
刹车: ['sha1','che1']
嘚瑟: ['de4','se5']
朝鲜: ['chao2','xian3']
阿房宫: ['e1','pang2','gong1']
阿胶: ['e1','jiao1']
咖喱: ['ga1','li5']
时分: ['shi2','fen1']
蚌埠: ['beng4','bu4']
驯服: ['xun4','fu2']
幸免于难: ['xing4','mian3','yu2','nan4']
恶行: ['e4','xing2']
唉: ['ai4']
扎实: ['zha1','shi2']
干将: ['gan4','jiang4']
陈威行: ['chen2', 'wei1', 'hang2']
郭晟: ['guo1', 'sheng4']
中标: ['zhong4', 'biao1']
抗住: ['kang2', 'zhu4']

View File

@ -0,0 +1,145 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import re
def wordize_and_map(text: str):
words = []
index_map_from_text_to_word = []
index_map_from_word_to_text = []
while len(text) > 0:
match_space = re.match(r'^ +', text)
if match_space:
space_str = match_space.group(0)
index_map_from_text_to_word += [None] * len(space_str)
text = text[len(space_str):]
continue
match_en = re.match(r'^[a-zA-Z0-9]+', text)
if match_en:
en_word = match_en.group(0)
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + len(en_word)
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)] * len(en_word)
words.append(en_word)
text = text[len(en_word):]
else:
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + 1
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)]
words.append(text[0])
text = text[1:]
return words, index_map_from_text_to_word, index_map_from_word_to_text
def tokenize_and_map(tokenizer, text: str):
words, text2word, word2text = wordize_and_map(text=text)
tokens = []
index_map_from_token_to_text = []
for word, (word_start, word_end) in zip(words, word2text):
word_tokens = tokenizer.tokenize(word)
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
index_map_from_token_to_text.append((word_start, word_end))
tokens.append('[UNK]')
else:
current_word_start = word_start
for word_token in word_tokens:
word_token_len = len(re.sub(r'^##', '', word_token))
index_map_from_token_to_text.append(
(current_word_start, current_word_start + word_token_len))
current_word_start = current_word_start + word_token_len
tokens.append(word_token)
index_map_from_text_to_token = text2word
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
for token_pos in range(token_start, token_end):
index_map_from_text_to_token[token_pos] = i
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
def _load_config(config_path: os.PathLike):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config
default_config_dict = {
'manual_seed': 1313,
'model_source': 'bert-base-chinese',
'window_size': 32,
'num_workers': 2,
'use_mask': True,
'use_char_phoneme': False,
'use_conditional': True,
'param_conditional': {
'affect_location': 'softmax',
'bias': True,
'char-linear': True,
'pos-linear': False,
'char+pos-second': True,
'char+pos-second_lowrank': False,
'lowrank_size': 0,
'char+pos-second_fm': False,
'fm_size': 0,
'fix_mode': None,
'count_json': 'train.count.json'
},
'lr': 5e-5,
'val_interval': 200,
'num_iter': 10000,
'use_focal': False,
'param_focal': {
'alpha': 0.0,
'gamma': 0.7
},
'use_pos': True,
'param_pos ': {
'weight': 0.1,
'pos_joint_training': True,
'train_pos_path': 'train.pos',
'valid_pos_path': 'dev.pos',
'test_pos_path': 'test.pos'
}
}
def load_config(config_path: os.PathLike, use_default: bool=False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():
if not hasattr(config, attr):
setattr(config, attr, val)
elif isinstance(val, dict):
d = getattr(config, attr)
for dict_k, dict_v in val.items():
if dict_k not in d:
d[dict_k] = dict_v
return config

View File

@ -4,8 +4,7 @@ import sys
import pyopenjtalk import pyopenjtalk
from text.symbols import punctuation
from text import symbols
# Regular expression matching Japanese without punctuation marks: # Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile( _japanese_characters = re.compile(
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
@ -56,15 +55,23 @@ def post_replace_ph(ph):
"": ",", "": ",",
"...": "", "...": "",
} }
if ph in rep_map.keys(): if ph in rep_map.keys():
ph = rep_map[ph] ph = rep_map[ph]
if ph in symbols: # if ph in symbols:
return ph # return ph
if ph not in symbols: # if ph not in symbols:
ph = "UNK" # ph = "UNK"
return ph return ph
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def symbols_to_japanese(text): def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese: for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
@ -94,6 +101,9 @@ def preprocess_jap(text, with_prosody=False):
def text_normalize(text): def text_normalize(text):
# todo: jap text normalize # todo: jap text normalize
# 避免重复标点引起的参考泄露
text = replace_consecutive_punctuation(text)
return text return text
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py # Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
@ -179,7 +189,7 @@ def _numeric_feature_by_regex(regex, s):
return -50 return -50
return int(match.group(1)) return int(match.group(1))
def g2p(norm_text, with_prosody=False): def g2p(norm_text, with_prosody=True):
phones = preprocess_jap(norm_text, with_prosody) phones = preprocess_jap(norm_text, with_prosody)
phones = [post_replace_ph(i) for i in phones] phones = [post_replace_ph(i) for i in phones]
# todo: implement tones and word2ph # todo: implement tones and word2ph

265
GPT_SoVITS/text/korean.py Normal file
View File

@ -0,0 +1,265 @@
# reference: https://github.com/ORI-Muchim/MB-iSTFT-VITS-Korean/blob/main/text/korean.py
import re
from jamo import h2j, j2hcj
import ko_pron
from g2pk2 import G2p
from text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
# List of (hangul, hangul divided) pairs:
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
# ('ㄵ', 'ㄴㅈ'),
# ('ㄶ', 'ㄴㅎ'),
# ('ㄺ', 'ㄹㄱ'),
# ('ㄻ', 'ㄹㅁ'),
# ('ㄼ', 'ㄹㅂ'),
# ('ㄽ', 'ㄹㅅ'),
# ('ㄾ', 'ㄹㅌ'),
# ('ㄿ', 'ㄹㅍ'),
# ('ㅀ', 'ㄹㅎ'),
# ('ㅄ', 'ㅂㅅ'),
('', 'ㅗㅏ'),
('', 'ㅗㅐ'),
('', 'ㅗㅣ'),
('', 'ㅜㅓ'),
('', 'ㅜㅔ'),
('', 'ㅜㅣ'),
('', 'ㅡㅣ'),
('', 'ㅣㅏ'),
('', 'ㅣㅐ'),
('', 'ㅣㅓ'),
('', 'ㅣㅔ'),
('', 'ㅣㅗ'),
('', 'ㅣㅜ')
]]
# List of (Latin alphabet, hangul) pairs:
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('a', '에이'),
('b', ''),
('c', ''),
('d', ''),
('e', ''),
('f', '에프'),
('g', ''),
('h', '에이치'),
('i', '아이'),
('j', '제이'),
('k', '케이'),
('l', ''),
('m', ''),
('n', ''),
('o', ''),
('p', ''),
('q', ''),
('r', '아르'),
('s', '에스'),
('t', ''),
('u', ''),
('v', '브이'),
('w', '더블유'),
('x', '엑스'),
('y', '와이'),
('z', '제트')
]]
# List of (ipa, lazy ipa) pairs:
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('t͡ɕ','ʧ'),
('d͡ʑ','ʥ'),
('ɲ','n^'),
('ɕ','ʃ'),
('ʷ','w'),
('ɭ','l`'),
('ʎ','ɾ'),
('ɣ','ŋ'),
('ɰ','ɯ'),
('ʝ','j'),
('ʌ','ə'),
('ɡ','g'),
('\u031a','#'),
('\u0348','='),
('\u031e',''),
('\u0320',''),
('\u0339','')
]]
def fix_g2pk2_error(text):
new_text = ""
i = 0
while i < len(text) - 4:
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == '':
new_text += text[i:i+3] + ' ' + ''
i += 5
else:
new_text += text[i]
i += 1
new_text += text[i:]
return new_text
def latin_to_hangul(text):
for regex, replacement in _latin_to_hangul:
text = re.sub(regex, replacement, text)
return text
def divide_hangul(text):
text = j2hcj(h2j(text))
for regex, replacement in _hangul_divided:
text = re.sub(regex, replacement, text)
return text
def hangul_number(num, sino=True):
'''Reference https://github.com/Kyubyong/g2pK'''
num = re.sub(',', '', num)
if num == '0':
return ''
if not sino and num == '20':
return '스무'
digits = '123456789'
names = '일이삼사오육칠팔구'
digit2name = {d: n for d, n in zip(digits, names)}
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
spelledout = []
for i, digit in enumerate(num):
i = len(num) - i - 1
if sino:
if i == 0:
name = digit2name.get(digit, '')
elif i == 1:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
else:
if i == 0:
name = digit2mod.get(digit, '')
elif i == 1:
name = digit2dec.get(digit, '')
if digit == '0':
if i % 4 == 0:
last_three = spelledout[-min(3, len(spelledout)):]
if ''.join(last_three) == '':
spelledout.append('')
continue
else:
spelledout.append('')
continue
if i == 2:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
elif i == 3:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
elif i == 4:
name = digit2name.get(digit, '') + ''
name = name.replace('일만', '')
elif i == 5:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
elif i == 6:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
elif i == 7:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
elif i == 8:
name = digit2name.get(digit, '') + ''
elif i == 9:
name = digit2name.get(digit, '') + ''
elif i == 10:
name = digit2name.get(digit, '') + ''
elif i == 11:
name = digit2name.get(digit, '') + ''
elif i == 12:
name = digit2name.get(digit, '') + ''
elif i == 13:
name = digit2name.get(digit, '') + ''
elif i == 14:
name = digit2name.get(digit, '') + ''
elif i == 15:
name = digit2name.get(digit, '') + ''
spelledout.append(name)
return ''.join(elem for elem in spelledout)
def number_to_hangul(text):
'''Reference https://github.com/Kyubyong/g2pK'''
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
for token in tokens:
num, classifier = token
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
spelledout = hangul_number(num, sino=False)
else:
spelledout = hangul_number(num, sino=True)
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
# digit by digit for remaining digits
digits = '0123456789'
names = '영일이삼사오육칠팔구'
for d, n in zip(digits, names):
text = text.replace(d, n)
return text
def korean_to_lazy_ipa(text):
text = latin_to_hangul(text)
text = number_to_hangul(text)
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
for regex, replacement in _ipa_to_lazy_ipa:
text = re.sub(regex, replacement, text)
return text
_g2p=G2p()
def korean_to_ipa(text):
text = latin_to_hangul(text)
text = number_to_hangul(text)
text = _g2p(text)
text = fix_g2pk2_error(text)
text = korean_to_lazy_ipa(text)
return text.replace('ʧ','').replace('ʥ','')
def post_replace_ph(ph):
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
"": ",",
"...": "",
" ": "",
}
if ph in rep_map.keys():
ph = rep_map[ph]
if ph in symbols:
return ph
if ph not in symbols:
ph = ""
return ph
def g2p(text):
text = latin_to_hangul(text)
text = _g2p(text)
text = divide_hangul(text)
text = fix_g2pk2_error(text)
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
# text = "".join([post_replace_ph(i) for i in text])
text = [post_replace_ph(i) for i in text]
return text

419
GPT_SoVITS/text/symbols2.py Normal file
View File

@ -0,0 +1,419 @@
import os
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-")
pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
pad = "_"
c = [
"AA",
"EE",
"OO",
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"w",
"x",
"y",
"z",
"zh",
]
v = [
"E1",
"En1",
"a1",
"ai1",
"an1",
"ang1",
"ao1",
"e1",
"ei1",
"en1",
"eng1",
"er1",
"i1",
"i01",
"ia1",
"ian1",
"iang1",
"iao1",
"ie1",
"in1",
"ing1",
"iong1",
"ir1",
"iu1",
"o1",
"ong1",
"ou1",
"u1",
"ua1",
"uai1",
"uan1",
"uang1",
"ui1",
"un1",
"uo1",
"v1",
"van1",
"ve1",
"vn1",
"E2",
"En2",
"a2",
"ai2",
"an2",
"ang2",
"ao2",
"e2",
"ei2",
"en2",
"eng2",
"er2",
"i2",
"i02",
"ia2",
"ian2",
"iang2",
"iao2",
"ie2",
"in2",
"ing2",
"iong2",
"ir2",
"iu2",
"o2",
"ong2",
"ou2",
"u2",
"ua2",
"uai2",
"uan2",
"uang2",
"ui2",
"un2",
"uo2",
"v2",
"van2",
"ve2",
"vn2",
"E3",
"En3",
"a3",
"ai3",
"an3",
"ang3",
"ao3",
"e3",
"ei3",
"en3",
"eng3",
"er3",
"i3",
"i03",
"ia3",
"ian3",
"iang3",
"iao3",
"ie3",
"in3",
"ing3",
"iong3",
"ir3",
"iu3",
"o3",
"ong3",
"ou3",
"u3",
"ua3",
"uai3",
"uan3",
"uang3",
"ui3",
"un3",
"uo3",
"v3",
"van3",
"ve3",
"vn3",
"E4",
"En4",
"a4",
"ai4",
"an4",
"ang4",
"ao4",
"e4",
"ei4",
"en4",
"eng4",
"er4",
"i4",
"i04",
"ia4",
"ian4",
"iang4",
"iao4",
"ie4",
"in4",
"ing4",
"iong4",
"ir4",
"iu4",
"o4",
"ong4",
"ou4",
"u4",
"ua4",
"uai4",
"uan4",
"uang4",
"ui4",
"un4",
"uo4",
"v4",
"van4",
"ve4",
"vn4",
"E5",
"En5",
"a5",
"ai5",
"an5",
"ang5",
"ao5",
"e5",
"ei5",
"en5",
"eng5",
"er5",
"i5",
"i05",
"ia5",
"ian5",
"iang5",
"iao5",
"ie5",
"in5",
"ing5",
"iong5",
"ir5",
"iu5",
"o5",
"ong5",
"ou5",
"u5",
"ua5",
"uai5",
"uan5",
"uang5",
"ui5",
"un5",
"uo5",
"v5",
"van5",
"ve5",
"vn5",
]
v_without_tone = [
"E",
"En",
"a",
"ai",
"an",
"ang",
"ao",
"e",
"ei",
"en",
"eng",
"er",
"i",
"i0",
"ia",
"ian",
"iang",
"iao",
"ie",
"in",
"ing",
"iong",
"ir",
"iu",
"o",
"ong",
"ou",
"u",
"ua",
"uai",
"uan",
"uang",
"ui",
"un",
"uo",
"v",
"van",
"ve",
"vn",
]
# japanese
ja_symbols = [
"I",
"N",
"U",
"a",
"b",
"by",
"ch",
"cl",
"d",
"dy",
"e",
"f",
"g",
"gy",
"h",
"hy",
"i",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"p",
"py",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"u",
"v",
"w",
"y",
"z",
###楼下2个留到后面加
# "[", #上升调型
# "]", #下降调型
# "$", #结束符
# "^", #开始符
]
arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停'
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
yue_symbols={'Yeot3', 'Yip1', 'Yyu3', 'Yeng4', 'Yut5', 'Yaan5', 'Ym5', 'Yaan6', 'Yang1', 'Yun4', 'Yon2', 'Yui5', 'Yun2', 'Yat3', 'Ye', 'Yeot1', 'Yoeng5', 'Yoek2', 'Yam2', 'Yeon6', 'Yu6', 'Yiu3', 'Yaang6', 'Yp5', 'Yai4', 'Yoek4', 'Yit6', 'Yam5', 'Yoeng6', 'Yg1', 'Yk3', 'Yoe4', 'Yam3', 'Yc', 'Yyu4', 'Yyut1', 'Yiu4', 'Ying3', 'Yip3', 'Yaap3', 'Yau3', 'Yan4', 'Yau1', 'Yap4', 'Yk6', 'Yok3', 'Yai1', 'Yeot6', 'Yan2', 'Yoek6', 'Yt1', 'Yoi1', 'Yit5', 'Yn4', 'Yaau3', 'Yau4', 'Yuk6', 'Ys', 'Yuk', 'Yin6', 'Yung6', 'Ya', 'You', 'Yaai5', 'Yau5', 'Yoi3', 'Yaak3', 'Yaat3', 'Ying2', 'Yok5', 'Yeng2', 'Yyut3', 'Yam1', 'Yip5', 'You1', 'Yam6', 'Yaa5', 'Yi6', 'Yek4', 'Yyu2', 'Yuk5', 'Yaam1', 'Yang2', 'Yai', 'Yiu6', 'Yin4', 'Yok4', 'Yot3', 'Yui2', 'Yeoi5', 'Yyun6', 'Yyu5', 'Yoi5', 'Yeot2', 'Yim4', 'Yeoi2', 'Yaan1', 'Yang6', 'Yong1', 'Yaang4', 'Yung5', 'Yeon1', 'Yin2', 'Ya3', 'Yaang3', 'Yg', 'Yk2', 'Yaau5', 'Yut1', 'Yt5', 'Yip4', 'Yung4', 'Yj', 'Yong3', 'Ya1', 'Yg6', 'Yaau6', 'Yit3', 'Yun3', 'Ying1', 'Yn2', 'Yg4', 'Yl', 'Yp3', 'Yn3', 'Yak1', 'Yang5', 'Yoe6', 'You2', 'Yap2', 'Yak2', 'Yt3', 'Yot5', 'Yim2', 'Yi1', 'Yn6', 'Yaat5', 'Yaam3', 'Yoek5', 'Ye3', 'Yeon4', 'Yaa2', 'Yu3', 'Yim6', 'Ym', 'Yoe3', 'Yaai2', 'Ym2', 'Ya6', 'Yeng6', 'Yik4', 'Yot4', 'Yaai4', 'Yyun3', 'Yu1', 'Yoeng1', 'Yaap2', 'Yuk3', 'Yoek3', 'Yeng5', 'Yeoi1', 'Yiu2', 'Yok1', 'Yo1', 'Yoek1', 'Yoeng2', 'Yeon5', 'Yiu1', 'Yoeng4', 'Yuk2', 'Yat4', 'Yg5', 'Yut4', 'Yan6', 'Yin3', 'Yaa6', 'Yap1', 'Yg2', 'Yoe5', 'Yt4', 'Ya5', 'Yo4', 'Yyu1', 'Yak3', 'Yeon2', 'Yong4', 'Ym1', 'Ye2', 'Yaang5', 'Yoi2', 'Yeng3', 'Yn', 'Yyut4', 'Yau', 'Yaak2', 'Yaan4', 'Yek2', 'Yin1', 'Yi5', 'Yoe2', 'Yei5', 'Yaat6', 'Yak5', 'Yp6', 'Yok6', 'Yei2', 'Yaap1', 'Yyut5', 'Yi4', 'Yim1', 'Yk5', 'Ye4', 'Yok2', 'Yaam6', 'Yat2', 'Yon6', 'Yei3', 'Yyu6', 'Yeot5', 'Yk4', 'Yai6', 'Yd', 'Yg3', 'Yei6', 'Yau2', 'Yok', 'Yau6', 'Yung3', 'Yim5', 'Yut6', 'Yit1', 'Yon3', 'Yat1', 'Yaam2', 'Yyut2', 'Yui6', 'Yt2', 'Yek6', 'Yt', 'Ye6', 'Yang3', 'Ying6', 'Yaau1', 'Yeon3', 'Yng', 'Yh', 'Yang4', 'Ying5', 'Yaap6', 'Yoeng3', 'Yyun4', 'You3', 'Yan5', 'Yat5', 'Yot1', 'Yun1', 'Yi3', 'Yaa1', 'Yaap4', 'You6', 'Yaang2', 'Yaap5', 'Yaa3', 'Yaak6', 'Yeng1', 'Yaak1', 'Yo5', 'Yoi4', 'Yam4', 'Yik1', 'Ye1', 'Yai5', 'Yung1', 'Yp2', 'Yui4', 'Yaak4', 'Yung2', 'Yak4', 'Yaat4', 'Yeoi4', 'Yut2', 'Yin5', 'Yaau4', 'Yap6', 'Yb', 'Yaam4', 'Yw', 'Yut3', 'Yong2', 'Yt6', 'Yaai6', 'Yap5', 'Yik5', 'Yun6', 'Yaam5', 'Yun5', 'Yik3', 'Ya2', 'Yyut6', 'Yon4', 'Yk1', 'Yit4', 'Yak6', 'Yaan2', 'Yuk1', 'Yai2', 'Yik2', 'Yaat2', 'Yo3', 'Ykw', 'Yn5', 'Yaa', 'Ye5', 'Yu4', 'Yei1', 'Yai3', 'Yyun5', 'Yip2', 'Yaau2', 'Yiu5', 'Ym4', 'Yeoi6', 'Yk', 'Ym6', 'Yoe1', 'Yeoi3', 'Yon', 'Yuk4', 'Yaai3', 'Yaa4', 'Yot6', 'Yaang1', 'Yei4', 'Yek1', 'Yo', 'Yp', 'Yo6', 'Yp4', 'Yan3', 'Yoi', 'Yap3', 'Yek3', 'Yim3', 'Yz', 'Yot2', 'Yoi6', 'Yit2', 'Yu5', 'Yaan3', 'Yan1', 'Yon5', 'Yp1', 'Yong5', 'Ygw', 'Yak', 'Yat6', 'Ying4', 'Yu2', 'Yf', 'Ya4', 'Yon1', 'You4', 'Yik6', 'Yui1', 'Yaat1', 'Yeot4', 'Yi2', 'Yaai1', 'Yek5', 'Ym3', 'Yong6', 'You5', 'Yyun1', 'Yn1', 'Yo2', 'Yip6', 'Yui3', 'Yaak5', 'Yyun2'}
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
symbols = sorted(set(symbols))
# print(len(symbols))
symbols+=["[","]"]##日文新增上升下降调型
symbols+=sorted(list(ko_symbols))
symbols+=sorted(list(yue_symbols))##新加的yue统一摆在后头#已查过开头加Y后没有重复韩文显然不会重复
# print(len(symbols))
if __name__ == "__main__":
print(len(symbols))
'''
粤语
732-353=379
韩文+粤语
732-322=410
'''

View File

@ -681,7 +681,6 @@ class ToneSandhi:
and seg[i - 1][0] == "" and seg[i - 1][0] == ""
and seg[i - 2][0] == word and seg[i - 2][0] == word
and pos == "v" and pos == "v"
and seg[i - 2][1] == "v"
): ):
continue continue
else: else:

View File

@ -28,7 +28,7 @@ UNITS = OrderedDict({
8: '亿', 8: '亿',
}) })
COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)' COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
# 分数表达式 # 分数表达式
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
@ -107,8 +107,11 @@ def replace_default_num(match):
# 加减乘除 # 加减乘除
# RE_ASMD = re.compile(
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
RE_ASMD = re.compile( RE_ASMD = re.compile(
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))') r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
asmd_map = { asmd_map = {
'+': '', '+': '',
'-': '', '-': '',
@ -117,7 +120,6 @@ asmd_map = {
'=': '等于' '=': '等于'
} }
def replace_asmd(match) -> str: def replace_asmd(match) -> str:
""" """
Args: Args:
@ -129,6 +131,39 @@ def replace_asmd(match) -> str:
return result return result
# 次方专项
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
power_map = {
'': '0',
'¹': '1',
'²': '2',
'³': '3',
'': '4',
'': '5',
'': '6',
'': '7',
'': '8',
'': '9',
'ˣ': 'x',
'ʸ': 'y',
'': 'n'
}
def replace_power(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
power_num = ""
for m in match.group(0):
power_num += power_map[m]
result = "" + power_num + "次方"
return result
# 数字表达式 # 数字表达式
# 纯小数 # 纯小数
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))') RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')

View File

@ -35,6 +35,7 @@ from .num import RE_POSITIVE_QUANTIFIERS
from .num import RE_RANGE from .num import RE_RANGE
from .num import RE_TO_RANGE from .num import RE_TO_RANGE
from .num import RE_ASMD from .num import RE_ASMD
from .num import RE_POWER
from .num import replace_default_num from .num import replace_default_num
from .num import replace_frac from .num import replace_frac
from .num import replace_negative_num from .num import replace_negative_num
@ -44,6 +45,7 @@ from .num import replace_positive_quantifier
from .num import replace_range from .num import replace_range
from .num import replace_to_range from .num import replace_to_range
from .num import replace_asmd from .num import replace_asmd
from .num import replace_power
from .phonecode import RE_MOBILE_PHONE from .phonecode import RE_MOBILE_PHONE
from .phonecode import RE_NATIONAL_UNIFORM_NUMBER from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
from .phonecode import RE_TELEPHONE from .phonecode import RE_TELEPHONE
@ -114,6 +116,12 @@ class TextNormalizer():
sentence = sentence.replace('χ', '') sentence = sentence.replace('χ', '')
sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛') sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛')
sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽') sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽')
# 兜底数学运算,顺便兼容懒人用语
sentence = sentence.replace('+', '')
sentence = sentence.replace('-', '')
sentence = sentence.replace('×', '')
sentence = sentence.replace('÷', '')
sentence = sentence.replace('=', '')
# re filter special characters, have one more character "-" than line 68 # re filter special characters, have one more character "-" than line 68
sentence = re.sub(r'[-——《》【】<=>{}()#&@“”^_|\\]', '', sentence) sentence = re.sub(r'[-——《》【】<=>{}()#&@“”^_|\\]', '', sentence)
return sentence return sentence
@ -136,6 +144,12 @@ class TextNormalizer():
sentence = RE_TO_RANGE.sub(replace_to_range, sentence) sentence = RE_TO_RANGE.sub(replace_to_range, sentence)
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
sentence = replace_measure(sentence) sentence = replace_measure(sentence)
# 处理数学运算
while RE_ASMD.search(sentence):
sentence = RE_ASMD.sub(replace_asmd, sentence)
sentence = RE_POWER.sub(replace_power, sentence)
sentence = RE_FRAC.sub(replace_frac, sentence) sentence = RE_FRAC.sub(replace_frac, sentence)
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
@ -145,10 +159,6 @@ class TextNormalizer():
sentence = RE_RANGE.sub(replace_range, sentence) sentence = RE_RANGE.sub(replace_range, sentence)
# 处理加减乘除
while RE_ASMD.search(sentence):
sentence = RE_ASMD.sub(replace_asmd, sentence)
sentence = RE_INTEGER.sub(replace_negative_num, sentence) sentence = RE_INTEGER.sub(replace_negative_num, sentence)
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,

View File

@ -1,10 +1,12 @@
import json import json
import locale import locale
import os import os
import pathlib
i18n_dir = pathlib.Path(os.path.dirname(__file__)).as_posix().replace("tools/","")
def load_language_list(language): def load_language_list(language):
with open(f"./i18n/locale/{language}.json", "r", encoding="utf-8") as f: with open(f"{i18n_dir}/locale/{language}.json", "r", encoding="utf-8") as f:
language_list = json.load(f) language_list = json.load(f)
return language_list return language_list
@ -15,7 +17,7 @@ class I18nAuto:
language = locale.getdefaultlocale()[ language = locale.getdefaultlocale()[
0 0
] # getlocale can't identify the system's language ((None, None)) ] # getlocale can't identify the system's language ((None, None))
if not os.path.exists(f"./i18n/locale/{language}.json"): if not os.path.exists(f"{i18n_dir}/locale/{language}.json"):
language = "en_US" language = "en_US"
self.language = language self.language = language
self.language_map = load_language_list(language) self.language_map = load_language_list(language)