mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
30 lines
756 B
Python
30 lines
756 B
Python
import logging
|
|
from functools import cache
|
|
|
|
import torch
|
|
|
|
from ..inference import inference
|
|
from .train import Denoiser, HParams
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@cache
|
|
def load_denoiser(run_dir, device):
|
|
if run_dir is None:
|
|
return Denoiser(HParams())
|
|
hp = HParams.load(run_dir)
|
|
denoiser = Denoiser(hp)
|
|
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
|
state_dict = torch.load(path, map_location="cpu")["module"]
|
|
denoiser.load_state_dict(state_dict)
|
|
denoiser.eval()
|
|
denoiser.to(device)
|
|
return denoiser
|
|
|
|
|
|
@torch.inference_mode()
|
|
def denoise(dwav, sr, run_dir, device):
|
|
denoiser = load_denoiser(run_dir, device)
|
|
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
|