diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index e926192e..18db06ef 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -14,7 +14,7 @@ logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("asyncio").setLevel(logging.ERROR) logging.getLogger("charset_normalizer").setLevel(logging.ERROR) logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) -import LangSegment, os, re, sys, json, importlib +import LangSegment, os, re, sys, json import pdb import torch @@ -37,7 +37,7 @@ with open(f"./weight.json", 'r', encoding="utf-8") as file: "gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name)) sovits_path = os.environ.get( "sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name)) - + # gpt_path = os.environ.get( # "gpt_path", pretrained_gpt_name # ) @@ -370,7 +370,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] - version = vq_model.version if not ref_free: prompt_text = prompt_text.strip("\n") @@ -644,14 +643,7 @@ def switch_version(version_, prompt_language, text_language,inp_ref): os.environ['version']=version_ global pretrained_sovits_name, pretrained_gpt_name, version version = version_ - importlib.reload(sys.modules['text.symbols']) - importlib.reload(sys.modules['text']) - importlib.reload(sys.modules['text.english']) - importlib.reload(sys.modules['text.japanese']) - importlib.reload(sys.modules['text.korean']) - importlib.reload(sys.modules['text.cleaner']) - importlib.reload(sys.modules['AR.models']) - importlib.reload(sys.modules['module.models']) + print(version) dict_language = dict_language_v1 if version =='v1' else dict_language_v2 if prompt_language in list(dict_language.keys()): prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language} @@ -671,10 +663,9 @@ def switch_version(version_, prompt_language, text_language,inp_ref): with open(f"./weight.json", 'r', encoding="utf-8") as file: weight_data = file.read() weight_data=json.loads(weight_data) - gpt_path = os.environ.get( - "gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name)) - sovits_path = os.environ.get( - "sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name)) + gpt_path = weight_data.get('GPT',{}).get(version,pretrained_gpt_name) + sovits_path = weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name) + print(gpt_path,sovits_path) return {'__type__':'update', 'value':gpt_path, 'choices':GPT_names}, {'__type__':'update', 'value':sovits_path, 'choices':SoVITS_names}, {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update diff --git a/GPT_SoVITS/text/__init__.py b/GPT_SoVITS/text/__init__.py index e4d690e1..2791d7ab 100644 --- a/GPT_SoVITS/text/__init__.py +++ b/GPT_SoVITS/text/__init__.py @@ -10,13 +10,14 @@ from text import symbols2 as symbols_v2 _symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)} _symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)} -def cleaned_text_to_sequence(cleaned_text, version): +def cleaned_text_to_sequence(cleaned_text, version=None): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. Args: text: string to convert to a sequence Returns: List of integers corresponding to the symbols in the text ''' + if version is None:version=os.environ.get('version', 'v2') if version == "v1": phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text] else: diff --git a/GPT_SoVITS/text/cleaner.py b/GPT_SoVITS/text/cleaner.py index 4a4e4404..3f6d4fdd 100644 --- a/GPT_SoVITS/text/cleaner.py +++ b/GPT_SoVITS/text/cleaner.py @@ -20,7 +20,8 @@ special = [ ] -def clean_text(text, language, version): +def clean_text(text, language, version=None): + if version is None:version=os.environ.get('version', 'v2') if version == "v1": symbols = symbols_v1.symbols language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english} @@ -57,7 +58,8 @@ def clean_text(text, language, version): return phones, word2ph, norm_text -def clean_special(text, language, special_s, target_symbol, version): +def clean_special(text, language, special_s, target_symbol, version=None): + if version is None:version=os.environ.get('version', 'v2') if version == "v1": symbols = symbols_v1.symbols language_module_map = {"zh": chinese_v1, "ja": japanese, "en": english} diff --git a/api.py b/api.py index 6500f6f7..e510ab95 100644 --- a/api.py +++ b/api.py @@ -522,6 +522,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_semantic = codes[0, 0] t1 = ttime() version = vq_model.version + os.environ['version'] = version prompt_language = dict_language[prompt_language.lower()] text_language = dict_language[text_language.lower()] phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)