Optimize code readability and formatted output.

This commit is contained in:
KamioRinn 2024-03-27 11:14:32 +08:00
parent 436032214a
commit 0eff854e3d

270
api.py
View File

@ -105,11 +105,6 @@ RESP: 无
import argparse import argparse
import os import os
import sys import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import signal import signal
import LangSegment import LangSegment
from time import time as ttime from time import time as ttime
@ -130,35 +125,7 @@ from text.cleaner import clean_text
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from my_utils import load_audio from my_utils import load_audio
import config as global_config import config as global_config
import logging
g_config = global_config.Config()
# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
# bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
args = parser.parse_args()
sovits_path = args.sovits_path
gpt_path = args.gpt_path
class DefaultRefer: class DefaultRefer:
@ -171,50 +138,6 @@ class DefaultRefer:
return is_full(self.path, self.text, self.language) return is_full(self.path, self.text, self.language)
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
device = args.device
port = args.port
host = args.bind_addr
if sovits_path == "":
sovits_path = g_config.pretrained_sovits_path
print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
if gpt_path == "":
gpt_path = g_config.pretrained_gpt_path
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
default_refer.path, default_refer.text, default_refer.language = "", "", ""
print("[INFO] 未指定默认参考音频")
else:
print(f"[INFO] 默认参考音频路径: {default_refer.path}")
print(f"[INFO] 默认参考音频文本: {default_refer.text}")
print(f"[INFO] 默认参考音频语种: {default_refer.language}")
is_half = g_config.is_half
if args.full_precision:
is_half = False
if args.half_precision:
is_half = True
if args.full_precision and args.half_precision:
is_half = g_config.is_half # 炒饭fallback
print(f"[INFO] 半精: {is_half}")
cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def is_empty(*items): # 任意一项不为空返回False def is_empty(*items): # 任意一项不为空返回False
for item in items: for item in items:
if item is not None and item != "": if item is not None and item != "":
@ -228,6 +151,7 @@ def is_full(*items): # 任意一项为空返回False
return False return False
return True return True
def change_sovits_weights(sovits_path): def change_sovits_weights(sovits_path):
global vq_model, hps global vq_model, hps
dict_s2 = torch.load(sovits_path, map_location="cpu") dict_s2 = torch.load(sovits_path, map_location="cpu")
@ -247,9 +171,7 @@ def change_sovits_weights(sovits_path):
else: else:
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) vq_model.load_state_dict(dict_s2["weight"], strict=False)
with open("./sweight.txt", "w", encoding="utf-8") as f:
f.write(sovits_path)
def change_gpt_weights(gpt_path): def change_gpt_weights(gpt_path):
@ -265,8 +187,7 @@ def change_gpt_weights(gpt_path):
t2s_model = t2s_model.to(device) t2s_model = t2s_model.to(device)
t2s_model.eval() t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()]) total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6)) logger.info("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
@ -344,8 +265,8 @@ def get_phones_and_bert(text,language):
# 因无法区别中日文汉字,以用户输入为准 # 因无法区别中日文汉字,以用户输入为准
langlist.append(language) langlist.append(language)
textlist.append(tmp["text"]) textlist.append(tmp["text"])
print(textlist) # logger.info(textlist)
print(langlist) # logger.info(langlist)
phones_list = [] phones_list = []
bert_list = [] bert_list = []
norm_text_list = [] norm_text_list = []
@ -363,11 +284,6 @@ def get_phones_and_bert(text,language):
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
n_semantic = 1024
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
class DictToAttrRecursive: class DictToAttrRecursive:
def __init__(self, input_dict): def __init__(self, input_dict):
for key, value in input_dict.items(): for key, value in input_dict.items():
@ -378,39 +294,6 @@ class DictToAttrRecursive:
setattr(self, key, value) setattr(self, key, value)
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
ssl_model = cnhubert.get_model()
if is_half:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
if is_half:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50
max_sec = config['data']['max_sec']
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
def get_spepc(hps, filename): def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate)) audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
@ -421,22 +304,6 @@ def get_spepc(hps, filename):
return spec return spec
dict_language = {
"中文": "all_zh",
"英文": "en",
"日文": "all_ja",
"中英混合": "zh",
"日英混合": "ja",
"多语种混合": "auto", #多语种启动切分识别语种
"all_zh": "all_zh",
"en": "en",
"all_ja": "all_ja",
"zh": "zh",
"ja": "ja",
"auto": "auto",
}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime() t0 = ttime()
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
@ -498,7 +365,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav)
t4 = ttime() t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
@ -521,10 +388,11 @@ def handle_change(path, text, language):
if language != "" or language is not None: if language != "" or language is not None:
default_refer.language = language default_refer.language = language
print(f"[INFO] 当前默认参考音频路径: {default_refer.path}") logger.info(f"当前默认参考音频路径: {default_refer.path}")
print(f"[INFO] 当前默认参考音频文本: {default_refer.text}") logger.info(f"当前默认参考音频文本: {default_refer.text}")
print(f"[INFO] 当前默认参考音频语种: {default_refer.language}") logger.info(f"当前默认参考音频语种: {default_refer.language}")
print(f"[INFO] is_ready: {default_refer.is_ready()}") logger.info(f"is_ready: {default_refer.is_ready()}")
return JSONResponse({"code": 0, "message": "Success"}, status_code=200) return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
@ -557,10 +425,116 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
return StreamingResponse(wav, media_type="audio/wav") return StreamingResponse(wav, media_type="audio/wav")
# --------------------------------
# 初始化部分
# --------------------------------
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
dict_language = {
"中文": "all_zh",
"英文": "en",
"日文": "all_ja",
"中英混合": "zh",
"日英混合": "ja",
"多语种混合": "auto", #多语种启动切分识别语种
"all_zh": "all_zh",
"en": "en",
"all_ja": "all_ja",
"zh": "zh",
"ja": "ja",
"auto": "auto",
}
# logger
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
logger = logging.getLogger('uvicorn')
# 获取配置
g_config = global_config.Config()
# 获取参数
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
# bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
args = parser.parse_args()
sovits_path = args.sovits_path
gpt_path = args.gpt_path
device = args.device
port = args.port
host = args.bind_addr
cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
# 应用参数配置
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
# 模型路径检查
if sovits_path == "":
sovits_path = g_config.pretrained_sovits_path
logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
if gpt_path == "":
gpt_path = g_config.pretrained_gpt_path
logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
default_refer.path, default_refer.text, default_refer.language = "", "", ""
logger.info("未指定默认参考音频")
else:
logger.info(f"默认参考音频路径: {default_refer.path}")
logger.info(f"默认参考音频文本: {default_refer.text}")
logger.info(f"默认参考音频语种: {default_refer.language}")
# 获取半精度
is_half = g_config.is_half
if args.full_precision:
is_half = False
if args.half_precision:
is_half = True
if args.full_precision and args.half_precision:
is_half = g_config.is_half # 炒饭fallback
logger.info(f"半精: {is_half}")
# 初始化模型
cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
ssl_model = cnhubert.get_model()
if is_half:
bert_model = bert_model.half().to(device)
ssl_model = ssl_model.half().to(device)
else:
bert_model = bert_model.to(device)
ssl_model = ssl_model.to(device)
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
# --------------------------------
# 接口部分
# --------------------------------
app = FastAPI() app = FastAPI()
#clark新增-----2024-02-21
#可在启动后动态修改模型以此满足同一个api不同的朗读者请求
@app.post("/set_model") @app.post("/set_model")
async def set_model(request: Request): async def set_model(request: Request):
json_post_raw = await request.json() json_post_raw = await request.json()
@ -568,11 +542,11 @@ async def set_model(request: Request):
gpt_path=json_post_raw.get("gpt_model_path") gpt_path=json_post_raw.get("gpt_model_path")
global sovits_path global sovits_path
sovits_path=json_post_raw.get("sovits_model_path") sovits_path=json_post_raw.get("sovits_model_path")
print("gptpath"+gpt_path+";vitspath"+sovits_path) logger.info("gptpath"+gpt_path+";vitspath"+sovits_path)
change_sovits_weights(sovits_path) change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path) change_gpt_weights(gpt_path)
return "ok" return "ok"
# 新增-----end------
@app.post("/control") @app.post("/control")
async def control(request: Request): async def control(request: Request):