mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-24 21:30:00 +08:00
Fix package pathing issues
This commit is contained in:
parent
6ca0eda2e9
commit
fe531567f1
@ -445,8 +445,9 @@ class TTS:
|
||||
def _init_models(
|
||||
self,
|
||||
):
|
||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
self.init_vits_weights(self.configs.vits_weights_path)
|
||||
# Don't auto load the weights, load them independently not on instantiation
|
||||
# self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
# self.init_vits_weights(self.configs.vits_weights_path)
|
||||
self.init_bert_weights(self.configs.bert_base_path)
|
||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||
# self.enable_half_precision(self.configs.is_half)
|
||||
@ -468,10 +469,11 @@ class TTS:
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.bert_model = self.bert_model.half()
|
||||
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
def init_vits_weights(self, weights_path: str, vocoder_path: str = None):
|
||||
self.configs.vits_weights_path = weights_path
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
||||
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
||||
print(if_lora_v3)
|
||||
|
||||
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
||||
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
|
||||
@ -502,9 +504,10 @@ class TTS:
|
||||
|
||||
self.configs.update_version(model_version)
|
||||
|
||||
# print(f"model_version:{model_version}")
|
||||
print(f"model_version:{model_version}")
|
||||
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
|
||||
if model_version not in {"v3", "v4"}:
|
||||
print("Using v2 model")
|
||||
vits_model = SynthesizerTrn(
|
||||
self.configs.filter_length // 2 + 1,
|
||||
self.configs.segment_size // self.configs.hop_length,
|
||||
@ -513,6 +516,7 @@ class TTS:
|
||||
)
|
||||
self.configs.use_vocoder = False
|
||||
else:
|
||||
print("Using v3 model")
|
||||
kwargs["version"]=model_version
|
||||
vits_model = SynthesizerTrnV3(
|
||||
self.configs.filter_length // 2 + 1,
|
||||
@ -520,8 +524,9 @@ class TTS:
|
||||
n_speakers=self.configs.n_speakers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.configs.use_vocoder = True
|
||||
self.init_vocoder(model_version)
|
||||
self.init_vocoder(model_version, vocoder_path)
|
||||
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
|
||||
del vits_model.enc_q
|
||||
|
||||
@ -570,7 +575,7 @@ class TTS:
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
def init_vocoder(self, version: str):
|
||||
def init_vocoder(self, version: str, vocoder_path: str = None):
|
||||
if version == "v3":
|
||||
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
||||
return
|
||||
@ -578,7 +583,7 @@ class TTS:
|
||||
self.vocoder.cpu()
|
||||
del self.vocoder
|
||||
self.empty_cache()
|
||||
|
||||
|
||||
self.vocoder = BigVGAN.from_pretrained(
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||
use_cuda_kernel=False,
|
||||
@ -611,8 +616,13 @@ class TTS:
|
||||
gin_channels=0, is_bias=True
|
||||
)
|
||||
self.vocoder.remove_weight_norm()
|
||||
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
||||
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
|
||||
print("loading vocoder",vocoder_path)
|
||||
try:
|
||||
state_dict_g = torch.load(vocoder_path, map_location="cpu")
|
||||
self.vocoder.load_state_dict(state_dict_g)
|
||||
except:
|
||||
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
||||
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
|
||||
|
||||
self.vocoder_configs["sr"] = 48000
|
||||
self.vocoder_configs["T_ref"] = 500
|
||||
@ -988,7 +998,7 @@ class TTS:
|
||||
seed = inputs.get("seed", -1)
|
||||
seed = -1 if seed in ["", None] else seed
|
||||
actual_seed = set_seed(seed)
|
||||
parallel_infer = inputs.get("parallel_infer", True)
|
||||
parallel_infer = inputs.get("parallel_infer", False)
|
||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||
sample_steps = inputs.get("sample_steps", 32)
|
||||
super_sampling = inputs.get("super_sampling", False)
|
||||
|
@ -25,7 +25,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
|
||||
from module.commons import sequence_mask
|
||||
from GPT_SoVITS.module.commons import sequence_mask
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
|
@ -14,7 +14,7 @@ from torch import nn
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
from GPT_SoVITS.f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvPositionEmbedding,
|
||||
MMDiTBlock,
|
||||
|
@ -17,7 +17,7 @@ import torch.nn.functional as F
|
||||
from x_transformers import RMSNorm
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
from GPT_SoVITS.f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
|
@ -10,7 +10,7 @@ from torch.nn import functional as F
|
||||
from GPT_SoVITS.module import commons
|
||||
from GPT_SoVITS.module import modules
|
||||
from GPT_SoVITS.module import attentions
|
||||
from f5_tts.model import DiT
|
||||
from GPT_SoVITS.f5_tts.model import DiT
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from GPT_SoVITS.module.commons import init_weights, get_padding
|
||||
|
@ -9,7 +9,7 @@ from GPT_SoVITS.text.symbols import punctuation
|
||||
from GPT_SoVITS.text.symbols2 import symbols
|
||||
|
||||
from builtins import str as unicode
|
||||
from text.en_normalization.expend import normalize
|
||||
from GPT_SoVITS.text.en_normalization.expend import normalize
|
||||
from nltk.tokenize import TweetTokenizer
|
||||
|
||||
word_tokenize = TweetTokenizer().tokenize
|
||||
|
Loading…
x
Reference in New Issue
Block a user