mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 17:40:09 +08:00
Fix package pathing issues
This commit is contained in:
parent
6ca0eda2e9
commit
fe531567f1
@ -445,8 +445,9 @@ class TTS:
|
|||||||
def _init_models(
|
def _init_models(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
# Don't auto load the weights, load them independently not on instantiation
|
||||||
self.init_vits_weights(self.configs.vits_weights_path)
|
# self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||||
|
# self.init_vits_weights(self.configs.vits_weights_path)
|
||||||
self.init_bert_weights(self.configs.bert_base_path)
|
self.init_bert_weights(self.configs.bert_base_path)
|
||||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||||
# self.enable_half_precision(self.configs.is_half)
|
# self.enable_half_precision(self.configs.is_half)
|
||||||
@ -468,10 +469,11 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||||
self.bert_model = self.bert_model.half()
|
self.bert_model = self.bert_model.half()
|
||||||
|
|
||||||
def init_vits_weights(self, weights_path: str):
|
def init_vits_weights(self, weights_path: str, vocoder_path: str = None):
|
||||||
self.configs.vits_weights_path = weights_path
|
self.configs.vits_weights_path = weights_path
|
||||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
||||||
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
||||||
|
print(if_lora_v3)
|
||||||
|
|
||||||
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
||||||
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
|
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
|
||||||
@ -502,9 +504,10 @@ class TTS:
|
|||||||
|
|
||||||
self.configs.update_version(model_version)
|
self.configs.update_version(model_version)
|
||||||
|
|
||||||
# print(f"model_version:{model_version}")
|
print(f"model_version:{model_version}")
|
||||||
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
|
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
|
||||||
if model_version not in {"v3", "v4"}:
|
if model_version not in {"v3", "v4"}:
|
||||||
|
print("Using v2 model")
|
||||||
vits_model = SynthesizerTrn(
|
vits_model = SynthesizerTrn(
|
||||||
self.configs.filter_length // 2 + 1,
|
self.configs.filter_length // 2 + 1,
|
||||||
self.configs.segment_size // self.configs.hop_length,
|
self.configs.segment_size // self.configs.hop_length,
|
||||||
@ -513,6 +516,7 @@ class TTS:
|
|||||||
)
|
)
|
||||||
self.configs.use_vocoder = False
|
self.configs.use_vocoder = False
|
||||||
else:
|
else:
|
||||||
|
print("Using v3 model")
|
||||||
kwargs["version"]=model_version
|
kwargs["version"]=model_version
|
||||||
vits_model = SynthesizerTrnV3(
|
vits_model = SynthesizerTrnV3(
|
||||||
self.configs.filter_length // 2 + 1,
|
self.configs.filter_length // 2 + 1,
|
||||||
@ -520,8 +524,9 @@ class TTS:
|
|||||||
n_speakers=self.configs.n_speakers,
|
n_speakers=self.configs.n_speakers,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.configs.use_vocoder = True
|
self.configs.use_vocoder = True
|
||||||
self.init_vocoder(model_version)
|
self.init_vocoder(model_version, vocoder_path)
|
||||||
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
|
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
|
||||||
del vits_model.enc_q
|
del vits_model.enc_q
|
||||||
|
|
||||||
@ -570,7 +575,7 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
def init_vocoder(self, version: str):
|
def init_vocoder(self, version: str, vocoder_path: str = None):
|
||||||
if version == "v3":
|
if version == "v3":
|
||||||
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
if self.vocoder is not None and self.vocoder.__class__.__name__ == "BigVGAN":
|
||||||
return
|
return
|
||||||
@ -578,7 +583,7 @@ class TTS:
|
|||||||
self.vocoder.cpu()
|
self.vocoder.cpu()
|
||||||
del self.vocoder
|
del self.vocoder
|
||||||
self.empty_cache()
|
self.empty_cache()
|
||||||
|
|
||||||
self.vocoder = BigVGAN.from_pretrained(
|
self.vocoder = BigVGAN.from_pretrained(
|
||||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||||
use_cuda_kernel=False,
|
use_cuda_kernel=False,
|
||||||
@ -611,8 +616,13 @@ class TTS:
|
|||||||
gin_channels=0, is_bias=True
|
gin_channels=0, is_bias=True
|
||||||
)
|
)
|
||||||
self.vocoder.remove_weight_norm()
|
self.vocoder.remove_weight_norm()
|
||||||
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
print("loading vocoder",vocoder_path)
|
||||||
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
|
try:
|
||||||
|
state_dict_g = torch.load(vocoder_path, map_location="cpu")
|
||||||
|
self.vocoder.load_state_dict(state_dict_g)
|
||||||
|
except:
|
||||||
|
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
||||||
|
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g))
|
||||||
|
|
||||||
self.vocoder_configs["sr"] = 48000
|
self.vocoder_configs["sr"] = 48000
|
||||||
self.vocoder_configs["T_ref"] = 500
|
self.vocoder_configs["T_ref"] = 500
|
||||||
@ -988,7 +998,7 @@ class TTS:
|
|||||||
seed = inputs.get("seed", -1)
|
seed = inputs.get("seed", -1)
|
||||||
seed = -1 if seed in ["", None] else seed
|
seed = -1 if seed in ["", None] else seed
|
||||||
actual_seed = set_seed(seed)
|
actual_seed = set_seed(seed)
|
||||||
parallel_infer = inputs.get("parallel_infer", True)
|
parallel_infer = inputs.get("parallel_infer", False)
|
||||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||||
sample_steps = inputs.get("sample_steps", 32)
|
sample_steps = inputs.get("sample_steps", 32)
|
||||||
super_sampling = inputs.get("super_sampling", False)
|
super_sampling = inputs.get("super_sampling", False)
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
|
|||||||
get_pos_embed_indices,
|
get_pos_embed_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
from module.commons import sequence_mask
|
from GPT_SoVITS.module.commons import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class TextEmbedding(nn.Module):
|
class TextEmbedding(nn.Module):
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from torch import nn
|
|||||||
|
|
||||||
from x_transformers.x_transformers import RotaryEmbedding
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
from f5_tts.model.modules import (
|
from GPT_SoVITS.f5_tts.model.modules import (
|
||||||
TimestepEmbedding,
|
TimestepEmbedding,
|
||||||
ConvPositionEmbedding,
|
ConvPositionEmbedding,
|
||||||
MMDiTBlock,
|
MMDiTBlock,
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import torch.nn.functional as F
|
|||||||
from x_transformers import RMSNorm
|
from x_transformers import RMSNorm
|
||||||
from x_transformers.x_transformers import RotaryEmbedding
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
from f5_tts.model.modules import (
|
from GPT_SoVITS.f5_tts.model.modules import (
|
||||||
TimestepEmbedding,
|
TimestepEmbedding,
|
||||||
ConvNeXtV2Block,
|
ConvNeXtV2Block,
|
||||||
ConvPositionEmbedding,
|
ConvPositionEmbedding,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from torch.nn import functional as F
|
|||||||
from GPT_SoVITS.module import commons
|
from GPT_SoVITS.module import commons
|
||||||
from GPT_SoVITS.module import modules
|
from GPT_SoVITS.module import modules
|
||||||
from GPT_SoVITS.module import attentions
|
from GPT_SoVITS.module import attentions
|
||||||
from f5_tts.model import DiT
|
from GPT_SoVITS.f5_tts.model import DiT
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
from GPT_SoVITS.module.commons import init_weights, get_padding
|
from GPT_SoVITS.module.commons import init_weights, get_padding
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from GPT_SoVITS.text.symbols import punctuation
|
|||||||
from GPT_SoVITS.text.symbols2 import symbols
|
from GPT_SoVITS.text.symbols2 import symbols
|
||||||
|
|
||||||
from builtins import str as unicode
|
from builtins import str as unicode
|
||||||
from text.en_normalization.expend import normalize
|
from GPT_SoVITS.text.en_normalization.expend import normalize
|
||||||
from nltk.tokenize import TweetTokenizer
|
from nltk.tokenize import TweetTokenizer
|
||||||
|
|
||||||
word_tokenize = TweetTokenizer().tokenize
|
word_tokenize = TweetTokenizer().tokenize
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user