diff --git a/api.py b/api.py index d92d9c8..0f34149 100644 --- a/api.py +++ b/api.py @@ -150,7 +150,7 @@ sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import signal -from text.LangSegmenter import LangSegmenter +from GPT_SoVITS.text.LangSegmenter import LangSegmenter from time import time as ttime import torch, torchaudio import librosa @@ -160,14 +160,14 @@ from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np -from feature_extractor import cnhubert +from GPT_SoVITS.feature_extractor import cnhubert from io import BytesIO -from module.models import SynthesizerTrn, SynthesizerTrnV3 +from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3 from peft import LoraConfig, PeftModel, get_peft_model -from AR.models.t2s_lightning_module import Text2SemanticLightningModule -from text import cleaned_text_to_sequence -from text.cleaner import clean_text -from module.mel_processing import spectrogram_torch +from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule +from GPT_SoVITS.text import cleaned_text_to_sequence +from GPT_SoVITS.text.cleaner import clean_text +from GPT_SoVITS.module.mel_processing import spectrogram_torch from tools.my_utils import load_audio import config as global_config import logging @@ -176,9 +176,9 @@ import subprocess class DefaultRefer: def __init__(self, path, text, language): - self.path = args.default_refer_path - self.text = args.default_refer_text - self.language = args.default_refer_language + self.path = path + self.text = text + self.language = language def is_ready(self) -> bool: return is_full(self.path, self.text, self.language) @@ -200,7 +200,7 @@ def is_full(*items): # 任意一项为空返回False def init_bigvgan(): global bigvgan_model - from BigVGAN import bigvgan + from GPT_SoVITS.BigVGAN import bigvgan bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions # remove weight norm in the model and set to eval mode bigvgan_model.remove_weight_norm() @@ -221,7 +221,7 @@ def resample(audio_tensor, sr0): return resample_transform_dict[sr0](audio_tensor) -from module.mel_processing import spectrogram_torch,mel_spectrogram_torch +from GPT_SoVITS.module.mel_processing import spectrogram_torch,mel_spectrogram_torch spec_min = -12 spec_max = 2 def norm_spec(x): @@ -240,6 +240,34 @@ mel_fn=lambda x: mel_spectrogram_torch(x, **{ }) +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + sr_model=None def audio_sr(audio,sr): global sr_model @@ -270,7 +298,7 @@ class Sovits: self.vq_model = vq_model self.hps = hps -from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new +from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast,load_sovits_new def get_sovits_weights(sovits_path): path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth" is_exist_s2gv3=os.path.exists(path_sovits_v3) @@ -340,6 +368,7 @@ def get_sovits_weights(sovits_path): sovits = Sovits(vq_model, hps) return sovits + class Gpt: def __init__(self, max_sec, t2s_model): self.max_sec = max_sec @@ -363,6 +392,7 @@ def get_gpt_weights(gpt_path): gpt = Gpt(max_sec, t2s_model) return gpt + def change_gpt_sovits_weights(gpt_path,sovits_path): try: gpt = get_gpt_weights(gpt_path) @@ -410,7 +440,8 @@ def get_bert_inf(phones, word2ph, norm_text, language): return bert -from text import chinese + +from GPT_SoVITS.text import chinese def get_phones_and_bert(text,language,version,final=False): if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: formattext = text @@ -475,36 +506,8 @@ def get_phones_and_bert(text,language,version,final=False): return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text -class DictToAttrRecursive(dict): - def __init__(self, input_dict): - super().__init__(input_dict) - for key, value in input_dict.items(): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - self[key] = value - setattr(self, key, value) - - def __getattr__(self, item): - try: - return self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - def __setattr__(self, key, value): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - super(DictToAttrRecursive, self).__setitem__(key, value) - super().__setattr__(key, value) - - def __delattr__(self, item): - try: - del self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - def get_spepc(hps, filename): - audio,_ = librosa.load(filename, int(hps.data.sampling_rate)) + audio,_ = librosa.load(filename, sr=int(hps.data.sampling_rate)) audio = torch.FloatTensor(audio) maxx=audio.abs().max() if(maxx>1): @@ -934,15 +937,23 @@ parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, h args = parser.parse_args() sovits_path = args.sovits_path gpt_path = args.gpt_path +default_refer_path = args.default_refer_path +default_refer_text = args.default_refer_text +default_refer_language = args.default_refer_language device = args.device port = args.port host = args.bind_addr +full_precision = args.full_precision +half_precision = args.half_precision +stream_mode = args.stream_mode +media_type = args.media_type +sub_type = args.sub_type +default_cut_punc = args.cut_punc cnhubert_base_path = args.hubert_path bert_path = args.bert_path -default_cut_punc = args.cut_punc # 应用参数配置 -default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) +default_refer = DefaultRefer(default_refer_path, default_refer_text, default_refer_language) # 模型路径检查 if sovits_path == "": @@ -963,24 +974,24 @@ else: # 获取半精度 is_half = g_config.is_half -if args.full_precision: +if full_precision: is_half = False -if args.half_precision: +if half_precision: is_half = True -if args.full_precision and args.half_precision: +if full_precision and half_precision: is_half = g_config.is_half # 炒饭fallback logger.info(f"半精: {is_half}") # 流式返回模式 -if args.stream_mode.lower() in ["normal","n"]: +if stream_mode.lower() in ["normal","n"]: stream_mode = "normal" logger.info("流式返回已开启") else: stream_mode = "close" # 音频编码格式 -if args.media_type.lower() in ["aac","ogg"]: - media_type = args.media_type.lower() +if media_type.lower() in ["aac","ogg"]: + media_type = media_type.lower() elif stream_mode == "close": media_type = "wav" else: @@ -988,7 +999,7 @@ else: logger.info(f"编码格式: {media_type}") # 音频数据类型 -if args.sub_type.lower() == 'int32': +if sub_type.lower() == 'int32': is_int32 = True logger.info(f"数据类型: int32") else: @@ -1102,4 +1113,4 @@ async def tts_endpoint( if __name__ == "__main__": - uvicorn.run(app, host=host, port=port, workers=1) + uvicorn.run(app, host=host, port=port, workers=1) \ No newline at end of file