Correct imports and parameter handling for api.py

- Fix import errors
- Resolve parameter assignment issue for DefaultRefer class
- Add missing key argument to librosa.load() call
- Complement parameter passing in initialization section
This commit is contained in:
Spr_Aachen 2025-04-05 17:26:33 +08:00
parent 9da7e17efe
commit 28effc2d44

115
api.py
View File

@ -150,7 +150,7 @@ sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir)) sys.path.append("%s/GPT_SoVITS" % (now_dir))
import signal import signal
from text.LangSegmenter import LangSegmenter from GPT_SoVITS.text.LangSegmenter import LangSegmenter
from time import time as ttime from time import time as ttime
import torch, torchaudio import torch, torchaudio
import librosa import librosa
@ -160,14 +160,14 @@ from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np import numpy as np
from feature_extractor import cnhubert from GPT_SoVITS.feature_extractor import cnhubert
from io import BytesIO from io import BytesIO
from module.models import SynthesizerTrn, SynthesizerTrnV3 from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence from GPT_SoVITS.text import cleaned_text_to_sequence
from text.cleaner import clean_text from GPT_SoVITS.text.cleaner import clean_text
from module.mel_processing import spectrogram_torch from GPT_SoVITS.module.mel_processing import spectrogram_torch
from tools.my_utils import load_audio from tools.my_utils import load_audio
import config as global_config import config as global_config
import logging import logging
@ -176,9 +176,9 @@ import subprocess
class DefaultRefer: class DefaultRefer:
def __init__(self, path, text, language): def __init__(self, path, text, language):
self.path = args.default_refer_path self.path = path
self.text = args.default_refer_text self.text = text
self.language = args.default_refer_language self.language = language
def is_ready(self) -> bool: def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language) return is_full(self.path, self.text, self.language)
@ -200,7 +200,7 @@ def is_full(*items): # 任意一项为空返回False
def init_bigvgan(): def init_bigvgan():
global bigvgan_model global bigvgan_model
from BigVGAN import bigvgan from GPT_SoVITS.BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm() bigvgan_model.remove_weight_norm()
@ -221,7 +221,7 @@ def resample(audio_tensor, sr0):
return resample_transform_dict[sr0](audio_tensor) return resample_transform_dict[sr0](audio_tensor)
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch from GPT_SoVITS.module.mel_processing import spectrogram_torch,mel_spectrogram_torch
spec_min = -12 spec_min = -12
spec_max = 2 spec_max = 2
def norm_spec(x): def norm_spec(x):
@ -240,6 +240,34 @@ mel_fn=lambda x: mel_spectrogram_torch(x, **{
}) })
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
sr_model=None sr_model=None
def audio_sr(audio,sr): def audio_sr(audio,sr):
global sr_model global sr_model
@ -270,7 +298,7 @@ class Sovits:
self.vq_model = vq_model self.vq_model = vq_model
self.hps = hps self.hps = hps
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
def get_sovits_weights(sovits_path): def get_sovits_weights(sovits_path):
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth" path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3=os.path.exists(path_sovits_v3) is_exist_s2gv3=os.path.exists(path_sovits_v3)
@ -340,6 +368,7 @@ def get_sovits_weights(sovits_path):
sovits = Sovits(vq_model, hps) sovits = Sovits(vq_model, hps)
return sovits return sovits
class Gpt: class Gpt:
def __init__(self, max_sec, t2s_model): def __init__(self, max_sec, t2s_model):
self.max_sec = max_sec self.max_sec = max_sec
@ -363,6 +392,7 @@ def get_gpt_weights(gpt_path):
gpt = Gpt(max_sec, t2s_model) gpt = Gpt(max_sec, t2s_model)
return gpt return gpt
def change_gpt_sovits_weights(gpt_path,sovits_path): def change_gpt_sovits_weights(gpt_path,sovits_path):
try: try:
gpt = get_gpt_weights(gpt_path) gpt = get_gpt_weights(gpt_path)
@ -410,7 +440,8 @@ def get_bert_inf(phones, word2ph, norm_text, language):
return bert return bert
from text import chinese
from GPT_SoVITS.text import chinese
def get_phones_and_bert(text,language,version,final=False): def get_phones_and_bert(text,language,version,final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
formattext = text formattext = text
@ -475,36 +506,8 @@ def get_phones_and_bert(text,language,version,final=False):
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def get_spepc(hps, filename): def get_spepc(hps, filename):
audio,_ = librosa.load(filename, int(hps.data.sampling_rate)) audio,_ = librosa.load(filename, sr=int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
maxx=audio.abs().max() maxx=audio.abs().max()
if(maxx>1): if(maxx>1):
@ -934,15 +937,23 @@ parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, h
args = parser.parse_args() args = parser.parse_args()
sovits_path = args.sovits_path sovits_path = args.sovits_path
gpt_path = args.gpt_path gpt_path = args.gpt_path
default_refer_path = args.default_refer_path
default_refer_text = args.default_refer_text
default_refer_language = args.default_refer_language
device = args.device device = args.device
port = args.port port = args.port
host = args.bind_addr host = args.bind_addr
full_precision = args.full_precision
half_precision = args.half_precision
stream_mode = args.stream_mode
media_type = args.media_type
sub_type = args.sub_type
default_cut_punc = args.cut_punc
cnhubert_base_path = args.hubert_path cnhubert_base_path = args.hubert_path
bert_path = args.bert_path bert_path = args.bert_path
default_cut_punc = args.cut_punc
# 应用参数配置 # 应用参数配置
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) default_refer = DefaultRefer(default_refer_path, default_refer_text, default_refer_language)
# 模型路径检查 # 模型路径检查
if sovits_path == "": if sovits_path == "":
@ -963,24 +974,24 @@ else:
# 获取半精度 # 获取半精度
is_half = g_config.is_half is_half = g_config.is_half
if args.full_precision: if full_precision:
is_half = False is_half = False
if args.half_precision: if half_precision:
is_half = True is_half = True
if args.full_precision and args.half_precision: if full_precision and half_precision:
is_half = g_config.is_half # 炒饭fallback is_half = g_config.is_half # 炒饭fallback
logger.info(f"半精: {is_half}") logger.info(f"半精: {is_half}")
# 流式返回模式 # 流式返回模式
if args.stream_mode.lower() in ["normal","n"]: if stream_mode.lower() in ["normal","n"]:
stream_mode = "normal" stream_mode = "normal"
logger.info("流式返回已开启") logger.info("流式返回已开启")
else: else:
stream_mode = "close" stream_mode = "close"
# 音频编码格式 # 音频编码格式
if args.media_type.lower() in ["aac","ogg"]: if media_type.lower() in ["aac","ogg"]:
media_type = args.media_type.lower() media_type = media_type.lower()
elif stream_mode == "close": elif stream_mode == "close":
media_type = "wav" media_type = "wav"
else: else:
@ -988,7 +999,7 @@ else:
logger.info(f"编码格式: {media_type}") logger.info(f"编码格式: {media_type}")
# 音频数据类型 # 音频数据类型
if args.sub_type.lower() == 'int32': if sub_type.lower() == 'int32':
is_int32 = True is_int32 = True
logger.info(f"数据类型: int32") logger.info(f"数据类型: int32")
else: else: