diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index a7e775e..adf51b8 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -152,6 +152,11 @@ def change_sovits_weights(sovits_path): hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" + if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + # print("sovits版本:",hps.model.version) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -231,9 +236,9 @@ dict_language = { } -def clean_text_inf(text, language): - phones, word2ph, norm_text = clean_text(text, language) - phones = cleaned_text_to_sequence(phones) +def clean_text_inf(text, language, version): + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text dtype=torch.float16 if is_half == True else torch.float32 @@ -259,7 +264,7 @@ def get_first(text): return text from text import chinese -def get_phones_and_bert(text,language): +def get_phones_and_bert(text,language,version): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: language = language.replace("all_","") if language == "en": @@ -274,16 +279,16 @@ def get_phones_and_bert(text,language): if re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.text_normalize(formattext) - return get_phones_and_bert(formattext,"zh") + return get_phones_and_bert(formattext,"zh",version) else: - phones, word2ph, norm_text = clean_text_inf(formattext, language) + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) bert = get_bert_feature(norm_text, word2ph).to(device) elif language == "yue" and re.search(r'[A-Za-z]', formattext): formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = chinese.text_normalize(formattext) - return get_phones_and_bert(formattext,"yue") + return get_phones_and_bert(formattext,"yue",version) else: - phones, word2ph, norm_text = clean_text_inf(formattext, language) + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) bert = torch.zeros( (1024, len(phones)), dtype=torch.float16 if is_half == True else torch.float32, @@ -317,7 +322,7 @@ def get_phones_and_bert(text,language): norm_text_list = [] for i in range(len(textlist)): lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) bert = get_bert_inf(phones, word2ph, norm_text, lang) phones_list.append(phones) norm_text_list.append(norm_text) @@ -357,6 +362,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, t0 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] + + version = vq_model.version + if not ref_free: prompt_text = prompt_text.strip("\n") if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." @@ -413,7 +421,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, texts = merge_short_text_in_array(texts, 5) audio_opt = [] if not ref_free: - phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language) + phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version) for i_text,text in enumerate(texts): # 解决输入目标文本的空行导致报错的问题 @@ -421,7 +429,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, continue if (text[-1] not in splits): text += "。" if text_language != "en" else "." print(i18n("实际输入的目标文本(每句):"), text) - phones2,bert2,norm_text2=get_phones_and_bert(text, text_language) + phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version) print(i18n("前端处理后的文本(每句):"), norm_text2) if not ref_free: bert = torch.cat([bert1, bert2], 1) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 92e6634..3d23575 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -15,7 +15,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from module.commons import init_weights, get_padding from module.mrte_model import MRTE from module.quantize import ResidualVectorQuantizer -from text import symbols +# from text import symbols +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 from torch.cuda.amp import autocast import contextlib @@ -185,6 +187,7 @@ class TextEncoder(nn.Module): kernel_size, p_dropout, latent_channels=192, + version = "v2", ): super().__init__() self.out_channels = out_channels @@ -195,6 +198,7 @@ class TextEncoder(nn.Module): self.kernel_size = kernel_size self.p_dropout = p_dropout self.latent_channels = latent_channels + self.version = version self.ssl_proj = nn.Conv1d(768, hidden_channels, 1) @@ -210,6 +214,11 @@ class TextEncoder(nn.Module): self.encoder_text = attentions.Encoder( hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout ) + + if self.version == "v1": + symbols = symbols_v1.symbols + else: + symbols = symbols_v2.symbols self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.mrte = MRTE() @@ -827,6 +836,7 @@ class SynthesizerTrn(nn.Module): use_sdp=True, semantic_frame_rate=None, freeze_quantizer=None, + version = "v2", **kwargs ): super().__init__() @@ -847,6 +857,7 @@ class SynthesizerTrn(nn.Module): self.segment_size = segment_size self.n_speakers = n_speakers self.gin_channels = gin_channels + self.version = version self.use_sdp = use_sdp self.enc_p = TextEncoder( @@ -857,6 +868,7 @@ class SynthesizerTrn(nn.Module): n_layers, kernel_size, p_dropout, + version = version, ) self.dec = Generator( inter_channels, @@ -881,7 +893,7 @@ class SynthesizerTrn(nn.Module): inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels ) - self.version=os.environ.get("version","v1") + # self.version=os.environ.get("version","v1") if(self.version=="v1"): self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) else: diff --git a/GPT_SoVITS/text/__init__.py b/GPT_SoVITS/text/__init__.py index 01afcf8..e4d690e 100644 --- a/GPT_SoVITS/text/__init__.py +++ b/GPT_SoVITS/text/__init__.py @@ -1,18 +1,26 @@ import os -if os.environ.get("version","v1")=="v1": - from text.symbols import symbols -else: - from text.symbols2 import symbols +# if os.environ.get("version","v1")=="v1": +# from text.symbols import symbols +# else: +# from text.symbols2 import symbols -_symbol_to_id = {s: i for i, s in enumerate(symbols)} +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 -def cleaned_text_to_sequence(cleaned_text): +_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)} +_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)} + +def cleaned_text_to_sequence(cleaned_text, version): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. Args: text: string to convert to a sequence Returns: List of integers corresponding to the symbols in the text ''' - phones = [_symbol_to_id[symbol] for symbol in cleaned_text] + if version == "v1": + phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text] + else: + phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text] + return phones diff --git a/GPT_SoVITS/text/cleaner.py b/GPT_SoVITS/text/cleaner.py index b1e5007..4a4e440 100644 --- a/GPT_SoVITS/text/cleaner.py +++ b/GPT_SoVITS/text/cleaner.py @@ -1,13 +1,17 @@ from text import japanese, cleaned_text_to_sequence, english,korean,cantonese import os -if os.environ.get("version","v1")=="v1": - from text import chinese - from text.symbols import symbols -else: - from text import chinese2 as chinese - from text.symbols2 import symbols +# if os.environ.get("version","v1")=="v1": +# from text import chinese +# from text.symbols import symbols +# else: +# from text import chinese2 as chinese +# from text.symbols2 import symbols + +from text import symbols as symbols_v1 +from text import symbols2 as symbols_v2 +from text import chinese as chinese_v1 +from text import chinese2 as chinese_v2 -language_module_map = {"zh": chinese, "ja": japanese, "en": english, "ko": korean,"yue":cantonese} special = [ # ("%", "zh", "SP"), ("¥", "zh", "SP2"), @@ -16,13 +20,20 @@ special = [ ] -def clean_text(text, language): +def clean_text(text, language, version): + if version == "v1": + symbols = symbols_v1.symbols + language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english} + else: + symbols = symbols_v2.symbols + language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese} + if(language not in language_module_map): language="en" text=" " for special_s, special_l, target_symbol in special: if special_s in text and language == special_l: - return clean_special(text, language, special_s, target_symbol) + return clean_special(text, language, special_s, target_symbol, version) language_module = language_module_map[language] if hasattr(language_module,"text_normalize"): norm_text = language_module.text_normalize(text) @@ -42,11 +53,18 @@ def clean_text(text, language): word2ph = None for ph in phones: - assert ph in symbols + phones = ['UNK' if ph not in symbols else ph for ph in phones] return phones, word2ph, norm_text -def clean_special(text, language, special_s, target_symbol): +def clean_special(text, language, special_s, target_symbol, version): + if version == "v1": + symbols = symbols_v1.symbols + language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english} + else: + symbols = symbols_v2.symbols + language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese} + """ 特殊静音段sp符号处理 """ diff --git a/GPT_SoVITS/text/english.py b/GPT_SoVITS/text/english.py index d264925..ceee52b 100644 --- a/GPT_SoVITS/text/english.py +++ b/GPT_SoVITS/text/english.py @@ -6,10 +6,7 @@ from g2p_en import G2p from text.symbols import punctuation -if os.environ.get("version","v1")=="v1": - from text.symbols import symbols -else: - from text.symbols2 import symbols +from text.symbols2 import symbols import unicodedata from builtins import str as unicode diff --git a/GPT_SoVITS/text/japanese.py b/GPT_SoVITS/text/japanese.py index fd77955..4c10720 100644 --- a/GPT_SoVITS/text/japanese.py +++ b/GPT_SoVITS/text/japanese.py @@ -4,12 +4,6 @@ import sys import pyopenjtalk - -import os -if os.environ.get("version","v1")=="v1": - from text.symbols import symbols -else: - from text.symbols2 import symbols from text.symbols import punctuation # Regular expression matching Japanese without punctuation marks: _japanese_characters = re.compile( @@ -61,12 +55,13 @@ def post_replace_ph(ph): "、": ",", "...": "…", } + if ph in rep_map.keys(): ph = rep_map[ph] - if ph in symbols: - return ph - if ph not in symbols: - ph = "UNK" + # if ph in symbols: + # return ph + # if ph not in symbols: + # ph = "UNK" return ph diff --git a/GPT_SoVITS/text/korean.py b/GPT_SoVITS/text/korean.py index a783305..23dea59 100644 --- a/GPT_SoVITS/text/korean.py +++ b/GPT_SoVITS/text/korean.py @@ -2,11 +2,8 @@ import re from jamo import h2j, j2hcj import ko_pron from g2pk2 import G2p -import os -if os.environ.get("version","v1")=="v1": - from text.symbols import symbols -else: - from text.symbols2 import symbols + +from text.symbols2 import symbols # This is a list of Korean classifiers preceded by pure Korean numerals. _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' diff --git a/api.py b/api.py index 69cbdb4..6500f6f 100644 --- a/api.py +++ b/api.py @@ -11,7 +11,7 @@ 调用请求缺少参考音频时使用 `-dr` - `默认参考音频路径` `-dt` - `默认参考音频文本` -`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` +`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"` `-d` - `推理设备, "cuda","cpu"` `-a` - `绑定地址, 默认"127.0.0.1"` @@ -201,6 +201,11 @@ def change_sovits_weights(sovits_path): hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" + if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + print("sovits版本:",hps.model.version) model_params_dict = vars(hps.model) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, @@ -251,9 +256,9 @@ def get_bert_feature(text, word2ph): return phone_level_feature.T -def clean_text_inf(text, language): - phones, word2ph, norm_text = clean_text(text, language) - phones = cleaned_text_to_sequence(phones) +def clean_text_inf(text, language, version): + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) return phones, word2ph, norm_text @@ -269,54 +274,64 @@ def get_bert_inf(phones, word2ph, norm_text, language): return bert - -def get_phones_and_bert(text,language): - if language in {"en","all_zh","all_ja"}: +from text import chinese +def get_phones_and_bert(text,language,version): + if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: language = language.replace("all_","") if language == "en": LangSegment.setfilters(["en"]) formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) else: - # 因无法区别中日文汉字,以用户输入为准 + # 因无法区别中日韩文汉字,以用户输入为准 formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") - phones, word2ph, norm_text = clean_text_inf(formattext, language) if language == "zh": - bert = get_bert_feature(norm_text, word2ph).to(device) + if re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.text_normalize(formattext) + return get_phones_and_bert(formattext,"zh",version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = get_bert_feature(norm_text, word2ph).to(device) + elif language == "yue" and re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.text_normalize(formattext) + return get_phones_and_bert(formattext,"yue",version) else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) bert = torch.zeros( (1024, len(phones)), dtype=torch.float16 if is_half == True else torch.float32, ).to(device) - elif language in {"zh", "ja","auto"}: + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: textlist=[] langlist=[] LangSegment.setfilters(["zh","ja","en","ko"]) if language == "auto": for tmp in LangSegment.getTexts(text): - if tmp["lang"] == "ko": - langlist.append("zh") - textlist.append(tmp["text"]) - else: - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) else: for tmp in LangSegment.getTexts(text): if tmp["lang"] == "en": langlist.append(tmp["lang"]) else: - # 因无法区别中日文汉字,以用户输入为准 + # 因无法区别中日韩文汉字,以用户输入为准 langlist.append(language) textlist.append(tmp["text"]) - # logger.info(textlist) - # logger.info(langlist) phones_list = [] bert_list = [] norm_text_list = [] for i in range(len(textlist)): lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) bert = get_bert_inf(phones, word2ph, norm_text, lang) phones_list.append(phones) norm_text_list.append(norm_text) @@ -328,14 +343,32 @@ def get_phones_and_bert(text,language): return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text -class DictToAttrRecursive: +class DictToAttrRecursive(dict): def __init__(self, input_dict): + super().__init__(input_dict) for key, value in input_dict.items(): if isinstance(value, dict): - # 如果值是字典,递归调用构造函数 - setattr(self, key, DictToAttrRecursive(value)) - else: - setattr(self, key, value) + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") def get_spepc(hps, filename): @@ -488,9 +521,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] t1 = ttime() + version = vq_model.version prompt_language = dict_language[prompt_language.lower()] text_language = dict_language[text_language.lower()] - phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language) + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) texts = text.split("\n") audio_bytes = BytesIO() @@ -500,7 +534,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, continue audio_opt = [] - phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language) + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) @@ -606,17 +640,27 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu # -------------------------------- dict_language = { "中文": "all_zh", + "粤语": "all_yue", "英文": "en", "日文": "all_ja", + "韩文": "all_ko", "中英混合": "zh", + "粤英混合": "yue", "日英混合": "ja", + "韩英混合": "ko", "多语种混合": "auto", #多语种启动切分识别语种 + "多语种混合(粤语)": "auto_yue", "all_zh": "all_zh", + "all_yue": "all_yue", "en": "en", "all_ja": "all_ja", + "all_ko": "all_ko", "zh": "zh", + "yue": "yue", "ja": "ja", + "ko": "ko", "auto": "auto", + "auto_yue": "auto_yue", } # logger