Merge pull request #14 from AnyaCoder/patch-3

FIx: cannot identify one class to a dict(needed)
This commit is contained in:
RVC-Boss 2024-01-17 15:50:20 +08:00 committed by GitHub
commit ff4ff6b637
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,16 +17,11 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
is_half = eval(os.environ.get("is_half", "True")) is_half = eval(os.environ.get("is_half", "True"))
import gradio as gr import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch, numpy as np import numpy as np
import os, librosa, torch import librosa,torch
# torch.backends.cuda.sdp_kernel("flash")
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
# torch.backends.cuda.enable_math_sdp(True)
from feature_extractor import cnhubert from feature_extractor import cnhubert
cnhubert.cnhubert_base_path=cnhubert_base_path
cnhubert.cnhubert_base_path = cnhubert_base_path
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
@ -63,21 +58,40 @@ def get_bert_feature(text, word2ph):
n_semantic = 1024 n_semantic = 1024
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
dict_s2=torch.load(sovits_path,map_location="cpu")
hps=dict_s2["config"]
class DictToAttrRecursive: class DictToAttrRecursive(dict):
def __init__(self, input_dict): def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items(): for key, value in input_dict.items():
if isinstance(value, dict): if isinstance(value, dict):
# 如果值是字典,递归调用构造函数 value = DictToAttrRecursive(value)
setattr(self, key, DictToAttrRecursive(value)) self[key] = value
else: setattr(self, key, value)
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
hps = DictToAttrRecursive(hps) hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz" hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu") dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"] config = dict_s1["config"]