mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-04 04:58:12 +08:00
.
This commit is contained in:
parent
49667f44e8
commit
d4c9eb031c
@ -14,7 +14,7 @@ import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
from torch.nn.utils.parametrizations weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from . import activations
|
||||
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
|
||||
@ -5,7 +5,7 @@ import glob
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
|
||||
@ -345,11 +345,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).eval()
|
||||
|
||||
if "pretrained" not in sovits_path:
|
||||
try:
|
||||
if hasattr(vq_model, "enc_q"):
|
||||
del vq_model.enc_q
|
||||
finally:
|
||||
pass
|
||||
|
||||
if is_lora is False:
|
||||
console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
@ -358,7 +357,8 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
console.print(f">> loading sovits_{model_version}spretrained_G")
|
||||
dict_pretrain = torch.load(path_sovits)["weight"]
|
||||
console.print(f">> loading sovits_{model_version}_lora{model_version}")
|
||||
state_dict = dict_pretrain.update(dict_s2["weight"])
|
||||
dict_pretrain.update(dict_s2["weight"])
|
||||
state_dict = dict_pretrain
|
||||
lora_rank = dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
|
||||
@ -8,7 +8,7 @@ from torch.cuda.amp import autocast
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from GPT_SoVITS.f5_tts.model import DiT
|
||||
from GPT_SoVITS.text import symbols as symbols_v1
|
||||
|
||||
@ -6,7 +6,7 @@ from torch import nn
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from GPT_SoVITS.f5_tts.model import DiT
|
||||
from GPT_SoVITS.text import symbols as symbols_v1
|
||||
|
||||
@ -7,7 +7,7 @@ from torch import nn
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from . import commons
|
||||
from .commons import get_padding, init_weights
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from .attentions import MultiHeadAttention
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ def save_ckpt(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||
if "enc_q" in key:
|
||||
continue
|
||||
opt["weight"][key] = ckpt[key].half()
|
||||
opt["config"] = hps.__dict__
|
||||
opt["config"] = hps.to_dict()
|
||||
opt["info"] = f"{epoch}epoch_{steps}iteration"
|
||||
if lora_rank:
|
||||
opt["lora_rank"] = lora_rank
|
||||
@ -51,8 +51,8 @@ def inspect_version(
|
||||
dict_s2 = torch.load(f, map_location="cpu", mmap=True)
|
||||
hps = dict_s2["config"]
|
||||
version: str | None = None
|
||||
if "version" in hps:
|
||||
version = hps.version
|
||||
if "version" in hps.keys():
|
||||
version = hps["version"]
|
||||
is_lora = "lora_rank" in dict_s2.keys()
|
||||
|
||||
if version is not None:
|
||||
|
||||
@ -93,7 +93,7 @@ def run(rank, n_gpus, hps):
|
||||
diagnose=True,
|
||||
format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}",
|
||||
)
|
||||
console.print(hps.__dict__)
|
||||
console.print(hps.to_dict())
|
||||
writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||
writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
else:
|
||||
|
||||
@ -88,7 +88,7 @@ def run(rank, n_gpus, hps):
|
||||
diagnose=True,
|
||||
format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}",
|
||||
)
|
||||
console.print(hps.__dict__)
|
||||
console.print(hps.to_dict())
|
||||
writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||
writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
else:
|
||||
|
||||
@ -97,7 +97,7 @@ def run(rank, n_gpus, hps):
|
||||
diagnose=True,
|
||||
format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}",
|
||||
)
|
||||
console.print(hps.__dict__)
|
||||
console.print(hps.to_dict())
|
||||
writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||
writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
else:
|
||||
|
||||
@ -360,3 +360,13 @@ class HParams:
|
||||
|
||||
def __repr__(self):
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert HParams to a plain dictionary recursively"""
|
||||
result = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, HParams):
|
||||
result[k] = v.to_dict()
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
# from utils import init_weights, get_padding
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user