mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
API修复优化 (#1503)
* model control * Mix timbre * Fix some detail problems * Optimize detail * Add int32 * Add example * Add aac pcm32 support
This commit is contained in:
parent
b3e8eb40c2
commit
6ca4aecea2
184
api.py
184
api.py
@ -20,6 +20,7 @@
|
|||||||
`-hp` - `覆盖 config.py 使用半精度`
|
`-hp` - `覆盖 config.py 使用半精度`
|
||||||
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
|
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
|
||||||
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
|
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
|
||||||
|
·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
|
||||||
·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入`
|
·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入`
|
||||||
|
|
||||||
`-hb` - `cnhubert路径`
|
`-hb` - `cnhubert路径`
|
||||||
@ -74,7 +75,7 @@ RESP:
|
|||||||
|
|
||||||
手动指定当次推理所使用的参考音频,并提供参数:
|
手动指定当次推理所使用的参考音频,并提供参数:
|
||||||
GET:
|
GET:
|
||||||
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1`
|
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"`
|
||||||
POST:
|
POST:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@ -86,7 +87,8 @@ POST:
|
|||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
"top_p": 0.6,
|
"top_p": 0.6,
|
||||||
"temperature": 0.6,
|
"temperature": 0.6,
|
||||||
"speed": 1
|
"speed": 1,
|
||||||
|
"inp_refs": ["456.wav","789.wav"]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -153,7 +155,7 @@ from time import time as ttime
|
|||||||
import torch
|
import torch
|
||||||
import librosa
|
import librosa
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, Query, HTTPException
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
@ -195,8 +197,24 @@ def is_full(*items): # 任意一项为空返回False
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def change_sovits_weights(sovits_path):
|
class Speaker:
|
||||||
global vq_model, hps
|
def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
|
||||||
|
self.name = name
|
||||||
|
self.sovits = sovits
|
||||||
|
self.gpt = gpt
|
||||||
|
self.phones = phones
|
||||||
|
self.bert = bert
|
||||||
|
self.prompt = prompt
|
||||||
|
|
||||||
|
speaker_list = {}
|
||||||
|
|
||||||
|
|
||||||
|
class Sovits:
|
||||||
|
def __init__(self, vq_model, hps):
|
||||||
|
self.vq_model = vq_model
|
||||||
|
self.hps = hps
|
||||||
|
|
||||||
|
def get_sovits_weights(sovits_path):
|
||||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
hps = DictToAttrRecursive(hps)
|
hps = DictToAttrRecursive(hps)
|
||||||
@ -205,7 +223,7 @@ def change_sovits_weights(sovits_path):
|
|||||||
hps.model.version = "v1"
|
hps.model.version = "v1"
|
||||||
else:
|
else:
|
||||||
hps.model.version = "v2"
|
hps.model.version = "v2"
|
||||||
print("sovits版本:",hps.model.version)
|
logger.info(f"模型版本: {hps.model.version}")
|
||||||
model_params_dict = vars(hps.model)
|
model_params_dict = vars(hps.model)
|
||||||
vq_model = SynthesizerTrn(
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
@ -222,10 +240,17 @@ def change_sovits_weights(sovits_path):
|
|||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
|
||||||
|
sovits = Sovits(vq_model, hps)
|
||||||
|
return sovits
|
||||||
|
|
||||||
def change_gpt_weights(gpt_path):
|
class Gpt:
|
||||||
global hz, max_sec, t2s_model, config
|
def __init__(self, max_sec, t2s_model):
|
||||||
hz = 50
|
self.max_sec = max_sec
|
||||||
|
self.t2s_model = t2s_model
|
||||||
|
|
||||||
|
global hz
|
||||||
|
hz = 50
|
||||||
|
def get_gpt_weights(gpt_path):
|
||||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||||
config = dict_s1["config"]
|
config = dict_s1["config"]
|
||||||
max_sec = config["data"]["max_sec"]
|
max_sec = config["data"]["max_sec"]
|
||||||
@ -238,6 +263,19 @@ def change_gpt_weights(gpt_path):
|
|||||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||||
logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
logger.info("Number of parameter: %.2fM" % (total / 1e6))
|
||||||
|
|
||||||
|
gpt = Gpt(max_sec, t2s_model)
|
||||||
|
return gpt
|
||||||
|
|
||||||
|
def change_gpt_sovits_weights(gpt_path,sovits_path):
|
||||||
|
try:
|
||||||
|
gpt = get_gpt_weights(gpt_path)
|
||||||
|
sovits = get_sovits_weights(sovits_path)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
|
||||||
|
|
||||||
|
speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
|
||||||
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -289,14 +327,14 @@ def get_phones_and_bert(text,language,version,final=False):
|
|||||||
if language == "zh":
|
if language == "zh":
|
||||||
if re.search(r'[A-Za-z]', formattext):
|
if re.search(r'[A-Za-z]', formattext):
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
formattext = chinese.text_normalize(formattext)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
return get_phones_and_bert(formattext,"zh",version)
|
return get_phones_and_bert(formattext,"zh",version)
|
||||||
else:
|
else:
|
||||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||||
bert = get_bert_feature(norm_text, word2ph).to(device)
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
||||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
formattext = chinese.text_normalize(formattext)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
return get_phones_and_bert(formattext,"yue",version)
|
return get_phones_and_bert(formattext,"yue",version)
|
||||||
else:
|
else:
|
||||||
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
||||||
@ -375,8 +413,11 @@ class DictToAttrRecursive(dict):
|
|||||||
|
|
||||||
|
|
||||||
def get_spepc(hps, filename):
|
def get_spepc(hps, filename):
|
||||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
|
||||||
audio = torch.FloatTensor(audio)
|
audio = torch.FloatTensor(audio)
|
||||||
|
maxx=audio.abs().max()
|
||||||
|
if(maxx>1):
|
||||||
|
audio/=min(2,maxx)
|
||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
|
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
|
||||||
@ -448,22 +489,32 @@ def pack_raw(audio_bytes, data, rate):
|
|||||||
|
|
||||||
|
|
||||||
def pack_wav(audio_bytes, rate):
|
def pack_wav(audio_bytes, rate):
|
||||||
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
|
if is_int32:
|
||||||
wav_bytes = BytesIO()
|
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
|
||||||
sf.write(wav_bytes, data, rate, format='wav')
|
wav_bytes = BytesIO()
|
||||||
|
sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
|
||||||
|
else:
|
||||||
|
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
|
||||||
|
wav_bytes = BytesIO()
|
||||||
|
sf.write(wav_bytes, data, rate, format='WAV')
|
||||||
return wav_bytes
|
return wav_bytes
|
||||||
|
|
||||||
|
|
||||||
def pack_aac(audio_bytes, data, rate):
|
def pack_aac(audio_bytes, data, rate):
|
||||||
|
if is_int32:
|
||||||
|
pcm = 's32le'
|
||||||
|
bit_rate = '256k'
|
||||||
|
else:
|
||||||
|
pcm = 's16le'
|
||||||
|
bit_rate = '128k'
|
||||||
process = subprocess.Popen([
|
process = subprocess.Popen([
|
||||||
'ffmpeg',
|
'ffmpeg',
|
||||||
'-f', 's16le', # 输入16位有符号小端整数PCM
|
'-f', pcm, # 输入16位有符号小端整数PCM
|
||||||
'-ar', str(rate), # 设置采样率
|
'-ar', str(rate), # 设置采样率
|
||||||
'-ac', '1', # 单声道
|
'-ac', '1', # 单声道
|
||||||
'-i', 'pipe:0', # 从管道读取输入
|
'-i', 'pipe:0', # 从管道读取输入
|
||||||
'-c:a', 'aac', # 音频编码器为AAC
|
'-c:a', 'aac', # 音频编码器为AAC
|
||||||
'-b:a', '192k', # 比特率
|
'-b:a', bit_rate, # 比特率
|
||||||
'-vn', # 不包含视频
|
'-vn', # 不包含视频
|
||||||
'-f', 'adts', # 输出AAC数据流格式
|
'-f', 'adts', # 输出AAC数据流格式
|
||||||
'pipe:1' # 将输出写入管道
|
'pipe:1' # 将输出写入管道
|
||||||
@ -504,10 +555,21 @@ def only_punc(text):
|
|||||||
return not any(t.isalnum() or t.isalpha() for t in text)
|
return not any(t.isalnum() or t.isalpha() for t in text)
|
||||||
|
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1):
|
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||||
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, spk = "default"):
|
||||||
|
infer_sovits = speaker_list[spk].sovits
|
||||||
|
vq_model = infer_sovits.vq_model
|
||||||
|
hps = infer_sovits.hps
|
||||||
|
|
||||||
|
infer_gpt = speaker_list[spk].gpt
|
||||||
|
t2s_model = infer_gpt.t2s_model
|
||||||
|
max_sec = infer_gpt.max_sec
|
||||||
|
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
|
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
prompt_language, text = prompt_language, text.strip("\n")
|
||||||
|
dtype = torch.float16 if is_half == True else torch.float32
|
||||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||||
@ -523,6 +585,19 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||||
codes = vq_model.extract_latent(ssl_content)
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
refers=[]
|
||||||
|
if(inp_refs):
|
||||||
|
for path in inp_refs:
|
||||||
|
try:
|
||||||
|
refer = get_spepc(hps, path).to(dtype).to(device)
|
||||||
|
refers.append(refer)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
if(len(refers)==0):
|
||||||
|
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||||
|
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
version = vq_model.version
|
version = vq_model.version
|
||||||
os.environ['version'] = version
|
os.environ['version'] = version
|
||||||
@ -538,16 +613,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
|
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||||
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
|
|
||||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||||
bert = bert.to(device).unsqueeze(0)
|
bert = bert.to(device).unsqueeze(0)
|
||||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# pred_semantic = t2s_model.model.infer(
|
|
||||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||||
all_phoneme_ids,
|
all_phoneme_ids,
|
||||||
all_phoneme_len,
|
all_phoneme_len,
|
||||||
@ -558,23 +632,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
top_p = top_p,
|
top_p = top_p,
|
||||||
temperature = temperature,
|
temperature = temperature,
|
||||||
early_stop_num=hz * max_sec)
|
early_stop_num=hz * max_sec)
|
||||||
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
|
||||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
|
||||||
if (is_half == True):
|
|
||||||
refer = refer.half().to(device)
|
|
||||||
else:
|
|
||||||
refer = refer.to(device)
|
|
||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
|
||||||
audio = \
|
audio = \
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||||
refer,speed=speed).detach().cpu().numpy()[
|
refers,speed=speed).detach().cpu().numpy()[
|
||||||
0, 0] ###试试重建不带上prompt部分
|
0, 0] ###试试重建不带上prompt部分
|
||||||
|
max_audio=np.abs(audio).max()
|
||||||
|
if max_audio>1:
|
||||||
|
audio/=max_audio
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
|
if is_int32:
|
||||||
|
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
|
||||||
|
else:
|
||||||
|
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
|
||||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
if stream_mode == "normal":
|
if stream_mode == "normal":
|
||||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||||
@ -615,7 +688,7 @@ def handle_change(path, text, language):
|
|||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed):
|
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs):
|
||||||
if (
|
if (
|
||||||
refer_wav_path == "" or refer_wav_path is None
|
refer_wav_path == "" or refer_wav_path is None
|
||||||
or prompt_text == "" or prompt_text is None
|
or prompt_text == "" or prompt_text is None
|
||||||
@ -634,7 +707,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
|
|||||||
else:
|
else:
|
||||||
text = cut_text(text,cut_punc)
|
text = cut_text(text,cut_punc)
|
||||||
|
|
||||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type)
|
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs), media_type="audio/"+media_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -691,6 +764,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
|
|||||||
# 此时 full_precision==True, half_precision==False
|
# 此时 full_precision==True, half_precision==False
|
||||||
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
|
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
|
||||||
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
|
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
|
||||||
|
parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
|
||||||
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
|
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
|
||||||
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
|
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
|
||||||
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
||||||
@ -752,6 +826,14 @@ else:
|
|||||||
media_type = "ogg"
|
media_type = "ogg"
|
||||||
logger.info(f"编码格式: {media_type}")
|
logger.info(f"编码格式: {media_type}")
|
||||||
|
|
||||||
|
# 音频数据类型
|
||||||
|
if args.sub_type.lower() == 'int32':
|
||||||
|
is_int32 = True
|
||||||
|
logger.info(f"数据类型: int32")
|
||||||
|
else:
|
||||||
|
is_int32 = False
|
||||||
|
logger.info(f"数据类型: int16")
|
||||||
|
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
@ -763,9 +845,7 @@ if is_half:
|
|||||||
else:
|
else:
|
||||||
bert_model = bert_model.to(device)
|
bert_model = bert_model.to(device)
|
||||||
ssl_model = ssl_model.to(device)
|
ssl_model = ssl_model.to(device)
|
||||||
change_sovits_weights(sovits_path)
|
change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
|
||||||
change_gpt_weights(gpt_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -777,14 +857,18 @@ app = FastAPI()
|
|||||||
@app.post("/set_model")
|
@app.post("/set_model")
|
||||||
async def set_model(request: Request):
|
async def set_model(request: Request):
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
global gpt_path
|
return change_gpt_sovits_weights(
|
||||||
gpt_path=json_post_raw.get("gpt_model_path")
|
gpt_path = json_post_raw.get("gpt_model_path"),
|
||||||
global sovits_path
|
sovits_path = json_post_raw.get("sovits_model_path")
|
||||||
sovits_path=json_post_raw.get("sovits_model_path")
|
)
|
||||||
logger.info("gptpath"+gpt_path+";vitspath"+sovits_path)
|
|
||||||
change_sovits_weights(sovits_path)
|
|
||||||
change_gpt_weights(gpt_path)
|
@app.get("/set_model")
|
||||||
return "ok"
|
async def set_model(
|
||||||
|
gpt_model_path: str = None,
|
||||||
|
sovits_model_path: str = None,
|
||||||
|
):
|
||||||
|
return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/control")
|
@app.post("/control")
|
||||||
@ -827,10 +911,11 @@ async def tts_endpoint(request: Request):
|
|||||||
json_post_raw.get("text"),
|
json_post_raw.get("text"),
|
||||||
json_post_raw.get("text_language"),
|
json_post_raw.get("text_language"),
|
||||||
json_post_raw.get("cut_punc"),
|
json_post_raw.get("cut_punc"),
|
||||||
json_post_raw.get("top_k", 10),
|
json_post_raw.get("top_k", 15),
|
||||||
json_post_raw.get("top_p", 1.0),
|
json_post_raw.get("top_p", 1.0),
|
||||||
json_post_raw.get("temperature", 1.0),
|
json_post_raw.get("temperature", 1.0),
|
||||||
json_post_raw.get("speed", 1.0)
|
json_post_raw.get("speed", 1.0),
|
||||||
|
json_post_raw.get("inp_refs", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -842,12 +927,13 @@ async def tts_endpoint(
|
|||||||
text: str = None,
|
text: str = None,
|
||||||
text_language: str = None,
|
text_language: str = None,
|
||||||
cut_punc: str = None,
|
cut_punc: str = None,
|
||||||
top_k: int = 10,
|
top_k: int = 15,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
speed: float = 1.0
|
speed: float = 1.0,
|
||||||
|
inp_refs: list = Query(default=[])
|
||||||
):
|
):
|
||||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed)
|
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user