Fix package pathing issues

This commit is contained in:
Jarod Mica 2025-05-23 02:48:09 -07:00
parent 6ca0eda2e9
commit fe531567f1
6 changed files with 25 additions and 15 deletions

View File

@ -445,8 +445,9 @@ class TTS:
def _init_models(
self,
):
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
# Don't auto load the weights, load them independently not on instantiation
# self.init_t2s_weights(self.configs.t2s_weights_path)
# self.init_vits_weights(self.configs.vits_weights_path)
self.init_bert_weights(self.configs.bert_base_path)
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
# self.enable_half_precision(self.configs.is_half)
@ -468,10 +469,11 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu":
self.bert_model = self.bert_model.half()
def init_vits_weights(self, weights_path: str):
def init_vits_weights(self, weights_path: str, vocoder_path: str = None):
self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
print(if_lora_v3)
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
@ -502,9 +504,10 @@ class TTS:
self.configs.update_version(model_version)
# print(f"model_version:{model_version}")
print(f"model_version:{model_version}")
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version not in {"v3", "v4"}:
print("Using v2 model")
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
@ -513,6 +516,7 @@ class TTS:
)
self.configs.use_vocoder = False
else:
print("Using v3 model")
kwargs["version"]=model_version
vits_model = SynthesizerTrnV3(
self.configs.filter_length // 2 + 1,
@ -520,8 +524,9 @@ class TTS:
n_speakers=self.configs.n_speakers,
**kwargs,
)
self.configs.use_vocoder = True
self.init_vocoder(model_version)
self.init_vocoder(model_version, vocoder_path)
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
del vits_model.enc_q
@ -570,7 +575,7 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half()
def init_vocoder(self, version: str):
def init_vocoder(self, version: str, vocoder_path: str = None):
if version == "v3":
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
return
@ -578,7 +583,7 @@ class TTS:
self.vocoder.cpu()
del self.vocoder
self.empty_cache()
self.vocoder = BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False,
@ -611,8 +616,13 @@ class TTS:
gin_channels=0, is_bias=True
)
self.vocoder.remove_weight_norm()
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
print("loading vocoder",vocoder_path)
try:
state_dict_g = torch.load(vocoder_path, map_location="cpu")
self.vocoder.load_state_dict(state_dict_g)
except:
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
self.vocoder_configs["sr"] = 48000
self.vocoder_configs["T_ref"] = 500
@ -988,7 +998,7 @@ class TTS:
seed = inputs.get("seed", -1)
seed = -1 if seed in ["", None] else seed
actual_seed = set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True)
parallel_infer = inputs.get("parallel_infer", False)
repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 32)
super_sampling = inputs.get("super_sampling", False)

View File

@ -25,7 +25,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
get_pos_embed_indices,
)
from module.commons import sequence_mask
from GPT_SoVITS.module.commons import sequence_mask
class TextEmbedding(nn.Module):

View File

@ -14,7 +14,7 @@ from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
from GPT_SoVITS.f5_tts.model.modules import (
TimestepEmbedding,
ConvPositionEmbedding,
MMDiTBlock,

View File

@ -17,7 +17,7 @@ import torch.nn.functional as F
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
from GPT_SoVITS.f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,

View File

@ -10,7 +10,7 @@ from torch.nn import functional as F
from GPT_SoVITS.module import commons
from GPT_SoVITS.module import modules
from GPT_SoVITS.module import attentions
from f5_tts.model import DiT
from GPT_SoVITS.f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from GPT_SoVITS.module.commons import init_weights, get_padding

View File

@ -9,7 +9,7 @@ from GPT_SoVITS.text.symbols import punctuation
from GPT_SoVITS.text.symbols2 import symbols
from builtins import str as unicode
from text.en_normalization.expend import normalize
from GPT_SoVITS.text.en_normalization.expend import normalize
from nltk.tokenize import TweetTokenizer
word_tokenize = TweetTokenizer().tokenize