Update audio_sr.py

This commit is contained in:
RVC-Boss 2025-02-27 22:25:24 +08:00 committed by GitHub
parent a68e3c4354
commit 0b20a949ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,7 +12,8 @@ import numpy as np
import torchaudio
import time
import torchaudio.functional as aF
from attrdict import AttrDict
# from attrdict import AttrDict####will be bug in py3.10
from datasets1.dataset import amp_pha_stft, amp_pha_istft
from models.model import APNet_BWE_Model
import soundfile as sf
@ -20,7 +21,7 @@ import matplotlib.pyplot as plt
from rich.progress import track
class AP_BWE():
def __init__(self,device,checkpoint_file=None):
def __init__(self,device,DictToAttrRecursive,checkpoint_file=None):
if checkpoint_file==None:
checkpoint_file="%s/24kto48k/g_24kto48k.zip"%(AP_BWE_main_dir_path)
if os.path.exists(checkpoint_file)==False:
@ -28,7 +29,8 @@ class AP_BWE():
config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
with open(config_file) as f:data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
# h = AttrDict(json_config)
h = DictToAttrRecursive(json_config)
model = APNet_BWE_Model(h).to(device)
state_dict = torch.load(checkpoint_file,map_location="cpu",weights_only=False)
model.load_state_dict(state_dict['generator'])
@ -46,4 +48,4 @@ class AP_BWE():
amp_wb_g, pha_wb_g, com_wb_g = self.model(amp_nb, pha_nb)
audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, self.h.n_fft, self.h.hop_size, self.h.win_size)
# sf.write(opt_path, audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate, 'PCM_16')
return audio_hr_g.squeeze().cpu().numpy(),self.h.hr_sampling_rate
return audio_hr_g.squeeze().cpu().numpy(),self.h.hr_sampling_rate