fix version bug, and more chores on version

switching
This commit is contained in:
XXXXRT666 2024-08-05 22:59:25 +08:00
parent fd370df259
commit 494c03107f
4 changed files with 13 additions and 18 deletions

View File

@ -14,7 +14,7 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import LangSegment, os, re, sys, json, importlib
import LangSegment, os, re, sys, json
import pdb
import torch
@ -37,7 +37,7 @@ with open(f"./weight.json", 'r', encoding="utf-8") as file:
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
sovits_path = os.environ.get(
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
# gpt_path = os.environ.get(
# "gpt_path", pretrained_gpt_name
# )
@ -370,7 +370,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
version = vq_model.version
if not ref_free:
prompt_text = prompt_text.strip("\n")
@ -644,14 +643,7 @@ def switch_version(version_, prompt_language, text_language,inp_ref):
os.environ['version']=version_
global pretrained_sovits_name, pretrained_gpt_name, version
version = version_
importlib.reload(sys.modules['text.symbols'])
importlib.reload(sys.modules['text'])
importlib.reload(sys.modules['text.english'])
importlib.reload(sys.modules['text.japanese'])
importlib.reload(sys.modules['text.korean'])
importlib.reload(sys.modules['text.cleaner'])
importlib.reload(sys.modules['AR.models'])
importlib.reload(sys.modules['module.models'])
print(version)
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
if prompt_language in list(dict_language.keys()):
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
@ -671,10 +663,9 @@ def switch_version(version_, prompt_language, text_language,inp_ref):
with open(f"./weight.json", 'r', encoding="utf-8") as file:
weight_data = file.read()
weight_data=json.loads(weight_data)
gpt_path = os.environ.get(
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
sovits_path = os.environ.get(
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
gpt_path = weight_data.get('GPT',{}).get(version,pretrained_gpt_name)
sovits_path = weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name)
print(gpt_path,sovits_path)
return {'__type__':'update', 'value':gpt_path, 'choices':GPT_names}, {'__type__':'update', 'value':sovits_path, 'choices':SoVITS_names}, {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update

View File

@ -10,13 +10,14 @@ from text import symbols2 as symbols_v2
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
def cleaned_text_to_sequence(cleaned_text, version):
def cleaned_text_to_sequence(cleaned_text, version=None):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
'''
if version is None:version=os.environ.get('version', 'v2')
if version == "v1":
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
else:

View File

@ -20,7 +20,8 @@ special = [
]
def clean_text(text, language, version):
def clean_text(text, language, version=None):
if version is None:version=os.environ.get('version', 'v2')
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}
@ -57,7 +58,8 @@ def clean_text(text, language, version):
return phones, word2ph, norm_text
def clean_special(text, language, special_s, target_symbol, version):
def clean_special(text, language, special_s, target_symbol, version=None):
if version is None:version=os.environ.get('version', 'v2')
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english}

1
api.py
View File

@ -522,6 +522,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_semantic = codes[0, 0]
t1 = ttime()
version = vq_model.version
os.environ['version'] = version
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)