修正了并发错误

This commit is contained in:
XTer 2024-03-19 17:35:41 +08:00
parent dbde64234a
commit 7f61235ba9
2 changed files with 19 additions and 19 deletions

@ -1 +1 @@
Subproject commit 2de410df8b8a0725dbda04105509457f0163c18e Subproject commit 8eb8219e47e996b7c7f1a77d7b0a71420c861e82

30
app.py
View File

@ -1,4 +1,4 @@
frontend_version = "2.2.4 240318" frontend_version = "2.3.1 240320"
from datetime import datetime from datetime import datetime
import gradio as gr import gradio as gr
@ -28,7 +28,7 @@ if os.path.exists(config_path):
default_batch_size = _config.get("batch_size", 10) default_batch_size = _config.get("batch_size", 10)
default_word_count = _config.get("max_word_count", 80) default_word_count = _config.get("max_word_count", 80)
is_share = _config.get("is_share", "false").lower() == "true" is_share = _config.get("is_share", "false").lower() == "true"
is_classic = _config.get("classic_inference", "false").lower() == "true" is_classic = False
enable_auth = _config.get("enable_auth", "false").lower() == "true" enable_auth = _config.get("enable_auth", "false").lower() == "true"
users = _config.get("user", {}) users = _config.get("user", {})
try: try:
@ -61,8 +61,14 @@ def load_character_emotions(character_name, characters_and_emotions):
from Inference.src.TTS_Instance import TTS_instance
from Inference.src.config_manager import update_character_info, models_path, get_deflaut_character_name
text_count = {}
for character in update_character_info()['characters_and_emotions']:
text_count[character.lower()] = 0
tts_instance = TTS_instance()
from load_infer_info import get_wav_from_text_api, update_character_info, load_character, character_name, models_path
import soundfile as sf import soundfile as sf
import io import io
@ -83,7 +89,6 @@ def send_request(
seed, seed,
stream="False", stream="False",
): ):
global character_name
global models_path global models_path
text_language = language_dict[text_language] text_language = language_dict[text_language]
cut_method = cut_method_dict[cut_method] cut_method = cut_method_dict[cut_method]
@ -91,16 +96,11 @@ def send_request(
cut_method = f"{cut_method}_{word_count}" cut_method = f"{cut_method}_{word_count}"
# Using Template to fill in variables # Using Template to fill in variables
expected_path = os.path.join(models_path, cha_name) if cha_name else None
# 检查cha_name和路径 # 检查cha_name和路径
if cha_name and cha_name != character_name and expected_path and os.path.exists(expected_path): try:
character_name = cha_name tts_instance.load_character(cha_name)
print(f"Loading character {character_name}") except Exception as e:
load_character(character_name) gr.Warning(f"Fails to load character: {cha_name} with error: {e}")
elif expected_path and not os.path.exists(expected_path):
gr.Warning("Directory {expected_path} does not exist. Using the current character.")
@ -119,11 +119,11 @@ def send_request(
"stream": stream "stream": stream
} }
# 如果不是经典模式,则添加额外的参数 # 如果不是经典模式,则添加额外的参数
if not is_classic:
params["batch_size"] = batch_size params["batch_size"] = batch_size
params["speed_factor"] = speed_factor params["speed_factor"] = speed_factor
params["seed"] = seed params["seed"] = seed
gen = get_wav_from_text_api(**params) gen = tts_instance.get_wav_from_text_api(**params)
sampling_rate, audio_data = next(gen) sampling_rate, audio_data = next(gen)
wav = io.BytesIO() wav = io.BytesIO()
sf.write(wav, audio_data, sampling_rate, format="wav") sf.write(wav, audio_data, sampling_rate, format="wav")