mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
Optimize code readability and formatted output.
This commit is contained in:
parent
436032214a
commit
0eff854e3d
270
api.py
270
api.py
@ -105,11 +105,6 @@ RESP: 无
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
|
||||||
sys.path.append(now_dir)
|
|
||||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
import LangSegment
|
import LangSegment
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
@ -130,35 +125,7 @@ from text.cleaner import clean_text
|
|||||||
from module.mel_processing import spectrogram_torch
|
from module.mel_processing import spectrogram_torch
|
||||||
from my_utils import load_audio
|
from my_utils import load_audio
|
||||||
import config as global_config
|
import config as global_config
|
||||||
|
import logging
|
||||||
g_config = global_config.Config()
|
|
||||||
|
|
||||||
# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
|
||||||
|
|
||||||
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
|
|
||||||
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
|
|
||||||
|
|
||||||
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
|
|
||||||
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
|
|
||||||
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
|
|
||||||
|
|
||||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
|
||||||
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
|
|
||||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
|
||||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
|
||||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
|
||||||
# bool值的用法为 `python ./api.py -fp ...`
|
|
||||||
# 此时 full_precision==True, half_precision==False
|
|
||||||
|
|
||||||
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
|
||||||
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
sovits_path = args.sovits_path
|
|
||||||
gpt_path = args.gpt_path
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultRefer:
|
class DefaultRefer:
|
||||||
@ -171,50 +138,6 @@ class DefaultRefer:
|
|||||||
return is_full(self.path, self.text, self.language)
|
return is_full(self.path, self.text, self.language)
|
||||||
|
|
||||||
|
|
||||||
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
|
|
||||||
|
|
||||||
device = args.device
|
|
||||||
port = args.port
|
|
||||||
host = args.bind_addr
|
|
||||||
|
|
||||||
if sovits_path == "":
|
|
||||||
sovits_path = g_config.pretrained_sovits_path
|
|
||||||
print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
|
||||||
if gpt_path == "":
|
|
||||||
gpt_path = g_config.pretrained_gpt_path
|
|
||||||
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
|
||||||
|
|
||||||
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
|
||||||
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
|
|
||||||
default_refer.path, default_refer.text, default_refer.language = "", "", ""
|
|
||||||
print("[INFO] 未指定默认参考音频")
|
|
||||||
else:
|
|
||||||
print(f"[INFO] 默认参考音频路径: {default_refer.path}")
|
|
||||||
print(f"[INFO] 默认参考音频文本: {default_refer.text}")
|
|
||||||
print(f"[INFO] 默认参考音频语种: {default_refer.language}")
|
|
||||||
|
|
||||||
is_half = g_config.is_half
|
|
||||||
if args.full_precision:
|
|
||||||
is_half = False
|
|
||||||
if args.half_precision:
|
|
||||||
is_half = True
|
|
||||||
if args.full_precision and args.half_precision:
|
|
||||||
is_half = g_config.is_half # 炒饭fallback
|
|
||||||
|
|
||||||
print(f"[INFO] 半精: {is_half}")
|
|
||||||
|
|
||||||
cnhubert_base_path = args.hubert_path
|
|
||||||
bert_path = args.bert_path
|
|
||||||
|
|
||||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
|
||||||
if is_half:
|
|
||||||
bert_model = bert_model.half().to(device)
|
|
||||||
else:
|
|
||||||
bert_model = bert_model.to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def is_empty(*items): # 任意一项不为空返回False
|
def is_empty(*items): # 任意一项不为空返回False
|
||||||
for item in items:
|
for item in items:
|
||||||
if item is not None and item != "":
|
if item is not None and item != "":
|
||||||
@ -228,6 +151,7 @@ def is_full(*items): # 任意一项为空返回False
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def change_sovits_weights(sovits_path):
|
def change_sovits_weights(sovits_path):
|
||||||
global vq_model, hps
|
global vq_model, hps
|
||||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||||
@ -247,9 +171,7 @@ def change_sovits_weights(sovits_path):
|
|||||||
else:
|
else:
|
||||||
vq_model = vq_model.to(device)
|
vq_model = vq_model.to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
with open("./sweight.txt", "w", encoding="utf-8") as f:
|
|
||||||
f.write(sovits_path)
|
|
||||||
|
|
||||||
|
|
||||||
def change_gpt_weights(gpt_path):
|
def change_gpt_weights(gpt_path):
|
||||||
@ -265,8 +187,7 @@ def change_gpt_weights(gpt_path):
|
|||||||
t2s_model = t2s_model.to(device)
|
t2s_model = t2s_model.to(device)
|
||||||
t2s_model.eval()
|
t2s_model.eval()
|
||||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
||||||
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)
|
|
||||||
|
|
||||||
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
@ -344,8 +265,8 @@ def get_phones_and_bert(text,language):
|
|||||||
# 因无法区别中日文汉字,以用户输入为准
|
# 因无法区别中日文汉字,以用户输入为准
|
||||||
langlist.append(language)
|
langlist.append(language)
|
||||||
textlist.append(tmp["text"])
|
textlist.append(tmp["text"])
|
||||||
print(textlist)
|
# logger.info(textlist)
|
||||||
print(langlist)
|
# logger.info(langlist)
|
||||||
phones_list = []
|
phones_list = []
|
||||||
bert_list = []
|
bert_list = []
|
||||||
norm_text_list = []
|
norm_text_list = []
|
||||||
@ -363,11 +284,6 @@ def get_phones_and_bert(text,language):
|
|||||||
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
|
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
|
||||||
|
|
||||||
|
|
||||||
n_semantic = 1024
|
|
||||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
|
||||||
hps = dict_s2["config"]
|
|
||||||
|
|
||||||
|
|
||||||
class DictToAttrRecursive:
|
class DictToAttrRecursive:
|
||||||
def __init__(self, input_dict):
|
def __init__(self, input_dict):
|
||||||
for key, value in input_dict.items():
|
for key, value in input_dict.items():
|
||||||
@ -378,39 +294,6 @@ class DictToAttrRecursive:
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
hps = DictToAttrRecursive(hps)
|
|
||||||
hps.model.semantic_frame_rate = "25hz"
|
|
||||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
|
||||||
config = dict_s1["config"]
|
|
||||||
ssl_model = cnhubert.get_model()
|
|
||||||
if is_half:
|
|
||||||
ssl_model = ssl_model.half().to(device)
|
|
||||||
else:
|
|
||||||
ssl_model = ssl_model.to(device)
|
|
||||||
|
|
||||||
vq_model = SynthesizerTrn(
|
|
||||||
hps.data.filter_length // 2 + 1,
|
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
|
||||||
n_speakers=hps.data.n_speakers,
|
|
||||||
**hps.model)
|
|
||||||
if is_half:
|
|
||||||
vq_model = vq_model.half().to(device)
|
|
||||||
else:
|
|
||||||
vq_model = vq_model.to(device)
|
|
||||||
vq_model.eval()
|
|
||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
|
||||||
hz = 50
|
|
||||||
max_sec = config['data']['max_sec']
|
|
||||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
|
||||||
t2s_model.load_state_dict(dict_s1["weight"])
|
|
||||||
if is_half:
|
|
||||||
t2s_model = t2s_model.half()
|
|
||||||
t2s_model = t2s_model.to(device)
|
|
||||||
t2s_model.eval()
|
|
||||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
|
||||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
|
||||||
|
|
||||||
|
|
||||||
def get_spepc(hps, filename):
|
def get_spepc(hps, filename):
|
||||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||||
audio = torch.FloatTensor(audio)
|
audio = torch.FloatTensor(audio)
|
||||||
@ -421,22 +304,6 @@ def get_spepc(hps, filename):
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
dict_language = {
|
|
||||||
"中文": "all_zh",
|
|
||||||
"英文": "en",
|
|
||||||
"日文": "all_ja",
|
|
||||||
"中英混合": "zh",
|
|
||||||
"日英混合": "ja",
|
|
||||||
"多语种混合": "auto", #多语种启动切分识别语种
|
|
||||||
"all_zh": "all_zh",
|
|
||||||
"en": "en",
|
|
||||||
"all_ja": "all_ja",
|
|
||||||
"zh": "zh",
|
|
||||||
"ja": "ja",
|
|
||||||
"auto": "auto",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
@ -498,7 +365,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
@ -521,10 +388,11 @@ def handle_change(path, text, language):
|
|||||||
if language != "" or language is not None:
|
if language != "" or language is not None:
|
||||||
default_refer.language = language
|
default_refer.language = language
|
||||||
|
|
||||||
print(f"[INFO] 当前默认参考音频路径: {default_refer.path}")
|
logger.info(f"当前默认参考音频路径: {default_refer.path}")
|
||||||
print(f"[INFO] 当前默认参考音频文本: {default_refer.text}")
|
logger.info(f"当前默认参考音频文本: {default_refer.text}")
|
||||||
print(f"[INFO] 当前默认参考音频语种: {default_refer.language}")
|
logger.info(f"当前默认参考音频语种: {default_refer.language}")
|
||||||
print(f"[INFO] is_ready: {default_refer.is_ready()}")
|
logger.info(f"is_ready: {default_refer.is_ready()}")
|
||||||
|
|
||||||
|
|
||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
@ -557,10 +425,116 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
|||||||
return StreamingResponse(wav, media_type="audio/wav")
|
return StreamingResponse(wav, media_type="audio/wav")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------
|
||||||
|
# 初始化部分
|
||||||
|
# --------------------------------
|
||||||
|
now_dir = os.getcwd()
|
||||||
|
sys.path.append(now_dir)
|
||||||
|
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||||
|
|
||||||
|
dict_language = {
|
||||||
|
"中文": "all_zh",
|
||||||
|
"英文": "en",
|
||||||
|
"日文": "all_ja",
|
||||||
|
"中英混合": "zh",
|
||||||
|
"日英混合": "ja",
|
||||||
|
"多语种混合": "auto", #多语种启动切分识别语种
|
||||||
|
"all_zh": "all_zh",
|
||||||
|
"en": "en",
|
||||||
|
"all_ja": "all_ja",
|
||||||
|
"zh": "zh",
|
||||||
|
"ja": "ja",
|
||||||
|
"auto": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
# logger
|
||||||
|
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
||||||
|
logger = logging.getLogger('uvicorn')
|
||||||
|
|
||||||
|
# 获取配置
|
||||||
|
g_config = global_config.Config()
|
||||||
|
|
||||||
|
# 获取参数
|
||||||
|
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||||
|
|
||||||
|
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
|
||||||
|
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
|
||||||
|
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
|
||||||
|
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
|
||||||
|
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
|
||||||
|
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
||||||
|
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
|
||||||
|
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||||
|
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||||
|
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||||
|
# bool值的用法为 `python ./api.py -fp ...`
|
||||||
|
# 此时 full_precision==True, half_precision==False
|
||||||
|
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
||||||
|
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
sovits_path = args.sovits_path
|
||||||
|
gpt_path = args.gpt_path
|
||||||
|
device = args.device
|
||||||
|
port = args.port
|
||||||
|
host = args.bind_addr
|
||||||
|
cnhubert_base_path = args.hubert_path
|
||||||
|
bert_path = args.bert_path
|
||||||
|
|
||||||
|
# 应用参数配置
|
||||||
|
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
|
||||||
|
|
||||||
|
# 模型路径检查
|
||||||
|
if sovits_path == "":
|
||||||
|
sovits_path = g_config.pretrained_sovits_path
|
||||||
|
logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
||||||
|
if gpt_path == "":
|
||||||
|
gpt_path = g_config.pretrained_gpt_path
|
||||||
|
logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||||
|
|
||||||
|
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
||||||
|
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
|
||||||
|
default_refer.path, default_refer.text, default_refer.language = "", "", ""
|
||||||
|
logger.info("未指定默认参考音频")
|
||||||
|
else:
|
||||||
|
logger.info(f"默认参考音频路径: {default_refer.path}")
|
||||||
|
logger.info(f"默认参考音频文本: {default_refer.text}")
|
||||||
|
logger.info(f"默认参考音频语种: {default_refer.language}")
|
||||||
|
|
||||||
|
# 获取半精度
|
||||||
|
is_half = g_config.is_half
|
||||||
|
if args.full_precision:
|
||||||
|
is_half = False
|
||||||
|
if args.half_precision:
|
||||||
|
is_half = True
|
||||||
|
if args.full_precision and args.half_precision:
|
||||||
|
is_half = g_config.is_half # 炒饭fallback
|
||||||
|
logger.info(f"半精: {is_half}")
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||||
|
ssl_model = cnhubert.get_model()
|
||||||
|
if is_half:
|
||||||
|
bert_model = bert_model.half().to(device)
|
||||||
|
ssl_model = ssl_model.half().to(device)
|
||||||
|
else:
|
||||||
|
bert_model = bert_model.to(device)
|
||||||
|
ssl_model = ssl_model.to(device)
|
||||||
|
change_sovits_weights(sovits_path)
|
||||||
|
change_gpt_weights(gpt_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------
|
||||||
|
# 接口部分
|
||||||
|
# --------------------------------
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
#clark新增-----2024-02-21
|
|
||||||
#可在启动后动态修改模型,以此满足同一个api不同的朗读者请求
|
|
||||||
@app.post("/set_model")
|
@app.post("/set_model")
|
||||||
async def set_model(request: Request):
|
async def set_model(request: Request):
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
@ -568,11 +542,11 @@ async def set_model(request: Request):
|
|||||||
gpt_path=json_post_raw.get("gpt_model_path")
|
gpt_path=json_post_raw.get("gpt_model_path")
|
||||||
global sovits_path
|
global sovits_path
|
||||||
sovits_path=json_post_raw.get("sovits_model_path")
|
sovits_path=json_post_raw.get("sovits_model_path")
|
||||||
print("gptpath"+gpt_path+";vitspath"+sovits_path)
|
logger.info("gptpath"+gpt_path+";vitspath"+sovits_path)
|
||||||
change_sovits_weights(sovits_path)
|
change_sovits_weights(sovits_path)
|
||||||
change_gpt_weights(gpt_path)
|
change_gpt_weights(gpt_path)
|
||||||
return "ok"
|
return "ok"
|
||||||
# 新增-----end------
|
|
||||||
|
|
||||||
@app.post("/control")
|
@app.post("/control")
|
||||||
async def control(request: Request):
|
async def control(request: Request):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user