diff --git a/tools/audio_sr.py b/tools/audio_sr.py index b714ab1..d51f055 100644 --- a/tools/audio_sr.py +++ b/tools/audio_sr.py @@ -12,7 +12,8 @@ import numpy as np import torchaudio import time import torchaudio.functional as aF -from attrdict import AttrDict +# from attrdict import AttrDict####will be bug in py3.10 + from datasets1.dataset import amp_pha_stft, amp_pha_istft from models.model import APNet_BWE_Model import soundfile as sf @@ -20,7 +21,7 @@ import matplotlib.pyplot as plt from rich.progress import track class AP_BWE(): - def __init__(self,device,checkpoint_file=None): + def __init__(self,device,DictToAttrRecursive,checkpoint_file=None): if checkpoint_file==None: checkpoint_file="%s/24kto48k/g_24kto48k.zip"%(AP_BWE_main_dir_path) if os.path.exists(checkpoint_file)==False: @@ -28,7 +29,8 @@ class AP_BWE(): config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json') with open(config_file) as f:data = f.read() json_config = json.loads(data) - h = AttrDict(json_config) + # h = AttrDict(json_config) + h = DictToAttrRecursive(json_config) model = APNet_BWE_Model(h).to(device) state_dict = torch.load(checkpoint_file,map_location="cpu",weights_only=False) model.load_state_dict(state_dict['generator']) @@ -46,4 +48,4 @@ class AP_BWE(): amp_wb_g, pha_wb_g, com_wb_g = self.model(amp_nb, pha_nb) audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, self.h.n_fft, self.h.hop_size, self.h.win_size) # sf.write(opt_path, audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate, 'PCM_16') - return audio_hr_g.squeeze().cpu().numpy(),self.h.hr_sampling_rate \ No newline at end of file + return audio_hr_g.squeeze().cpu().numpy(),self.h.hr_sampling_rate