diff --git a/tools/audio_sr.py b/tools/audio_sr.py new file mode 100644 index 0000000..b714ab1 --- /dev/null +++ b/tools/audio_sr.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import sys,os +import traceback +AP_BWE_main_dir_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'AP_BWE_main') +sys.path.append(AP_BWE_main_dir_path) +import glob +import argparse +import json +from re import S +import torch +import numpy as np +import torchaudio +import time +import torchaudio.functional as aF +from attrdict import AttrDict +from datasets1.dataset import amp_pha_stft, amp_pha_istft +from models.model import APNet_BWE_Model +import soundfile as sf +import matplotlib.pyplot as plt +from rich.progress import track + +class AP_BWE(): + def __init__(self,device,checkpoint_file=None): + if checkpoint_file==None: + checkpoint_file="%s/24kto48k/g_24kto48k.zip"%(AP_BWE_main_dir_path) + if os.path.exists(checkpoint_file)==False: + raise FileNotFoundError + config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json') + with open(config_file) as f:data = f.read() + json_config = json.loads(data) + h = AttrDict(json_config) + model = APNet_BWE_Model(h).to(device) + state_dict = torch.load(checkpoint_file,map_location="cpu",weights_only=False) + model.load_state_dict(state_dict['generator']) + model.eval() + self.device=device + self.model=model + self.h=h + + def __call__(self, audio,orig_sampling_rate): + with torch.no_grad(): + # audio, orig_sampling_rate = torchaudio.load(inp_path) + # audio = audio.to(self.device) + audio = aF.resample(audio, orig_freq=orig_sampling_rate, new_freq=self.h.hr_sampling_rate) + amp_nb, pha_nb, com_nb = amp_pha_stft(audio, self.h.n_fft, self.h.hop_size, self.h.win_size) + amp_wb_g, pha_wb_g, com_wb_g = self.model(amp_nb, pha_nb) + audio_hr_g = amp_pha_istft(amp_wb_g, pha_wb_g, self.h.n_fft, self.h.hop_size, self.h.win_size) + # sf.write(opt_path, audio_hr_g.squeeze().cpu().numpy(), self.h.hr_sampling_rate, 'PCM_16') + return audio_hr_g.squeeze().cpu().numpy(),self.h.hr_sampling_rate \ No newline at end of file