Version Check (#1390)

* version check

* fix webui and symbols

* fix v1 language map
This commit is contained in:
KamioRinn 2024-08-05 17:24:42 +08:00 committed by GitHub
parent 0c25e57959
commit 4e34814c70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 157 additions and 78 deletions

View File

@ -152,6 +152,11 @@ def change_sovits_weights(sovits_path):
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
# print("sovits版本:",hps.model.version)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@ -231,9 +236,9 @@ dict_language = {
}
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
def clean_text_inf(text, language, version):
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
dtype=torch.float16 if is_half == True else torch.float32
@ -259,7 +264,7 @@ def get_first(text):
return text
from text import chinese
def get_phones_and_bert(text,language):
def get_phones_and_bert(text,language,version):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
@ -274,16 +279,16 @@ def get_phones_and_bert(text,language):
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"zh")
return get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"yue")
return get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
@ -317,7 +322,7 @@ def get_phones_and_bert(text,language):
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
@ -357,6 +362,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
version = vq_model.version
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
@ -413,7 +421,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
for i_text,text in enumerate(texts):
# 解决输入目标文本的空行导致报错的问题
@ -421,7 +429,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)

View File

@ -15,7 +15,9 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
from text import symbols
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
import contextlib
@ -185,6 +187,7 @@ class TextEncoder(nn.Module):
kernel_size,
p_dropout,
latent_channels=192,
version = "v2",
):
super().__init__()
self.out_channels = out_channels
@ -195,6 +198,7 @@ class TextEncoder(nn.Module):
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.latent_channels = latent_channels
self.version = version
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
@ -210,6 +214,11 @@ class TextEncoder(nn.Module):
self.encoder_text = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
if self.version == "v1":
symbols = symbols_v1.symbols
else:
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
@ -827,6 +836,7 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version = "v2",
**kwargs
):
super().__init__()
@ -847,6 +857,7 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
@ -857,6 +868,7 @@ class SynthesizerTrn(nn.Module):
n_layers,
kernel_size,
p_dropout,
version = version,
)
self.dec = Generator(
inter_channels,
@ -881,7 +893,7 @@ class SynthesizerTrn(nn.Module):
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.version=os.environ.get("version","v1")
# self.version=os.environ.get("version","v1")
if(self.version=="v1"):
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:

View File

@ -1,18 +1,26 @@
import os
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
# if os.environ.get("version","v1")=="v1":
# from text.symbols import symbols
# else:
# from text.symbols2 import symbols
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
def cleaned_text_to_sequence(cleaned_text):
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
def cleaned_text_to_sequence(cleaned_text, version):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
'''
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
if version == "v1":
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
else:
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
return phones

View File

@ -1,13 +1,17 @@
from text import japanese, cleaned_text_to_sequence, english,korean,cantonese
import os
if os.environ.get("version","v1")=="v1":
from text import chinese
from text.symbols import symbols
else:
from text import chinese2 as chinese
from text.symbols2 import symbols
# if os.environ.get("version","v1")=="v1":
# from text import chinese
# from text.symbols import symbols
# else:
# from text import chinese2 as chinese
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from text import chinese as chinese_v1
from text import chinese2 as chinese_v2
language_module_map = {"zh": chinese, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
special = [
# ("%", "zh", "SP"),
("", "zh", "SP2"),
@ -16,13 +20,20 @@ special = [
]
def clean_text(text, language):
def clean_text(text, language, version):
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
if(language not in language_module_map):
language="en"
text=" "
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol)
return clean_special(text, language, special_s, target_symbol, version)
language_module = language_module_map[language]
if hasattr(language_module,"text_normalize"):
norm_text = language_module.text_normalize(text)
@ -42,11 +53,18 @@ def clean_text(text, language):
word2ph = None
for ph in phones:
assert ph in symbols
phones = ['UNK' if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text
def clean_special(text, language, special_s, target_symbol):
def clean_special(text, language, special_s, target_symbol, version):
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": chinese_v2, "ja": japanese, "en": english, "ko": korean,"yue":cantonese}
"""
特殊静音段sp符号处理
"""

View File

@ -6,10 +6,7 @@ from g2p_en import G2p
from text.symbols import punctuation
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
from text.symbols2 import symbols
import unicodedata
from builtins import str as unicode

View File

@ -4,12 +4,6 @@ import sys
import pyopenjtalk
import os
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
from text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
@ -61,12 +55,13 @@ def post_replace_ph(ph):
"": ",",
"...": "",
}
if ph in rep_map.keys():
ph = rep_map[ph]
if ph in symbols:
return ph
if ph not in symbols:
ph = "UNK"
# if ph in symbols:
# return ph
# if ph not in symbols:
# ph = "UNK"
return ph

View File

@ -2,11 +2,8 @@ import re
from jamo import h2j, j2hcj
import ko_pron
from g2pk2 import G2p
import os
if os.environ.get("version","v1")=="v1":
from text.symbols import symbols
else:
from text.symbols2 import symbols
from text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'

100
api.py
View File

@ -11,7 +11,7 @@
调用请求缺少参考音频时使用
`-dr` - `默认参考音频路径`
`-dt` - `默认参考音频文本`
`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"`
`-d` - `推理设备, "cuda","cpu"`
`-a` - `绑定地址, 默认"127.0.0.1"`
@ -201,6 +201,11 @@ def change_sovits_weights(sovits_path):
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
print("sovits版本:",hps.model.version)
model_params_dict = vars(hps.model)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
@ -251,9 +256,9 @@ def get_bert_feature(text, word2ph):
return phone_level_feature.T
def clean_text_inf(text, language):
phones, word2ph, norm_text = clean_text(text, language)
phones = cleaned_text_to_sequence(phones)
def clean_text_inf(text, language, version):
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
@ -269,54 +274,64 @@ def get_bert_inf(phones, word2ph, norm_text, language):
return bert
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
from text import chinese
def get_phones_and_bert(text,language,version):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日文汉字,以用户输入为准
# 因无法区别中日文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
bert = get_bert_feature(norm_text, word2ph).to(device)
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja","auto"}:
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "ko":
langlist.append("zh")
textlist.append(tmp["text"])
else:
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日文汉字,以用户输入为准
# 因无法区别中日文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# logger.info(textlist)
# logger.info(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
@ -328,14 +343,32 @@ def get_phones_and_bert(text,language):
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
class DictToAttrRecursive:
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
# 如果值是字典,递归调用构造函数
setattr(self, key, DictToAttrRecursive(value))
else:
setattr(self, key, value)
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def get_spepc(hps, filename):
@ -488,9 +521,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
version = vq_model.version
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
texts = text.split("\n")
audio_bytes = BytesIO()
@ -500,7 +534,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
continue
audio_opt = []
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
@ -606,17 +640,27 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
# --------------------------------
dict_language = {
"中文": "all_zh",
"粤语": "all_yue",
"英文": "en",
"日文": "all_ja",
"韩文": "all_ko",
"中英混合": "zh",
"粤英混合": "yue",
"日英混合": "ja",
"韩英混合": "ko",
"多语种混合": "auto", #多语种启动切分识别语种
"多语种混合(粤语)": "auto_yue",
"all_zh": "all_zh",
"all_yue": "all_yue",
"en": "en",
"all_ja": "all_ja",
"all_ko": "all_ko",
"zh": "zh",
"yue": "yue",
"ja": "ja",
"ko": "ko",
"auto": "auto",
"auto_yue": "auto_yue",
}
# logger