diff --git a/GPT_SoVITS/inference_cli.py b/GPT_SoVITS/inference_cli.py new file mode 100644 index 00000000..bd987aaf --- /dev/null +++ b/GPT_SoVITS/inference_cli.py @@ -0,0 +1,55 @@ +import argparse +import os +import soundfile as sf + +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav + +i18n = I18nAuto() + +def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path): + # Read reference text + with open(ref_text_path, 'r', encoding='utf-8') as file: + ref_text = file.read() + + # Read target text + with open(target_text_path, 'r', encoding='utf-8') as file: + target_text = file.read() + + # Change model weights + change_gpt_weights(gpt_path=GPT_model_path) + change_sovits_weights(sovits_path=SoVITS_model_path) + + # Synthesize audio + synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=i18n(ref_language), + text=target_text, + text_language=i18n(target_language), top_p=1, temperature=1) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + output_wav_path = os.path.join(output_path, "output.wav") + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + print(f"Audio saved to {output_wav_path}") + +def main(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") + parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") + parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") + parser.add_argument('--ref_text', required=True, help="Path to the reference text file") + parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio") + parser.add_argument('--target_text', required=True, help="Path to the target text file") + parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text") + parser.add_argument('--output_path', required=True, help="Path to the output directory") + + args = parser.parse_args() + + synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path) + +if __name__ == '__main__': + main() + diff --git a/GPT_SoVITS/inference_gui.py b/GPT_SoVITS/inference_gui.py index f6cfdc5e..2059155d 100644 --- a/GPT_SoVITS/inference_gui.py +++ b/GPT_SoVITS/inference_gui.py @@ -1,3 +1,4 @@ +import os import sys from PyQt5.QtCore import QEvent from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit @@ -7,16 +8,16 @@ import soundfile as sf from tools.i18n.i18n import I18nAuto i18n = I18nAuto() -from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav class GPTSoVITSGUI(QMainWindow): + GPT_Path = gpt_path + SoVITS_Path = sovits_path + def __init__(self): super().__init__() - self.init_ui() - - def init_ui(self): self.setWindowTitle('GPT-SoVITS GUI') self.setGeometry(800, 450, 950, 850) @@ -71,6 +72,7 @@ class GPTSoVITSGUI(QMainWindow): self.GPT_model_label = QLabel("选择GPT模型:") self.GPT_model_input = QLineEdit() self.GPT_model_input.setPlaceholderText("拖拽或选择文件") + self.GPT_model_input.setText(self.GPT_Path) self.GPT_model_input.setReadOnly(True) self.GPT_model_button = QPushButton("选择GPT模型文件") self.GPT_model_button.clicked.connect(self.select_GPT_model) @@ -78,6 +80,7 @@ class GPTSoVITSGUI(QMainWindow): self.SoVITS_model_label = QLabel("选择SoVITS模型:") self.SoVITS_model_input = QLineEdit() self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件") + self.SoVITS_model_input.setText(self.SoVITS_Path) self.SoVITS_model_input.setReadOnly(True) self.SoVITS_model_button = QPushButton("选择SoVITS模型文件") self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model) @@ -91,25 +94,25 @@ class GPTSoVITSGUI(QMainWindow): self.ref_text_label = QLabel("参考音频文本:") self.ref_text_input = QLineEdit() - self.ref_text_input.setPlaceholderText("拖拽或选择文件") - self.ref_text_input.setReadOnly(True) + self.ref_text_input.setPlaceholderText("直接输入文字或上传文本") self.ref_text_button = QPushButton("上传文本") self.ref_text_button.clicked.connect(self.upload_ref_text) - self.language_label = QLabel("参考音频语言:") - self.language_combobox = QComboBox() - self.language_combobox.addItems(["中文", "英文", "日文"]) + self.ref_language_label = QLabel("参考音频语言:") + self.ref_language_combobox = QComboBox() + self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"]) + self.ref_language_combobox.setCurrentText("多语种混合") self.target_text_label = QLabel("合成目标文本:") self.target_text_input = QLineEdit() - self.target_text_input.setPlaceholderText("拖拽或选择文件") - self.target_text_input.setReadOnly(True) + self.target_text_input.setPlaceholderText("直接输入文字或上传文本") self.target_text_button = QPushButton("上传文本") self.target_text_button.clicked.connect(self.upload_target_text) - self.language_label_02 = QLabel("合成音频语言:") - self.language_combobox_02 = QComboBox() - self.language_combobox_02.addItems(["中文", "英文", "日文"]) + self.target_language_label = QLabel("合成音频语言:") + self.target_language_combobox = QComboBox() + self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"]) + self.target_language_combobox.setCurrentText("多语种混合") self.output_label = QLabel("输出音频路径:") self.output_input = QLineEdit() @@ -140,10 +143,8 @@ class GPTSoVITSGUI(QMainWindow): main_layout = QVBoxLayout() - input_layout = QGridLayout() - input_layout.setSpacing(10) - - self.setLayout(input_layout) + input_layout = QGridLayout(self) + input_layout.setSpacing(10) input_layout.addWidget(license_label, 0, 0, 1, 3) @@ -159,22 +160,22 @@ class GPTSoVITSGUI(QMainWindow): input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2) input_layout.addWidget(self.ref_audio_button, 6, 2) - input_layout.addWidget(self.language_label, 7, 0) - input_layout.addWidget(self.language_combobox, 8, 0, 1, 1) + input_layout.addWidget(self.ref_language_label, 7, 0) + input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1) input_layout.addWidget(self.ref_text_label, 9, 0) input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2) input_layout.addWidget(self.ref_text_button, 10, 2) - input_layout.addWidget(self.language_label_02, 11, 0) - input_layout.addWidget(self.language_combobox_02, 12, 0, 1, 1) + input_layout.addWidget(self.target_language_label, 11, 0) + input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1) input_layout.addWidget(self.target_text_label, 13, 0) input_layout.addWidget(self.target_text_input, 14, 0, 1, 2) input_layout.addWidget(self.target_text_button, 14, 2) - + input_layout.addWidget(self.output_label, 15, 0) input_layout.addWidget(self.output_input, 16, 0, 1, 2) input_layout.addWidget(self.output_button, 16, 2) - + main_layout.addLayout(input_layout) output_layout = QVBoxLayout() @@ -198,10 +199,8 @@ class GPTSoVITSGUI(QMainWindow): def dropEvent(self, event): if event.mimeData().hasUrls(): file_paths = [url.toLocalFile() for url in event.mimeData().urls()] - if len(file_paths) == 1: self.update_ref_audio(file_paths[0]) - self.update_input_paths(self.ref_audio_input, file_paths[0]) else: self.update_ref_audio(", ".join(file_paths)) @@ -211,23 +210,13 @@ class GPTSoVITSGUI(QMainWindow): widget.installEventFilter(self) def eventFilter(self, obj, event): - if event.type() == QEvent.DragEnter: + if event.type() in (QEvent.DragEnter, QEvent.Drop): mime_data = event.mimeData() if mime_data.hasUrls(): event.acceptProposedAction() - - elif event.type() == QEvent.Drop: - mime_data = event.mimeData() - if mime_data.hasUrls(): - file_paths = [url.toLocalFile() for url in mime_data.urls()] - if len(file_paths) == 1: - self.update_input_paths(obj, file_paths[0]) - else: - self.update_input_paths(obj, ", ".join(file_paths)) - event.acceptProposedAction() return super().eventFilter(obj, event) - + def select_GPT_model(self): file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)") if file_path: @@ -239,24 +228,9 @@ class GPTSoVITSGUI(QMainWindow): self.SoVITS_model_input.setText(file_path) def select_ref_audio(self): - options = QFileDialog.Options() - options |= QFileDialog.DontUseNativeDialog - options |= QFileDialog.ShowDirsOnly - - file_dialog = QFileDialog() - file_dialog.setOptions(options) - - file_dialog.setFileMode(QFileDialog.AnyFile) - file_dialog.setNameFilter("Audio Files (*.wav *.mp3)") - - if file_dialog.exec_(): - file_paths = file_dialog.selectedFiles() - - if len(file_paths) == 1: - self.update_ref_audio(file_paths[0]) - self.update_input_paths(self.ref_audio_input, file_paths[0]) - else: - self.update_ref_audio(", ".join(file_paths)) + file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)") + if file_path: + self.update_ref_audio(file_path) def upload_ref_text(self): file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") @@ -264,7 +238,6 @@ class GPTSoVITSGUI(QMainWindow): with open(file_path, 'r', encoding='utf-8') as file: content = file.read() self.ref_text_input.setText(content) - self.update_input_paths(self.ref_text_input, file_path) def upload_target_text(self): file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") @@ -272,7 +245,6 @@ class GPTSoVITSGUI(QMainWindow): with open(file_path, 'r', encoding='utf-8') as file: content = file.read() self.target_text_input.setText(content) - self.update_input_paths(self.target_text_input, file_path) def select_output_path(self): options = QFileDialog.Options() @@ -290,9 +262,6 @@ class GPTSoVITSGUI(QMainWindow): def update_ref_audio(self, file_path): self.ref_audio_input.setText(file_path) - def update_input_paths(self, input_box, file_path): - input_box.setText(file_path) - def clear_output(self): self.output_text.clear() @@ -300,23 +269,27 @@ class GPTSoVITSGUI(QMainWindow): GPT_model_path = self.GPT_model_input.text() SoVITS_model_path = self.SoVITS_model_input.text() ref_audio_path = self.ref_audio_input.text() - language_combobox = self.language_combobox.currentText() + language_combobox = self.ref_language_combobox.currentText() language_combobox = i18n(language_combobox) ref_text = self.ref_text_input.text() - language_combobox_02 = self.language_combobox_02.currentText() - language_combobox_02 = i18n(language_combobox_02) + target_language_combobox = self.target_language_combobox.currentText() + target_language_combobox = i18n(target_language_combobox) target_text = self.target_text_input.text() output_path = self.output_input.text() - change_gpt_weights(gpt_path=GPT_model_path) - change_sovits_weights(sovits_path=SoVITS_model_path) + if GPT_model_path != self.GPT_Path: + change_gpt_weights(gpt_path=GPT_model_path) + self.GPT_Path = GPT_model_path + if SoVITS_model_path != self.SoVITS_Path: + change_sovits_weights(sovits_path=SoVITS_model_path) + self.SoVITS_Path = SoVITS_model_path synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path, prompt_text=ref_text, prompt_language=language_combobox, text=target_text, - text_language=language_combobox_02) - + text_language=target_language_combobox) + result_list = list(synthesis_result) if result_list: @@ -329,12 +302,9 @@ class GPTSoVITSGUI(QMainWindow): self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000) self.output_text.append("处理结果:\n" + result) -def main(): + +if __name__ == '__main__': app = QApplication(sys.argv) mainWin = GPTSoVITSGUI() mainWin.show() - sys.exit(app.exec_()) - - -if __name__ == '__main__': - main() + sys.exit(app.exec_()) \ No newline at end of file diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4fe8045d..44c6d0eb 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -50,6 +50,7 @@ is_share = eval(is_share) if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +punctuation = set(['!', '?', '…', ',', '.', '-'," "]) import gradio as gr from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np @@ -64,7 +65,7 @@ from text import cleaned_text_to_sequence from text.cleaner import clean_text from time import time as ttime from module.mel_processing import spectrogram_torch -from my_utils import load_audio +from tools.my_utils import load_audio from tools.i18n.i18n import I18nAuto i18n = I18nAuto() @@ -322,6 +323,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." print(i18n("实际输入的参考文本:"), prompt_text) text = text.strip("\n") + text = replace_consecutive_punctuation(text) if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text print(i18n("实际输入的目标文本:"), text) @@ -366,6 +368,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, text = text.replace("\n\n", "\n") print(i18n("实际输入的目标文本(切句后):"), text) texts = text.split("\n") + texts = process_text(texts) texts = merge_short_text_in_array(texts, 5) audio_opt = [] if not ref_free: @@ -463,6 +466,7 @@ def cut1(inp): opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) else: opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] return "\n".join(opts) @@ -487,17 +491,21 @@ def cut2(inp): if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 opts[-2] = opts[-2] + opts[-1] opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] return "\n".join(opts) def cut3(inp): inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) - + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) def cut4(inp): inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) + opts = ["%s" % item for item in inp.strip(".").split(".")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py @@ -511,8 +519,8 @@ def cut5(inp): # 在句子不存在符号或句尾无符号的时候保证文本完整 if len(items)%2 == 1: mergeitems.append(items[-1]) - opt = "\n".join(mergeitems) - return opt + opt = [item for item in mergeitems if not set(item).issubset(punctuation)] + return "\n".join(opt) def custom_sort_key(s): @@ -522,6 +530,24 @@ def custom_sort_key(s): parts = [int(part) if part.isdigit() else part for part in parts] return parts +def process_text(texts): + _text=[] + if all(text in [None, " ", "\n",""] for text in texts): + raise ValueError(i18n("请输入有效文本")) + for text in texts: + if text in [None, " ", ""]: + pass + else: + _text.append(text) + return _text + + +def replace_consecutive_punctuation(text): + punctuations = ''.join(re.escape(p) for p in punctuation) + pattern = f'([{punctuations}])([{punctuations}])+' + result = re.sub(pattern, r'\1', text) + return result + def change_choices(): SoVITS_names, GPT_names = get_weights_names() @@ -613,10 +639,11 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: button5.click(cut5, [text_inp], [text_opt]) gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")) -app.queue(concurrency_count=511, max_size=1022).launch( - server_name="0.0.0.0", - inbrowser=True, - share=is_share, - server_port=infer_ttswebui, - quiet=True, -) +if __name__ == '__main__': + app.queue(concurrency_count=511, max_size=1022).launch( + server_name="0.0.0.0", + inbrowser=True, + share=is_share, + server_port=infer_ttswebui, + quiet=True, + ) diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py index ff4c4f43..72c80555 100644 --- a/GPT_SoVITS/module/data_utils.py +++ b/GPT_SoVITS/module/data_utils.py @@ -17,7 +17,7 @@ from functools import lru_cache import requests from scipy.io import wavfile from io import BytesIO -from my_utils import load_audio +from tools.my_utils import load_audio # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79) class TextAudioSpeakerLoader(torch.utils.data.Dataset): diff --git a/GPT_SoVITS/my_utils.py b/GPT_SoVITS/my_utils.py deleted file mode 100644 index 776939dd..00000000 --- a/GPT_SoVITS/my_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -import ffmpeg -import numpy as np - - -def load_audio(file, sr): - try: - # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 - # This launches a subprocess to decode audio while down-mixing and resampling as necessary. - # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. - file = ( - file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") - ) # 防止小白拷路径头尾带了空格和"和回车 - out, _ = ( - ffmpeg.input(file, threads=0) - .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) - .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) - ) - except Exception as e: - raise RuntimeError(f"Failed to load audio: {e}") - - return np.frombuffer(out, np.float32).flatten() diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index b82e987f..ab457d75 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -9,7 +9,7 @@ cnhubert.cnhubert_base_path=cnhubert_base_path ssl_model = cnhubert.get_model() from text import cleaned_text_to_sequence import soundfile -from my_utils import load_audio +from tools.my_utils import load_audio import os import json diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py index 61c933a4..17394ee4 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py @@ -17,7 +17,7 @@ from scipy.io import wavfile import librosa,torch now_dir = os.getcwd() sys.path.append(now_dir) -from my_utils import load_audio +from tools.my_utils import load_audio # from config import cnhubert_base_path # cnhubert.cnhubert_base_path=cnhubert_base_path diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 43cfa19a..ece295d3 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -79,15 +79,17 @@ class my_model_ckpt(ModelCheckpoint): to_save_od["config"] = self.config to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) # torch.save( - my_save( - to_save_od, - "%s/%s-e%s.ckpt" - % ( - self.half_weights_save_dir, - self.exp_name, - trainer.current_epoch + 1, - ), - ) + # print(os.environ) + if(os.environ.get("LOCAL_RANK","0")=="0"): + my_save( + to_save_od, + "%s/%s-e%s.ckpt" + % ( + self.half_weights_save_dir, + self.exp_name, + trainer.current_epoch + 1, + ), + ) self._save_last_checkpoint(trainer, monitor_candidates) diff --git a/api.py b/api.py index 041fa349..aa822ca7 100644 --- a/api.py +++ b/api.py @@ -143,7 +143,7 @@ from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text from module.mel_processing import spectrogram_torch -from my_utils import load_audio +from tools.my_utils import load_audio import config as global_config import logging import subprocess @@ -339,8 +339,46 @@ def pack_audio(audio_bytes, data, rate): def pack_ogg(audio_bytes, data, rate): - with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: - audio_file.write(data) + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: + audio_file.write(data) + + import threading + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") return audio_bytes diff --git a/docs/cn/Changelog_CN.md b/docs/cn/Changelog_CN.md index 36c1db45..abd7263f 100644 --- a/docs/cn/Changelog_CN.md +++ b/docs/cn/Changelog_CN.md @@ -169,6 +169,21 @@ 6-nan自动转fp32阶段的hubert提取bug修复 +### 20240610 + +小问题修复: + +1-完善纯标点、多标点文本输入的判断逻辑 https://github.com/RVC-Boss/GPT-SoVITS/pull/1168 https://github.com/RVC-Boss/GPT-SoVITS/pull/1169 + +2-uvr5中的mdxnet去混响cmd格式修复,兼容路径带空格 [#501a74a](https://github.com/RVC-Boss/GPT-SoVITS/commit/501a74ae96789a26b48932babed5eb4e9483a232) + +3-s2训练进度条逻辑修复 https://github.com/RVC-Boss/GPT-SoVITS/pull/1159 + +大问题修复: + +4-修复了webui的GPT中文微调没读到bert导致和推理不一致,训练太多可能效果还会变差的问题。如果大量数据微调的建议重新微调模型得到质量优化 [#99f09c8](https://github.com/RVC-Boss/GPT-SoVITS/commit/99f09c8bdc155c1f4272b511940717705509582a) + + todolist: 1-中文多音字推理优化(有没有人来测试的,欢迎把测试结果写在pr评论区里) https://github.com/RVC-Boss/GPT-SoVITS/pull/488 @@ -177,3 +192,5 @@ todolist: 2-正在尝试解决低音质参考音频导致音质较差的问题,v2再试试如果能解决就发了,节点暂定高考后吧 + + diff --git a/requirements.txt b/requirements.txt index 73912d01..bf2e28a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -numpy +numpy==1.23.4 scipy tensorboard librosa==0.9.2 diff --git a/tools/slice_audio.py b/tools/slice_audio.py index 46ee408a..8a06292d 100644 --- a/tools/slice_audio.py +++ b/tools/slice_audio.py @@ -3,7 +3,7 @@ import traceback from scipy.io import wavfile # parent_directory = os.path.dirname(os.path.abspath(__file__)) # sys.path.append(parent_directory) -from my_utils import load_audio +from tools.my_utils import load_audio from slicer2 import Slicer def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,i_part,all_part): diff --git a/tools/uvr5/mdxnet.py b/tools/uvr5/mdxnet.py index 0d609c41..372db25b 100644 --- a/tools/uvr5/mdxnet.py +++ b/tools/uvr5/mdxnet.py @@ -220,7 +220,7 @@ class Predictor: opt_path_other = path_other[:-4] + ".%s" % format if os.path.exists(path_vocal): os.system( - "ffmpeg -i %s -vn %s -q:a 2 -y" % (path_vocal, opt_path_vocal) + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal) ) if os.path.exists(opt_path_vocal): try: @@ -229,7 +229,7 @@ class Predictor: pass if os.path.exists(path_other): os.system( - "ffmpeg -i %s -vn %s -q:a 2 -y" % (path_other, opt_path_other) + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other) ) if os.path.exists(opt_path_other): try: diff --git a/webui.py b/webui.py index c71c1ca4..a200a747 100644 --- a/webui.py +++ b/webui.py @@ -85,7 +85,7 @@ if if_gpu_ok and len(gpu_infos) > 0: else: gpu_info = ("%s\t%s" % ("0", "CPU")) gpu_infos.append("%s\t%s" % ("0", "CPU")) - default_batch_size = psutil.virtual_memory().total/ 1024 / 1024 / 1024 / 2 + default_batch_size = int(psutil.virtual_memory().total/ 1024 / 1024 / 1024 / 2) gpus = "-".join([i[0] for i in gpu_infos]) pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"