Merge branch 'fast_inference_' of https://github.com/KevinZhang19870314/GPT-SoVITS into fast_inference_dev

This commit is contained in:
kevin.zhang 2024-05-27 14:01:57 +08:00
commit c5dc7697a8
13 changed files with 43 additions and 37 deletions

5
.gitignore vendored
View File

@ -10,8 +10,3 @@ reference
GPT_weights GPT_weights
SoVITS_weights SoVITS_weights
TEMP TEMP
PortableGit
ffmpeg.exe
ffprobe.exe
tmp_audio
trained

View File

@ -34,9 +34,6 @@ RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
fi fi
# Copy the rest of the application
COPY . /workspace
# Copy the rest of the application # Copy the rest of the application
COPY . /workspace COPY . /workspace

View File

@ -9,7 +9,7 @@ now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
import ffmpeg import ffmpeg
import os import os
from typing import Generator, List, Union from typing import Generator, List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -597,7 +597,7 @@ class TTS:
"repetition_penalty": 1.35 # float. repetition penalty for T2S model. "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
} }
returns: returns:
tuple[int, np.ndarray]: sampling rate and audio data. Tuple[int, np.ndarray]: sampling rate and audio data.
""" """
########## variables initialization ########### ########## variables initialization ###########
self.stop_flag:bool = False self.stop_flag:bool = False
@ -880,7 +880,7 @@ class TTS:
speed_factor:float=1.0, speed_factor:float=1.0,
split_bucket:bool=True, split_bucket:bool=True,
fragment_interval:float=0.3 fragment_interval:float=0.3
)->tuple[int, np.ndarray]: )->Tuple[int, np.ndarray]:
zero_wav = torch.zeros( zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), int(self.configs.sampling_rate * fragment_interval),
dtype=self.precision, dtype=self.precision,

View File

@ -16,6 +16,7 @@ from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
from text import symbols from text import symbols
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
@ -891,9 +892,10 @@ class SynthesizerTrn(nn.Module):
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
if freeze_quantizer: self.freeze_quantizer = freeze_quantizer
self.ssl_proj.requires_grad_(False) # if freeze_quantizer:
self.quantizer.requires_grad_(False) # self.ssl_proj.requires_grad_(False)
# self.quantizer.requires_grad_(False)
#self.quantizer.eval() #self.quantizer.eval()
# self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
@ -906,6 +908,11 @@ class SynthesizerTrn(nn.Module):
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False): with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer( quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0] ssl, layers=[0]

View File

@ -117,9 +117,12 @@ if os.path.exists(txt_path) == False:
try: try:
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"]) # todo.append([name,text,"zh"])
if language in language_v1_to_language_v2.keys():
todo.append( todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)] [wav_name, text, language_v1_to_language_v2.get(language, language)]
) )
else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except: except:
print(line, traceback.format_exc()) print(line, traceback.format_exc())

View File

