2025-09-06 22:58:58 +08:00

47 lines
1.2 KiB
Python

import logging
import os
import torch
import torch.nn as nn
from transformers import (
HubertModel,
Wav2Vec2FeatureExtractor,
)
from transformers import logging as tf_logging
tf_logging.set_verbosity_error()
logging.getLogger("numba").setLevel(logging.WARNING)
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
class CNHubert(nn.Module):
def __init__(self, base_path: str = ""):
super().__init__()
if not base_path:
base_path = cnhubert_base_path
if os.path.exists(base_path):
...
else:
raise FileNotFoundError(base_path)
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
def forward(self, x):
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats
def get_model():
model = CNHubert()
model.eval()
return model
def get_content(hmodel, wav_16k_tensor):
with torch.no_grad():
feats = hmodel(wav_16k_tensor)
return feats.transpose(1, 2)