mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-24 13:33:33 +08:00
* Update Install.sh * Format Code * Delete dev null * Update README, Support Dark Mode in CSS/JS
33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
import sys
|
|
import os
|
|
import torch
|
|
|
|
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
|
|
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
|
from ERes2NetV2 import ERes2NetV2
|
|
import kaldi as Kaldi
|
|
|
|
|
|
class SV:
|
|
def __init__(self, device, is_half):
|
|
pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False)
|
|
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
|
embedding_model.load_state_dict(pretrained_state)
|
|
embedding_model.eval()
|
|
self.embedding_model = embedding_model
|
|
if is_half == False:
|
|
self.embedding_model = self.embedding_model.to(device)
|
|
else:
|
|
self.embedding_model = self.embedding_model.half().to(device)
|
|
self.is_half = is_half
|
|
|
|
def compute_embedding3(self, wav):
|
|
with torch.no_grad():
|
|
if self.is_half == True:
|
|
wav = wav.half()
|
|
feat = torch.stack(
|
|
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
|
|
)
|
|
sv_emb = self.embedding_model.forward3(feat)
|
|
return sv_emb
|