@ -82,7 +82,7 @@ def name2go(wav_name,wav_path):
tensor_wav16 = tensor_wav16.to(device) tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0: if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append(wav_name) nan_fails.append((wav_name,wav_path))
print("nan filtered:%s"%wav_name) print("nan filtered:%s"%wav_name)
return return
wavfile.write( wavfile.write(
@ -113,8 +113,8 @@ for line in lines[int(i_part)::int(all_parts)]:
if(len(nan_fails)>0 and is_half==True): if(len(nan_fails)>0 and is_half==True):
is_half=False is_half=False
model=model.float() model=model.float()
for wav_name in nan_fails: for wav in nan_fails:
try: try:
name2go(wav_name) name2go(wav[0],wav[1])
except: except:
print(wav_name,traceback.format_exc()) print(wav_name,traceback.format_exc())

View File

@ -320,7 +320,7 @@ class en_G2p(G2p):
# 尝试分离所有格 # 尝试分离所有格
if re.match(r"^([a-z]+)('s)$", word): if re.match(r"^([a-z]+)('s)$", word):
phones = self.qryword(word[:-2]) phones = self.qryword(word[:-2])[:]
# P T K F TH HH 无声辅音结尾 's 发 ['S'] # P T K F TH HH 无声辅音结尾 's 发 ['S']
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']: if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
phones.extend(['S']) phones.extend(['S'])

13
api.py
View File

@ -120,6 +120,11 @@ RESP: 无
import argparse import argparse
import os,re import os,re
import sys import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import signal import signal
import LangSegment import LangSegment
from time import time as ttime from time import time as ttime
@ -381,7 +386,7 @@ def read_clean_buffer(audio_bytes):
def cut_text(text, punc): def cut_text(text, punc):
punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}] punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "", "", "", "", "", "", "", ""}]
if len(punc_list) > 0: if len(punc_list) > 0:
punds = r"[" + "".join(punc_list) + r"]" punds = r"[" + "".join(punc_list) + r"]"
text = text.strip("\n") text = text.strip("\n")
@ -536,10 +541,6 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
# -------------------------------- # --------------------------------
# 初始化部分 # 初始化部分
# -------------------------------- # --------------------------------
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
dict_language = { dict_language = {
"中文": "all_zh", "中文": "all_zh",
"英文": "en", "英文": "en",
@ -579,7 +580,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
# 此时 full_precision==True, half_precision==False # 此时 full_precision==True, half_precision==False
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!:…")
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"` # 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")

View File

@ -446,7 +446,7 @@ async def set_sovits_weights(weights_path: str = None):
if __name__ == "__main__": if __name__ == "__main__":
try: try:
uvicorn.run(APP, host=host, port=port, workers=1) uvicorn.run(app="api_v2:APP", host=host, port=port, workers=1)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM) os.kill(os.getpid(), signal.SIGTERM)

View File

@ -54,11 +54,11 @@
"source": [ "source": [
"# @title Download pretrained models 下载预训练模型\n", "# @title Download pretrained models 下载预训练模型\n",
"!mkdir -p /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", "!mkdir -p /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
"!mkdir -p /kaggle/working/GPT-SoVITS/tools/damo_asr/models\n", "!mkdir -p /kaggle/working/GPT-SoVITS/tools/asr/models\n",
"!mkdir -p /kaggle/working/GPT-SoVITS/tools/uvr5\n", "!mkdir -p /kaggle/working/GPT-SoVITS/tools/uvr5\n",
"%cd /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", "%cd /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
"!git clone https://huggingface.co/lj1995/GPT-SoVITS\n", "!git clone https://huggingface.co/lj1995/GPT-SoVITS\n",
"%cd /kaggle/working/GPT-SoVITS/tools/damo_asr/models\n", "%cd /kaggle/working/GPT-SoVITS/tools/asr/models\n",
"!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n", "!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n",
"!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n", "!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n",
"!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n", "!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n",

View File

@ -28,4 +28,4 @@ def load_audio(file, sr):
def clean_path(path_str): def clean_path(path_str):
if platform.system() == 'Windows': if platform.system() == 'Windows':
path_str = path_str.replace('/', '\\') path_str = path_str.replace('/', '\\')
return path_str.strip(" ").strip('"').strip("\n").strip('"').strip(" ") return path_str.strip(" ").strip('"').strip("\n").strip('"').strip(" ").strip("\u202a")

View File

@ -418,7 +418,10 @@ def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
with open(path_text, "w", encoding="utf8") as f: with open(path_text, "w", encoding="utf8") as f:
f.write("\n".join(opt) + "\n") f.write("\n".join(opt) + "\n")
ps1a=[] ps1a=[]
yield "文本进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False} if len("".join(opt)) > 0:
yield "文本进程成功", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
else:
yield "文本进程失败", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
else: else:
yield "已有正在进行的文本任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True} yield "已有正在进行的文本任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
@ -583,7 +586,7 @@ def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numb
os.remove(txt_path) os.remove(txt_path)
with open(path_text, "w",encoding="utf8") as f: with open(path_text, "w",encoding="utf8") as f:
f.write("\n".join(opt) + "\n") f.write("\n".join(opt) + "\n")
assert len("".join(opt)) > 0, "1Aa-文本获取进程失败"
yield "进度1a-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True} yield "进度1a-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
ps1abc=[] ps1abc=[]
#############################1b #############################1b