This commit is contained in:
XXXXRT666 2025-09-05 21:32:27 +00:00
parent 49667f44e8
commit d4c9eb031c
13 changed files with 27 additions and 17 deletions

View File

@ -14,7 +14,7 @@ import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations weight_norm
from torch.nn.utils import weight_norm
from . import activations
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d

View File

@ -5,7 +5,7 @@ import glob
import os
import torch
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import weight_norm
def init_weights(m, mean=0.0, std=0.01):

View File

@ -345,11 +345,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
n_speakers=hps.data.n_speakers,
**hps.model,
).eval()
if "pretrained" not in sovits_path:
try:
if hasattr(vq_model, "enc_q"):
del vq_model.enc_q
finally:
pass
if is_lora is False:
console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"], strict=False))
@ -358,7 +357,8 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
console.print(f">> loading sovits_{model_version}spretrained_G")
dict_pretrain = torch.load(path_sovits)["weight"]
console.print(f">> loading sovits_{model_version}_lora{model_version}")
state_dict = dict_pretrain.update(dict_s2["weight"])
dict_pretrain.update(dict_s2["weight"])
state_dict = dict_pretrain
lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],

View File

@ -8,7 +8,7 @@ from torch.cuda.amp import autocast
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import weight_norm
from GPT_SoVITS.f5_tts.model import DiT
from GPT_SoVITS.text import symbols as symbols_v1

View File

@ -6,7 +6,7 @@ from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import weight_norm
from GPT_SoVITS.f5_tts.model import DiT
from GPT_SoVITS.text import symbols as symbols_v1

View File

@ -7,7 +7,7 @@ from torch import nn
from torch.nn import Conv1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import weight_norm
from . import commons
from .commons import get_padding, init_weights

View File

@ -3,7 +3,7 @@
import torch
from torch import nn
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import weight_norm
from .attentions import MultiHeadAttention

View File

@ -30,7 +30,7 @@ def save_ckpt(ckpt, name, epoch, steps, hps, lora_rank=None):
if "enc_q" in key:
continue
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps.__dict__
opt["config"] = hps.to_dict()
opt["info"] = f"{epoch}epoch_{steps}iteration"
if lora_rank:
opt["lora_rank"] = lora_rank
@ -51,8 +51,8 @@ def inspect_version(
dict_s2 = torch.load(f, map_location="cpu", mmap=True)
hps = dict_s2["config"]
version: str | None = None
if "version" in hps:
version = hps.version
if "version" in hps.keys():
version = hps["version"]
is_lora = "lora_rank" in dict_s2.keys()
if version is not None:

View File

@ -93,7 +93,7 @@ def run(rank, n_gpus, hps):
diagnose=True,
format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}",
)
console.print(hps.__dict__)
console.print(hps.to_dict())
writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
else:

View File

@ -88,7 +88,7 @@ def run(rank, n_gpus, hps):
diagnose=True,
format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}",
)
console.print(hps.__dict__)
console.print(hps.to_dict())
writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
else:

View File

@ -97,7 +97,7 @@ def run(rank, n_gpus, hps):
diagnose=True,
format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}",
)
console.print(hps.__dict__)
console.print(hps.to_dict())
writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
else:

View File

@ -360,3 +360,13 @@ class HParams:
def __repr__(self):
return self.__dict__.__repr__()
def to_dict(self):
"""Convert HParams to a plain dictionary recursively"""
result = {}
for k, v in self.__dict__.items():
if isinstance(v, HParams):
result[k] = v.to_dict()
else:
result[k] = v
return result

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import weight_norm
# from utils import init_weights, get_padding