diff --git a/GPT_SoVITS/inference_cli.py b/GPT_SoVITS/inference_cli.py
deleted file mode 100644
index bd987aa..0000000
--- a/GPT_SoVITS/inference_cli.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import argparse
-import os
-import soundfile as sf
-
-from tools.i18n.i18n import I18nAuto
-from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
-
-i18n = I18nAuto()
-
-def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
- # Read reference text
- with open(ref_text_path, 'r', encoding='utf-8') as file:
- ref_text = file.read()
-
- # Read target text
- with open(target_text_path, 'r', encoding='utf-8') as file:
- target_text = file.read()
-
- # Change model weights
- change_gpt_weights(gpt_path=GPT_model_path)
- change_sovits_weights(sovits_path=SoVITS_model_path)
-
- # Synthesize audio
- synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
- prompt_text=ref_text,
- prompt_language=i18n(ref_language),
- text=target_text,
- text_language=i18n(target_language), top_p=1, temperature=1)
-
- result_list = list(synthesis_result)
-
- if result_list:
- last_sampling_rate, last_audio_data = result_list[-1]
- output_wav_path = os.path.join(output_path, "output.wav")
- sf.write(output_wav_path, last_audio_data, last_sampling_rate)
- print(f"Audio saved to {output_wav_path}")
-
-def main():
- parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
- parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
- parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
- parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
- parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
- parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
- parser.add_argument('--target_text', required=True, help="Path to the target text file")
- parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
- parser.add_argument('--output_path', required=True, help="Path to the output directory")
-
- args = parser.parse_args()
-
- synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
-
-if __name__ == '__main__':
- main()
-
diff --git a/GPT_SoVITS/inference_gui.py b/GPT_SoVITS/inference_gui.py
deleted file mode 100644
index 2059155..0000000
--- a/GPT_SoVITS/inference_gui.py
+++ /dev/null
@@ -1,310 +0,0 @@
-import os
-import sys
-from PyQt5.QtCore import QEvent
-from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
-from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
-import soundfile as sf
-
-from tools.i18n.i18n import I18nAuto
-i18n = I18nAuto()
-
-from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
-
-
-class GPTSoVITSGUI(QMainWindow):
- GPT_Path = gpt_path
- SoVITS_Path = sovits_path
-
- def __init__(self):
- super().__init__()
-
- self.setWindowTitle('GPT-SoVITS GUI')
- self.setGeometry(800, 450, 950, 850)
-
- self.setStyleSheet("""
- QWidget {
- background-color: #a3d3b1;
- }
-
- QTabWidget::pane {
- background-color: #a3d3b1;
- }
-
- QTabWidget::tab-bar {
- alignment: left;
- }
-
- QTabBar::tab {
- background: #8da4bf;
- color: #ffffff;
- padding: 8px;
- }
-
- QTabBar::tab:selected {
- background: #2a3f54;
- }
-
- QLabel {
- color: #000000;
- }
-
- QPushButton {
- background-color: #4CAF50;
- color: white;
- padding: 8px;
- border: 1px solid #4CAF50;
- border-radius: 4px;
- }
-
- QPushButton:hover {
- background-color: #45a049;
- border: 1px solid #45a049;
- box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
- }
- """)
-
- license_text = (
- "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
- "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
- license_label = QLabel(license_text)
- license_label.setWordWrap(True)
-
- self.GPT_model_label = QLabel("选择GPT模型:")
- self.GPT_model_input = QLineEdit()
- self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
- self.GPT_model_input.setText(self.GPT_Path)
- self.GPT_model_input.setReadOnly(True)
- self.GPT_model_button = QPushButton("选择GPT模型文件")
- self.GPT_model_button.clicked.connect(self.select_GPT_model)
-
- self.SoVITS_model_label = QLabel("选择SoVITS模型:")
- self.SoVITS_model_input = QLineEdit()
- self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
- self.SoVITS_model_input.setText(self.SoVITS_Path)
- self.SoVITS_model_input.setReadOnly(True)
- self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
- self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
-
- self.ref_audio_label = QLabel("上传参考音频:")
- self.ref_audio_input = QLineEdit()
- self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
- self.ref_audio_input.setReadOnly(True)
- self.ref_audio_button = QPushButton("选择音频文件")
- self.ref_audio_button.clicked.connect(self.select_ref_audio)
-
- self.ref_text_label = QLabel("参考音频文本:")
- self.ref_text_input = QLineEdit()
- self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
- self.ref_text_button = QPushButton("上传文本")
- self.ref_text_button.clicked.connect(self.upload_ref_text)
-
- self.ref_language_label = QLabel("参考音频语言:")
- self.ref_language_combobox = QComboBox()
- self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
- self.ref_language_combobox.setCurrentText("多语种混合")
-
- self.target_text_label = QLabel("合成目标文本:")
- self.target_text_input = QLineEdit()
- self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
- self.target_text_button = QPushButton("上传文本")
- self.target_text_button.clicked.connect(self.upload_target_text)
-
- self.target_language_label = QLabel("合成音频语言:")
- self.target_language_combobox = QComboBox()
- self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
- self.target_language_combobox.setCurrentText("多语种混合")
-
- self.output_label = QLabel("输出音频路径:")
- self.output_input = QLineEdit()
- self.output_input.setPlaceholderText("拖拽或选择文件")
- self.output_input.setReadOnly(True)
- self.output_button = QPushButton("选择文件夹")
- self.output_button.clicked.connect(self.select_output_path)
-
- self.output_text = QTextEdit()
- self.output_text.setReadOnly(True)
-
- self.add_drag_drop_events([
- self.GPT_model_input,
- self.SoVITS_model_input,
- self.ref_audio_input,
- self.ref_text_input,
- self.target_text_input,
- self.output_input,
- ])
-
- self.synthesize_button = QPushButton("合成")
- self.synthesize_button.clicked.connect(self.synthesize)
-
- self.clear_output_button = QPushButton("清空输出")
- self.clear_output_button.clicked.connect(self.clear_output)
-
- self.status_bar = QStatusBar()
-
- main_layout = QVBoxLayout()
-
- input_layout = QGridLayout(self)
- input_layout.setSpacing(10)
-
- input_layout.addWidget(license_label, 0, 0, 1, 3)
-
- input_layout.addWidget(self.GPT_model_label, 1, 0)
- input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
- input_layout.addWidget(self.GPT_model_button, 2, 2)
-
- input_layout.addWidget(self.SoVITS_model_label, 3, 0)
- input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
- input_layout.addWidget(self.SoVITS_model_button, 4, 2)
-
- input_layout.addWidget(self.ref_audio_label, 5, 0)
- input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
- input_layout.addWidget(self.ref_audio_button, 6, 2)
-
- input_layout.addWidget(self.ref_language_label, 7, 0)
- input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
- input_layout.addWidget(self.ref_text_label, 9, 0)
- input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
- input_layout.addWidget(self.ref_text_button, 10, 2)
-
- input_layout.addWidget(self.target_language_label, 11, 0)
- input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
- input_layout.addWidget(self.target_text_label, 13, 0)
- input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
- input_layout.addWidget(self.target_text_button, 14, 2)
-
- input_layout.addWidget(self.output_label, 15, 0)
- input_layout.addWidget(self.output_input, 16, 0, 1, 2)
- input_layout.addWidget(self.output_button, 16, 2)
-
- main_layout.addLayout(input_layout)
-
- output_layout = QVBoxLayout()
- output_layout.addWidget(self.output_text)
- main_layout.addLayout(output_layout)
-
- main_layout.addWidget(self.synthesize_button)
-
- main_layout.addWidget(self.clear_output_button)
-
- main_layout.addWidget(self.status_bar)
-
- self.central_widget = QWidget()
- self.central_widget.setLayout(main_layout)
- self.setCentralWidget(self.central_widget)
-
- def dragEnterEvent(self, event):
- if event.mimeData().hasUrls():
- event.acceptProposedAction()
-
- def dropEvent(self, event):
- if event.mimeData().hasUrls():
- file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
- if len(file_paths) == 1:
- self.update_ref_audio(file_paths[0])
- else:
- self.update_ref_audio(", ".join(file_paths))
-
- def add_drag_drop_events(self, widgets):
- for widget in widgets:
- widget.setAcceptDrops(True)
- widget.installEventFilter(self)
-
- def eventFilter(self, obj, event):
- if event.type() in (QEvent.DragEnter, QEvent.Drop):
- mime_data = event.mimeData()
- if mime_data.hasUrls():
- event.acceptProposedAction()
-
- return super().eventFilter(obj, event)
-
- def select_GPT_model(self):
- file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
- if file_path:
- self.GPT_model_input.setText(file_path)
-
- def select_SoVITS_model(self):
- file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
- if file_path:
- self.SoVITS_model_input.setText(file_path)
-
- def select_ref_audio(self):
- file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
- if file_path:
- self.update_ref_audio(file_path)
-
- def upload_ref_text(self):
- file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
- if file_path:
- with open(file_path, 'r', encoding='utf-8') as file:
- content = file.read()
- self.ref_text_input.setText(content)
-
- def upload_target_text(self):
- file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
- if file_path:
- with open(file_path, 'r', encoding='utf-8') as file:
- content = file.read()
- self.target_text_input.setText(content)
-
- def select_output_path(self):
- options = QFileDialog.Options()
- options |= QFileDialog.DontUseNativeDialog
- options |= QFileDialog.ShowDirsOnly
-
- folder_dialog = QFileDialog()
- folder_dialog.setOptions(options)
- folder_dialog.setFileMode(QFileDialog.Directory)
-
- if folder_dialog.exec_():
- folder_path = folder_dialog.selectedFiles()[0]
- self.output_input.setText(folder_path)
-
- def update_ref_audio(self, file_path):
- self.ref_audio_input.setText(file_path)
-
- def clear_output(self):
- self.output_text.clear()
-
- def synthesize(self):
- GPT_model_path = self.GPT_model_input.text()
- SoVITS_model_path = self.SoVITS_model_input.text()
- ref_audio_path = self.ref_audio_input.text()
- language_combobox = self.ref_language_combobox.currentText()
- language_combobox = i18n(language_combobox)
- ref_text = self.ref_text_input.text()
- target_language_combobox = self.target_language_combobox.currentText()
- target_language_combobox = i18n(target_language_combobox)
- target_text = self.target_text_input.text()
- output_path = self.output_input.text()
-
- if GPT_model_path != self.GPT_Path:
- change_gpt_weights(gpt_path=GPT_model_path)
- self.GPT_Path = GPT_model_path
- if SoVITS_model_path != self.SoVITS_Path:
- change_sovits_weights(sovits_path=SoVITS_model_path)
- self.SoVITS_Path = SoVITS_model_path
-
- synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
- prompt_text=ref_text,
- prompt_language=language_combobox,
- text=target_text,
- text_language=target_language_combobox)
-
- result_list = list(synthesis_result)
-
- if result_list:
- last_sampling_rate, last_audio_data = result_list[-1]
- output_wav_path = os.path.join(output_path, "output.wav")
- sf.write(output_wav_path, last_audio_data, last_sampling_rate)
-
- result = "Audio saved to " + output_wav_path
-
- self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
- self.output_text.append("处理结果:\n" + result)
-
-
-if __name__ == '__main__':
- app = QApplication(sys.argv)
- mainWin = GPTSoVITSGUI()
- mainWin.show()
- sys.exit(app.exec_())
\ No newline at end of file
diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py
deleted file mode 100644
index afae2cf..0000000
--- a/GPT_SoVITS/inference_webui.py
+++ /dev/null
@@ -1,952 +0,0 @@
-'''
-按中英混合识别
-按日英混合识别
-多语种启动切分识别语种
-全部按中文识别
-全部按英文识别
-全部按日文识别
-'''
-import logging
-import traceback,torchaudio,warnings
-logging.getLogger("markdown_it").setLevel(logging.ERROR)
-logging.getLogger("urllib3").setLevel(logging.ERROR)
-logging.getLogger("httpcore").setLevel(logging.ERROR)
-logging.getLogger("httpx").setLevel(logging.ERROR)
-logging.getLogger("asyncio").setLevel(logging.ERROR)
-logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
-logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
-logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
-warnings.simplefilter(action='ignore', category=FutureWarning)
-
-import os, re, sys, json
-import pdb
-import torch
-from text.LangSegmenter import LangSegmenter
-
-try:
- import gradio.analytics as analytics
- analytics.version_check = lambda:None
-except:...
-version=model_version=os.environ.get("version","v2")
-path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
-is_exist_s2gv3=os.path.exists(path_sovits_v3)
-pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
-pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
-
-
-
-_ =[[],[]]
-for i in range(3):
- if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
- if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
-pretrained_gpt_name,pretrained_sovits_name = _
-
-
-if os.path.exists(f"./weight.json"):
- pass
-else:
- with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
-
-with open(f"./weight.json", 'r', encoding="utf-8") as file:
- weight_data = file.read()
- weight_data=json.loads(weight_data)
- gpt_path = os.environ.get(
- "gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
- sovits_path = os.environ.get(
- "sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
- if isinstance(gpt_path,list):
- gpt_path = gpt_path[0]
- if isinstance(sovits_path,list):
- sovits_path = sovits_path[0]
-
-# gpt_path = os.environ.get(
-# "gpt_path", pretrained_gpt_name
-# )
-# sovits_path = os.environ.get("sovits_path", pretrained_sovits_name)
-cnhubert_base_path = os.environ.get(
- "cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
-)
-bert_path = os.environ.get(
- "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
-)
-infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
-infer_ttswebui = int(infer_ttswebui)
-is_share = os.environ.get("is_share", "False")
-is_share = eval(is_share)
-if "_CUDA_VISIBLE_DEVICES" in os.environ:
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
-is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
-# is_half=False
-punctuation = set(['!', '?', '…', ',', '.', '-'," "])
-import gradio as gr
-from transformers import AutoModelForMaskedLM, AutoTokenizer
-import numpy as np
-import librosa
-from feature_extractor import cnhubert
-
-cnhubert.cnhubert_base_path = cnhubert_base_path
-
-from GPT_SoVITS.module.models import SynthesizerTrn,SynthesizerTrnV3
-import numpy as np
-import random
-def set_seed(seed):
- if seed == -1:
- seed = random.randint(0, 1000000)
- seed = int(seed)
- random.seed(seed)
- os.environ["PYTHONHASHSEED"] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
-# set_seed(42)
-
-from AR.models.t2s_lightning_module import Text2SemanticLightningModule
-from text import cleaned_text_to_sequence
-from text.cleaner import clean_text
-from time import time as ttime
-from tools.my_utils import load_audio
-from tools.i18n.i18n import I18nAuto, scan_language_list
-from peft import LoraConfig, PeftModel, get_peft_model
-
-language=os.environ.get("language","Auto")
-language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
-i18n = I18nAuto(language=language)
-
-# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
-
-if torch.cuda.is_available():
- device = "cuda"
-else:
- device = "cpu"
-
-dict_language_v1 = {
- i18n("中文"): "all_zh",#全部按中文识别
- i18n("英文"): "en",#全部按英文识别#######不变
- i18n("日文"): "all_ja",#全部按日文识别
- i18n("中英混合"): "zh",#按中英混合识别####不变
- i18n("日英混合"): "ja",#按日英混合识别####不变
- i18n("多语种混合"): "auto",#多语种启动切分识别语种
-}
-dict_language_v2 = {
- i18n("中文"): "all_zh",#全部按中文识别
- i18n("英文"): "en",#全部按英文识别#######不变
- i18n("日文"): "all_ja",#全部按日文识别
- i18n("粤语"): "all_yue",#全部按中文识别
- i18n("韩文"): "all_ko",#全部按韩文识别
- i18n("中英混合"): "zh",#按中英混合识别####不变
- i18n("日英混合"): "ja",#按日英混合识别####不变
- i18n("粤英混合"): "yue",#按粤英混合识别####不变
- i18n("韩英混合"): "ko",#按韩英混合识别####不变
- i18n("多语种混合"): "auto",#多语种启动切分识别语种
- i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
-}
-dict_language = dict_language_v1 if version =='v1' else dict_language_v2
-
-tokenizer = AutoTokenizer.from_pretrained(bert_path)
-bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
-if is_half == True:
- bert_model = bert_model.half().to(device)
-else:
- bert_model = bert_model.to(device)
-
-
-def get_bert_feature(text, word2ph):
- with torch.no_grad():
- inputs = tokenizer(text, return_tensors="pt")
- for i in inputs:
- inputs[i] = inputs[i].to(device)
- res = bert_model(**inputs, output_hidden_states=True)
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
- assert len(word2ph) == len(text)
- phone_level_feature = []
- for i in range(len(word2ph)):
- repeat_feature = res[i].repeat(word2ph[i], 1)
- phone_level_feature.append(repeat_feature)
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
- return phone_level_feature.T
-
-
-class DictToAttrRecursive(dict):
- def __init__(self, input_dict):
- super().__init__(input_dict)
- for key, value in input_dict.items():
- if isinstance(value, dict):
- value = DictToAttrRecursive(value)
- self[key] = value
- setattr(self, key, value)
-
- def __getattr__(self, item):
- try:
- return self[item]
- except KeyError:
- raise AttributeError(f"Attribute {item} not found")
-
- def __setattr__(self, key, value):
- if isinstance(value, dict):
- value = DictToAttrRecursive(value)
- super(DictToAttrRecursive, self).__setitem__(key, value)
- super().__setattr__(key, value)
-
- def __delattr__(self, item):
- try:
- del self[item]
- except KeyError:
- raise AttributeError(f"Attribute {item} not found")
-
-
-ssl_model = cnhubert.get_model()
-if is_half == True:
- ssl_model = ssl_model.half().to(device)
-else:
- ssl_model = ssl_model.to(device)
-
-resample_transform_dict={}
-def resample(audio_tensor, sr0):
- global resample_transform_dict
- if sr0 not in resample_transform_dict:
- resample_transform_dict[sr0] = torchaudio.transforms.Resample(
- sr0, 24000
- ).to(device)
- return resample_transform_dict[sr0](audio_tensor)
-
-###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
-#symbol_version-model_version-if_lora_v3
-from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
-def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
- global vq_model, hps, version, model_version, dict_language,if_lora_v3
- version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
- # print(sovits_path,version, model_version, if_lora_v3)
- if if_lora_v3==True and is_exist_s2gv3==False:
- info= "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
- gr.Warning(info)
- raise FileExistsError(info)
- dict_language = dict_language_v1 if version =='v1' else dict_language_v2
- if prompt_language is not None and text_language is not None:
- if prompt_language in list(dict_language.keys()):
- prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
- else:
- prompt_text_update = {'__type__':'update', 'value':''}
- prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
- if text_language in list(dict_language.keys()):
- text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
- else:
- text_update = {'__type__':'update', 'value':''}
- text_language_update = {'__type__':'update', 'value':i18n("中文")}
- if model_version=="v3":
- visible_sample_steps=True
- visible_inp_refs=False
- else:
- visible_sample_steps=False
- visible_inp_refs=True
- yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
-
- dict_s2 = load_sovits_new(sovits_path)
- hps = dict_s2["config"]
- hps = DictToAttrRecursive(hps)
- hps.model.semantic_frame_rate = "25hz"
- if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
- hps.model.version = "v2"#v3model,v2sybomls
- elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
- hps.model.version = "v1"
- else:
- hps.model.version = "v2"
- version=hps.model.version
- # print("sovits版本:",hps.model.version)
- if model_version!="v3":
- vq_model = SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model
- )
- model_version=version
- else:
- vq_model = SynthesizerTrnV3(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model
- )
- if ("pretrained" not in sovits_path):
- try:
- del vq_model.enc_q
- except:pass
- if is_half == True:
- vq_model = vq_model.half().to(device)
- else:
- vq_model = vq_model.to(device)
- vq_model.eval()
- if if_lora_v3==False:
- print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False))
- else:
- print("loading sovits_v3pretrained_G", vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False))
- lora_rank=dict_s2["lora_rank"]
- lora_config = LoraConfig(
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
- r=lora_rank,
- lora_alpha=lora_rank,
- init_lora_weights=True,
- )
- vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
- print("loading sovits_v3_lora%s"%(lora_rank))
- vq_model.load_state_dict(dict_s2["weight"], strict=False)
- vq_model.cfm = vq_model.cfm.merge_and_unload()
- # torch.save(vq_model.state_dict(),"merge_win.pth")
- vq_model.eval()
-
- with open("./weight.json")as f:
- data=f.read()
- data=json.loads(data)
- data["SoVITS"][version]=sovits_path
- with open("./weight.json","w")as f:f.write(json.dumps(data))
-
-
-try:next(change_sovits_weights(sovits_path))
-except:pass
-
-def change_gpt_weights(gpt_path):
- global hz, max_sec, t2s_model, config
- hz = 50
- dict_s1 = torch.load(gpt_path, map_location="cpu")
- config = dict_s1["config"]
- max_sec = config["data"]["max_sec"]
- t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
- t2s_model.load_state_dict(dict_s1["weight"])
- if is_half == True:
- t2s_model = t2s_model.half()
- t2s_model = t2s_model.to(device)
- t2s_model.eval()
- # total = sum([param.nelement() for param in t2s_model.parameters()])
- # print("Number of parameter: %.2fM" % (total / 1e6))
- with open("./weight.json")as f:
- data=f.read()
- data=json.loads(data)
- data["GPT"][version]=gpt_path
- with open("./weight.json","w")as f:f.write(json.dumps(data))
-
-
-change_gpt_weights(gpt_path)
-os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
-import torch,soundfile
-now_dir = os.getcwd()
-import soundfile
-
-def init_bigvgan():
- global bigvgan_model
- from BigVGAN import bigvgan
- bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
- # remove weight norm in the model and set to eval mode
- bigvgan_model.remove_weight_norm()
- bigvgan_model = bigvgan_model.eval()
- if is_half == True:
- bigvgan_model = bigvgan_model.half().to(device)
- else:
- bigvgan_model = bigvgan_model.to(device)
-
-if model_version!="v3":bigvgan_model=None
-else:init_bigvgan()
-
-
-def get_spepc(hps, filename):
- # audio = load_audio(filename, int(hps.data.sampling_rate))
- audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate))
- audio = torch.FloatTensor(audio)
- maxx=audio.abs().max()
- if(maxx>1):audio/=min(2,maxx)
- audio_norm = audio
- audio_norm = audio_norm.unsqueeze(0)
- spec = spectrogram_torch(
- audio_norm,
- hps.data.filter_length,
- hps.data.sampling_rate,
- hps.data.hop_length,
- hps.data.win_length,
- center=False,
- )
- return spec
-
-def clean_text_inf(text, language, version):
- language = language.replace("all_","")
- phones, word2ph, norm_text = clean_text(text, language, version)
- phones = cleaned_text_to_sequence(phones, version)
- return phones, word2ph, norm_text
-
-dtype=torch.float16 if is_half == True else torch.float32
-def get_bert_inf(phones, word2ph, norm_text, language):
- language=language.replace("all_","")
- if language == "zh":
- bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
- else:
- bert = torch.zeros(
- (1024, len(phones)),
- dtype=torch.float16 if is_half == True else torch.float32,
- ).to(device)
-
- return bert
-
-
-splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
-
-
-def get_first(text):
- pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
- text = re.split(pattern, text)[0].strip()
- return text
-
-from text import chinese
-def get_phones_and_bert(text,language,version,final=False):
- if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
- formattext = text
- while " " in formattext:
- formattext = formattext.replace(" ", " ")
- if language == "all_zh":
- if re.search(r'[A-Za-z]', formattext):
- formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
- formattext = chinese.mix_text_normalize(formattext)
- return get_phones_and_bert(formattext,"zh",version)
- else:
- phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
- bert = get_bert_feature(norm_text, word2ph).to(device)
- elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
- formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
- formattext = chinese.mix_text_normalize(formattext)
- return get_phones_and_bert(formattext,"yue",version)
- else:
- phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
- bert = torch.zeros(
- (1024, len(phones)),
- dtype=torch.float16 if is_half == True else torch.float32,
- ).to(device)
- elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
- textlist=[]
- langlist=[]
- if language == "auto":
- for tmp in LangSegmenter.getTexts(text):
- langlist.append(tmp["lang"])
- textlist.append(tmp["text"])
- elif language == "auto_yue":
- for tmp in LangSegmenter.getTexts(text):
- if tmp["lang"] == "zh":
- tmp["lang"] = "yue"
- langlist.append(tmp["lang"])
- textlist.append(tmp["text"])
- else:
- for tmp in LangSegmenter.getTexts(text):
- if tmp["lang"] == "en":
- langlist.append(tmp["lang"])
- else:
- # 因无法区别中日韩文汉字,以用户输入为准
- langlist.append(language)
- textlist.append(tmp["text"])
- print(textlist)
- print(langlist)
- phones_list = []
- bert_list = []
- norm_text_list = []
- for i in range(len(textlist)):
- lang = langlist[i]
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
- bert = get_bert_inf(phones, word2ph, norm_text, lang)
- phones_list.append(phones)
- norm_text_list.append(norm_text)
- bert_list.append(bert)
- bert = torch.cat(bert_list, dim=1)
- phones = sum(phones_list, [])
- norm_text = ''.join(norm_text_list)
-
- if not final and len(phones) < 6:
- return get_phones_and_bert("." + text,language,version,final=True)
-
- return phones,bert.to(dtype),norm_text
-
-from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
-spec_min = -12
-spec_max = 2
-def norm_spec(x):
- return (x - spec_min) / (spec_max - spec_min) * 2 - 1
-def denorm_spec(x):
- return (x + 1) / 2 * (spec_max - spec_min) + spec_min
-mel_fn=lambda x: mel_spectrogram_torch(x, **{
- "n_fft": 1024,
- "win_size": 1024,
- "hop_size": 256,
- "num_mels": 100,
- "sampling_rate": 24000,
- "fmin": 0,
- "fmax": None,
- "center": False
-})
-
-def merge_short_text_in_array(texts, threshold):
- if (len(texts)) < 2:
- return texts
- result = []
- text = ""
- for ele in texts:
- text += ele
- if len(text) >= threshold:
- result.append(text)
- text = ""
- if (len(text) > 0):
- if len(result) == 0:
- result.append(text)
- else:
- result[len(result) - 1] += text
- return result
-
-sr_model=None
-def audio_sr(audio,sr):
- global sr_model
- if sr_model==None:
- from tools.audio_sr import AP_BWE
- try:
- sr_model=AP_BWE(device,DictToAttrRecursive)
- except FileNotFoundError:
- gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
- return audio.cpu().detach().numpy(),sr
- return sr_model(audio,sr)
-
-
-##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
-# cache_tokens={}#暂未实现清理机制
-cache= {}
-def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=8,if_sr=False,pause_second=0.3):
- global cache
- if ref_wav_path:pass
- else:gr.Warning(i18n('请上传参考音频'))
- if text:pass
- else:gr.Warning(i18n('请填入推理文本'))
- t = []
- if prompt_text is None or len(prompt_text) == 0:
- ref_free = True
- if model_version=="v3":
- ref_free=False#s2v3暂不支持ref_free
- else:
- if_sr=False
- t0 = ttime()
- prompt_language = dict_language[prompt_language]
- text_language = dict_language[text_language]
-
-
- if not ref_free:
- prompt_text = prompt_text.strip("\n")
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
- print(i18n("实际输入的参考文本:"), prompt_text)
- text = text.strip("\n")
- # if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
-
- print(i18n("实际输入的目标文本:"), text)
- zero_wav = np.zeros(
- int(hps.data.sampling_rate * pause_second),
- dtype=np.float16 if is_half == True else np.float32,
- )
- zero_wav_torch = torch.from_numpy(zero_wav)
- if is_half == True:
- zero_wav_torch = zero_wav_torch.half().to(device)
- else:
- zero_wav_torch = zero_wav_torch.to(device)
- if not ref_free:
- with torch.no_grad():
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
- if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
- gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
- raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
- wav16k = torch.from_numpy(wav16k)
- if is_half == True:
- wav16k = wav16k.half().to(device)
- else:
- wav16k = wav16k.to(device)
- wav16k = torch.cat([wav16k, zero_wav_torch])
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
- "last_hidden_state"
- ].transpose(
- 1, 2
- ) # .float()
- codes = vq_model.extract_latent(ssl_content)
- prompt_semantic = codes[0, 0]
- prompt = prompt_semantic.unsqueeze(0).to(device)
-
- t1 = ttime()
- t.append(t1-t0)
-
- if (how_to_cut == i18n("凑四句一切")):
- text = cut1(text)
- elif (how_to_cut == i18n("凑50字一切")):
- text = cut2(text)
- elif (how_to_cut == i18n("按中文句号。切")):
- text = cut3(text)
- elif (how_to_cut == i18n("按英文句号.切")):
- text = cut4(text)
- elif (how_to_cut == i18n("按标点符号切")):
- text = cut5(text)
- while "\n\n" in text:
- text = text.replace("\n\n", "\n")
- print(i18n("实际输入的目标文本(切句后):"), text)
- texts = text.split("\n")
- texts = process_text(texts)
- texts = merge_short_text_in_array(texts, 5)
- audio_opt = []
- ###s2v3暂不支持ref_free
- if not ref_free:
- phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
-
- for i_text,text in enumerate(texts):
- # 解决输入目标文本的空行导致报错的问题
- if (len(text.strip()) == 0):
- continue
- if (text[-1] not in splits): text += "。" if text_language != "en" else "."
- print(i18n("实际输入的目标文本(每句):"), text)
- phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
- print(i18n("前端处理后的文本(每句):"), norm_text2)
- if not ref_free:
- bert = torch.cat([bert1, bert2], 1)
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
- else:
- bert = bert2
- all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
-
- bert = bert.to(device).unsqueeze(0)
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
-
- t2 = ttime()
- # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
- # print(cache.keys(),if_freeze)
- if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
- else:
- with torch.no_grad():
- pred_semantic, idx = t2s_model.model.infer_panel(
- all_phoneme_ids,
- all_phoneme_len,
- None if ref_free else prompt,
- bert,
- # prompt_phone_len=ph_offset,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- early_stop_num=hz * max_sec,
- )
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
- cache[i_text]=pred_semantic
- t3 = ttime()
- ###v3不存在以下逻辑和inp_refs
- if model_version!="v3":
- refers=[]
- if(inp_refs):
- for path in inp_refs:
- try:
- refer = get_spepc(hps, path.name).to(dtype).to(device)
- refers.append(refer)
- except:
- traceback.print_exc()
- if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed)[0][0]#.cpu().detach().numpy()
- else:
- refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
- phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
- phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
- # print(11111111, phoneme_ids0, phoneme_ids1)
- fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
- ref_audio, sr = torchaudio.load(ref_wav_path)
- ref_audio=ref_audio.to(device).float()
- if (ref_audio.shape[0] == 2):
- ref_audio = ref_audio.mean(0).unsqueeze(0)
- if sr!=24000:
- ref_audio=resample(ref_audio,sr)
- # print("ref_audio",ref_audio.abs().mean())
- mel2 = mel_fn(ref_audio)
- mel2 = norm_spec(mel2)
- T_min = min(mel2.shape[2], fea_ref.shape[2])
- mel2 = mel2[:, :, :T_min]
- fea_ref = fea_ref[:, :, :T_min]
- if (T_min > 468):
- mel2 = mel2[:, :, -468:]
- fea_ref = fea_ref[:, :, -468:]
- T_min = 468
- chunk_len = 934 - T_min
- # print("fea_ref",fea_ref,fea_ref.shape)
- # print("mel2",mel2)
- mel2=mel2.to(dtype)
- fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
- # print("fea_todo",fea_todo)
- # print("ge",ge.abs().mean())
- cfm_resss = []
- idx = 0
- while (1):
- fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
- if (fea_todo_chunk.shape[-1] == 0): break
- idx += chunk_len
- fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
- # set_seed(123)
- cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
- cfm_res = cfm_res[:, :, mel2.shape[2]:]
- mel2 = cfm_res[:, :, -T_min:]
- # print("fea", fea)
- # print("mel2in", mel2)
- fea_ref = fea_todo_chunk[:, :, -T_min:]
- cfm_resss.append(cfm_res)
- cmf_res = torch.cat(cfm_resss, 2)
- cmf_res = denorm_spec(cmf_res)
- if bigvgan_model==None:init_bigvgan()
- with torch.inference_mode():
- wav_gen = bigvgan_model(cmf_res)
- audio=wav_gen[0][0]#.cpu().detach().numpy()
- max_audio=torch.abs(audio).max()#简单防止16bit爆音
- if max_audio>1:audio=audio/max_audio
- audio_opt.append(audio)
- audio_opt.append(zero_wav_torch)#zero_wav
- t4 = ttime()
- t.extend([t2 - t1,t3 - t2, t4 - t3])
- t1 = ttime()
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
- audio_opt=torch.cat(audio_opt, 0)#np.concatenate
- sr=hps.data.sampling_rate if model_version!="v3"else 24000
- if if_sr==True and sr==24000:
- print(i18n("音频超分中"))
- audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
- max_audio=np.abs(audio_opt).max()
- if max_audio > 1: audio_opt /= max_audio
- else:
- audio_opt=audio_opt.cpu().detach().numpy()
- yield sr, (audio_opt * 32767).astype(np.int16)
-
-
-def split(todo_text):
- todo_text = todo_text.replace("……", "。").replace("——", ",")
- if todo_text[-1] not in splits:
- todo_text += "。"
- i_split_head = i_split_tail = 0
- len_text = len(todo_text)
- todo_texts = []
- while 1:
- if i_split_head >= len_text:
- break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
- if todo_text[i_split_head] in splits:
- i_split_head += 1
- todo_texts.append(todo_text[i_split_tail:i_split_head])
- i_split_tail = i_split_head
- else:
- i_split_head += 1
- return todo_texts
-
-
-def cut1(inp):
- inp = inp.strip("\n")
- inps = split(inp)
- split_idx = list(range(0, len(inps), 4))
- split_idx[-1] = None
- if len(split_idx) > 1:
- opts = []
- for idx in range(len(split_idx) - 1):
- opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
- else:
- opts = [inp]
- opts = [item for item in opts if not set(item).issubset(punctuation)]
- return "\n".join(opts)
-
-
-def cut2(inp):
- inp = inp.strip("\n")
- inps = split(inp)
- if len(inps) < 2:
- return inp
- opts = []
- summ = 0
- tmp_str = ""
- for i in range(len(inps)):
- summ += len(inps[i])
- tmp_str += inps[i]
- if summ > 50:
- summ = 0
- opts.append(tmp_str)
- tmp_str = ""
- if tmp_str != "":
- opts.append(tmp_str)
- # print(opts)
- if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
- opts[-2] = opts[-2] + opts[-1]
- opts = opts[:-1]
- opts = [item for item in opts if not set(item).issubset(punctuation)]
- return "\n".join(opts)
-
-
-def cut3(inp):
- inp = inp.strip("\n")
- opts = ["%s" % item for item in inp.strip("。").split("。")]
- opts = [item for item in opts if not set(item).issubset(punctuation)]
- return "\n".join(opts)
-
-def cut4(inp):
- inp = inp.strip("\n")
- opts = re.split(r'(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
- items.append(char)
- else:
- items.append(char)
- mergeitems.append("".join(items))
- items = []
- else:
- items.append(char)
-
- if items:
- mergeitems.append("".join(items))
-
- opt = [item for item in mergeitems if not set(item).issubset(punds)]
- return "\n".join(opt)
-
-
-def custom_sort_key(s):
- # 使用正则表达式提取字符串中的数字部分和非数字部分
- parts = re.split('(\d+)', s)
- # 将数字部分转换为整数,非数字部分保持不变
- parts = [int(part) if part.isdigit() else part for part in parts]
- return parts
-
-def process_text(texts):
- _text=[]
- if all(text in [None, " ", "\n",""] for text in texts):
- raise ValueError(i18n("请输入有效文本"))
- for text in texts:
- if text in [None, " ", ""]:
- pass
- else:
- _text.append(text)
- return _text
-
-
-def change_choices():
- SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
- return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
-
-
-SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
-GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
-for path in SoVITS_weight_root+GPT_weight_root:
- os.makedirs(path,exist_ok=True)
-
-
-def get_weights_names(GPT_weight_root, SoVITS_weight_root):
- SoVITS_names = [i for i in pretrained_sovits_name]
- for path in SoVITS_weight_root:
- for name in os.listdir(path):
- if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
- GPT_names = [i for i in pretrained_gpt_name]
- for path in GPT_weight_root:
- for name in os.listdir(path):
- if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
- return SoVITS_names, GPT_names
-
-
-SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
-
-def html_center(text, label='p'):
- return f"""
- <{label} style="margin: 0; padding: 0;">{text}{label}>
-
"""
-
-def html_left(text, label='p'):
- return f"""
- <{label} style="margin: 0; padding: 0;">{text}{label}>
-
"""
-
-
-with gr.Blocks(title="GPT-SoVITS WebUI") as app:
- gr.Markdown(
- value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "
" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
- )
- with gr.Group():
- gr.Markdown(html_center(i18n("模型切换"),'h3'))
- with gr.Row():
- GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True, scale=14)
- SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True, scale=14)
- refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14)
- refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
- gr.Markdown(html_center(i18n("*请上传并填写参考信息"),'h3'))
- with gr.Row():
- inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath", scale=13)
- with gr.Column(scale=13):
- ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")+i18n("v3暂不支持该模式,使用了会报错。"), value=False, interactive=True, show_label=True,scale=1)
- gr.Markdown(html_left(i18n("使用无参考文本模式时建议使用微调的GPT")+"
"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")))
- prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=5, max_lines=5,scale=1)
- with gr.Column(scale=14):
- prompt_language = gr.Dropdown(
- label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文"),
- )
- inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple")if model_version!="v3"else gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple",visible=False)
- sample_steps = gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=32,choices=[4,8,16,32],visible=True)if model_version=="v3"else gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),choices=[4,8,16,32],visible=False,value=32)
- if_sr_Checkbox=gr.Checkbox(label=i18n("v3输出如果觉得闷可以试试开超分"), value=False, interactive=True, show_label=True,visible=False if model_version!="v3"else True)
- gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"),'h3'))
- with gr.Row():
- with gr.Column(scale=13):
- text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
- with gr.Column(scale=7):
- text_language = gr.Dropdown(
- label=i18n("需要合成的语种")+i18n(".限制范围越小判别效果越好。"), choices=list(dict_language.keys()), value=i18n("中文"), scale=1
- )
- how_to_cut = gr.Dropdown(
- label=i18n("怎么切"),
- choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
- value=i18n("凑四句一切"),
- interactive=True, scale=1
- )
- gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
- if_freeze=gr.Checkbox(label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), value=False, interactive=True,show_label=True, scale=1)
- with gr.Row():
- speed = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label=i18n("语速"),value=1,interactive=True, scale=1)
- pause_second_slider = gr.Slider(minimum=0.1,maximum=0.5,step=0.01,label=i18n("句间停顿秒数"),value=0.3,interactive=True, scale=1)
- gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
- top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=15,interactive=True, scale=1)
- top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True, scale=1)
- temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True, scale=1)
- # with gr.Column():
- # gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。"))
- # phoneme=gr.Textbox(label=i18n("音素框"), value="")
- # get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
- with gr.Row():
- inference_button = gr.Button(i18n("合成语音"), variant="primary", size='lg', scale=25)
- output = gr.Audio(label=i18n("输出的语音"), scale=14)
-
- inference_button.click(
- get_tts_wav,
- [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs,sample_steps,if_sr_Checkbox,pause_second_slider],
- [output],
- )
- SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs,ref_text_free,if_sr_Checkbox])
- GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
-
- # gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
- # with gr.Row():
- # text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
- # button1 = gr.Button(i18n("凑四句一切"), variant="primary")
- # button2 = gr.Button(i18n("凑50字一切"), variant="primary")
- # button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
- # button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
- # button5 = gr.Button(i18n("按标点符号切"), variant="primary")
- # text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
- # button1.click(cut1, [text_inp], [text_opt])
- # button2.click(cut2, [text_inp], [text_opt])
- # button3.click(cut3, [text_inp], [text_opt])
- # button4.click(cut4, [text_inp], [text_opt])
- # button5.click(cut5, [text_inp], [text_opt])
- # gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")))
-
-if __name__ == '__main__':
- app.queue().launch(#concurrency_count=511, max_size=1022
- server_name="0.0.0.0",
- inbrowser=True,
- share=is_share,
- server_port=infer_ttswebui,
- quiet=True,
- )
diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py
deleted file mode 100644
index 5a6910d..0000000
--- a/GPT_SoVITS/inference_webui_fast.py
+++ /dev/null
@@ -1,336 +0,0 @@
-'''
-按中英混合识别
-按日英混合识别
-多语种启动切分识别语种
-全部按中文识别
-全部按英文识别
-全部按日文识别
-'''
-import random
-import os, re, logging
-import sys
-now_dir = os.getcwd()
-sys.path.append(now_dir)
-sys.path.append("%s/GPT_SoVITS" % (now_dir))
-
-logging.getLogger("markdown_it").setLevel(logging.ERROR)
-logging.getLogger("urllib3").setLevel(logging.ERROR)
-logging.getLogger("httpcore").setLevel(logging.ERROR)
-logging.getLogger("httpx").setLevel(logging.ERROR)
-logging.getLogger("asyncio").setLevel(logging.ERROR)
-logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
-logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
-import pdb
-import torch
-
-try:
- import gradio.analytics as analytics
- analytics.version_check = lambda:None
-except:...
-
-
-infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
-infer_ttswebui = int(infer_ttswebui)
-is_share = os.environ.get("is_share", "False")
-is_share = eval(is_share)
-if "_CUDA_VISIBLE_DEVICES" in os.environ:
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
-
-is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
-gpt_path = os.environ.get("gpt_path", None)
-sovits_path = os.environ.get("sovits_path", None)
-cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
-bert_path = os.environ.get("bert_path", None)
-version=os.environ.get("version","v2")
-
-import gradio as gr
-from TTS_infer_pack.TTS import TTS, TTS_Config
-from TTS_infer_pack.text_segmentation_method import get_method
-from tools.i18n.i18n import I18nAuto, scan_language_list
-
-language=os.environ.get("language","Auto")
-language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
-i18n = I18nAuto(language=language)
-
-
-# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
-
-if torch.cuda.is_available():
- device = "cuda"
-# elif torch.backends.mps.is_available():
-# device = "mps"
-else:
- device = "cpu"
-
-dict_language_v1 = {
- i18n("中文"): "all_zh",#全部按中文识别
- i18n("英文"): "en",#全部按英文识别#######不变
- i18n("日文"): "all_ja",#全部按日文识别
- i18n("中英混合"): "zh",#按中英混合识别####不变
- i18n("日英混合"): "ja",#按日英混合识别####不变
- i18n("多语种混合"): "auto",#多语种启动切分识别语种
-}
-dict_language_v2 = {
- i18n("中文"): "all_zh",#全部按中文识别
- i18n("英文"): "en",#全部按英文识别#######不变
- i18n("日文"): "all_ja",#全部按日文识别
- i18n("粤语"): "all_yue",#全部按中文识别
- i18n("韩文"): "all_ko",#全部按韩文识别
- i18n("中英混合"): "zh",#按中英混合识别####不变
- i18n("日英混合"): "ja",#按日英混合识别####不变
- i18n("粤英混合"): "yue",#按粤英混合识别####不变
- i18n("韩英混合"): "ko",#按韩英混合识别####不变
- i18n("多语种混合"): "auto",#多语种启动切分识别语种
- i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
-}
-dict_language = dict_language_v1 if version =='v1' else dict_language_v2
-
-cut_method = {
- i18n("不切"):"cut0",
- i18n("凑四句一切"): "cut1",
- i18n("凑50字一切"): "cut2",
- i18n("按中文句号。切"): "cut3",
- i18n("按英文句号.切"): "cut4",
- i18n("按标点符号切"): "cut5",
-}
-
-tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
-tts_config.device = device
-tts_config.is_half = is_half
-tts_config.version = version
-if gpt_path is not None:
- tts_config.t2s_weights_path = gpt_path
-if sovits_path is not None:
- tts_config.vits_weights_path = sovits_path
-if cnhubert_base_path is not None:
- tts_config.cnhuhbert_base_path = cnhubert_base_path
-if bert_path is not None:
- tts_config.bert_base_path = bert_path
-
-print(tts_config)
-tts_pipeline = TTS(tts_config)
-gpt_path = tts_config.t2s_weights_path
-sovits_path = tts_config.vits_weights_path
-version = tts_config.version
-
-def inference(text, text_lang,
- ref_audio_path,
- aux_ref_audio_paths,
- prompt_text,
- prompt_lang, top_k,
- top_p, temperature,
- text_split_method, batch_size,
- speed_factor, ref_text_free,
- split_bucket,fragment_interval,
- seed, keep_random, parallel_infer,
- repetition_penalty
- ):
-
- seed = -1 if keep_random else seed
- actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
- inputs={
- "text": text,
- "text_lang": dict_language[text_lang],
- "ref_audio_path": ref_audio_path,
- "aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [],
- "prompt_text": prompt_text if not ref_text_free else "",
- "prompt_lang": dict_language[prompt_lang],
- "top_k": top_k,
- "top_p": top_p,
- "temperature": temperature,
- "text_split_method": cut_method[text_split_method],
- "batch_size":int(batch_size),
- "speed_factor":float(speed_factor),
- "split_bucket":split_bucket,
- "return_fragment":False,
- "fragment_interval":fragment_interval,
- "seed":actual_seed,
- "parallel_infer": parallel_infer,
- "repetition_penalty": repetition_penalty,
- }
- for item in tts_pipeline.run(inputs):
- yield item, actual_seed
-
-def custom_sort_key(s):
- # 使用正则表达式提取字符串中的数字部分和非数字部分
- parts = re.split('(\d+)', s)
- # 将数字部分转换为整数,非数字部分保持不变
- parts = [int(part) if part.isdigit() else part for part in parts]
- return parts
-
-
-def change_choices():
- SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
- return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
-
-
-pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
-pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
-_ =[[],[]]
-for i in range(2):
- if os.path.exists(pretrained_gpt_name[i]):
- _[0].append(pretrained_gpt_name[i])
- if os.path.exists(pretrained_sovits_name[i]):
- _[-1].append(pretrained_sovits_name[i])
-pretrained_gpt_name,pretrained_sovits_name = _
-
-SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
-GPT_weight_root=["GPT_weights_v2","GPT_weights"]
-for path in SoVITS_weight_root+GPT_weight_root:
- os.makedirs(path,exist_ok=True)
-
-def get_weights_names(GPT_weight_root, SoVITS_weight_root):
- SoVITS_names = [i for i in pretrained_sovits_name]
- for path in SoVITS_weight_root:
- for name in os.listdir(path):
- if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
- GPT_names = [i for i in pretrained_gpt_name]
- for path in GPT_weight_root:
- for name in os.listdir(path):
- if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
- return SoVITS_names, GPT_names
-
-
-SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
-
-
-
-def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
- tts_pipeline.init_vits_weights(sovits_path)
- global version, dict_language
- dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
- if prompt_language is not None and text_language is not None:
- if prompt_language in list(dict_language.keys()):
- prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
- else:
- prompt_text_update = {'__type__':'update', 'value':''}
- prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
- if text_language in list(dict_language.keys()):
- text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
- else:
- text_update = {'__type__':'update', 'value':''}
- text_language_update = {'__type__':'update', 'value':i18n("中文")}
- return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
-
-
-
-with gr.Blocks(title="GPT-SoVITS WebUI") as app:
- gr.Markdown(
- value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "
" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
- )
-
- with gr.Column():
- # with gr.Group():
- gr.Markdown(value=i18n("模型切换"))
- with gr.Row():
- GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
- SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
- refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
- refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
-
-
- with gr.Row():
- with gr.Column():
- gr.Markdown(value=i18n("*请上传并填写参考信息"))
- with gr.Row():
- inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
- inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple")
- prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
- with gr.Row():
- prompt_language = gr.Dropdown(
- label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
- )
- with gr.Column():
- ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
- gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"
"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。"))
-
- with gr.Column():
- gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
- text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
- text_language = gr.Dropdown(
- label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
- )
-
-
- with gr.Group():
- gr.Markdown(value=i18n("推理设置"))
- with gr.Row():
-
- with gr.Column():
- batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
- fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
- speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
- top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
- top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
- temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
- repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
- with gr.Column():
- with gr.Row():
- how_to_cut = gr.Dropdown(
- label=i18n("怎么切"),
- choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
- value=i18n("凑四句一切"),
- interactive=True, scale=1
- )
- parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
- split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
-
- with gr.Row():
- seed = gr.Number(label=i18n("随机种子"),value=-1)
- keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
-
- output = gr.Audio(label=i18n("输出的语音"))
- with gr.Row():
- inference_button = gr.Button(i18n("合成语音"), variant="primary")
- stop_infer = gr.Button(i18n("终止合成"), variant="primary")
-
-
- inference_button.click(
- inference,
- [
- text,text_language, inp_ref, inp_refs,
- prompt_text, prompt_language,
- top_k, top_p, temperature,
- how_to_cut, batch_size,
- speed_factor, ref_text_free,
- split_bucket,fragment_interval,
- seed, keep_random, parallel_infer,
- repetition_penalty
- ],
- [output, seed],
- )
- stop_infer.click(tts_pipeline.stop, [], [])
- SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
- GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
-
- with gr.Group():
- gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
- with gr.Row():
- text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
- with gr.Column():
- _how_to_cut = gr.Radio(
- label=i18n("怎么切"),
- choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
- value=i18n("凑四句一切"),
- interactive=True,
- )
- cut_text= gr.Button(i18n("切分"), variant="primary")
-
- def to_cut(text_inp, how_to_cut):
- if len(text_inp.strip()) == 0 or text_inp==[]:
- return ""
- method = get_method(cut_method[how_to_cut])
- return method(text_inp)
-
- text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
- cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
- gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
-
-if __name__ == '__main__':
- app.queue().launch(#concurrency_count=511, max_size=1022
- server_name="0.0.0.0",
- inbrowser=True,
- share=is_share,
- server_port=infer_ttswebui,
- quiet=True,
- )
diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py
deleted file mode 100644
index 4311db9..0000000
--- a/GPT_SoVITS/s1_train.py
+++ /dev/null
@@ -1,179 +0,0 @@
-# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
-import os
-import pdb
-
-if "_CUDA_VISIBLE_DEVICES" in os.environ:
- os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
-import argparse
-import logging
-from pathlib import Path
-
-import torch, platform
-from pytorch_lightning import seed_everything
-from pytorch_lightning import Trainer
-from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
-from pytorch_lightning.strategies import DDPStrategy
-from AR.data.data_module import Text2SemanticDataModule
-from AR.models.t2s_lightning_module import Text2SemanticLightningModule
-from AR.utils.io import load_yaml_config
-
-logging.getLogger("numba").setLevel(logging.WARNING)
-logging.getLogger("matplotlib").setLevel(logging.WARNING)
-torch.set_float32_matmul_precision("high")
-from AR.utils import get_newest_ckpt
-
-from collections import OrderedDict
-from time import time as ttime
-import shutil
-from process_ckpt import my_save
-
-
-class my_model_ckpt(ModelCheckpoint):
- def __init__(
- self,
- config,
- if_save_latest,
- if_save_every_weights,
- half_weights_save_dir,
- exp_name,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.if_save_latest = if_save_latest
- self.if_save_every_weights = if_save_every_weights
- self.half_weights_save_dir = half_weights_save_dir
- self.exp_name = exp_name
- self.config = config
-
- def on_train_epoch_end(self, trainer, pl_module):
- # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
- if self._should_save_on_train_epoch_end(trainer):
- monitor_candidates = self._monitor_candidates(trainer)
- if (
- self._every_n_epochs >= 1
- and (trainer.current_epoch + 1) % self._every_n_epochs == 0
- ):
- if (
- self.if_save_latest == True
- ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
- to_clean = list(os.listdir(self.dirpath))
- self._save_topk_checkpoint(trainer, monitor_candidates)
- if self.if_save_latest == True:
- for name in to_clean:
- try:
- os.remove("%s/%s" % (self.dirpath, name))
- except:
- pass
- if self.if_save_every_weights == True:
- to_save_od = OrderedDict()
- to_save_od["weight"] = OrderedDict()
- dictt = trainer.strategy._lightning_module.state_dict()
- for key in dictt:
- to_save_od["weight"][key] = dictt[key].half()
- to_save_od["config"] = self.config
- to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
- # torch.save(
- # print(os.environ)
- if(os.environ.get("LOCAL_RANK","0")=="0"):
- my_save(
- to_save_od,
- "%s/%s-e%s.ckpt"
- % (
- self.half_weights_save_dir,
- self.exp_name,
- trainer.current_epoch + 1,
- ),
- )
- self._save_last_checkpoint(trainer, monitor_candidates)
-
-
-def main(args):
- config = load_yaml_config(args.config_file)
-
- output_dir = Path(config["output_dir"])
- output_dir.mkdir(parents=True, exist_ok=True)
-
- ckpt_dir = output_dir / "ckpt"
- ckpt_dir.mkdir(parents=True, exist_ok=True)
-
- seed_everything(config["train"]["seed"], workers=True)
- ckpt_callback: ModelCheckpoint = my_model_ckpt(
- config=config,
- if_save_latest=config["train"]["if_save_latest"],
- if_save_every_weights=config["train"]["if_save_every_weights"],
- half_weights_save_dir=config["train"]["half_weights_save_dir"],
- exp_name=config["train"]["exp_name"],
- save_top_k=-1,
- monitor="top_3_acc",
- mode="max",
- save_on_train_epoch_end=True,
- every_n_epochs=config["train"]["save_every_n_epoch"],
- dirpath=ckpt_dir,
- )
- logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
- os.environ["MASTER_ADDR"]="localhost"
- os.environ["USE_LIBUV"] = "0"
- trainer: Trainer = Trainer(
- max_epochs=config["train"]["epochs"],
- accelerator="gpu" if torch.cuda.is_available() else "cpu",
- # val_check_interval=9999999999999999999999,###不要验证
- # check_val_every_n_epoch=None,
- limit_val_batches=0,
- devices=-1 if torch.cuda.is_available() else 1,
- benchmark=False,
- fast_dev_run=False,
- strategy = DDPStrategy(
- process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
- ) if torch.cuda.is_available() else "auto",
- precision=config["train"]["precision"],
- logger=logger,
- num_sanity_val_steps=0,
- callbacks=[ckpt_callback],
- use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
- )
-
- model: Text2SemanticLightningModule = Text2SemanticLightningModule(
- config, output_dir
- )
-
- data_module: Text2SemanticDataModule = Text2SemanticDataModule(
- config,
- train_semantic_path=config["train_semantic_path"],
- train_phoneme_path=config["train_phoneme_path"],
- # dev_semantic_path=args.dev_semantic_path,
- # dev_phoneme_path=args.dev_phoneme_path
- )
-
- try:
- # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
- newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
- ckpt_path = ckpt_dir / newest_ckpt_name
- except Exception:
- ckpt_path = None
- print("ckpt_path:", ckpt_path)
- trainer.fit(model, data_module, ckpt_path=ckpt_path)
-
-
-# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-c",
- "--config_file",
- type=str,
- default="configs/s1longer.yaml",
- help="path of config file",
- )
- # args for dataset
- # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
- # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
-
- # parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
- # parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
- # parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
- # parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
-
- args = parser.parse_args()
- logging.info(str(args))
- main(args)
diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py
deleted file mode 100644
index 4d88ee8..0000000
--- a/GPT_SoVITS/s2_train.py
+++ /dev/null
@@ -1,604 +0,0 @@
-import warnings
-warnings.filterwarnings("ignore")
-import utils, os
-hps = utils.get_hparams(stage=2)
-os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
-import torch
-from torch.nn import functional as F
-from torch.utils.data import DataLoader
-from torch.utils.tensorboard import SummaryWriter
-import torch.multiprocessing as mp
-import torch.distributed as dist, traceback
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.cuda.amp import autocast, GradScaler
-from tqdm import tqdm
-import logging, traceback
-
-logging.getLogger("matplotlib").setLevel(logging.INFO)
-logging.getLogger("h5py").setLevel(logging.INFO)
-logging.getLogger("numba").setLevel(logging.INFO)
-from random import randint
-from module import commons
-
-from module.data_utils import (
- TextAudioSpeakerLoader,
- TextAudioSpeakerCollate,
- DistributedBucketSampler,
-)
-from module.models import (
- SynthesizerTrn,
- MultiPeriodDiscriminator,
-)
-from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
-from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
-from process_ckpt import savee
-
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = False
-###反正A100fp32更快,那试试tf32吧
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
-torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
-# from config import pretrained_s2G,pretrained_s2D
-global_step = 0
-
-device = "cpu" # cuda以外的设备,等mps优化后加入
-
-
-def main():
-
- if torch.cuda.is_available():
- n_gpus = torch.cuda.device_count()
- else:
- n_gpus = 1
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
-
- mp.spawn(
- run,
- nprocs=n_gpus,
- args=(
- n_gpus,
- hps,
- ),
- )
-
-
-def run(rank, n_gpus, hps):
- global global_step
- if rank == 0:
- logger = utils.get_logger(hps.data.exp_dir)
- logger.info(hps)
- # utils.check_git_hash(hps.s2_ckpt_dir)
- writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
-
- dist.init_process_group(
- backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
- init_method="env://?use_libuv=False",
- world_size=n_gpus,
- rank=rank,
- )
- torch.manual_seed(hps.train.seed)
- if torch.cuda.is_available():
- torch.cuda.set_device(rank)
-
- train_dataset = TextAudioSpeakerLoader(hps.data) ########
- train_sampler = DistributedBucketSampler(
- train_dataset,
- hps.train.batch_size,
- [
- 32,
- 300,
- 400,
- 500,
- 600,
- 700,
- 800,
- 900,
- 1000,
- 1100,
- 1200,
- 1300,
- 1400,
- 1500,
- 1600,
- 1700,
- 1800,
- 1900,
- ],
- num_replicas=n_gpus,
- rank=rank,
- shuffle=True,
- )
- collate_fn = TextAudioSpeakerCollate()
- train_loader = DataLoader(
- train_dataset,
- num_workers=6,
- shuffle=False,
- pin_memory=True,
- collate_fn=collate_fn,
- batch_sampler=train_sampler,
- persistent_workers=True,
- prefetch_factor=4,
- )
- # if rank == 0:
- # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
- # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
- # batch_size=1, pin_memory=True,
- # drop_last=False, collate_fn=collate_fn)
-
- net_g = SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model,
- ).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model,
- ).to(device)
-
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
- for name, param in net_g.named_parameters():
- if not param.requires_grad:
- print(name, "not requires_grad")
-
- te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
- et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
- mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
- base_params = filter(
- lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
- net_g.parameters(),
- )
-
- # te_p=net_g.enc_p.text_embedding.parameters()
- # et_p=net_g.enc_p.encoder_text.parameters()
- # mrte_p=net_g.enc_p.mrte.parameters()
-
- optim_g = torch.optim.AdamW(
- # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
- [
- {"params": base_params, "lr": hps.train.learning_rate},
- {
- "params": net_g.enc_p.text_embedding.parameters(),
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
- },
- {
- "params": net_g.enc_p.encoder_text.parameters(),
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
- },
- {
- "params": net_g.enc_p.mrte.parameters(),
- "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
- },
- ],
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- optim_d = torch.optim.AdamW(
- net_d.parameters(),
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- if torch.cuda.is_available():
- net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
- net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
- else:
- net_g = net_g.to(device)
- net_d = net_d.to(device)
-
- try: # 如果能加载自动resume
- _, _, _, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
- net_d,
- optim_d,
- ) # D多半加载没事
- if rank == 0:
- logger.info("loaded D")
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
- _, _, _, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
- net_g,
- optim_g,
- )
- epoch_str+=1
- global_step = (epoch_str - 1) * len(train_loader)
- # epoch_str = 1
- # global_step = 0
- except: # 如果首次不能加载,加载pretrain
- # traceback.print_exc()
- epoch_str = 1
- global_step = 0
- if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
- if rank == 0:
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
- print("loaded pretrained %s" % hps.train.pretrained_s2G,
- net_g.module.load_state_dict(
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
- strict=False,
- ) if torch.cuda.is_available() else net_g.load_state_dict(
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
- strict=False,
- )
- ) ##测试不加载优化器
- if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
- if rank == 0:
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
- print("loaded pretrained %s" % hps.train.pretrained_s2D,
- net_d.module.load_state_dict(
- torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
- ) if torch.cuda.is_available() else net_d.load_state_dict(
- torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
- )
- )
-
- # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
- # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
-
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
- optim_g, gamma=hps.train.lr_decay, last_epoch=-1
- )
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
- optim_d, gamma=hps.train.lr_decay, last_epoch=-1
- )
- for _ in range(epoch_str):
- scheduler_g.step()
- scheduler_d.step()
-
- scaler = GradScaler(enabled=hps.train.fp16_run)
-
- print("start training from epoch %s" % epoch_str)
- for epoch in range(epoch_str, hps.train.epochs + 1):
- if rank == 0:
- train_and_evaluate(
- rank,
- epoch,
- hps,
- [net_g, net_d],
- [optim_g, optim_d],
- [scheduler_g, scheduler_d],
- scaler,
- # [train_loader, eval_loader], logger, [writer, writer_eval])
- [train_loader, None],
- logger,
- [writer, writer_eval],
- )
- else:
- train_and_evaluate(
- rank,
- epoch,
- hps,
- [net_g, net_d],
- [optim_g, optim_d],
- [scheduler_g, scheduler_d],
- scaler,
- [train_loader, None],
- None,
- None,
- )
- scheduler_g.step()
- scheduler_d.step()
- print("training done")
-
-
-def train_and_evaluate(
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
-):
- net_g, net_d = nets
- optim_g, optim_d = optims
- # scheduler_g, scheduler_d = schedulers
- train_loader, eval_loader = loaders
- if writers is not None:
- writer, writer_eval = writers
-
- train_loader.batch_sampler.set_epoch(epoch)
- global global_step
-
- net_g.train()
- net_d.train()
- for batch_idx, (
- ssl,
- ssl_lengths,
- spec,
- spec_lengths,
- y,
- y_lengths,
- text,
- text_lengths,
- ) in enumerate(tqdm(train_loader)):
- if torch.cuda.is_available():
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
- rank, non_blocking=True
- )
- y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
- rank, non_blocking=True
- )
- ssl = ssl.cuda(rank, non_blocking=True)
- ssl.requires_grad = False
- # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
- text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
- rank, non_blocking=True
- )
- else:
- spec, spec_lengths = spec.to(device), spec_lengths.to(device)
- y, y_lengths = y.to(device), y_lengths.to(device)
- ssl = ssl.to(device)
- ssl.requires_grad = False
- # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
- text, text_lengths = text.to(device), text_lengths.to(device)
-
- with autocast(enabled=hps.train.fp16_run):
- (
- y_hat,
- kl_ssl,
- ids_slice,
- x_mask,
- z_mask,
- (z, z_p, m_p, logs_p, m_q, logs_q),
- stats_ssl,
- ) = net_g(ssl, spec, spec_lengths, text, text_lengths)
-
- mel = spec_to_mel_torch(
- spec,
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
- y_mel = commons.slice_segments(
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
- )
- y_hat_mel = mel_spectrogram_torch(
- y_hat.squeeze(1),
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.hop_length,
- hps.data.win_length,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
-
- y = commons.slice_segments(
- y, ids_slice * hps.data.hop_length, hps.train.segment_size
- ) # slice
-
- # Discriminator
- y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
- with autocast(enabled=False):
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
- y_d_hat_r, y_d_hat_g
- )
- loss_disc_all = loss_disc
- optim_d.zero_grad()
- scaler.scale(loss_disc_all).backward()
- scaler.unscale_(optim_d)
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
- scaler.step(optim_d)
-
- with autocast(enabled=hps.train.fp16_run):
- # Generator
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
- with autocast(enabled=False):
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
-
- loss_fm = feature_loss(fmap_r, fmap_g)
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
- loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
-
- optim_g.zero_grad()
- scaler.scale(loss_gen_all).backward()
- scaler.unscale_(optim_g)
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
- scaler.step(optim_g)
- scaler.update()
-
- if rank == 0:
- if global_step % hps.train.log_interval == 0:
- lr = optim_g.param_groups[0]["lr"]
- losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
- logger.info(
- "Train Epoch: {} [{:.0f}%]".format(
- epoch, 100.0 * batch_idx / len(train_loader)
- )
- )
- logger.info([x.item() for x in losses] + [global_step, lr])
-
- scalar_dict = {
- "loss/g/total": loss_gen_all,
- "loss/d/total": loss_disc_all,
- "learning_rate": lr,
- "grad_norm_d": grad_norm_d,
- "grad_norm_g": grad_norm_g,
- }
- scalar_dict.update(
- {
- "loss/g/fm": loss_fm,
- "loss/g/mel": loss_mel,
- "loss/g/kl_ssl": kl_ssl,
- "loss/g/kl": loss_kl,
- }
- )
-
- # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
- # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
- # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
- image_dict = {
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
- y_mel[0].data.cpu().numpy()
- ),
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
- y_hat_mel[0].data.cpu().numpy()
- ),
- "all/mel": utils.plot_spectrogram_to_numpy(
- mel[0].data.cpu().numpy()
- ),
- "all/stats_ssl": utils.plot_spectrogram_to_numpy(
- stats_ssl[0].data.cpu().numpy()
- ),
- }
- utils.summarize(
- writer=writer,
- global_step=global_step,
- images=image_dict,
- scalars=scalar_dict,
- )
- global_step += 1
- if epoch % hps.train.save_every_epoch == 0 and rank == 0:
- if hps.train.if_save_latest == 0:
- utils.save_checkpoint(
- net_g,
- optim_g,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
- ),
- )
- utils.save_checkpoint(
- net_d,
- optim_d,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
- ),
- )
- else:
- utils.save_checkpoint(
- net_g,
- optim_g,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
- ),
- )
- utils.save_checkpoint(
- net_d,
- optim_d,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
- ),
- )
- if rank == 0 and hps.train.if_save_every_weights == True:
- if hasattr(net_g, "module"):
- ckpt = net_g.module.state_dict()
- else:
- ckpt = net_g.state_dict()
- logger.info(
- "saving ckpt %s_e%s:%s"
- % (
- hps.name,
- epoch,
- savee(
- ckpt,
- hps.name + "_e%s_s%s" % (epoch, global_step),
- epoch,
- global_step,
- hps,
- ),
- )
- )
-
- if rank == 0:
- logger.info("====> Epoch: {}".format(epoch))
-
-
-def evaluate(hps, generator, eval_loader, writer_eval):
- generator.eval()
- image_dict = {}
- audio_dict = {}
- print("Evaluating ...")
- with torch.no_grad():
- for batch_idx, (
- ssl,
- ssl_lengths,
- spec,
- spec_lengths,
- y,
- y_lengths,
- text,
- text_lengths,
- ) in enumerate(eval_loader):
- print(111)
- if torch.cuda.is_available():
- spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
- y, y_lengths = y.cuda(), y_lengths.cuda()
- ssl = ssl.cuda()
- text, text_lengths = text.cuda(), text_lengths.cuda()
- else:
- spec, spec_lengths = spec.to(device), spec_lengths.to(device)
- y, y_lengths = y.to(device), y_lengths.to(device)
- ssl = ssl.to(device)
- text, text_lengths = text.to(device), text_lengths.to(device)
- for test in [0, 1]:
- y_hat, mask, *_ = generator.module.infer(
- ssl, spec, spec_lengths, text, text_lengths, test=test
- ) if torch.cuda.is_available() else generator.infer(
- ssl, spec, spec_lengths, text, text_lengths, test=test
- )
- y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
-
- mel = spec_to_mel_torch(
- spec,
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
- y_hat_mel = mel_spectrogram_torch(
- y_hat.squeeze(1).float(),
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.hop_length,
- hps.data.win_length,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
- image_dict.update(
- {
- f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
- y_hat_mel[0].cpu().numpy()
- )
- }
- )
- audio_dict.update(
- {f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
- )
- image_dict.update(
- {
- f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
- mel[0].cpu().numpy()
- )
- }
- )
- audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
-
- # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
- # audio_dict.update({
- # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
- # })
-
- utils.summarize(
- writer=writer_eval,
- global_step=global_step,
- images=image_dict,
- audios=audio_dict,
- audio_sampling_rate=hps.data.sampling_rate,
- )
- generator.train()
-
-
-if __name__ == "__main__":
- main()
diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py
deleted file mode 100644
index 9933dee..0000000
--- a/GPT_SoVITS/s2_train_v3.py
+++ /dev/null
@@ -1,416 +0,0 @@
-import warnings
-warnings.filterwarnings("ignore")
-import utils, os
-hps = utils.get_hparams(stage=2)
-os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
-import torch
-from torch.nn import functional as F
-from torch.utils.data import DataLoader
-from torch.utils.tensorboard import SummaryWriter
-import torch.multiprocessing as mp
-import torch.distributed as dist, traceback
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.cuda.amp import autocast, GradScaler
-from tqdm import tqdm
-import logging, traceback
-
-logging.getLogger("matplotlib").setLevel(logging.INFO)
-logging.getLogger("h5py").setLevel(logging.INFO)
-logging.getLogger("numba").setLevel(logging.INFO)
-from random import randint
-from module import commons
-
-from module.data_utils import (
- TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
- TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
- DistributedBucketSampler,
-)
-from module.models import (
- SynthesizerTrnV3 as SynthesizerTrn,
- MultiPeriodDiscriminator,
-)
-from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
-from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
-from process_ckpt import savee
-
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = False
-###反正A100fp32更快,那试试tf32吧
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
-torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
-# from config import pretrained_s2G,pretrained_s2D
-global_step = 0
-
-device = "cpu" # cuda以外的设备,等mps优化后加入
-
-
-def main():
-
- if torch.cuda.is_available():
- n_gpus = torch.cuda.device_count()
- else:
- n_gpus = 1
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
-
- mp.spawn(
- run,
- nprocs=n_gpus,
- args=(
- n_gpus,
- hps,
- ),
- )
-
-
-def run(rank, n_gpus, hps):
- global global_step
- if rank == 0:
- logger = utils.get_logger(hps.data.exp_dir)
- logger.info(hps)
- # utils.check_git_hash(hps.s2_ckpt_dir)
- writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
-
- dist.init_process_group(
- backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
- init_method="env://?use_libuv=False",
- world_size=n_gpus,
- rank=rank,
- )
- torch.manual_seed(hps.train.seed)
- if torch.cuda.is_available():
- torch.cuda.set_device(rank)
-
- train_dataset = TextAudioSpeakerLoader(hps.data) ########
- train_sampler = DistributedBucketSampler(
- train_dataset,
- hps.train.batch_size,
- [
- 32,
- 300,
- 400,
- 500,
- 600,
- 700,
- 800,
- 900,
- 1000,
- # 1100,
- # 1200,
- # 1300,
- # 1400,
- # 1500,
- # 1600,
- # 1700,
- # 1800,
- # 1900,
- ],
- num_replicas=n_gpus,
- rank=rank,
- shuffle=True,
- )
- collate_fn = TextAudioSpeakerCollate()
- train_loader = DataLoader(
- train_dataset,
- num_workers=6,
- shuffle=False,
- pin_memory=True,
- collate_fn=collate_fn,
- batch_sampler=train_sampler,
- persistent_workers=True,
- prefetch_factor=4,
- )
- # if rank == 0:
- # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
- # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
- # batch_size=1, pin_memory=True,
- # drop_last=False, collate_fn=collate_fn)
-
- net_g = SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model,
- ).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model,
- ).to(device)
-
- # net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
- # for name, param in net_g.named_parameters():
- # if not param.requires_grad:
- # print(name, "not requires_grad")
-
- optim_g = torch.optim.AdamW(
- filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- # optim_d = torch.optim.AdamW(
- # net_d.parameters(),
- # hps.train.learning_rate,
- # betas=hps.train.betas,
- # eps=hps.train.eps,
- # )
- if torch.cuda.is_available():
- net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
- # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
- else:
- net_g = net_g.to(device)
- # net_d = net_d.to(device)
-
- try: # 如果能加载自动resume
- # _, _, _, epoch_str = utils.load_checkpoint(
- # utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
- # net_d,
- # optim_d,
- # ) # D多半加载没事
- # if rank == 0:
- # logger.info("loaded D")
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
- _, _, _, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
- net_g,
- optim_g,
- )
- epoch_str+=1
- global_step = (epoch_str - 1) * len(train_loader)
- # epoch_str = 1
- # global_step = 0
- except: # 如果首次不能加载,加载pretrain
- # traceback.print_exc()
- epoch_str = 1
- global_step = 0
- if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
- if rank == 0:
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
- print("loaded pretrained %s" % hps.train.pretrained_s2G,
- net_g.module.load_state_dict(
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
- strict=False,
- ) if torch.cuda.is_available() else net_g.load_state_dict(
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
- strict=False,
- )
- ) ##测试不加载优化器
- # if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
- # if rank == 0:
- # logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
- # print(
- # net_d.module.load_state_dict(
- # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
- # ) if torch.cuda.is_available() else net_d.load_state_dict(
- # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
- # )
- # )
-
- # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
- # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
-
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
- optim_g, gamma=hps.train.lr_decay, last_epoch=-1
- )
- # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
- # optim_d, gamma=hps.train.lr_decay, last_epoch=-1
- # )
- for _ in range(epoch_str):
- scheduler_g.step()
- # scheduler_d.step()
-
- scaler = GradScaler(enabled=hps.train.fp16_run)
-
- net_d=optim_d=scheduler_d=None
- print("start training from epoch %s" % epoch_str)
- for epoch in range(epoch_str, hps.train.epochs + 1):
- if rank == 0:
- train_and_evaluate(
- rank,
- epoch,
- hps,
- [net_g, net_d],
- [optim_g, optim_d],
- [scheduler_g, scheduler_d],
- scaler,
- # [train_loader, eval_loader], logger, [writer, writer_eval])
- [train_loader, None],
- logger,
- [writer, writer_eval],
- )
- else:
- train_and_evaluate(
- rank,
- epoch,
- hps,
- [net_g, net_d],
- [optim_g, optim_d],
- [scheduler_g, scheduler_d],
- scaler,
- [train_loader, None],
- None,
- None,
- )
- scheduler_g.step()
- # scheduler_d.step()
- print("training done")
-
-
-def train_and_evaluate(
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
-):
- net_g, net_d = nets
- optim_g, optim_d = optims
- # scheduler_g, scheduler_d = schedulers
- train_loader, eval_loader = loaders
- if writers is not None:
- writer, writer_eval = writers
-
- train_loader.batch_sampler.set_epoch(epoch)
- global global_step
-
- net_g.train()
- # net_d.train()
- # for batch_idx, (
- # ssl,
- # ssl_lengths,
- # spec,
- # spec_lengths,
- # y,
- # y_lengths,
- # text,
- # text_lengths,
- # ) in enumerate(tqdm(train_loader)):
- for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
- if torch.cuda.is_available():
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
- rank, non_blocking=True
- )
- mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
- rank, non_blocking=True
- )
- ssl = ssl.cuda(rank, non_blocking=True)
- ssl.requires_grad = False
- # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
- text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
- rank, non_blocking=True
- )
- else:
- spec, spec_lengths = spec.to(device), spec_lengths.to(device)
- mel, mel_lengths = mel.to(device), mel_lengths.to(device)
- ssl = ssl.to(device)
- ssl.requires_grad = False
- # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
- text, text_lengths = text.to(device), text_lengths.to(device)
-
- with autocast(enabled=hps.train.fp16_run):
- cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
- loss_gen_all=cfm_loss
- optim_g.zero_grad()
- scaler.scale(loss_gen_all).backward()
- scaler.unscale_(optim_g)
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
- scaler.step(optim_g)
- scaler.update()
-
- if rank == 0:
- if global_step % hps.train.log_interval == 0:
- lr = optim_g.param_groups[0]['lr']
- # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
- losses = [cfm_loss]
- logger.info('Train Epoch: {} [{:.0f}%]'.format(
- epoch,
- 100. * batch_idx / len(train_loader)))
- logger.info([x.item() for x in losses] + [global_step, lr])
-
- scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
- # image_dict = {
- # "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
- # "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
- # "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
- # "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
- # }
- utils.summarize(
- writer=writer,
- global_step=global_step,
- # images=image_dict,
- scalars=scalar_dict)
-
- # if global_step % hps.train.eval_interval == 0:
- # # evaluate(hps, net_g, eval_loader, writer_eval)
- # utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step)),scaler)
- # # utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step)),scaler)
- # # keep_ckpts = getattr(hps.train, 'keep_ckpts', 3)
- # # if keep_ckpts > 0:
- # # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
-
-
- global_step += 1
- if epoch % hps.train.save_every_epoch == 0 and rank == 0:
- if hps.train.if_save_latest == 0:
- utils.save_checkpoint(
- net_g,
- optim_g,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
- ),
- )
- # utils.save_checkpoint(
- # net_d,
- # optim_d,
- # hps.train.learning_rate,
- # epoch,
- # os.path.join(
- # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
- # ),
- # )
- else:
- utils.save_checkpoint(
- net_g,
- optim_g,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
- ),
- )
- # utils.save_checkpoint(
- # net_d,
- # optim_d,
- # hps.train.learning_rate,
- # epoch,
- # os.path.join(
- # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
- # ),
- # )
- if rank == 0 and hps.train.if_save_every_weights == True:
- if hasattr(net_g, "module"):
- ckpt = net_g.module.state_dict()
- else:
- ckpt = net_g.state_dict()
- logger.info(
- "saving ckpt %s_e%s:%s"
- % (
- hps.name,
- epoch,
- savee(
- ckpt,
- hps.name + "_e%s_s%s" % (epoch, global_step),
- epoch,
- global_step,
- hps,
- ),
- )
- )
-
- if rank == 0:
- logger.info("====> Epoch: {}".format(epoch))
-
-
-if __name__ == "__main__":
- main()
diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py
deleted file mode 100644
index 75b3415..0000000
--- a/GPT_SoVITS/s2_train_v3_lora.py
+++ /dev/null
@@ -1,345 +0,0 @@
-import warnings
-warnings.filterwarnings("ignore")
-import utils, os
-hps = utils.get_hparams(stage=2)
-os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
-import torch
-from torch.nn import functional as F
-from torch.utils.data import DataLoader
-from torch.utils.tensorboard import SummaryWriter
-import torch.multiprocessing as mp
-import torch.distributed as dist, traceback
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.cuda.amp import autocast, GradScaler
-from tqdm import tqdm
-import logging, traceback
-
-logging.getLogger("matplotlib").setLevel(logging.INFO)
-logging.getLogger("h5py").setLevel(logging.INFO)
-logging.getLogger("numba").setLevel(logging.INFO)
-from random import randint
-from module import commons
-from peft import LoraConfig, PeftModel, get_peft_model
-from module.data_utils import (
- TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
- TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
- DistributedBucketSampler,
-)
-from module.models import (
- SynthesizerTrnV3 as SynthesizerTrn,
- MultiPeriodDiscriminator,
-)
-from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
-from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
-from process_ckpt import savee
-from collections import OrderedDict as od
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = False
-###反正A100fp32更快,那试试tf32吧
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
-torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
-# from config import pretrained_s2G,pretrained_s2D
-global_step = 0
-
-device = "cpu" # cuda以外的设备,等mps优化后加入
-
-
-def main():
-
- if torch.cuda.is_available():
- n_gpus = torch.cuda.device_count()
- else:
- n_gpus = 1
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
-
- mp.spawn(
- run,
- nprocs=n_gpus,
- args=(
- n_gpus,
- hps,
- ),
- )
-
-
-def run(rank, n_gpus, hps):
- global global_step,no_grad_names,save_root,lora_rank
- if rank == 0:
- logger = utils.get_logger(hps.data.exp_dir)
- logger.info(hps)
- # utils.check_git_hash(hps.s2_ckpt_dir)
- writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
-
- dist.init_process_group(
- backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
- init_method="env://?use_libuv=False",
- world_size=n_gpus,
- rank=rank,
- )
- torch.manual_seed(hps.train.seed)
- if torch.cuda.is_available():
- torch.cuda.set_device(rank)
-
- train_dataset = TextAudioSpeakerLoader(hps.data) ########
- train_sampler = DistributedBucketSampler(
- train_dataset,
- hps.train.batch_size,
- [
- 32,
- 300,
- 400,
- 500,
- 600,
- 700,
- 800,
- 900,
- 1000,
- # 1100,
- # 1200,
- # 1300,
- # 1400,
- # 1500,
- # 1600,
- # 1700,
- # 1800,
- # 1900,
- ],
- num_replicas=n_gpus,
- rank=rank,
- shuffle=True,
- )
- collate_fn = TextAudioSpeakerCollate()
- train_loader = DataLoader(
- train_dataset,
- num_workers=6,
- shuffle=False,
- pin_memory=True,
- collate_fn=collate_fn,
- batch_sampler=train_sampler,
- persistent_workers=True,
- prefetch_factor=4,
- )
- save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
- os.makedirs(save_root,exist_ok=True)
- lora_rank=int(hps.train.lora_rank)
- lora_config = LoraConfig(
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
- r=lora_rank,
- lora_alpha=lora_rank,
- init_lora_weights=True,
- )
- def get_model(hps):return SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model,
- )
- def get_optim(net_g):
- return torch.optim.AdamW(
- filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- def model2cuda(net_g,rank):
- if torch.cuda.is_available():
- net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
- else:
- net_g = net_g.to(device)
- return net_g
- try:# 如果能加载自动resume
- net_g = get_model(hps)
- net_g.cfm = get_peft_model(net_g.cfm, lora_config)
- net_g=model2cuda(net_g,rank)
- optim_g=get_optim(net_g)
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
- _, _, _, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path(save_root, "G_*.pth"),
- net_g,
- optim_g,
- )
- epoch_str+=1
- global_step = (epoch_str - 1) * len(train_loader)
- except: # 如果首次不能加载,加载pretrain
- # traceback.print_exc()
- epoch_str = 1
- global_step = 0
- net_g = get_model(hps)
- if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
- if rank == 0:
- logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
- print("loaded pretrained %s" % hps.train.pretrained_s2G,
- net_g.load_state_dict(
- torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
- strict=False,
- )
- )
- net_g.cfm = get_peft_model(net_g.cfm, lora_config)
- net_g=model2cuda(net_g,rank)
- optim_g = get_optim(net_g)
-
- no_grad_names=set()
- for name, param in net_g.named_parameters():
- if not param.requires_grad:
- no_grad_names.add(name.replace("module.",""))
- # print(name, "not requires_grad")
- # print(no_grad_names)
- # os._exit(233333)
-
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
- optim_g, gamma=hps.train.lr_decay, last_epoch=-1
- )
- for _ in range(epoch_str):
- scheduler_g.step()
-
- scaler = GradScaler(enabled=hps.train.fp16_run)
-
- net_d=optim_d=scheduler_d=None
- print("start training from epoch %s"%epoch_str)
- for epoch in range(epoch_str, hps.train.epochs + 1):
- if rank == 0:
- train_and_evaluate(
- rank,
- epoch,
- hps,
- [net_g, net_d],
- [optim_g, optim_d],
- [scheduler_g, scheduler_d],
- scaler,
- # [train_loader, eval_loader], logger, [writer, writer_eval])
- [train_loader, None],
- logger,
- [writer, writer_eval],
- )
- else:
- train_and_evaluate(
- rank,
- epoch,
- hps,
- [net_g, net_d],
- [optim_g, optim_d],
- [scheduler_g, scheduler_d],
- scaler,
- [train_loader, None],
- None,
- None,
- )
- scheduler_g.step()
- print("training done")
-
-def train_and_evaluate(
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
-):
- net_g, net_d = nets
- optim_g, optim_d = optims
- # scheduler_g, scheduler_d = schedulers
- train_loader, eval_loader = loaders
- if writers is not None:
- writer, writer_eval = writers
-
- train_loader.batch_sampler.set_epoch(epoch)
- global global_step
-
- net_g.train()
- for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
- if torch.cuda.is_available():
- spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
- rank, non_blocking=True
- )
- mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
- rank, non_blocking=True
- )
- ssl = ssl.cuda(rank, non_blocking=True)
- ssl.requires_grad = False
- text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
- rank, non_blocking=True
- )
- else:
- spec, spec_lengths = spec.to(device), spec_lengths.to(device)
- mel, mel_lengths = mel.to(device), mel_lengths.to(device)
- ssl = ssl.to(device)
- ssl.requires_grad = False
- text, text_lengths = text.to(device), text_lengths.to(device)
-
- with autocast(enabled=hps.train.fp16_run):
- cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
- loss_gen_all=cfm_loss
- optim_g.zero_grad()
- scaler.scale(loss_gen_all).backward()
- scaler.unscale_(optim_g)
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
- scaler.step(optim_g)
- scaler.update()
-
- if rank == 0:
- if global_step % hps.train.log_interval == 0:
- lr = optim_g.param_groups[0]['lr']
- losses = [cfm_loss]
- logger.info('Train Epoch: {} [{:.0f}%]'.format(
- epoch,
- 100. * batch_idx / len(train_loader)))
- logger.info([x.item() for x in losses] + [global_step, lr])
-
- scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
- utils.summarize(
- writer=writer,
- global_step=global_step,
- scalars=scalar_dict)
-
- global_step += 1
- if epoch % hps.train.save_every_epoch == 0 and rank == 0:
- if hps.train.if_save_latest == 0:
- utils.save_checkpoint(
- net_g,
- optim_g,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- save_root, "G_{}.pth".format(global_step)
- ),
- )
- else:
- utils.save_checkpoint(
- net_g,
- optim_g,
- hps.train.learning_rate,
- epoch,
- os.path.join(
- save_root, "G_{}.pth".format(233333333333)
- ),
- )
- if rank == 0 and hps.train.if_save_every_weights == True:
- if hasattr(net_g, "module"):
- ckpt = net_g.module.state_dict()
- else:
- ckpt = net_g.state_dict()
- sim_ckpt=od()
- for key in ckpt:
- # if "cfm"not in key:
- # print(key)
- if key not in no_grad_names:
- sim_ckpt[key]=ckpt[key].half().cpu()
- logger.info(
- "saving ckpt %s_e%s:%s"
- % (
- hps.name,
- epoch,
- savee(
- sim_ckpt,
- hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
- epoch,
- global_step,
- hps,lora_rank=lora_rank
- ),
- )
- )
-
- if rank == 0:
- logger.info("====> Epoch: {}".format(epoch))
-
-
-if __name__ == "__main__":
- main()
diff --git a/GPT_SoVITS_Inference.ipynb b/GPT_SoVITS_Inference.ipynb
deleted file mode 100644
index a5b5532..0000000
--- a/GPT_SoVITS_Inference.ipynb
+++ /dev/null
@@ -1,152 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": []
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "accelerator": "GPU"
- },
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "# Credits for bubarino giving me the huggingface import code (感谢 bubarino 给了我 huggingface 导入代码)"
- ],
- "metadata": {
- "id": "himHYZmra7ix"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "e9b7iFV3dm1f"
- },
- "source": [
- "!git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
- "%cd GPT-SoVITS\n",
- "!apt-get update && apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && git lfs install\n",
- "!pip install -r requirements.txt"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# @title Download pretrained models 下载预训练模型\n",
- "!mkdir -p /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
- "!mkdir -p /content/GPT-SoVITS/tools/damo_asr/models\n",
- "!mkdir -p /content/GPT-SoVITS/tools/uvr5\n",
- "%cd /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
- "!git clone https://huggingface.co/lj1995/GPT-SoVITS\n",
- "%cd /content/GPT-SoVITS/tools/damo_asr/models\n",
- "!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n",
- "!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n",
- "!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n",
- "# @title UVR5 pretrains 安装uvr5模型\n",
- "%cd /content/GPT-SoVITS/tools/uvr5\n",
- "!git clone https://huggingface.co/Delik/uvr5_weights\n",
- "!git config core.sparseCheckout true\n",
- "!mv /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/GPT-SoVITS/* /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/"
- ],
- "metadata": {
- "id": "0NgxXg5sjv7z",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Create folder models 创建文件夹模型\n",
- "import os\n",
- "base_directory = \"/content/GPT-SoVITS\"\n",
- "folder_names = [\"SoVITS_weights\", \"GPT_weights\"]\n",
- "\n",
- "for folder_name in folder_names:\n",
- " if os.path.exists(os.path.join(base_directory, folder_name)):\n",
- " print(f\"The folder '{folder_name}' already exists. (文件夹'{folder_name}'已经存在。)\")\n",
- " else:\n",
- " os.makedirs(os.path.join(base_directory, folder_name))\n",
- " print(f\"The folder '{folder_name}' was created successfully! (文件夹'{folder_name}'已成功创建!)\")\n",
- "\n",
- "print(\"All folders have been created. (所有文件夹均已创建。)\")"
- ],
- "metadata": {
- "cellView": "form",
- "id": "cPDEH-9czOJF"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "import requests\n",
- "import zipfile\n",
- "import shutil\n",
- "import os\n",
- "\n",
- "#@title Import model 导入模型 (HuggingFace)\n",
- "hf_link = 'https://huggingface.co/modelloosrvcc/Nagisa_Shingetsu_GPT-SoVITS/resolve/main/Nagisa.zip' #@param {type: \"string\"}\n",
- "\n",
- "output_path = '/content/'\n",
- "\n",
- "response = requests.get(hf_link)\n",
- "with open(output_path + 'file.zip', 'wb') as file:\n",
- " file.write(response.content)\n",
- "\n",
- "with zipfile.ZipFile(output_path + 'file.zip', 'r') as zip_ref:\n",
- " zip_ref.extractall(output_path)\n",
- "\n",
- "os.remove(output_path + \"file.zip\")\n",
- "\n",
- "source_directory = output_path\n",
- "SoVITS_destination_directory = '/content/GPT-SoVITS/SoVITS_weights'\n",
- "GPT_destination_directory = '/content/GPT-SoVITS/GPT_weights'\n",
- "\n",
- "for filename in os.listdir(source_directory):\n",
- " if filename.endswith(\".pth\"):\n",
- " source_path = os.path.join(source_directory, filename)\n",
- " destination_path = os.path.join(SoVITS_destination_directory, filename)\n",
- " shutil.move(source_path, destination_path)\n",
- "\n",
- "for filename in os.listdir(source_directory):\n",
- " if filename.endswith(\".ckpt\"):\n",
- " source_path = os.path.join(source_directory, filename)\n",
- " destination_path = os.path.join(GPT_destination_directory, filename)\n",
- " shutil.move(source_path, destination_path)\n",
- "\n",
- "print(f'Model downloaded. (模型已下载。)')"
- ],
- "metadata": {
- "cellView": "form",
- "id": "vbZY-LnM0tzq"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# @title launch WebUI 启动WebUI\n",
- "!/usr/local/bin/pip install ipykernel\n",
- "!sed -i '10s/False/True/' /content/GPT-SoVITS/config.py\n",
- "%cd /content/GPT-SoVITS/\n",
- "!/usr/local/bin/python webui.py"
- ],
- "metadata": {
- "id": "4oRGUzkrk8C7",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
-}
\ No newline at end of file
diff --git a/api.py b/api.py
deleted file mode 100644
index d92d9c8..0000000
--- a/api.py
+++ /dev/null
@@ -1,1105 +0,0 @@
-"""
-# api.py usage
-
-` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" `
-
-## 执行参数:
-
-`-s` - `SoVITS模型路径, 可在 config.py 中指定`
-`-g` - `GPT模型路径, 可在 config.py 中指定`
-
-调用请求缺少参考音频时使用
-`-dr` - `默认参考音频路径`
-`-dt` - `默认参考音频文本`
-`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"`
-
-`-d` - `推理设备, "cuda","cpu"`
-`-a` - `绑定地址, 默认"127.0.0.1"`
-`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
-`-fp` - `覆盖 config.py 使用全精度`
-`-hp` - `覆盖 config.py 使用半精度`
-`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
-·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
-·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
-·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入`
-
-`-hb` - `cnhubert路径`
-`-b` - `bert路径`
-
-## 调用:
-
-### 推理
-
-endpoint: `/`
-
-使用执行参数指定的参考音频:
-GET:
- `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
-POST:
-```json
-{
- "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
- "text_language": "zh"
-}
-```
-
-使用执行参数指定的参考音频并设定分割符号:
-GET:
- `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&cut_punc=,。`
-POST:
-```json
-{
- "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
- "text_language": "zh",
- "cut_punc": ",。",
-}
-```
-
-手动指定当次推理所使用的参考音频:
-GET:
- `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
-POST:
-```json
-{
- "refer_wav_path": "123.wav",
- "prompt_text": "一二三。",
- "prompt_language": "zh",
- "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
- "text_language": "zh"
-}
-```
-
-RESP:
-成功: 直接返回 wav 音频流, http code 200
-失败: 返回包含错误信息的 json, http code 400
-
-手动指定当次推理所使用的参考音频,并提供参数:
-GET:
- `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"`
-POST:
-```json
-{
- "refer_wav_path": "123.wav",
- "prompt_text": "一二三。",
- "prompt_language": "zh",
- "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
- "text_language": "zh",
- "top_k": 20,
- "top_p": 0.6,
- "temperature": 0.6,
- "speed": 1,
- "inp_refs": ["456.wav","789.wav"]
-}
-```
-
-RESP:
-成功: 直接返回 wav 音频流, http code 200
-失败: 返回包含错误信息的 json, http code 400
-
-
-### 更换默认参考音频
-
-endpoint: `/change_refer`
-
-key与推理端一样
-
-GET:
- `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh`
-POST:
-```json
-{
- "refer_wav_path": "123.wav",
- "prompt_text": "一二三。",
- "prompt_language": "zh"
-}
-```
-
-RESP:
-成功: json, http code 200
-失败: json, 400
-
-
-### 命令控制
-
-endpoint: `/control`
-
-command:
-"restart": 重新运行
-"exit": 结束运行
-
-GET:
- `http://127.0.0.1:9880/control?command=restart`
-POST:
-```json
-{
- "command": "restart"
-}
-```
-
-RESP: 无
-
-"""
-
-
-import argparse
-import os,re
-import sys
-
-now_dir = os.getcwd()
-sys.path.append(now_dir)
-sys.path.append("%s/GPT_SoVITS" % (now_dir))
-
-import signal
-from text.LangSegmenter import LangSegmenter
-from time import time as ttime
-import torch, torchaudio
-import librosa
-import soundfile as sf
-from fastapi import FastAPI, Request, Query, HTTPException
-from fastapi.responses import StreamingResponse, JSONResponse
-import uvicorn
-from transformers import AutoModelForMaskedLM, AutoTokenizer
-import numpy as np
-from feature_extractor import cnhubert
-from io import BytesIO
-from module.models import SynthesizerTrn, SynthesizerTrnV3
-from peft import LoraConfig, PeftModel, get_peft_model
-from AR.models.t2s_lightning_module import Text2SemanticLightningModule
-from text import cleaned_text_to_sequence
-from text.cleaner import clean_text
-from module.mel_processing import spectrogram_torch
-from tools.my_utils import load_audio
-import config as global_config
-import logging
-import subprocess
-
-
-class DefaultRefer:
- def __init__(self, path, text, language):
- self.path = args.default_refer_path
- self.text = args.default_refer_text
- self.language = args.default_refer_language
-
- def is_ready(self) -> bool:
- return is_full(self.path, self.text, self.language)
-
-
-def is_empty(*items): # 任意一项不为空返回False
- for item in items:
- if item is not None and item != "":
- return False
- return True
-
-
-def is_full(*items): # 任意一项为空返回False
- for item in items:
- if item is None or item == "":
- return False
- return True
-
-
-def init_bigvgan():
- global bigvgan_model
- from BigVGAN import bigvgan
- bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
- # remove weight norm in the model and set to eval mode
- bigvgan_model.remove_weight_norm()
- bigvgan_model = bigvgan_model.eval()
- if is_half == True:
- bigvgan_model = bigvgan_model.half().to(device)
- else:
- bigvgan_model = bigvgan_model.to(device)
-
-
-resample_transform_dict={}
-def resample(audio_tensor, sr0):
- global resample_transform_dict
- if sr0 not in resample_transform_dict:
- resample_transform_dict[sr0] = torchaudio.transforms.Resample(
- sr0, 24000
- ).to(device)
- return resample_transform_dict[sr0](audio_tensor)
-
-
-from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
-spec_min = -12
-spec_max = 2
-def norm_spec(x):
- return (x - spec_min) / (spec_max - spec_min) * 2 - 1
-def denorm_spec(x):
- return (x + 1) / 2 * (spec_max - spec_min) + spec_min
-mel_fn=lambda x: mel_spectrogram_torch(x, **{
- "n_fft": 1024,
- "win_size": 1024,
- "hop_size": 256,
- "num_mels": 100,
- "sampling_rate": 24000,
- "fmin": 0,
- "fmax": None,
- "center": False
-})
-
-
-sr_model=None
-def audio_sr(audio,sr):
- global sr_model
- if sr_model==None:
- from tools.audio_sr import AP_BWE
- try:
- sr_model=AP_BWE(device,DictToAttrRecursive)
- except FileNotFoundError:
- logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载")
- return audio.cpu().detach().numpy(),sr
- return sr_model(audio,sr)
-
-
-class Speaker:
- def __init__(self, name, gpt, sovits, phones = None, bert = None, prompt = None):
- self.name = name
- self.sovits = sovits
- self.gpt = gpt
- self.phones = phones
- self.bert = bert
- self.prompt = prompt
-
-speaker_list = {}
-
-
-class Sovits:
- def __init__(self, vq_model, hps):
- self.vq_model = vq_model
- self.hps = hps
-
-from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
-def get_sovits_weights(sovits_path):
- path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
- is_exist_s2gv3=os.path.exists(path_sovits_v3)
-
- version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
- if if_lora_v3==True and is_exist_s2gv3==False:
- logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
-
- dict_s2 = load_sovits_new(sovits_path)
- hps = dict_s2["config"]
- hps = DictToAttrRecursive(hps)
- hps.model.semantic_frame_rate = "25hz"
- if 'enc_p.text_embedding.weight' not in dict_s2['weight']:
- hps.model.version = "v2"#v3model,v2sybomls
- elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
- hps.model.version = "v1"
- else:
- hps.model.version = "v2"
-
- if model_version == "v3":
- hps.model.version = "v3"
-
- model_params_dict = vars(hps.model)
- if model_version!="v3":
- vq_model = SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **model_params_dict
- )
- else:
- vq_model = SynthesizerTrnV3(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **model_params_dict
- )
- init_bigvgan()
- model_version=hps.model.version
- logger.info(f"模型版本: {model_version}")
- if ("pretrained" not in sovits_path):
- try:
- del vq_model.enc_q
- except:pass
- if is_half == True:
- vq_model = vq_model.half().to(device)
- else:
- vq_model = vq_model.to(device)
- vq_model.eval()
- if if_lora_v3 == False:
- vq_model.load_state_dict(dict_s2["weight"], strict=False)
- else:
- vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
- lora_rank=dict_s2["lora_rank"]
- lora_config = LoraConfig(
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
- r=lora_rank,
- lora_alpha=lora_rank,
- init_lora_weights=True,
- )
- vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
- vq_model.load_state_dict(dict_s2["weight"], strict=False)
- vq_model.cfm = vq_model.cfm.merge_and_unload()
- # torch.save(vq_model.state_dict(),"merge_win.pth")
- vq_model.eval()
-
- sovits = Sovits(vq_model, hps)
- return sovits
-
-class Gpt:
- def __init__(self, max_sec, t2s_model):
- self.max_sec = max_sec
- self.t2s_model = t2s_model
-
-global hz
-hz = 50
-def get_gpt_weights(gpt_path):
- dict_s1 = torch.load(gpt_path, map_location="cpu")
- config = dict_s1["config"]
- max_sec = config["data"]["max_sec"]
- t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
- t2s_model.load_state_dict(dict_s1["weight"])
- if is_half == True:
- t2s_model = t2s_model.half()
- t2s_model = t2s_model.to(device)
- t2s_model.eval()
- # total = sum([param.nelement() for param in t2s_model.parameters()])
- # logger.info("Number of parameter: %.2fM" % (total / 1e6))
-
- gpt = Gpt(max_sec, t2s_model)
- return gpt
-
-def change_gpt_sovits_weights(gpt_path,sovits_path):
- try:
- gpt = get_gpt_weights(gpt_path)
- sovits = get_sovits_weights(sovits_path)
- except Exception as e:
- return JSONResponse({"code": 400, "message": str(e)}, status_code=400)
-
- speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits)
- return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
-
-
-def get_bert_feature(text, word2ph):
- with torch.no_grad():
- inputs = tokenizer(text, return_tensors="pt")
- for i in inputs:
- inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
- res = bert_model(**inputs, output_hidden_states=True)
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
- assert len(word2ph) == len(text)
- phone_level_feature = []
- for i in range(len(word2ph)):
- repeat_feature = res[i].repeat(word2ph[i], 1)
- phone_level_feature.append(repeat_feature)
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
- return phone_level_feature.T
-
-
-def clean_text_inf(text, language, version):
- language = language.replace("all_","")
- phones, word2ph, norm_text = clean_text(text, language, version)
- phones = cleaned_text_to_sequence(phones, version)
- return phones, word2ph, norm_text
-
-
-def get_bert_inf(phones, word2ph, norm_text, language):
- language=language.replace("all_","")
- if language == "zh":
- bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
- else:
- bert = torch.zeros(
- (1024, len(phones)),
- dtype=torch.float16 if is_half == True else torch.float32,
- ).to(device)
-
- return bert
-
-from text import chinese
-def get_phones_and_bert(text,language,version,final=False):
- if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
- formattext = text
- while " " in formattext:
- formattext = formattext.replace(" ", " ")
- if language == "all_zh":
- if re.search(r'[A-Za-z]', formattext):
- formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
- formattext = chinese.mix_text_normalize(formattext)
- return get_phones_and_bert(formattext,"zh",version)
- else:
- phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
- bert = get_bert_feature(norm_text, word2ph).to(device)
- elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
- formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
- formattext = chinese.mix_text_normalize(formattext)
- return get_phones_and_bert(formattext,"yue",version)
- else:
- phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
- bert = torch.zeros(
- (1024, len(phones)),
- dtype=torch.float16 if is_half == True else torch.float32,
- ).to(device)
- elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
- textlist=[]
- langlist=[]
- if language == "auto":
- for tmp in LangSegmenter.getTexts(text):
- langlist.append(tmp["lang"])
- textlist.append(tmp["text"])
- elif language == "auto_yue":
- for tmp in LangSegmenter.getTexts(text):
- if tmp["lang"] == "zh":
- tmp["lang"] = "yue"
- langlist.append(tmp["lang"])
- textlist.append(tmp["text"])
- else:
- for tmp in LangSegmenter.getTexts(text):
- if tmp["lang"] == "en":
- langlist.append(tmp["lang"])
- else:
- # 因无法区别中日韩文汉字,以用户输入为准
- langlist.append(language)
- textlist.append(tmp["text"])
- phones_list = []
- bert_list = []
- norm_text_list = []
- for i in range(len(textlist)):
- lang = langlist[i]
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
- bert = get_bert_inf(phones, word2ph, norm_text, lang)
- phones_list.append(phones)
- norm_text_list.append(norm_text)
- bert_list.append(bert)
- bert = torch.cat(bert_list, dim=1)
- phones = sum(phones_list, [])
- norm_text = ''.join(norm_text_list)
-
- if not final and len(phones) < 6:
- return get_phones_and_bert("." + text,language,version,final=True)
-
- return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
-
-
-class DictToAttrRecursive(dict):
- def __init__(self, input_dict):
- super().__init__(input_dict)
- for key, value in input_dict.items():
- if isinstance(value, dict):
- value = DictToAttrRecursive(value)
- self[key] = value
- setattr(self, key, value)
-
- def __getattr__(self, item):
- try:
- return self[item]
- except KeyError:
- raise AttributeError(f"Attribute {item} not found")
-
- def __setattr__(self, key, value):
- if isinstance(value, dict):
- value = DictToAttrRecursive(value)
- super(DictToAttrRecursive, self).__setitem__(key, value)
- super().__setattr__(key, value)
-
- def __delattr__(self, item):
- try:
- del self[item]
- except KeyError:
- raise AttributeError(f"Attribute {item} not found")
-
-
-def get_spepc(hps, filename):
- audio,_ = librosa.load(filename, int(hps.data.sampling_rate))
- audio = torch.FloatTensor(audio)
- maxx=audio.abs().max()
- if(maxx>1):
- audio/=min(2,maxx)
- audio_norm = audio
- audio_norm = audio_norm.unsqueeze(0)
- spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
- hps.data.win_length, center=False)
- return spec
-
-
-def pack_audio(audio_bytes, data, rate):
- if media_type == "ogg":
- audio_bytes = pack_ogg(audio_bytes, data, rate)
- elif media_type == "aac":
- audio_bytes = pack_aac(audio_bytes, data, rate)
- else:
- # wav无法流式, 先暂存raw
- audio_bytes = pack_raw(audio_bytes, data, rate)
-
- return audio_bytes
-
-
-def pack_ogg(audio_bytes, data, rate):
- # Author: AkagawaTsurunaki
- # Issue:
- # Stack overflow probabilistically occurs
- # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called
- # using the Python library `soundfile`
- # Note:
- # This is an issue related to `libsndfile`, not this project itself.
- # It happens when you generate a large audio tensor (about 499804 frames in my PC)
- # and try to convert it to an ogg file.
- # Related:
- # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199
- # https://github.com/libsndfile/libsndfile/issues/1023
- # https://github.com/bastibe/python-soundfile/issues/396
- # Suggestion:
- # Or split the whole audio data into smaller audio segment to avoid stack overflow?
-
- def handle_pack_ogg():
- with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
- audio_file.write(data)
-
- import threading
- # See: https://docs.python.org/3/library/threading.html
- # The stack size of this thread is at least 32768
- # If stack overflow error still occurs, just modify the `stack_size`.
- # stack_size = n * 4096, where n should be a positive integer.
- # Here we chose n = 4096.
- stack_size = 4096 * 4096
- try:
- threading.stack_size(stack_size)
- pack_ogg_thread = threading.Thread(target=handle_pack_ogg)
- pack_ogg_thread.start()
- pack_ogg_thread.join()
- except RuntimeError as e:
- # If changing the thread stack size is unsupported, a RuntimeError is raised.
- print("RuntimeError: {}".format(e))
- print("Changing the thread stack size is unsupported.")
- except ValueError as e:
- # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified.
- print("ValueError: {}".format(e))
- print("The specified stack size is invalid.")
-
- return audio_bytes
-
-
-def pack_raw(audio_bytes, data, rate):
- audio_bytes.write(data.tobytes())
-
- return audio_bytes
-
-
-def pack_wav(audio_bytes, rate):
- if is_int32:
- data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
- wav_bytes = BytesIO()
- sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
- else:
- data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
- wav_bytes = BytesIO()
- sf.write(wav_bytes, data, rate, format='WAV')
- return wav_bytes
-
-
-def pack_aac(audio_bytes, data, rate):
- if is_int32:
- pcm = 's32le'
- bit_rate = '256k'
- else:
- pcm = 's16le'
- bit_rate = '128k'
- process = subprocess.Popen([
- 'ffmpeg',
- '-f', pcm, # 输入16位有符号小端整数PCM
- '-ar', str(rate), # 设置采样率
- '-ac', '1', # 单声道
- '-i', 'pipe:0', # 从管道读取输入
- '-c:a', 'aac', # 音频编码器为AAC
- '-b:a', bit_rate, # 比特率
- '-vn', # 不包含视频
- '-f', 'adts', # 输出AAC数据流格式
- 'pipe:1' # 将输出写入管道
- ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- out, _ = process.communicate(input=data.tobytes())
- audio_bytes.write(out)
-
- return audio_bytes
-
-
-def read_clean_buffer(audio_bytes):
- audio_chunk = audio_bytes.getvalue()
- audio_bytes.truncate(0)
- audio_bytes.seek(0)
-
- return audio_bytes, audio_chunk
-
-
-def cut_text(text, punc):
- punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}]
- if len(punc_list) > 0:
- punds = r"[" + "".join(punc_list) + r"]"
- text = text.strip("\n")
- items = re.split(f"({punds})", text)
- mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
- # 在句子不存在符号或句尾无符号的时候保证文本完整
- if len(items)%2 == 1:
- mergeitems.append(items[-1])
- text = "\n".join(mergeitems)
-
- while "\n\n" in text:
- text = text.replace("\n\n", "\n")
-
- return text
-
-
-def only_punc(text):
- return not any(t.isalnum() or t.isalpha() for t in text)
-
-
-splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
-def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 15, top_p = 0.6, temperature = 0.6, speed = 1, inp_refs = None, sample_steps = 32, if_sr = False, spk = "default"):
- infer_sovits = speaker_list[spk].sovits
- vq_model = infer_sovits.vq_model
- hps = infer_sovits.hps
- version = vq_model.version
-
- infer_gpt = speaker_list[spk].gpt
- t2s_model = infer_gpt.t2s_model
- max_sec = infer_gpt.max_sec
-
- t0 = ttime()
- prompt_text = prompt_text.strip("\n")
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
- prompt_language, text = prompt_language, text.strip("\n")
- dtype = torch.float16 if is_half == True else torch.float32
- zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
- with torch.no_grad():
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
- wav16k = torch.from_numpy(wav16k)
- zero_wav_torch = torch.from_numpy(zero_wav)
- if (is_half == True):
- wav16k = wav16k.half().to(device)
- zero_wav_torch = zero_wav_torch.half().to(device)
- else:
- wav16k = wav16k.to(device)
- zero_wav_torch = zero_wav_torch.to(device)
- wav16k = torch.cat([wav16k, zero_wav_torch])
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
- codes = vq_model.extract_latent(ssl_content)
- prompt_semantic = codes[0, 0]
- prompt = prompt_semantic.unsqueeze(0).to(device)
-
- if version != "v3":
- refers=[]
- if(inp_refs):
- for path in inp_refs:
- try:
- refer = get_spepc(hps, path).to(dtype).to(device)
- refers.append(refer)
- except Exception as e:
- logger.error(e)
- if(len(refers)==0):
- refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
- else:
- refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
-
- t1 = ttime()
- # os.environ['version'] = version
- prompt_language = dict_language[prompt_language.lower()]
- text_language = dict_language[text_language.lower()]
- phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
- texts = text.split("\n")
- audio_bytes = BytesIO()
-
- for text in texts:
- # 简单防止纯符号引发参考音频泄露
- if only_punc(text):
- continue
-
- audio_opt = []
- if (text[-1] not in splits): text += "。" if text_language != "en" else "."
- phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
- bert = torch.cat([bert1, bert2], 1)
-
- all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
- bert = bert.to(device).unsqueeze(0)
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
- t2 = ttime()
- with torch.no_grad():
- pred_semantic, idx = t2s_model.model.infer_panel(
- all_phoneme_ids,
- all_phoneme_len,
- prompt,
- bert,
- # prompt_phone_len=ph_offset,
- top_k = top_k,
- top_p = top_p,
- temperature = temperature,
- early_stop_num=hz * max_sec)
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
- t3 = ttime()
-
- if version != "v3":
- audio = \
- vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
- refers,speed=speed).detach().cpu().numpy()[
- 0, 0] ###试试重建不带上prompt部分
- else:
- phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
- phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
- # print(11111111, phoneme_ids0, phoneme_ids1)
- fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
- ref_audio, sr = torchaudio.load(ref_wav_path)
- ref_audio=ref_audio.to(device).float()
- if (ref_audio.shape[0] == 2):
- ref_audio = ref_audio.mean(0).unsqueeze(0)
- if sr!=24000:
- ref_audio=resample(ref_audio,sr)
- # print("ref_audio",ref_audio.abs().mean())
- mel2 = mel_fn(ref_audio)
- mel2 = norm_spec(mel2)
- T_min = min(mel2.shape[2], fea_ref.shape[2])
- mel2 = mel2[:, :, :T_min]
- fea_ref = fea_ref[:, :, :T_min]
- if (T_min > 468):
- mel2 = mel2[:, :, -468:]
- fea_ref = fea_ref[:, :, -468:]
- T_min = 468
- chunk_len = 934 - T_min
- # print("fea_ref",fea_ref,fea_ref.shape)
- # print("mel2",mel2)
- mel2=mel2.to(dtype)
- fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
- # print("fea_todo",fea_todo)
- # print("ge",ge.abs().mean())
- cfm_resss = []
- idx = 0
- while (1):
- fea_todo_chunk = fea_todo[:, :, idx:idx + chunk_len]
- if (fea_todo_chunk.shape[-1] == 0): break
- idx += chunk_len
- fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
- # set_seed(123)
- cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
- cfm_res = cfm_res[:, :, mel2.shape[2]:]
- mel2 = cfm_res[:, :, -T_min:]
- # print("fea", fea)
- # print("mel2in", mel2)
- fea_ref = fea_todo_chunk[:, :, -T_min:]
- cfm_resss.append(cfm_res)
- cmf_res = torch.cat(cfm_resss, 2)
- cmf_res = denorm_spec(cmf_res)
- if bigvgan_model==None:init_bigvgan()
- with torch.inference_mode():
- wav_gen = bigvgan_model(cmf_res)
- audio=wav_gen[0][0].cpu().detach().numpy()
-
- max_audio=np.abs(audio).max()
- if max_audio>1:
- audio/=max_audio
- audio_opt.append(audio)
- audio_opt.append(zero_wav)
- audio_opt = np.concatenate(audio_opt, 0)
- t4 = ttime()
-
- sr = hps.data.sampling_rate if version != "v3" else 24000
- if if_sr and sr == 24000:
- audio_opt = torch.from_numpy(audio_opt).float().to(device)
- audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
- max_audio=np.abs(audio_opt).max()
- if max_audio > 1: audio_opt /= max_audio
- sr = 48000
-
- if is_int32:
- audio_bytes = pack_audio(audio_bytes,(audio_opt * 2147483647).astype(np.int32),sr)
- else:
- audio_bytes = pack_audio(audio_bytes,(audio_opt * 32768).astype(np.int16),sr)
- # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
- if stream_mode == "normal":
- audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
- yield audio_chunk
-
- if not stream_mode == "normal":
- if media_type == "wav":
- sr = 48000 if if_sr else 24000
- sr = hps.data.sampling_rate if version != "v3" else sr
- audio_bytes = pack_wav(audio_bytes,sr)
- yield audio_bytes.getvalue()
-
-
-
-def handle_control(command):
- if command == "restart":
- os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
- elif command == "exit":
- os.kill(os.getpid(), signal.SIGTERM)
- exit(0)
-
-
-def handle_change(path, text, language):
- if is_empty(path, text, language):
- return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
-
- if path != "" or path is not None:
- default_refer.path = path
- if text != "" or text is not None:
- default_refer.text = text
- if language != "" or language is not None:
- default_refer.language = language
-
- logger.info(f"当前默认参考音频路径: {default_refer.path}")
- logger.info(f"当前默认参考音频文本: {default_refer.text}")
- logger.info(f"当前默认参考音频语种: {default_refer.language}")
- logger.info(f"is_ready: {default_refer.is_ready()}")
-
-
- return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
-
-
-def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr):
- if (
- refer_wav_path == "" or refer_wav_path is None
- or prompt_text == "" or prompt_text is None
- or prompt_language == "" or prompt_language is None
- ):
- refer_wav_path, prompt_text, prompt_language = (
- default_refer.path,
- default_refer.text,
- default_refer.language,
- )
- if not default_refer.is_ready():
- return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
-
- if not sample_steps in [4,8,16,32]:
- sample_steps = 32
-
- if cut_punc == None:
- text = cut_text(text,default_cut_punc)
- else:
- text = cut_text(text,cut_punc)
-
- return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr), media_type="audio/"+media_type)
-
-
-
-
-# --------------------------------
-# 初始化部分
-# --------------------------------
-dict_language = {
- "中文": "all_zh",
- "粤语": "all_yue",
- "英文": "en",
- "日文": "all_ja",
- "韩文": "all_ko",
- "中英混合": "zh",
- "粤英混合": "yue",
- "日英混合": "ja",
- "韩英混合": "ko",
- "多语种混合": "auto", #多语种启动切分识别语种
- "多语种混合(粤语)": "auto_yue",
- "all_zh": "all_zh",
- "all_yue": "all_yue",
- "en": "en",
- "all_ja": "all_ja",
- "all_ko": "all_ko",
- "zh": "zh",
- "yue": "yue",
- "ja": "ja",
- "ko": "ko",
- "auto": "auto",
- "auto_yue": "auto_yue",
-}
-
-# logger
-logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
-logger = logging.getLogger('uvicorn')
-
-# 获取配置
-g_config = global_config.Config()
-
-# 获取参数
-parser = argparse.ArgumentParser(description="GPT-SoVITS api")
-
-parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
-parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
-parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
-parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
-parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
-parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
-parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
-parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
-parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
-parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
-# bool值的用法为 `python ./api.py -fp ...`
-# 此时 full_precision==True, half_precision==False
-parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
-parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
-parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
-parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
-# 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
-parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
-parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
-
-args = parser.parse_args()
-sovits_path = args.sovits_path
-gpt_path = args.gpt_path
-device = args.device
-port = args.port
-host = args.bind_addr
-cnhubert_base_path = args.hubert_path
-bert_path = args.bert_path
-default_cut_punc = args.cut_punc
-
-# 应用参数配置
-default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
-
-# 模型路径检查
-if sovits_path == "":
- sovits_path = g_config.pretrained_sovits_path
- logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
-if gpt_path == "":
- gpt_path = g_config.pretrained_gpt_path
- logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
-
-# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
-if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
- default_refer.path, default_refer.text, default_refer.language = "", "", ""
- logger.info("未指定默认参考音频")
-else:
- logger.info(f"默认参考音频路径: {default_refer.path}")
- logger.info(f"默认参考音频文本: {default_refer.text}")
- logger.info(f"默认参考音频语种: {default_refer.language}")
-
-# 获取半精度
-is_half = g_config.is_half
-if args.full_precision:
- is_half = False
-if args.half_precision:
- is_half = True
-if args.full_precision and args.half_precision:
- is_half = g_config.is_half # 炒饭fallback
-logger.info(f"半精: {is_half}")
-
-# 流式返回模式
-if args.stream_mode.lower() in ["normal","n"]:
- stream_mode = "normal"
- logger.info("流式返回已开启")
-else:
- stream_mode = "close"
-
-# 音频编码格式
-if args.media_type.lower() in ["aac","ogg"]:
- media_type = args.media_type.lower()
-elif stream_mode == "close":
- media_type = "wav"
-else:
- media_type = "ogg"
-logger.info(f"编码格式: {media_type}")
-
-# 音频数据类型
-if args.sub_type.lower() == 'int32':
- is_int32 = True
- logger.info(f"数据类型: int32")
-else:
- is_int32 = False
- logger.info(f"数据类型: int16")
-
-# 初始化模型
-cnhubert.cnhubert_base_path = cnhubert_base_path
-tokenizer = AutoTokenizer.from_pretrained(bert_path)
-bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
-ssl_model = cnhubert.get_model()
-if is_half:
- bert_model = bert_model.half().to(device)
- ssl_model = ssl_model.half().to(device)
-else:
- bert_model = bert_model.to(device)
- ssl_model = ssl_model.to(device)
-change_gpt_sovits_weights(gpt_path = gpt_path, sovits_path = sovits_path)
-
-
-
-# --------------------------------
-# 接口部分
-# --------------------------------
-app = FastAPI()
-
-@app.post("/set_model")
-async def set_model(request: Request):
- json_post_raw = await request.json()
- return change_gpt_sovits_weights(
- gpt_path = json_post_raw.get("gpt_model_path"),
- sovits_path = json_post_raw.get("sovits_model_path")
- )
-
-
-@app.get("/set_model")
-async def set_model(
- gpt_model_path: str = None,
- sovits_model_path: str = None,
-):
- return change_gpt_sovits_weights(gpt_path = gpt_model_path, sovits_path = sovits_model_path)
-
-
-@app.post("/control")
-async def control(request: Request):
- json_post_raw = await request.json()
- return handle_control(json_post_raw.get("command"))
-
-
-@app.get("/control")
-async def control(command: str = None):
- return handle_control(command)
-
-
-@app.post("/change_refer")
-async def change_refer(request: Request):
- json_post_raw = await request.json()
- return handle_change(
- json_post_raw.get("refer_wav_path"),
- json_post_raw.get("prompt_text"),
- json_post_raw.get("prompt_language")
- )
-
-
-@app.get("/change_refer")
-async def change_refer(
- refer_wav_path: str = None,
- prompt_text: str = None,
- prompt_language: str = None
-):
- return handle_change(refer_wav_path, prompt_text, prompt_language)
-
-
-@app.post("/")
-async def tts_endpoint(request: Request):
- json_post_raw = await request.json()
- return handle(
- json_post_raw.get("refer_wav_path"),
- json_post_raw.get("prompt_text"),
- json_post_raw.get("prompt_language"),
- json_post_raw.get("text"),
- json_post_raw.get("text_language"),
- json_post_raw.get("cut_punc"),
- json_post_raw.get("top_k", 15),
- json_post_raw.get("top_p", 1.0),
- json_post_raw.get("temperature", 1.0),
- json_post_raw.get("speed", 1.0),
- json_post_raw.get("inp_refs", []),
- json_post_raw.get("sample_steps", 32),
- json_post_raw.get("if_sr", False)
- )
-
-
-@app.get("/")
-async def tts_endpoint(
- refer_wav_path: str = None,
- prompt_text: str = None,
- prompt_language: str = None,
- text: str = None,
- text_language: str = None,
- cut_punc: str = None,
- top_k: int = 15,
- top_p: float = 1.0,
- temperature: float = 1.0,
- speed: float = 1.0,
- inp_refs: list = Query(default=[]),
- sample_steps: int = 32,
- if_sr: bool = False
-):
- return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed, inp_refs, sample_steps, if_sr)
-
-
-if __name__ == "__main__":
- uvicorn.run(app, host=host, port=port, workers=1)
diff --git a/api_v2.py b/api_v2.py
deleted file mode 100644
index 92a18f3..0000000
--- a/api_v2.py
+++ /dev/null
@@ -1,460 +0,0 @@
-"""
-# WebAPI文档
-
-` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
-
-## 执行参数:
- `-a` - `绑定地址, 默认"127.0.0.1"`
- `-p` - `绑定端口, 默认9880`
- `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"`
-
-## 调用:
-
-### 推理
-
-endpoint: `/tts`
-GET:
-```
-http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true
-```
-
-POST:
-```json
-{
- "text": "", # str.(required) text to be synthesized
- "text_lang: "", # str.(required) language of the text to be synthesized
- "ref_audio_path": "", # str.(required) reference audio path
- "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
- "prompt_text": "", # str.(optional) prompt text for the reference audio
- "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
- "top_k": 5, # int. top k sampling
- "top_p": 1, # float. top p sampling
- "temperature": 1, # float. temperature for sampling
- "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
- "batch_size": 1, # int. batch size for inference
- "batch_threshold": 0.75, # float. threshold for batch splitting.
- "split_bucket: True, # bool. whether to split the batch into multiple buckets.
- "speed_factor":1.0, # float. control the speed of the synthesized audio.
- "streaming_mode": False, # bool. whether to return a streaming response.
- "seed": -1, # int. random seed for reproducibility.
- "parallel_infer": True, # bool. whether to use parallel inference.
- "repetition_penalty": 1.35 # float. repetition penalty for T2S model.
-}
-```
-
-RESP:
-成功: 直接返回 wav 音频流, http code 200
-失败: 返回包含错误信息的 json, http code 400
-
-### 命令控制
-
-endpoint: `/control`
-
-command:
-"restart": 重新运行
-"exit": 结束运行
-
-GET:
-```
-http://127.0.0.1:9880/control?command=restart
-```
-POST:
-```json
-{
- "command": "restart"
-}
-```
-
-RESP: 无
-
-
-### 切换GPT模型
-
-endpoint: `/set_gpt_weights`
-
-GET:
-```
-http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
-```
-RESP:
-成功: 返回"success", http code 200
-失败: 返回包含错误信息的 json, http code 400
-
-
-### 切换Sovits模型
-
-endpoint: `/set_sovits_weights`
-
-GET:
-```
-http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth
-```
-
-RESP:
-成功: 返回"success", http code 200
-失败: 返回包含错误信息的 json, http code 400
-
-"""
-import os
-import sys
-import traceback
-from typing import Generator
-
-now_dir = os.getcwd()
-sys.path.append(now_dir)
-sys.path.append("%s/GPT_SoVITS" % (now_dir))
-
-import argparse
-import subprocess
-import wave
-import signal
-import numpy as np
-import soundfile as sf
-from fastapi import FastAPI, Request, HTTPException, Response
-from fastapi.responses import StreamingResponse, JSONResponse
-from fastapi import FastAPI, UploadFile, File
-import uvicorn
-from io import BytesIO
-from tools.i18n.i18n import I18nAuto
-from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
-from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
-from fastapi.responses import StreamingResponse
-from pydantic import BaseModel
-# print(sys.path)
-i18n = I18nAuto()
-cut_method_names = get_cut_method_names()
-
-parser = argparse.ArgumentParser(description="GPT-SoVITS api")
-parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
-parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
-parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
-args = parser.parse_args()
-config_path = args.tts_config
-# device = args.device
-port = args.port
-host = args.bind_addr
-argv = sys.argv
-
-if config_path in [None, ""]:
- config_path = "GPT-SoVITS/configs/tts_infer.yaml"
-
-tts_config = TTS_Config(config_path)
-print(tts_config)
-tts_pipeline = TTS(tts_config)
-
-APP = FastAPI()
-class TTS_Request(BaseModel):
- text: str = None
- text_lang: str = None
- ref_audio_path: str = None
- aux_ref_audio_paths: list = None
- prompt_lang: str = None
- prompt_text: str = ""
- top_k:int = 5
- top_p:float = 1
- temperature:float = 1
- text_split_method:str = "cut5"
- batch_size:int = 1
- batch_threshold:float = 0.75
- split_bucket:bool = True
- speed_factor:float = 1.0
- fragment_interval:float = 0.3
- seed:int = -1
- media_type:str = "wav"
- streaming_mode:bool = False
- parallel_infer:bool = True
- repetition_penalty:float = 1.35
-
-### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
-def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
- with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
- audio_file.write(data)
- return io_buffer
-
-
-def pack_raw(io_buffer:BytesIO, data:np.ndarray, rate:int):
- io_buffer.write(data.tobytes())
- return io_buffer
-
-
-def pack_wav(io_buffer:BytesIO, data:np.ndarray, rate:int):
- io_buffer = BytesIO()
- sf.write(io_buffer, data, rate, format='wav')
- return io_buffer
-
-def pack_aac(io_buffer:BytesIO, data:np.ndarray, rate:int):
- process = subprocess.Popen([
- 'ffmpeg',
- '-f', 's16le', # 输入16位有符号小端整数PCM
- '-ar', str(rate), # 设置采样率
- '-ac', '1', # 单声道
- '-i', 'pipe:0', # 从管道读取输入
- '-c:a', 'aac', # 音频编码器为AAC
- '-b:a', '192k', # 比特率
- '-vn', # 不包含视频
- '-f', 'adts', # 输出AAC数据流格式
- 'pipe:1' # 将输出写入管道
- ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- out, _ = process.communicate(input=data.tobytes())
- io_buffer.write(out)
- return io_buffer
-
-def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
- if media_type == "ogg":
- io_buffer = pack_ogg(io_buffer, data, rate)
- elif media_type == "aac":
- io_buffer = pack_aac(io_buffer, data, rate)
- elif media_type == "wav":
- io_buffer = pack_wav(io_buffer, data, rate)
- else:
- io_buffer = pack_raw(io_buffer, data, rate)
- io_buffer.seek(0)
- return io_buffer
-
-
-
-# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
-def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
- # This will create a wave header then append the frame input
- # It should be first on a streaming wav file
- # Other frames better should not have it (else you will hear some artifacts each chunk start)
- wav_buf = BytesIO()
- with wave.open(wav_buf, "wb") as vfout:
- vfout.setnchannels(channels)
- vfout.setsampwidth(sample_width)
- vfout.setframerate(sample_rate)
- vfout.writeframes(frame_input)
-
- wav_buf.seek(0)
- return wav_buf.read()
-
-
-def handle_control(command:str):
- if command == "restart":
- os.execl(sys.executable, sys.executable, *argv)
- elif command == "exit":
- os.kill(os.getpid(), signal.SIGTERM)
- exit(0)
-
-
-def check_params(req:dict):
- text:str = req.get("text", "")
- text_lang:str = req.get("text_lang", "")
- ref_audio_path:str = req.get("ref_audio_path", "")
- streaming_mode:bool = req.get("streaming_mode", False)
- media_type:str = req.get("media_type", "wav")
- prompt_lang:str = req.get("prompt_lang", "")
- text_split_method:str = req.get("text_split_method", "cut5")
-
- if ref_audio_path in [None, ""]:
- return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
- if text in [None, ""]:
- return JSONResponse(status_code=400, content={"message": "text is required"})
- if (text_lang in [None, ""]) :
- return JSONResponse(status_code=400, content={"message": "text_lang is required"})
- elif text_lang.lower() not in tts_config.languages:
- return JSONResponse(status_code=400, content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"})
- if (prompt_lang in [None, ""]) :
- return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
- elif prompt_lang.lower() not in tts_config.languages:
- return JSONResponse(status_code=400, content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"})
- if media_type not in ["wav", "raw", "ogg", "aac"]:
- return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"})
- elif media_type == "ogg" and not streaming_mode:
- return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
-
- if text_split_method not in cut_method_names:
- return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
-
- return None
-
-async def tts_handle(req:dict):
- """
- Text to speech handler.
-
- Args:
- req (dict):
- {
- "text": "", # str.(required) text to be synthesized
- "text_lang: "", # str.(required) language of the text to be synthesized
- "ref_audio_path": "", # str.(required) reference audio path
- "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
- "prompt_text": "", # str.(optional) prompt text for the reference audio
- "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
- "top_k": 5, # int. top k sampling
- "top_p": 1, # float. top p sampling
- "temperature": 1, # float. temperature for sampling
- "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
- "batch_size": 1, # int. batch size for inference
- "batch_threshold": 0.75, # float. threshold for batch splitting.
- "split_bucket: True, # bool. whether to split the batch into multiple buckets.
- "speed_factor":1.0, # float. control the speed of the synthesized audio.
- "fragment_interval":0.3, # float. to control the interval of the audio fragment.
- "seed": -1, # int. random seed for reproducibility.
- "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
- "streaming_mode": False, # bool. whether to return a streaming response.
- "parallel_infer": True, # bool.(optional) whether to use parallel inference.
- "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
- }
- returns:
- StreamingResponse: audio stream response.
- """
-
- streaming_mode = req.get("streaming_mode", False)
- return_fragment = req.get("return_fragment", False)
- media_type = req.get("media_type", "wav")
-
- check_res = check_params(req)
- if check_res is not None:
- return check_res
-
- if streaming_mode or return_fragment:
- req["return_fragment"] = True
-
- try:
- tts_generator=tts_pipeline.run(req)
-
- if streaming_mode:
- def streaming_generator(tts_generator:Generator, media_type:str):
- if media_type == "wav":
- yield wave_header_chunk()
- media_type = "raw"
- for sr, chunk in tts_generator:
- yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
- # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
- return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
-
- else:
- sr, audio_data = next(tts_generator)
- audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
- return Response(audio_data, media_type=f"audio/{media_type}")
- except Exception as e:
- return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
-
-
-
-
-
-
-@APP.get("/control")
-async def control(command: str = None):
- if command is None:
- return JSONResponse(status_code=400, content={"message": "command is required"})
- handle_control(command)
-
-
-
-@APP.get("/tts")
-async def tts_get_endpoint(
- text: str = None,
- text_lang: str = None,
- ref_audio_path: str = None,
- aux_ref_audio_paths:list = None,
- prompt_lang: str = None,
- prompt_text: str = "",
- top_k:int = 5,
- top_p:float = 1,
- temperature:float = 1,
- text_split_method:str = "cut0",
- batch_size:int = 1,
- batch_threshold:float = 0.75,
- split_bucket:bool = True,
- speed_factor:float = 1.0,
- fragment_interval:float = 0.3,
- seed:int = -1,
- media_type:str = "wav",
- streaming_mode:bool = False,
- parallel_infer:bool = True,
- repetition_penalty:float = 1.35
- ):
- req = {
- "text": text,
- "text_lang": text_lang.lower(),
- "ref_audio_path": ref_audio_path,
- "aux_ref_audio_paths": aux_ref_audio_paths,
- "prompt_text": prompt_text,
- "prompt_lang": prompt_lang.lower(),
- "top_k": top_k,
- "top_p": top_p,
- "temperature": temperature,
- "text_split_method": text_split_method,
- "batch_size":int(batch_size),
- "batch_threshold":float(batch_threshold),
- "speed_factor":float(speed_factor),
- "split_bucket":split_bucket,
- "fragment_interval":fragment_interval,
- "seed":seed,
- "media_type":media_type,
- "streaming_mode":streaming_mode,
- "parallel_infer":parallel_infer,
- "repetition_penalty":float(repetition_penalty)
- }
- return await tts_handle(req)
-
-
-@APP.post("/tts")
-async def tts_post_endpoint(request: TTS_Request):
- req = request.dict()
- return await tts_handle(req)
-
-
-@APP.get("/set_refer_audio")
-async def set_refer_aduio(refer_audio_path: str = None):
- try:
- tts_pipeline.set_ref_audio(refer_audio_path)
- except Exception as e:
- return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
- return JSONResponse(status_code=200, content={"message": "success"})
-
-
-# @APP.post("/set_refer_audio")
-# async def set_refer_aduio_post(audio_file: UploadFile = File(...)):
-# try:
-# # 检查文件类型,确保是音频文件
-# if not audio_file.content_type.startswith("audio/"):
-# return JSONResponse(status_code=400, content={"message": "file type is not supported"})
-
-# os.makedirs("uploaded_audio", exist_ok=True)
-# save_path = os.path.join("uploaded_audio", audio_file.filename)
-# # 保存音频文件到服务器上的一个目录
-# with open(save_path , "wb") as buffer:
-# buffer.write(await audio_file.read())
-
-# tts_pipeline.set_ref_audio(save_path)
-# except Exception as e:
-# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)})
-# return JSONResponse(status_code=200, content={"message": "success"})
-
-@APP.get("/set_gpt_weights")
-async def set_gpt_weights(weights_path: str = None):
- try:
- if weights_path in ["", None]:
- return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
- tts_pipeline.init_t2s_weights(weights_path)
- except Exception as e:
- return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)})
-
- return JSONResponse(status_code=200, content={"message": "success"})
-
-
-@APP.get("/set_sovits_weights")
-async def set_sovits_weights(weights_path: str = None):
- try:
- if weights_path in ["", None]:
- return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
- tts_pipeline.init_vits_weights(weights_path)
- except Exception as e:
- return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
- return JSONResponse(status_code=200, content={"message": "success"})
-
-
-
-if __name__ == "__main__":
- try:
- if host == 'None': # 在调用时使用 -a None 参数,可以让api监听双栈
- host = None
- uvicorn.run(app=APP, host=host, port=port, workers=1)
- except Exception as e:
- traceback.print_exc()
- os.kill(os.getpid(), signal.SIGTERM)
- exit(0)
diff --git a/colab_webui.ipynb b/colab_webui.ipynb
deleted file mode 100644
index 838f826..0000000
--- a/colab_webui.ipynb
+++ /dev/null
@@ -1,97 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "accelerator": "GPU"
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "环境配置 environment"
- ],
- "metadata": {
- "id": "_o6a8GS2lWQM"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "e9b7iFV3dm1f"
- },
- "source": [
- "!pip install -q condacolab\n",
- "# Setting up condacolab and installing packages\n",
- "import condacolab\n",
- "condacolab.install_from_url(\"https://repo.anaconda.com/miniconda/Miniconda3-py39_23.11.0-2-Linux-x86_64.sh\")\n",
- "%cd -q /content\n",
- "!git clone https://github.com/RVC-Boss/GPT-SoVITS\n",
- "!conda install -y -q -c pytorch -c nvidia cudatoolkit\n",
- "%cd -q /content/GPT-SoVITS\n",
- "!conda install -y -q -c conda-forge gcc gxx ffmpeg cmake -c pytorch -c nvidia\n",
- "!/usr/local/bin/pip install -r requirements.txt"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# @title Download pretrained models 下载预训练模型\n",
- "!mkdir -p /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
- "!mkdir -p /content/GPT-SoVITS/tools/damo_asr/models\n",
- "!mkdir -p /content/GPT-SoVITS/tools/uvr5\n",
- "%cd /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
- "!git clone https://huggingface.co/lj1995/GPT-SoVITS\n",
- "%cd /content/GPT-SoVITS/tools/damo_asr/models\n",
- "!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n",
- "!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n",
- "!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n",
- "# @title UVR5 pretrains 安装uvr5模型\n",
- "%cd /content/GPT-SoVITS/tools/uvr5\n",
- "%rm -r uvr5_weights\n",
- "!git clone https://huggingface.co/Delik/uvr5_weights\n",
- "!git config core.sparseCheckout true\n",
- "!mv /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/GPT-SoVITS/* /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/"
- ],
- "metadata": {
- "id": "0NgxXg5sjv7z"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# @title launch WebUI 启动WebUI\n",
- "!/usr/local/bin/pip install ipykernel\n",
- "!sed -i '10s/False/True/' /content/GPT-SoVITS/config.py\n",
- "%cd /content/GPT-SoVITS/\n",
- "!/usr/local/bin/python webui.py"
- ],
- "metadata": {
- "id": "4oRGUzkrk8C7"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
-}
diff --git a/go-webui.bat b/go-webui.bat
deleted file mode 100644
index 398f6d9..0000000
--- a/go-webui.bat
+++ /dev/null
@@ -1,2 +0,0 @@
-runtime\python.exe webui.py zh_CN
-pause
diff --git a/go-webui.ps1 b/go-webui.ps1
deleted file mode 100644
index 6e8dce2..0000000
--- a/go-webui.ps1
+++ /dev/null
@@ -1,4 +0,0 @@
-$ErrorActionPreference = "SilentlyContinue"
-chcp 65001
-& "$PSScriptRoot\runtime\python.exe" "$PSScriptRoot\webui.py" zh_CN
-pause
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..0408176
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,265 @@
+import os
+import sys
+import traceback
+from typing import Generator
+now_dir = os.getcwd()
+sys.path.append(now_dir)
+sys.path.append("%s/GPT_SoVITS" % (now_dir))
+import argparse
+import subprocess
+import wave
+import signal
+import numpy as np
+import soundfile as sf
+from io import BytesIO
+from tools.i18n.i18n import I18nAuto
+from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
+from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
+from pydantic import BaseModel
+
+i18n = I18nAuto()
+cut_method_names = get_cut_method_names()
+
+parser = argparse.ArgumentParser(description="GPT-SoVITS api")
+parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
+parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
+parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
+args = parser.parse_args()
+config_path = args.tts_config
+# device = args.device
+port = args.port
+host = args.bind_addr
+argv = sys.argv
+
+if config_path in [None, ""]:
+ config_path = "GPT-SoVITS/configs/tts_infer.yaml"
+
+tts_config = TTS_Config(config_path)
+print(tts_config)
+tts_pipeline = TTS(tts_config)
+
+# speaker configuration
+speakers = {
+ "firefly":{
+ "gpt_model" : "/root/autodl-tmp/GPT-SoVITS/models/GPT_models/firefly_312-e15.ckpt",
+ "sovits_model" : "/root/autodl-tmp/GPT-SoVITS/models/VITS_models/firefly_312_e8_s504.pth",
+ "ref_audio" : "/root/autodl-tmp/GPT-SoVITS/firefly/chapter3_2_firefly_103.wav",
+ "ref_text" : "谢谢,如果没有您出手相助,我真的不知道该怎么办",
+ "ref_language" : "zh",
+ "target_language" : "zh"
+ },
+ "keele":{
+ "gpt_model" : "/root/autodl-tmp/GPT-SoVITS/models/GPT_models/Keele-e15.ckpt",
+ "sovits_model" : "/root/autodl-tmp/GPT-SoVITS/models/VITS_models/Keele_e8_s656.pth",
+ "ref_audio" : "/root/autodl-tmp/GPT-SoVITS/keele/vo_dialog_KLLQ003_klee_03.wav",
+ "ref_text" : "我听说,冒险家协会也有一套冒险的守则,是不是,应该去拜托他们",
+ "ref_language" : "zh",
+ "target_language" : "zh"
+ },
+}
+
+# process the output audio type
+def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
+ with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
+ audio_file.write(data)
+ return io_buffer
+
+def pack_raw(io_buffer:BytesIO, data:np.ndarray, rate:int):
+ io_buffer.write(data.tobytes())
+ return io_buffer
+
+def pack_wav(io_buffer:BytesIO, data:np.ndarray, rate:int):
+ io_buffer = BytesIO()
+ sf.write(io_buffer, data, rate, format='wav')
+ return io_buffer
+
+def pack_aac(io_buffer:BytesIO, data:np.ndarray, rate:int):
+ process = subprocess.Popen([
+ 'ffmpeg',
+ '-f', 's16le', # 输入16位有符号小端整数PCM
+ '-ar', str(rate), # 设置采样率
+ '-ac', '1', # 单声道
+ '-i', 'pipe:0', # 从管道读取输入
+ '-c:a', 'aac', # 音频编码器为AAC
+ '-b:a', '192k', # 比特率
+ '-vn', # 不包含视频
+ '-f', 'adts', # 输出AAC数据流格式
+ 'pipe:1' # 将输出写入管道
+ ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, _ = process.communicate(input=data.tobytes())
+ io_buffer.write(out)
+ return io_buffer
+
+def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
+ if media_type == "ogg":
+ io_buffer = pack_ogg(io_buffer, data, rate)
+ elif media_type == "aac":
+ io_buffer = pack_aac(io_buffer, data, rate)
+ elif media_type == "wav":
+ io_buffer = pack_wav(io_buffer, data, rate)
+ else:
+ io_buffer = pack_raw(io_buffer, data, rate)
+ io_buffer.seek(0)
+ return io_buffer
+
+# create the audio from text on specific speaker
+def check_params(req:dict):
+ text:str = req.get("text", "")
+ text_lang:str = req.get("text_lang", "")
+ ref_audio_path:str = req.get("ref_audio_path", "")
+ streaming_mode:bool = req.get("streaming_mode", False)
+ media_type:str = req.get("media_type", "wav")
+ prompt_lang:str = req.get("prompt_lang", "")
+ text_split_method:str = req.get("text_split_method", "cut5")
+
+ if ref_audio_path in [None, ""]:
+ print("ref_audio_path is required")
+ return False
+ if text in [None, ""]:
+ print("text is required")
+ return False
+
+ if (text_lang in [None, ""]) :
+ print("text_lang is required")
+ return False
+ elif text_lang.lower() not in tts_config.languages:
+ print(f"text_lang: {text_lang} is not supported in version {tts_config.version}")
+ return False
+
+ if (prompt_lang in [None, ""]) :
+ print("prompt_lang is required")
+ return False
+ elif prompt_lang.lower() not in tts_config.languages:
+ print(f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}")
+ return False
+
+ if media_type not in ["wav", "raw", "ogg", "aac"]:
+ print(f"media_type: {media_type} is not supported")
+ return False
+ elif media_type == "ogg" and not streaming_mode:
+ print("ogg format is not supported in non-streaming mode")
+ return False
+
+ if text_split_method not in cut_method_names:
+ print(f"text_split_method:{text_split_method} is not supported")
+ return False
+
+ return True
+
+def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
+ # This will create a wave header then append the frame input
+ # It should be first on a streaming wav file
+ # Other frames better should not have it (else you will hear some artifacts each chunk start)
+ wav_buf = BytesIO()
+ with wave.open(wav_buf, "wb") as vfout:
+ vfout.setnchannels(channels)
+ vfout.setsampwidth(sample_width)
+ vfout.setframerate(sample_rate)
+ vfout.writeframes(frame_input)
+
+ wav_buf.seek(0)
+ return wav_buf.read()
+
+def tts_handle(req:dict):
+ """
+ Text to speech handler.
+
+ Args:
+ req (dict):
+ {
+ "text": "", # str.(required) text to be synthesized
+ "text_lang: "", # str.(required) language of the text to be synthesized
+ "ref_audio_path": "", # str.(required) reference audio path
+ "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis
+ "prompt_text": "", # str.(optional) prompt text for the reference audio
+ "prompt_lang": "", # str.(required) language of the prompt text for the reference audio
+ "top_k": 5, # int. top k sampling
+ "top_p": 1, # float. top p sampling
+ "temperature": 1, # float. temperature for sampling
+ "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
+ "batch_size": 1, # int. batch size for inference
+ "batch_threshold": 0.75, # float. threshold for batch splitting.
+ "split_bucket: True, # bool. whether to split the batch into multiple buckets.
+ "speed_factor":1.0, # float. control the speed of the synthesized audio.
+ "fragment_interval":0.3, # float. to control the interval of the audio fragment.
+ "seed": -1, # int. random seed for reproducibility.
+ "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
+ "streaming_mode": False, # bool. whether to return a streaming response.
+ "parallel_infer": True, # bool.(optional) whether to use parallel inference.
+ "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
+ }
+ returns:
+ StreamingResponse: audio stream response.
+ """
+
+ streaming_mode = req.get("streaming_mode", False)
+ return_fragment = req.get("return_fragment", False)
+ media_type = req.get("media_type", "wav")
+
+ check_res = check_params(req)
+ if not check_res:
+ return None
+
+ if streaming_mode or return_fragment:
+ req["return_fragment"] = True
+
+ try:
+ tts_generator=tts_pipeline.run(req)
+
+ if streaming_mode:
+ def streaming_generator(tts_generator:Generator, media_type:str):
+ if media_type == "wav":
+ yield wave_header_chunk()
+ media_type = "raw"
+ for sr, chunk in tts_generator:
+ yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
+ # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
+ return streaming_generator(tts_generator, media_type, )
+
+ else:
+ sr, audio_data = next(tts_generator)
+ audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
+ return audio_data
+ except Exception as e:
+ print("tts failed, Exception", str(e))
+ return None
+
+def tts_get_endpoint(text, speaker_id="firefly"):
+ speaker = speakers[speaker_id]
+ tts_pipeline.init_vits_weights(speaker["sovits_model"])
+ tts_pipeline.init_t2s_weights(speaker["gpt_model"])
+ req = {
+ "text" : text,
+ "text_lang" : speaker["target_language"],
+ "ref_audio_path" : speaker["ref_audio"],
+ "aux_ref_audio_paths" : None,
+ "prompt_text" : speaker["ref_text"],
+ "prompt_lang" : speaker["ref_language"],
+ "top_k" : int(15),
+ "top_p" : float(1.0),
+ "temperature" : float(1),
+ "text_split_method" : "cut0",
+ "batch_size" : int(1),
+ "batch_threshold" : float(0.75),
+ "speed_factor" : float(0.75),
+ "split_bucket" : True,
+ "fragment_interval" : float(0.3),
+ "seed" : int(-1),
+ "media_type" : "wav",
+ "streaming_mode" : False,
+ "parallel_infer" : True,
+ "repetition_penalty" : float(1.35)
+ }
+ return tts_handle(req)
+
+def save_wav(filename, audio_data, sample_rate):
+ with wave.open(filename, 'wb') as wav_file:
+ wav_file.setnchannels(1)
+ wav_file.setsampwidth(2)
+ wav_file.setframerate(sample_rate)
+ wav_file.writeframes(audio_data)
+
+if __name__ == "__main__":
+ audio = tts_get_endpoint("我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可")
+ save_wav("./output.wav", audio, 32000)
+
\ No newline at end of file
diff --git a/webui.py b/webui.py
deleted file mode 100644
index b73ed89..0000000
--- a/webui.py
+++ /dev/null
@@ -1,1132 +0,0 @@
-import os,sys
-if len(sys.argv)==1:sys.argv.append('v2')
-version="v1"if sys.argv[1]=="v1" else"v2"
-os.environ["version"]=version
-now_dir = os.getcwd()
-sys.path.insert(0, now_dir)
-import warnings
-warnings.filterwarnings("ignore")
-import json,yaml,torch,pdb,re,shutil
-import platform
-import psutil
-import signal
-os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
-torch.manual_seed(233333)
-tmp = os.path.join(now_dir, "TEMP")
-os.makedirs(tmp, exist_ok=True)
-os.environ["TEMP"] = tmp
-if(os.path.exists(tmp)):
- for name in os.listdir(tmp):
- if(name=="jieba.cache"):continue
- path="%s/%s"%(tmp,name)
- delete=os.remove if os.path.isfile(path) else shutil.rmtree
- try:
- delete(path)
- except Exception as e:
- print(str(e))
- pass
-import site
-import traceback
-site_packages_roots = []
-for path in site.getsitepackages():
- if "packages" in path:
- site_packages_roots.append(path)
-if(site_packages_roots==[]):site_packages_roots=["%s/runtime/Lib/site-packages" % now_dir]
-#os.environ["OPENBLAS_NUM_THREADS"] = "4"
-os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
-os.environ["all_proxy"] = ""
-for site_packages_root in site_packages_roots:
- if os.path.exists(site_packages_root):
- try:
- with open("%s/users.pth" % (site_packages_root), "w") as f:
- f.write(
- # "%s\n%s/runtime\n%s/tools\n%s/tools/asr\n%s/GPT_SoVITS\n%s/tools/uvr5"
- "%s\n%s/GPT_SoVITS/BigVGAN\n%s/tools\n%s/tools/asr\n%s/GPT_SoVITS\n%s/tools/uvr5"
- % (now_dir, now_dir, now_dir, now_dir, now_dir, now_dir)
- )
- break
- except PermissionError as e:
- traceback.print_exc()
-from tools import my_utils
-import shutil
-import pdb
-import subprocess
-from subprocess import Popen
-import signal
-from config import python_exec,infer_device,is_half,exp_root,webui_port_main,webui_port_infer_tts,webui_port_uvr5,webui_port_subfix,is_share
-from tools.i18n.i18n import I18nAuto, scan_language_list
-language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto"
-os.environ["language"]=language
-i18n = I18nAuto(language=language)
-from scipy.io import wavfile
-from tools.my_utils import load_audio, check_for_existance, check_details
-from multiprocessing import cpu_count
-# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
-try:
- import gradio.analytics as analytics
- analytics.version_check = lambda:None
-except:...
-import gradio as gr
-n_cpu=cpu_count()
-
-ngpu = torch.cuda.device_count()
-gpu_infos = []
-mem = []
-if_gpu_ok = False
-
-# 判断是否有能用来训练和加速推理的N卡
-ok_gpu_keywords={"10","16","20","30","40","A2","A3","A4","P4","A50","500","A60","70","80","90","M4","T4","TITAN","L4","4060","H","600","506","507","508","509"}
-set_gpu_numbers=set()
-if torch.cuda.is_available() or ngpu != 0:
- for i in range(ngpu):
- gpu_name = torch.cuda.get_device_name(i)
- if any(value in gpu_name.upper()for value in ok_gpu_keywords):
- # A10#A100#V100#A40#P40#M40#K80#A4500
- if_gpu_ok = True # 至少有一张能用的N卡
- gpu_infos.append("%s\t%s" % (i, gpu_name))
- set_gpu_numbers.add(i)
- mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4))
-# # 判断是否支持mps加速
-# if torch.backends.mps.is_available():
-# if_gpu_ok = True
-# gpu_infos.append("%s\t%s" % ("0", "Apple GPU"))
-# mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存
-
-def set_default():
- global default_batch_size,default_max_batch_size,gpu_info,default_sovits_epoch,default_sovits_save_every_epoch,max_sovits_epoch,max_sovits_save_every_epoch,default_batch_size_s1,if_force_ckpt
- if_force_ckpt = False
- if if_gpu_ok and len(gpu_infos) > 0:
- gpu_info = "\n".join(gpu_infos)
- minmem = min(mem)
- # if version == "v3" and minmem < 14:
- # # API读取不到共享显存,直接填充确认
- # try:
- # torch.zeros((1024,1024,1024,14),dtype=torch.int8,device="cuda")
- # torch.cuda.empty_cache()
- # minmem = 14
- # except RuntimeError as _:
- # # 强制梯度检查只需要12G显存
- # if minmem >= 12 :
- # if_force_ckpt = True
- # minmem = 14
- # else:
- # try:
- # torch.zeros((1024,1024,1024,12),dtype=torch.int8,device="cuda")
- # torch.cuda.empty_cache()
- # if_force_ckpt = True
- # minmem = 14
- # except RuntimeError as _:
- # print("显存不足以开启V3训练")
- default_batch_size = minmem // 2 if version!="v3"else minmem//8
- default_batch_size_s1=minmem // 2
- else:
- gpu_info = ("%s\t%s" % ("0", "CPU"))
- gpu_infos.append("%s\t%s" % ("0", "CPU"))
- set_gpu_numbers.add(0)
- default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total/ 1024 / 1024 / 1024 / 4)
- if version!="v3":
- default_sovits_epoch=8
- default_sovits_save_every_epoch=4
- max_sovits_epoch=25#40
- max_sovits_save_every_epoch=25#10
- else:
- default_sovits_epoch=2
- default_sovits_save_every_epoch=1
- max_sovits_epoch=3#40
- max_sovits_save_every_epoch=3#10
-
- default_batch_size = max(1, default_batch_size)
- default_batch_size_s1 = max(1, default_batch_size_s1)
- default_max_batch_size = default_batch_size * 3
-
-set_default()
-
-gpus = "-".join([i[0] for i in gpu_infos])
-default_gpu_numbers=str(sorted(list(set_gpu_numbers))[0])
-def fix_gpu_number(input):#将越界的number强制改到界内
- try:
- if(int(input)not in set_gpu_numbers):return default_gpu_numbers
- except:return input
- return input
-def fix_gpu_numbers(inputs):
- output=[]
- try:
- for input in inputs.split(","):output.append(str(fix_gpu_number(input)))
- return ",".join(output)
- except:
- return inputs
-
-pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth","GPT_SoVITS/pretrained_models/s2Gv3.pth"]
-pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
-
-pretrained_model_list = (pretrained_sovits_name[int(version[-1])-1],pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D"),pretrained_gpt_name[int(version[-1])-1],"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large","GPT_SoVITS/pretrained_models/chinese-hubert-base")
-
-_ = ''
-for i in pretrained_model_list:
- if "s2Dv3" not in i and os.path.exists(i) == False:
- _ += f'\n {i}'
-if _:
- print("warning: ", i18n('以下模型不存在:') + _)
-
-_ = [[],[]]
-for i in range(3):
- if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
- else:_[0].append("")##没有下pretrained模型的,说不定他们是想自己从零训底模呢
- if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
- else:_[-1].append("")
-pretrained_gpt_name,pretrained_sovits_name = _
-
-SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
-GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
-for root in SoVITS_weight_root+GPT_weight_root:
- os.makedirs(root,exist_ok=True)
-def get_weights_names():
- SoVITS_names = [name for name in pretrained_sovits_name if name!=""]
- for path in SoVITS_weight_root:
- for name in os.listdir(path):
- if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
- GPT_names = [name for name in pretrained_gpt_name if name!=""]
- for path in GPT_weight_root:
- for name in os.listdir(path):
- if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
- return SoVITS_names, GPT_names
-
-SoVITS_names,GPT_names = get_weights_names()
-for path in SoVITS_weight_root+GPT_weight_root:
- os.makedirs(path,exist_ok=True)
-
-def custom_sort_key(s):
- # 使用正则表达式提取字符串中的数字部分和非数字部分
- parts = re.split('(\d+)', s)
- # 将数字部分转换为整数,非数字部分保持不变
- parts = [int(part) if part.isdigit() else part for part in parts]
- return parts
-
-def change_choices():
- SoVITS_names, GPT_names = get_weights_names()
- return {"choices": sorted(SoVITS_names,key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names,key=custom_sort_key), "__type__": "update"}
-
-p_label=None
-p_uvr5=None
-p_asr=None
-p_denoise=None
-p_tts_inference=None
-
-def kill_proc_tree(pid, including_parent=True):
- try:
- parent = psutil.Process(pid)
- except psutil.NoSuchProcess:
- # Process already terminated
- return
-
- children = parent.children(recursive=True)
- for child in children:
- try:
- os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
- except OSError:
- pass
- if including_parent:
- try:
- os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
- except OSError:
- pass
-
-system=platform.system()
-def kill_process(pid, process_name=""):
- if(system=="Windows"):
- cmd = "taskkill /t /f /pid %s" % pid
- # os.system(cmd)
- subprocess.run(cmd,shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
- else:
- kill_proc_tree(pid)
- print(process_name + i18n("进程已终止"))
-
-def process_info(process_name="", indicator=""):
- if indicator == "opened":
- return process_name + i18n("已开启")
- elif indicator == "open":
- return i18n("开启") + process_name
- elif indicator == "closed":
- return process_name + i18n("已关闭")
- elif indicator == "close":
- return i18n("关闭") + process_name
- elif indicator == "running":
- return process_name + i18n("运行中")
- elif indicator == "occupy":
- return process_name + i18n("占用中") + "," + i18n("需先终止才能开启下一次任务")
- elif indicator == "finish":
- return process_name + i18n("已完成")
- elif indicator == "failed":
- return process_name + i18n("失败")
- elif indicator == "info":
- return process_name + i18n("进程输出信息")
- else:
- return process_name
-
-process_name_subfix = i18n("音频标注WebUI")
-def change_label(path_list):
- global p_label
- if p_label is None:
- check_for_existance([path_list])
- path_list = my_utils.clean_path(path_list)
- cmd = '"%s" tools/subfix_webui.py --load_list "%s" --webui_port %s --is_share %s'%(python_exec,path_list,webui_port_subfix,is_share)
- yield process_info(process_name_subfix, "opened"), {'__type__':'update','visible':False}, {'__type__':'update','visible':True}
- print(cmd)
- p_label = Popen(cmd, shell=True)
- else:
- kill_process(p_label.pid, process_name_subfix)
- p_label = None
- yield process_info(process_name_subfix, "closed"), {'__type__':'update','visible':True}, {'__type__':'update','visible':False}
-
-process_name_uvr5 = i18n("人声分离WebUI")
-def change_uvr5():
- global p_uvr5
- if p_uvr5 is None:
- cmd = '"%s" tools/uvr5/webui.py "%s" %s %s %s'%(python_exec,infer_device,is_half,webui_port_uvr5,is_share)
- yield process_info(process_name_uvr5, "opened"), {'__type__':'update','visible':False}, {'__type__':'update','visible':True}
- print(cmd)
- p_uvr5 = Popen(cmd, shell=True)
- else:
- kill_process(p_uvr5.pid, process_name_uvr5)
- p_uvr5 = None
- yield process_info(process_name_uvr5, "closed"), {'__type__':'update','visible':True}, {'__type__':'update','visible':False}
-
-process_name_tts = i18n("TTS推理WebUI")
-def change_tts_inference(bert_path,cnhubert_base_path,gpu_number,gpt_path,sovits_path, batched_infer_enabled):
- global p_tts_inference
- if batched_infer_enabled:
- cmd = '"%s" GPT_SoVITS/inference_webui_fast.py "%s"'%(python_exec, language)
- else:
- cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
- #####v3暂不支持加速推理
- if version=="v3":
- cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
- if p_tts_inference is None:
- os.environ["gpt_path"]=gpt_path if "/" in gpt_path else "%s/%s"%(GPT_weight_root,gpt_path)
- os.environ["sovits_path"]=sovits_path if "/"in sovits_path else "%s/%s"%(SoVITS_weight_root,sovits_path)
- os.environ["cnhubert_base_path"]=cnhubert_base_path
- os.environ["bert_path"]=bert_path
- os.environ["_CUDA_VISIBLE_DEVICES"]=fix_gpu_number(gpu_number)
- os.environ["is_half"]=str(is_half)
- os.environ["infer_ttswebui"]=str(webui_port_infer_tts)
- os.environ["is_share"]=str(is_share)
- yield process_info(process_name_tts, "opened"), {'__type__':'update','visible':False}, {'__type__':'update','visible':True}
- print(cmd)
- p_tts_inference = Popen(cmd, shell=True)
- else:
- kill_process(p_tts_inference.pid, process_name_tts)
- p_tts_inference = None
- yield process_info(process_name_tts, "closed"), {'__type__':'update','visible':True}, {'__type__':'update','visible':False}
-
-from tools.asr.config import asr_dict
-
-process_name_asr = i18n("语音识别")
-def open_asr(asr_inp_dir, asr_opt_dir, asr_model, asr_model_size, asr_lang, asr_precision):
- global p_asr
- if p_asr is None:
- asr_inp_dir=my_utils.clean_path(asr_inp_dir)
- asr_opt_dir=my_utils.clean_path(asr_opt_dir)
- check_for_existance([asr_inp_dir])
- cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
- cmd += f' -i "{asr_inp_dir}"'
- cmd += f' -o "{asr_opt_dir}"'
- cmd += f' -s {asr_model_size}'
- cmd += f' -l {asr_lang}'
- cmd += f" -p {asr_precision}"
- output_file_name = os.path.basename(asr_inp_dir)
- output_folder = asr_opt_dir or "output/asr_opt"
- output_file_path = os.path.abspath(f'{output_folder}/{output_file_name}.list')
- yield process_info(process_name_asr, "opened"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}, {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
- print(cmd)
- p_asr = Popen(cmd, shell=True)
- p_asr.wait()
- p_asr = None
- yield process_info(process_name_asr, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}, {"__type__": "update", "value": output_file_path}, {"__type__": "update", "value": output_file_path}, {"__type__": "update", "value": asr_inp_dir}
- else:
- yield process_info(process_name_asr, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}, {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
-
-def close_asr():
- global p_asr
- if p_asr is not None:
- kill_process(p_asr.pid, process_name_asr)
- p_asr = None
- return process_info(process_name_asr, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-process_name_denoise = i18n("语音降噪")
-def open_denoise(denoise_inp_dir, denoise_opt_dir):
- global p_denoise
- if(p_denoise==None):
- denoise_inp_dir=my_utils.clean_path(denoise_inp_dir)
- denoise_opt_dir=my_utils.clean_path(denoise_opt_dir)
- check_for_existance([denoise_inp_dir])
- cmd = '"%s" tools/cmd-denoise.py -i "%s" -o "%s" -p %s'%(python_exec,denoise_inp_dir,denoise_opt_dir,"float16"if is_half==True else "float32")
-
- yield process_info(process_name_denoise, "opened"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}, {"__type__": "update"}, {"__type__": "update"}
- print(cmd)
- p_denoise = Popen(cmd, shell=True)
- p_denoise.wait()
- p_denoise=None
- yield process_info(process_name_denoise, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}, {"__type__": "update", "value": denoise_opt_dir}, {"__type__": "update", "value": denoise_opt_dir}
- else:
- yield process_info(process_name_denoise, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}, {"__type__": "update"}, {"__type__": "update"}
-
-def close_denoise():
- global p_denoise
- if p_denoise is not None:
- kill_process(p_denoise.pid, process_name_denoise)
- p_denoise = None
- return process_info(process_name_denoise, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-p_train_SoVITS=None
-process_name_sovits = i18n("SoVITS训练")
-def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D,if_grad_ckpt,lora_rank):
- global p_train_SoVITS
- if(p_train_SoVITS==None):
- with open("GPT_SoVITS/configs/s2.json")as f:
- data=f.read()
- data=json.loads(data)
- s2_dir="%s/%s"%(exp_root,exp_name)
- os.makedirs("%s/logs_s2_%s"%(s2_dir,version),exist_ok=True)
- if check_for_existance([s2_dir],is_train=True):
- check_details([s2_dir],is_train=True)
- if(is_half==False):
- data["train"]["fp16_run"]=False
- batch_size=max(1,batch_size//2)
- data["train"]["batch_size"]=batch_size
- data["train"]["epochs"]=total_epoch
- data["train"]["text_low_lr_rate"]=text_low_lr_rate
- data["train"]["pretrained_s2G"]=pretrained_s2G
- data["train"]["pretrained_s2D"]=pretrained_s2D
- data["train"]["if_save_latest"]=if_save_latest
- data["train"]["if_save_every_weights"]=if_save_every_weights
- data["train"]["save_every_epoch"]=save_every_epoch
- data["train"]["gpu_numbers"]=gpu_numbers1Ba
- data["train"]["grad_ckpt"]=if_grad_ckpt
- data["train"]["lora_rank"]=lora_rank
- data["model"]["version"]=version
- data["data"]["exp_dir"]=data["s2_ckpt_dir"]=s2_dir
- data["save_weight_dir"]=SoVITS_weight_root[int(version[-1])-1]
- data["name"]=exp_name
- data["version"]=version
- tmp_config_path="%s/tmp_s2.json"%tmp
- with open(tmp_config_path,"w")as f:f.write(json.dumps(data))
- if version in ["v1","v2"]:
- cmd = '"%s" GPT_SoVITS/s2_train.py --config "%s"'%(python_exec,tmp_config_path)
- else:
- cmd = '"%s" GPT_SoVITS/s2_train_v3_lora.py --config "%s"'%(python_exec,tmp_config_path)
- yield process_info(process_name_sovits, "opened"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- print(cmd)
- p_train_SoVITS = Popen(cmd, shell=True)
- p_train_SoVITS.wait()
- p_train_SoVITS = None
- yield process_info(process_name_sovits, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- else:
- yield process_info(process_name_sovits, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
-
-def close1Ba():
- global p_train_SoVITS
- if p_train_SoVITS is not None:
- kill_process(p_train_SoVITS.pid, process_name_sovits)
- p_train_SoVITS = None
- return process_info(process_name_sovits, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-p_train_GPT=None
-process_name_gpt = i18n("GPT训练")
-def open1Bb(batch_size,total_epoch,exp_name,if_dpo,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers,pretrained_s1):
- global p_train_GPT
- if(p_train_GPT==None):
- with open("GPT_SoVITS/configs/s1longer.yaml"if version=="v1"else "GPT_SoVITS/configs/s1longer-v2.yaml")as f:
- data=f.read()
- data=yaml.load(data, Loader=yaml.FullLoader)
- s1_dir="%s/%s"%(exp_root,exp_name)
- os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
- if check_for_existance([s1_dir],is_train=True):
- check_details([s1_dir],is_train=True)
- if(is_half==False):
- data["train"]["precision"]="32"
- batch_size = max(1, batch_size // 2)
- data["train"]["batch_size"]=batch_size
- data["train"]["epochs"]=total_epoch
- data["pretrained_s1"]=pretrained_s1
- data["train"]["save_every_n_epoch"]=save_every_epoch
- data["train"]["if_save_every_weights"]=if_save_every_weights
- data["train"]["if_save_latest"]=if_save_latest
- data["train"]["if_dpo"]=if_dpo
- data["train"]["half_weights_save_dir"]=GPT_weight_root[int(version[-1])-1]
- data["train"]["exp_name"]=exp_name
- data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
- data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
- data["output_dir"]="%s/logs_s1_%s"%(s1_dir,version)
- # data["version"]=version
-
- os.environ["_CUDA_VISIBLE_DEVICES"]=fix_gpu_numbers(gpu_numbers.replace("-",","))
- os.environ["hz"]="25hz"
- tmp_config_path="%s/tmp_s1.yaml"%tmp
- with open(tmp_config_path, "w") as f:f.write(yaml.dump(data, default_flow_style=False))
- # cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" --train_semantic_path "%s/6-name2semantic.tsv" --train_phoneme_path "%s/2-name2text.txt" --output_dir "%s/logs_s1"'%(python_exec,tmp_config_path,s1_dir,s1_dir,s1_dir)
- cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" '%(python_exec,tmp_config_path)
- yield process_info(process_name_gpt, "opened"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- print(cmd)
- p_train_GPT = Popen(cmd, shell=True)
- p_train_GPT.wait()
- p_train_GPT = None
- yield process_info(process_name_gpt, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- else:
- yield process_info(process_name_gpt, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
-
-def close1Bb():
- global p_train_GPT
- if p_train_GPT is not None:
- kill_process(p_train_GPT.pid, process_name_gpt)
- p_train_GPT = None
- return process_info(process_name_gpt, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-ps_slice=[]
-process_name_slice = i18n("语音切分")
-def open_slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_parts):
- global ps_slice
- inp = my_utils.clean_path(inp)
- opt_root = my_utils.clean_path(opt_root)
- check_for_existance([inp])
- if(os.path.exists(inp)==False):
- yield i18n("输入路径不存在"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}, {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
- return
- if os.path.isfile(inp):n_parts=1
- elif os.path.isdir(inp):pass
- else:
- yield i18n("输入路径存在但不可用"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}, {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
- return
- if (ps_slice == []):
- for i_part in range(n_parts):
- cmd = '"%s" tools/slice_audio.py "%s" "%s" %s %s %s %s %s %s %s %s %s''' % (python_exec,inp, opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, i_part, n_parts)
- print(cmd)
- p = Popen(cmd, shell=True)
- ps_slice.append(p)
- yield process_info(process_name_slice, "opened"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}, {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
- for p in ps_slice:
- p.wait()
- ps_slice=[]
- yield process_info(process_name_slice, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}, {"__type__": "update", "value": opt_root}, {"__type__": "update", "value": opt_root}, {"__type__": "update", "value": opt_root}
- else:
- yield process_info(process_name_slice, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}, {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
-
-def close_slice():
- global ps_slice
- if (ps_slice != []):
- for p_slice in ps_slice:
- try:
- kill_process(p_slice.pid, process_name_slice)
- except:
- traceback.print_exc()
- ps_slice=[]
- return process_info(process_name_slice, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-ps1a=[]
-process_name_1a = i18n("文本分词与特征提取")
-def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
- global ps1a
- inp_text = my_utils.clean_path(inp_text)
- inp_wav_dir = my_utils.clean_path(inp_wav_dir)
- if check_for_existance([inp_text,inp_wav_dir], is_dataset_processing=True):
- check_details([inp_text,inp_wav_dir], is_dataset_processing=True)
- if (ps1a == []):
- opt_dir="%s/%s"%(exp_root,exp_name)
- config={
- "inp_text":inp_text,
- "inp_wav_dir":inp_wav_dir,
- "exp_name":exp_name,
- "opt_dir":opt_dir,
- "bert_pretrained_dir":bert_pretrained_dir,
- }
- gpu_names=gpu_numbers.split("-")
- all_parts=len(gpu_names)
- for i_part in range(all_parts):
- config.update(
- {
- "i_part": str(i_part),
- "all_parts": str(all_parts),
- "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
- "is_half": str(is_half)
- }
- )
- os.environ.update(config)
- cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
- print(cmd)
- p = Popen(cmd, shell=True)
- ps1a.append(p)
- yield process_info(process_name_1a, "running"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- for p in ps1a:
- p.wait()
- opt = []
- for i_part in range(all_parts):
- txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
- with open(txt_path, "r", encoding="utf8") as f:
- opt += f.read().strip("\n").split("\n")
- os.remove(txt_path)
- path_text = "%s/2-name2text.txt" % opt_dir
- with open(path_text, "w", encoding="utf8") as f:
- f.write("\n".join(opt) + "\n")
- ps1a=[]
- if len("".join(opt)) > 0:
- yield process_info(process_name_1a, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- else:
- yield process_info(process_name_1a, "failed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- else:
- yield process_info(process_name_1a, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
-
-def close1a():
- global ps1a
- if ps1a != []:
- for p1a in ps1a:
- try:
- kill_process(p1a.pid, process_name_1a)
- except:
- traceback.print_exc()
- ps1a = []
- return process_info(process_name_1a, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-ps1b=[]
-process_name_1b = i18n("语音自监督特征提取")
-def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
- global ps1b
- inp_text = my_utils.clean_path(inp_text)
- inp_wav_dir = my_utils.clean_path(inp_wav_dir)
- if check_for_existance([inp_text,inp_wav_dir], is_dataset_processing=True):
- check_details([inp_text,inp_wav_dir], is_dataset_processing=True)
- if (ps1b == []):
- config={
- "inp_text":inp_text,
- "inp_wav_dir":inp_wav_dir,
- "exp_name":exp_name,
- "opt_dir": "%s/%s"%(exp_root,exp_name),
- "cnhubert_base_dir":ssl_pretrained_dir,
- "is_half": str(is_half)
- }
- gpu_names=gpu_numbers.split("-")
- all_parts=len(gpu_names)
- for i_part in range(all_parts):
- config.update(
- {
- "i_part": str(i_part),
- "all_parts": str(all_parts),
- "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
- }
- )
- os.environ.update(config)
- cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
- print(cmd)
- p = Popen(cmd, shell=True)
- ps1b.append(p)
- yield process_info(process_name_1b, "running"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- for p in ps1b:
- p.wait()
- ps1b=[]
- yield process_info(process_name_1b, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- else:
- yield process_info(process_name_1b, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
-
-def close1b():
- global ps1b
- if (ps1b != []):
- for p1b in ps1b:
- try:
- kill_process(p1b.pid, process_name_1b)
- except:
- traceback.print_exc()
- ps1b=[]
- return process_info(process_name_1b, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-ps1c=[]
-process_name_1c = i18n("语义Token提取")
-def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
- global ps1c
- inp_text = my_utils.clean_path(inp_text)
- if check_for_existance([inp_text,''], is_dataset_processing=True):
- check_details([inp_text,''], is_dataset_processing=True)
- if (ps1c == []):
- opt_dir="%s/%s"%(exp_root,exp_name)
- config={
- "inp_text":inp_text,
- "exp_name":exp_name,
- "opt_dir":opt_dir,
- "pretrained_s2G":pretrained_s2G_path,
- "s2config_path":"GPT_SoVITS/configs/s2.json",
- "is_half": str(is_half)
- }
- gpu_names=gpu_numbers.split("-")
- all_parts=len(gpu_names)
- for i_part in range(all_parts):
- config.update(
- {
- "i_part": str(i_part),
- "all_parts": str(all_parts),
- "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
- }
- )
- os.environ.update(config)
- cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
- print(cmd)
- p = Popen(cmd, shell=True)
- ps1c.append(p)
- yield process_info(process_name_1c, "running"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- for p in ps1c:
- p.wait()
- opt = ["item_name\tsemantic_audio"]
- path_semantic = "%s/6-name2semantic.tsv" % opt_dir
- for i_part in range(all_parts):
- semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
- with open(semantic_path, "r", encoding="utf8") as f:
- opt += f.read().strip("\n").split("\n")
- os.remove(semantic_path)
- with open(path_semantic, "w", encoding="utf8") as f:
- f.write("\n".join(opt) + "\n")
- ps1c=[]
- yield process_info(process_name_1c, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- else:
- yield process_info(process_name_1c, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
-
-def close1c():
- global ps1c
- if (ps1c != []):
- for p1c in ps1c:
- try:
- kill_process(p1c.pid, process_name_1c)
- except:
- traceback.print_exc()
- ps1c=[]
- return process_info(process_name_1c, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-ps1abc=[]
-process_name_1abc = i18n("训练集格式化一键三连")
-def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
- global ps1abc
- inp_text = my_utils.clean_path(inp_text)
- inp_wav_dir = my_utils.clean_path(inp_wav_dir)
- if check_for_existance([inp_text,inp_wav_dir], is_dataset_processing=True):
- check_details([inp_text,inp_wav_dir], is_dataset_processing=True)
- if (ps1abc == []):
- opt_dir="%s/%s"%(exp_root,exp_name)
- try:
- #############################1a
- path_text="%s/2-name2text.txt" % opt_dir
- if(os.path.exists(path_text)==False or (os.path.exists(path_text)==True and len(open(path_text,"r",encoding="utf8").read().strip("\n").split("\n"))<2)):
- config={
- "inp_text":inp_text,
- "inp_wav_dir":inp_wav_dir,
- "exp_name":exp_name,
- "opt_dir":opt_dir,
- "bert_pretrained_dir":bert_pretrained_dir,
- "is_half": str(is_half)
- }
- gpu_names=gpu_numbers1a.split("-")
- all_parts=len(gpu_names)
- for i_part in range(all_parts):
- config.update(
- {
- "i_part": str(i_part),
- "all_parts": str(all_parts),
- "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
- }
- )
- os.environ.update(config)
- cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py'%python_exec
- print(cmd)
- p = Popen(cmd, shell=True)
- ps1abc.append(p)
- yield i18n("进度") + ": 1A-Doing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- for p in ps1abc:p.wait()
-
- opt = []
- for i_part in range(all_parts):#txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
- txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
- with open(txt_path, "r",encoding="utf8") as f:
- opt += f.read().strip("\n").split("\n")
- os.remove(txt_path)
- with open(path_text, "w",encoding="utf8") as f:
- f.write("\n".join(opt) + "\n")
- assert len("".join(opt)) > 0, process_info(process_name_1a, "failed")
- yield i18n("进度") + ": 1A-Done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- ps1abc=[]
- #############################1b
- config={
- "inp_text":inp_text,
- "inp_wav_dir":inp_wav_dir,
- "exp_name":exp_name,
- "opt_dir":opt_dir,
- "cnhubert_base_dir":ssl_pretrained_dir,
- }
- gpu_names=gpu_numbers1Ba.split("-")
- all_parts=len(gpu_names)
- for i_part in range(all_parts):
- config.update(
- {
- "i_part": str(i_part),
- "all_parts": str(all_parts),
- "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
- }
- )
- os.environ.update(config)
- cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py'%python_exec
- print(cmd)
- p = Popen(cmd, shell=True)
- ps1abc.append(p)
- yield i18n("进度") + ": 1A-Done, 1B-Doing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- for p in ps1abc:p.wait()
- yield i18n("进度") + ": 1A-Done, 1B-Done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- ps1abc=[]
- #############################1c
- path_semantic = "%s/6-name2semantic.tsv" % opt_dir
- if(os.path.exists(path_semantic)==False or (os.path.exists(path_semantic)==True and os.path.getsize(path_semantic)<31)):
- config={
- "inp_text":inp_text,
- "exp_name":exp_name,
- "opt_dir":opt_dir,
- "pretrained_s2G":pretrained_s2G_path,
- "s2config_path":"GPT_SoVITS/configs/s2.json",
- }
- gpu_names=gpu_numbers1c.split("-")
- all_parts=len(gpu_names)
- for i_part in range(all_parts):
- config.update(
- {
- "i_part": str(i_part),
- "all_parts": str(all_parts),
- "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]),
- }
- )
- os.environ.update(config)
- cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py'%python_exec
- print(cmd)
- p = Popen(cmd, shell=True)
- ps1abc.append(p)
- yield i18n("进度") + ": 1A-Done, 1B-Done, 1C-Doing", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- for p in ps1abc:p.wait()
-
- opt = ["item_name\tsemantic_audio"]
- for i_part in range(all_parts):
- semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
- with open(semantic_path, "r",encoding="utf8") as f:
- opt += f.read().strip("\n").split("\n")
- os.remove(semantic_path)
- with open(path_semantic, "w",encoding="utf8") as f:
- f.write("\n".join(opt) + "\n")
- yield i18n("进度") + ": 1A-Done, 1B-Done, 1C-Done", {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
- ps1abc = []
- yield process_info(process_name_1abc, "finish"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- except:
- traceback.print_exc()
- close1abc()
- yield process_info(process_name_1abc, "failed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
- else:
- yield process_info(process_name_1abc, "occupy"), {"__type__": "update", "visible": False}, {"__type__": "update", "visible": True}
-
-def close1abc():
- global ps1abc
- if (ps1abc != []):
- for p1abc in ps1abc:
- try:
- kill_process(p1abc.pid, process_name_1abc)
- except:
- traceback.print_exc()
- ps1abc=[]
- return process_info(process_name_1abc, "closed"), {"__type__": "update", "visible": True}, {"__type__": "update", "visible": False}
-
-def switch_version(version_):
- os.environ["version"]=version_
- global version
- version = version_
- if pretrained_sovits_name[int(version[-1])-1] !='' and pretrained_gpt_name[int(version[-1])-1] !='':...
- else:
- gr.Warning(i18n('未下载模型') + ": " + version.upper())
- set_default()
- return {'__type__': 'update', 'value': pretrained_sovits_name[int(version[-1])-1]}, \
- {'__type__': 'update', 'value': pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D")}, \
- {'__type__': 'update', 'value': pretrained_gpt_name[int(version[-1])-1]}, \
- {'__type__': 'update', 'value': pretrained_gpt_name[int(version[-1])-1]}, \
- {'__type__': 'update', 'value': pretrained_sovits_name[int(version[-1])-1]}, \
- {'__type__': 'update', "value": default_batch_size, "maximum": default_max_batch_size}, \
- {'__type__': 'update', "value": default_sovits_epoch, "maximum": max_sovits_epoch}, \
- {'__type__': 'update', "value": default_sovits_save_every_epoch,"maximum": max_sovits_save_every_epoch}, \
- {'__type__': 'update', "visible": True if version!="v3"else False}, \
- {'__type__': 'update', "value": False if not if_force_ckpt else True, "interactive": True if not if_force_ckpt else False}, \
- {'__type__': 'update', "interactive": False if version == "v3" else True, "value": False}, \
- {'__type__': 'update', "visible": True if version== "v3" else False}
-
-if os.path.exists('GPT_SoVITS/text/G2PWModel'):...
-else:
- cmd = '"%s" GPT_SoVITS/download.py'%python_exec
- p = Popen(cmd, shell=True)
- p.wait()
-
-def sync(text):
- return {'__type__': 'update', 'value': text}
-
-with gr.Blocks(title="GPT-SoVITS WebUI") as app:
- gr.Markdown(
- value=
- i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "
" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
- )
- gr.Markdown(
- value=
- i18n("中文教程文档") + ": " + "https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e"
- )
-
- with gr.Tabs():
- with gr.TabItem("0-"+i18n("前置数据集获取工具")):#提前随机切片防止uvr5爆内存->uvr5->slicer->asr->打标
- gr.Markdown(value="0a-"+i18n("UVR5人声伴奏分离&去混响去延迟工具"))
- with gr.Row():
- with gr.Column(scale=3):
- with gr.Row():
- uvr5_info = gr.Textbox(label=process_info(process_name_uvr5, "info"))
- open_uvr5 = gr.Button(value=process_info(process_name_uvr5, "open"),variant="primary",visible=True)
- close_uvr5 = gr.Button(value=process_info(process_name_uvr5, "close"),variant="primary",visible=False)
-
- gr.Markdown(value="0b-"+i18n("语音切分工具"))
- with gr.Row():
- with gr.Column(scale=3):
- with gr.Row():
- slice_inp_path=gr.Textbox(label=i18n("音频自动切分输入路径,可文件可文件夹"),value="")
- slice_opt_root=gr.Textbox(label=i18n("切分后的子音频的输出根目录"),value="output/slicer_opt")
- with gr.Row():
- threshold=gr.Textbox(label=i18n("threshold:音量小于这个值视作静音的备选切割点"),value="-34")
- min_length=gr.Textbox(label=i18n("min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值"),value="4000")
- min_interval=gr.Textbox(label=i18n("min_interval:最短切割间隔"),value="300")
- hop_size=gr.Textbox(label=i18n("hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)"),value="10")
- max_sil_kept=gr.Textbox(label=i18n("max_sil_kept:切完后静音最多留多长"),value="500")
- with gr.Row():
- _max=gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("max:归一化后最大值多少"),value=0.9,interactive=True)
- alpha=gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("alpha_mix:混多少比例归一化后音频进来"),value=0.25,interactive=True)
- with gr.Row():
- n_process=gr.Slider(minimum=1,maximum=n_cpu,step=1,label=i18n("切割使用的进程数"),value=4,interactive=True)
- slicer_info = gr.Textbox(label=process_info(process_name_slice, "info"))
- open_slicer_button = gr.Button(value=process_info(process_name_slice, "open"),variant="primary",visible=True)
- close_slicer_button = gr.Button(value=process_info(process_name_slice, "close"),variant="primary",visible=False)
-
- gr.Markdown(value="0bb-"+i18n("语音降噪工具"))
- with gr.Row():
- with gr.Column(scale=3):
- with gr.Row():
- denoise_input_dir=gr.Textbox(label=i18n("输入文件夹路径"),value="")
- denoise_output_dir=gr.Textbox(label=i18n("输出文件夹路径"),value="output/denoise_opt")
- with gr.Row():
- denoise_info = gr.Textbox(label=process_info(process_name_denoise, "info"))
- open_denoise_button = gr.Button(value=process_info(process_name_denoise, "open"),variant="primary",visible=True)
- close_denoise_button = gr.Button(value=process_info(process_name_denoise, "close"),variant="primary",visible=False)
-
- gr.Markdown(value="0c-"+i18n("语音识别工具"))
- with gr.Row():
- with gr.Column(scale=3):
- with gr.Row():
- asr_inp_dir = gr.Textbox(label=i18n("输入文件夹路径"), value="D:\\GPT-SoVITS\\raw\\xxx", interactive=True)
- asr_opt_dir = gr.Textbox(label=i18n("输出文件夹路径"), value="output/asr_opt", interactive=True)
- with gr.Row():
- asr_model = gr.Dropdown(label=i18n("ASR 模型"), choices=list(asr_dict.keys()), interactive=True, value="达摩 ASR (中文)")
- asr_size = gr.Dropdown(label=i18n("ASR 模型尺寸"), choices=["large"], interactive=True, value="large")
- asr_lang = gr.Dropdown(label=i18n("ASR 语言设置"), choices=["zh","yue"], interactive=True, value="zh")
- asr_precision = gr.Dropdown(label=i18n("数据类型精度"), choices=["float32"], interactive=True, value="float32")
- with gr.Row():
- asr_info = gr.Textbox(label=process_info(process_name_asr, "info"))
- open_asr_button = gr.Button(value=process_info(process_name_asr, "open"),variant="primary",visible=True)
- close_asr_button = gr.Button(value=process_info(process_name_asr, "close"),variant="primary",visible=False)
-
- def change_lang_choices(key): #根据选择的模型修改可选的语言
- return {"__type__": "update", "choices": asr_dict[key]['lang'], "value": asr_dict[key]['lang'][0]}
- def change_size_choices(key): # 根据选择的模型修改可选的模型尺寸
- return {"__type__": "update", "choices": asr_dict[key]['size'], "value": asr_dict[key]['size'][-1]}
- def change_precision_choices(key): #根据选择的模型修改可选的语言
- if key =="Faster Whisper (多语种)":
- if default_batch_size <= 4:
- precision = 'int8'
- elif is_half:
- precision = 'float16'
- else:
- precision = 'float32'
- else:
- precision = 'float32'
- return {"__type__": "update", "choices": asr_dict[key]['precision'], "value": precision}
- asr_model.change(change_lang_choices, [asr_model], [asr_lang])
- asr_model.change(change_size_choices, [asr_model], [asr_size])
- asr_model.change(change_precision_choices, [asr_model], [asr_precision])
-
- gr.Markdown(value="0d-"+i18n("语音文本校对标注工具"))
- with gr.Row():
- with gr.Column(scale=3):
- with gr.Row():
- path_list = gr.Textbox(label=i18n("标注文件路径 (含文件后缀 *.list)"), value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx.list", interactive=True)
- label_info = gr.Textbox(label=process_info(process_name_subfix, "info"))
- open_label = gr.Button(value=process_info(process_name_subfix, "open"),variant="primary",visible=True)
- close_label = gr.Button(value=process_info(process_name_subfix, "close"),variant="primary",visible=False)
-
- open_label.click(change_label, [path_list], [label_info,open_label,close_label])
- close_label.click(change_label, [path_list], [label_info,open_label,close_label])
- open_uvr5.click(change_uvr5, [], [uvr5_info,open_uvr5,close_uvr5])
- close_uvr5.click(change_uvr5, [], [uvr5_info,open_uvr5,close_uvr5])
-
- with gr.TabItem(i18n("1-GPT-SoVITS-TTS")):
- with gr.Row():
- with gr.Row():
- exp_name = gr.Textbox(label=i18n("*实验/模型名"), value="xxx", interactive=True)
- gpu_info = gr.Textbox(label=i18n("显卡信息"), value=gpu_info, visible=True, interactive=False)
- version_checkbox = gr.Radio(label=i18n("版本"),value=version,choices=['v1','v2','v3'])
- with gr.Row():
- pretrained_s2G = gr.Textbox(label=i18n("预训练SoVITS-G模型路径"), value=pretrained_sovits_name[int(version[-1])-1], interactive=True, lines=2, max_lines=3,scale=9)
- pretrained_s2D = gr.Textbox(label=i18n("预训练SoVITS-D模型路径"), value=pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D"), interactive=True, lines=2, max_lines=3,scale=9)
- pretrained_s1 = gr.Textbox(label=i18n("预训练GPT模型路径"), value=pretrained_gpt_name[int(version[-1])-1], interactive=True, lines=2, max_lines=3,scale=10)
-
- with gr.TabItem("1A-"+i18n("训练集格式化工具")):
- gr.Markdown(value=i18n("输出logs/实验名目录下应有23456开头的文件和文件夹"))
- with gr.Row():
- with gr.Row():
- inp_text = gr.Textbox(label=i18n("*文本标注文件"),value=r"D:\RVC1006\GPT-SoVITS\raw\xxx.list",interactive=True,scale=10)
- with gr.Row():
- inp_wav_dir = gr.Textbox(
- label=i18n("*训练集音频文件目录"),
- # value=r"D:\RVC1006\GPT-SoVITS\raw\xxx",
- interactive=True,
- placeholder=i18n("填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。"), scale=10
- )
-
- gr.Markdown(value="1Aa-"+process_name_1a)
- with gr.Row():
- with gr.Row():
- gpu_numbers1a = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"),value="%s-%s"%(gpus,gpus),interactive=True)
- with gr.Row():
- bert_pretrained_dir = gr.Textbox(label=i18n("预训练中文BERT模型路径"),value="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",interactive=False,lines=2)
- with gr.Row():
- button1a_open = gr.Button(value=process_info(process_name_1a, "open"),variant="primary",visible=True)
- button1a_close = gr.Button(value=process_info(process_name_1a, "close"),variant="primary",visible=False)
- with gr.Row():
- info1a=gr.Textbox(label=process_info(process_name_1a, "info"))
-
- gr.Markdown(value="1Ab-"+process_name_1b)
- with gr.Row():
- with gr.Row():
- gpu_numbers1Ba = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"),value="%s-%s"%(gpus,gpus),interactive=True)
- with gr.Row():
- cnhubert_base_dir = gr.Textbox(label=i18n("预训练SSL模型路径"),value="GPT_SoVITS/pretrained_models/chinese-hubert-base",interactive=False,lines=2)
- with gr.Row():
- button1b_open = gr.Button(value=process_info(process_name_1b, "open"),variant="primary",visible=True)
- button1b_close = gr.Button(value=process_info(process_name_1b, "close"),variant="primary",visible=False)
- with gr.Row():
- info1b=gr.Textbox(label=process_info(process_name_1b, "info"))
-
- gr.Markdown(value="1Ac-"+process_name_1c)
- with gr.Row():
- with gr.Row():
- gpu_numbers1c = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"),value="%s-%s"%(gpus,gpus),interactive=True)
- with gr.Row():
- pretrained_s2G_ = gr.Textbox(label=i18n("预训练SoVITS-G模型路径"), value=pretrained_sovits_name[int(version[-1])-1], interactive=False,lines=2)
- with gr.Row():
- button1c_open = gr.Button(value=process_info(process_name_1c, "open"),variant="primary",visible=True)
- button1c_close = gr.Button(value=process_info(process_name_1c, "close"),variant="primary",visible=False)
- with gr.Row():
- info1c=gr.Textbox(label=process_info(process_name_1c, "info"))
-
- gr.Markdown(value="1Aabc-"+process_name_1abc)
- with gr.Row():
- with gr.Row():
- button1abc_open = gr.Button(value=process_info(process_name_1abc, "open"),variant="primary",visible=True)
- button1abc_close = gr.Button(value=process_info(process_name_1abc, "close"),variant="primary",visible=False)
- with gr.Row():
- info1abc=gr.Textbox(label=process_info(process_name_1abc, "info"))
-
- pretrained_s2G.change(sync,[pretrained_s2G],[pretrained_s2G_])
- open_asr_button.click(open_asr, [asr_inp_dir, asr_opt_dir, asr_model, asr_size, asr_lang, asr_precision], [asr_info,open_asr_button,close_asr_button,path_list,inp_text,inp_wav_dir])
- close_asr_button.click(close_asr, [], [asr_info,open_asr_button,close_asr_button])
- open_slicer_button.click(open_slice, [slice_inp_path,slice_opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_max,alpha,n_process], [slicer_info,open_slicer_button,close_slicer_button,asr_inp_dir,denoise_input_dir,inp_wav_dir])
- close_slicer_button.click(close_slice, [], [slicer_info,open_slicer_button,close_slicer_button])
- open_denoise_button.click(open_denoise, [denoise_input_dir,denoise_output_dir], [denoise_info,open_denoise_button,close_denoise_button,asr_inp_dir,inp_wav_dir])
- close_denoise_button.click(close_denoise, [], [denoise_info,open_denoise_button,close_denoise_button])
-
- button1a_open.click(open1a, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,bert_pretrained_dir], [info1a,button1a_open,button1a_close])
- button1a_close.click(close1a, [], [info1a,button1a_open,button1a_close])
- button1b_open.click(open1b, [inp_text,inp_wav_dir,exp_name,gpu_numbers1Ba,cnhubert_base_dir], [info1b,button1b_open,button1b_close])
- button1b_close.click(close1b, [], [info1b,button1b_open,button1b_close])
- button1c_open.click(open1c, [inp_text,exp_name,gpu_numbers1c,pretrained_s2G], [info1c,button1c_open,button1c_close])
- button1c_close.click(close1c, [], [info1c,button1c_open,button1c_close])
- button1abc_open.click(open1abc, [inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,cnhubert_base_dir,pretrained_s2G], [info1abc,button1abc_open,button1abc_close])
- button1abc_close.click(close1abc, [], [info1abc,button1abc_open,button1abc_close])
-
- with gr.TabItem("1B-"+i18n("微调训练")):
- gr.Markdown(value="1Ba-"+i18n("SoVITS 训练: 模型权重文件在 SoVITS_weights/"))
- with gr.Row():
- with gr.Column():
- with gr.Row():
- batch_size = gr.Slider(minimum=1,maximum=default_max_batch_size,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size,interactive=True)
- total_epoch = gr.Slider(minimum=1,maximum=max_sovits_epoch,step=1,label=i18n("总训练轮数total_epoch,不建议太高"),value=default_sovits_epoch,interactive=True)
- with gr.Row():
- text_low_lr_rate = gr.Slider(minimum=0.2,maximum=0.6,step=0.05,label=i18n("文本模块学习率权重"),value=0.4,visible=True if version!="v3"else False)#v3 not need
- lora_rank = gr.Radio(label=i18n("LoRA秩"), value="32", choices=['16', '32', '64', '128'],visible=True if version=="v3"else False)#v1v2 not need
- save_every_epoch = gr.Slider(minimum=1,maximum=max_sovits_save_every_epoch,step=1,label=i18n("保存频率save_every_epoch"),value=default_sovits_save_every_epoch,interactive=True)
- with gr.Column():
- with gr.Column():
- if_save_latest = gr.Checkbox(label=i18n("是否仅保存最新的权重文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
- if_save_every_weights = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
- if_grad_ckpt = gr.Checkbox(label="v3是否开启梯度检查点节省显存占用", value=False, interactive=True if version == "v3" else False, show_label=True,visible=False) # 只有V3s2可以用
- with gr.Row():
- gpu_numbers1Ba = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"), value="%s" % (gpus), interactive=True)
- with gr.Row():
- with gr.Row():
- button1Ba_open = gr.Button(value=process_info(process_name_sovits, "open"),variant="primary",visible=True)
- button1Ba_close = gr.Button(value=process_info(process_name_sovits, "close"),variant="primary",visible=False)
- with gr.Row():
- info1Ba=gr.Textbox(label=process_info(process_name_sovits, "info"))
- gr.Markdown(value="1Bb-"+i18n("GPT 训练: 模型权重文件在 GPT_weights/"))
- with gr.Row():
- with gr.Column():
- with gr.Row():
- batch_size1Bb = gr.Slider(minimum=1,maximum=40,step=1,label=i18n("每张显卡的batch_size"),value=default_batch_size_s1,interactive=True)
- total_epoch1Bb = gr.Slider(minimum=2,maximum=50,step=1,label=i18n("总训练轮数total_epoch"),value=15,interactive=True)
- with gr.Row():
- save_every_epoch1Bb = gr.Slider(minimum=1,maximum=50,step=1,label=i18n("保存频率save_every_epoch"),value=5,interactive=True)
- if_dpo = gr.Checkbox(label=i18n("是否开启DPO训练选项(实验性)"), value=False, interactive=True, show_label=True)
- with gr.Column():
- with gr.Column():
- if_save_latest1Bb = gr.Checkbox(label=i18n("是否仅保存最新的权重文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
- if_save_every_weights1Bb = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
- with gr.Row():
- gpu_numbers1Bb = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"), value="%s" % (gpus), interactive=True)
- with gr.Row():
- with gr.Row():
- button1Bb_open = gr.Button(value=process_info(process_name_gpt, "open"),variant="primary",visible=True)
- button1Bb_close = gr.Button(value=process_info(process_name_gpt, "close"),variant="primary",visible=False)
- with gr.Row():
- info1Bb=gr.Textbox(label=process_info(process_name_gpt, "info"))
-
- button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D,if_grad_ckpt,lora_rank], [info1Ba,button1Ba_open,button1Ba_close])
- button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close])
- button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_dpo,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close])
- button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close])
-
- with gr.TabItem("1C-"+i18n("推理")):
- gr.Markdown(value=i18n("选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。"))
- with gr.Row():
- with gr.Row():
- GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names,key=custom_sort_key),value=pretrained_gpt_name[0],interactive=True)
- SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names,key=custom_sort_key),value=pretrained_sovits_name[0],interactive=True)
- with gr.Row():
- gpu_number_1C=gr.Textbox(label=i18n("GPU卡号,只能填1个整数"), value=gpus, interactive=True)
- refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
- refresh_button.click(fn=change_choices,inputs=[],outputs=[SoVITS_dropdown,GPT_dropdown])
- with gr.Row():
- with gr.Row():
- batched_infer_enabled = gr.Checkbox(label=i18n("启用并行推理版本"), value=False, interactive=True, show_label=True)
- with gr.Row():
- open_tts = gr.Button(value=process_info(process_name_tts, "open"),variant='primary',visible=True)
- close_tts = gr.Button(value=process_info(process_name_tts, "close"),variant='primary',visible=False)
- with gr.Row():
- tts_info = gr.Textbox(label=process_info(process_name_tts, "info"))
- open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
- close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
-
- version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown,batch_size,total_epoch,save_every_epoch,text_low_lr_rate, if_grad_ckpt, batched_infer_enabled, lora_rank])
-
- with gr.TabItem(i18n("2-GPT-SoVITS-变声")):gr.Markdown(value=i18n("施工中,请静候佳音"))
-
- app.queue().launch(#concurrency_count=511, max_size=1022
- server_name="0.0.0.0",
- inbrowser=True,
- share=is_share,
- server_port=webui_port_main,
- quiet=True,
- )