diff --git a/.gitignore b/.gitignore index 28b8a7a5..754b06b7 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,3 @@ reference GPT_weights SoVITS_weights TEMP -PortableGit -ffmpeg.exe -ffprobe.exe -tmp_audio -trained diff --git a/Dockerfile b/Dockerfile index 74e282c4..80cd9f3a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,9 +34,6 @@ RUN if [ "$IMAGE_TYPE" != "elite" ]; then \ fi -# Copy the rest of the application -COPY . /workspace - # Copy the rest of the application COPY . /workspace diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index ef9662ad..eaacb529 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -9,7 +9,7 @@ now_dir = os.getcwd() sys.path.append(now_dir) import ffmpeg import os -from typing import Generator, List, Union +from typing import Generator, List, Tuple, Union import numpy as np import torch import torch.nn.functional as F @@ -597,7 +597,7 @@ class TTS: "repetition_penalty": 1.35 # float. repetition penalty for T2S model. } returns: - tuple[int, np.ndarray]: sampling rate and audio data. + Tuple[int, np.ndarray]: sampling rate and audio data. """ ########## variables initialization ########### self.stop_flag:bool = False @@ -880,7 +880,7 @@ class TTS: speed_factor:float=1.0, split_bucket:bool=True, fragment_interval:float=0.3 - )->tuple[int, np.ndarray]: + )->Tuple[int, np.ndarray]: zero_wav = torch.zeros( int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index b14e7c81..26840ccc 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -16,6 +16,7 @@ from module.mrte_model import MRTE from module.quantize import ResidualVectorQuantizer from text import symbols from torch.cuda.amp import autocast +import contextlib class StochasticDurationPredictor(nn.Module): @@ -891,9 +892,10 @@ class SynthesizerTrn(nn.Module): self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) - if freeze_quantizer: - self.ssl_proj.requires_grad_(False) - self.quantizer.requires_grad_(False) + self.freeze_quantizer = freeze_quantizer + # if freeze_quantizer: + # self.ssl_proj.requires_grad_(False) + # self.quantizer.requires_grad_(False) #self.quantizer.eval() # self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False) @@ -906,6 +908,11 @@ class SynthesizerTrn(nn.Module): ge = self.ref_enc(y * y_mask, y_mask) with autocast(enabled=False): + maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() + with maybe_no_grad: + if self.freeze_quantizer: + self.ssl_proj.eval() + self.quantizer.eval() ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer( ssl, layers=[0] diff --git a/GPT_SoVITS/prepare_datasets/1-get-text.py b/GPT_SoVITS/prepare_datasets/1-get-text.py index b2413826..e01a63b9 100644 --- a/GPT_SoVITS/prepare_datasets/1-get-text.py +++ b/GPT_SoVITS/prepare_datasets/1-get-text.py @@ -117,9 +117,12 @@ if os.path.exists(txt_path) == False: try: wav_name, spk_name, language, text = line.split("|") # todo.append([name,text,"zh"]) - todo.append( - [wav_name, text, language_v1_to_language_v2.get(language, language)] - ) + if language in language_v1_to_language_v2.keys(): + todo.append( + [wav_name, text, language_v1_to_language_v2.get(language, language)] + ) + else: + print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m") except: print(line, traceback.format_exc()) diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py index 9a2f73c0..61c933a4 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py @@ -82,7 +82,7 @@ def name2go(wav_name,wav_path): tensor_wav16 = tensor_wav16.to(device) ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) if np.isnan(ssl.detach().numpy()).sum()!= 0: - nan_fails.append(wav_name) + nan_fails.append((wav_name,wav_path)) print("nan filtered:%s"%wav_name) return wavfile.write( @@ -90,7 +90,7 @@ def name2go(wav_name,wav_path): 32000, tmp_audio32.astype("int16"), ) - my_save(ssl,hubert_path ) + my_save(ssl,hubert_path) with open(inp_text,"r",encoding="utf8")as f: lines=f.read().strip("\n").split("\n") @@ -113,8 +113,8 @@ for line in lines[int(i_part)::int(all_parts)]: if(len(nan_fails)>0 and is_half==True): is_half=False model=model.float() - for wav_name in nan_fails: + for wav in nan_fails: try: - name2go(wav_name) + name2go(wav[0],wav[1]) except: print(wav_name,traceback.format_exc()) diff --git a/GPT_SoVITS/text/english.py b/GPT_SoVITS/text/english.py index 68ce7896..30fafb51 100644 --- a/GPT_SoVITS/text/english.py +++ b/GPT_SoVITS/text/english.py @@ -320,7 +320,7 @@ class en_G2p(G2p): # 尝试分离所有格 if re.match(r"^([a-z]+)('s)$", word): - phones = self.qryword(word[:-2]) + phones = self.qryword(word[:-2])[:] # P T K F TH HH 无声辅音结尾 's 发 ['S'] if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']: phones.extend(['S']) @@ -359,4 +359,4 @@ def g2p(text): if __name__ == "__main__": print(g2p("hello")) print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture."))) - print(g2p(text_normalize("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))) \ No newline at end of file + print(g2p(text_normalize("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))) diff --git a/api.py b/api.py index ea0e39d0..ea3e123f 100644 --- a/api.py +++ b/api.py @@ -120,6 +120,11 @@ RESP: 无 import argparse import os,re import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + import signal import LangSegment from time import time as ttime @@ -381,7 +386,7 @@ def read_clean_buffer(audio_bytes): def cut_text(text, punc): - punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}] + punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}] if len(punc_list) > 0: punds = r"[" + "".join(punc_list) + r"]" text = text.strip("\n") @@ -536,10 +541,6 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu # -------------------------------- # 初始化部分 # -------------------------------- -now_dir = os.getcwd() -sys.path.append(now_dir) -sys.path.append("%s/GPT_SoVITS" % (now_dir)) - dict_language = { "中文": "all_zh", "英文": "en", @@ -579,7 +580,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals # 此时 full_precision==True, half_precision==False parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") -parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") +parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") # 切割常用分句符为 `python ./api.py -cp ".?!。?!"` parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") diff --git a/api_v2.py b/api_v2.py index 9f45ac53..aaa56e0b 100644 --- a/api_v2.py +++ b/api_v2.py @@ -446,8 +446,8 @@ async def set_sovits_weights(weights_path: str = None): if __name__ == "__main__": try: - uvicorn.run(APP, host=host, port=port, workers=1) + uvicorn.run(app="api_v2:APP", host=host, port=port, workers=1) except Exception as e: traceback.print_exc() os.kill(os.getpid(), signal.SIGTERM) - exit(0) \ No newline at end of file + exit(0) diff --git a/docs/ja/README.md b/docs/ja/README.md index 02d1b836..a910f94d 100644 --- a/docs/ja/README.md +++ b/docs/ja/README.md @@ -159,7 +159,7 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin. - [ ] **優先度 高:** - [x] 日本語と英語でのローカライズ。 - - [] ユーザーガイド。 + - [ ] ユーザーガイド。 - [x] 日本語データセットと英語データセットのファインチューニングトレーニング。 - [ ] **機能:** diff --git a/gpt-sovits_kaggle.ipynb b/gpt-sovits_kaggle.ipynb index 1980a77a..84ecd89c 100644 --- a/gpt-sovits_kaggle.ipynb +++ b/gpt-sovits_kaggle.ipynb @@ -54,11 +54,11 @@ "source": [ "# @title Download pretrained models 下载预训练模型\n", "!mkdir -p /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", - "!mkdir -p /kaggle/working/GPT-SoVITS/tools/damo_asr/models\n", + "!mkdir -p /kaggle/working/GPT-SoVITS/tools/asr/models\n", "!mkdir -p /kaggle/working/GPT-SoVITS/tools/uvr5\n", "%cd /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", "!git clone https://huggingface.co/lj1995/GPT-SoVITS\n", - "%cd /kaggle/working/GPT-SoVITS/tools/damo_asr/models\n", + "%cd /kaggle/working/GPT-SoVITS/tools/asr/models\n", "!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n", "!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n", "!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n", diff --git a/tools/my_utils.py b/tools/my_utils.py index a7755d6d..de79f3b5 100644 --- a/tools/my_utils.py +++ b/tools/my_utils.py @@ -28,4 +28,4 @@ def load_audio(file, sr): def clean_path(path_str): if platform.system() == 'Windows': path_str = path_str.replace('/', '\\') - return path_str.strip(" ").strip('"').strip("\n").strip('"').strip(" ") + return path_str.strip(" ").strip('"').strip("\n").strip('"').strip(" ").strip("\u202a") diff --git a/webui.py b/webui.py index e1c36e1e..c71c1ca4 100644 --- a/webui.py +++ b/webui.py @@ -418,7 +418,10 @@ def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir): with open(path_text, "w", encoding="utf8") as f: f.write("\n".join(opt) + "\n") ps1a=[] - yield "文本进程结束",{"__type__":"update","visible":True},{"__type__":"update","visible":False} + if len("".join(opt)) > 0: + yield "文本进程成功", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False} + else: + yield "文本进程失败", {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False} else: yield "已有正在进行的文本任务,需先终止才能开启下一次任务", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True} @@ -583,7 +586,7 @@ def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numb os.remove(txt_path) with open(path_text, "w",encoding="utf8") as f: f.write("\n".join(opt) + "\n") - + assert len("".join(opt)) > 0, "1Aa-文本获取进程失败" yield "进度:1a-done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True} ps1abc=[] #############################1b