feat: Added VoiceChange.py

This commit is contained in:
Kaning123 2026-04-06 12:59:31 +08:00
parent fb50fc090f
commit 24d7290c11

View File

@ -0,0 +1,175 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchaudio
import math
from torchaudio.transforms import Resample
import VoiceSave
import uuid
def get_train_set(voice_file_path):
if type(voice_file_path) == str:
voice_file_path = [voice_file_path]
ret = []
for i in voice_file_path:
tensors_ = VoiceSave.load_tensor(i,
f"get_{uuid.uuid4()}",
find_func=VoiceSave.__find_func__,
MySet=set())
ret.append(tensors_)
return ret
class MelSpectrogram(nn.Module):
def __init__(self, hps):
super().__init__()
self.filter_length = hps.data.filter_length
self.hop_length = hps.data.hop_length
self.win_length = hps.data.win_length
self.sampling_rate = hps.data.sampling_rate
self.n_mel_channels = hps.data.n_mel_channels
self.mel_fmin = hps.data.mel_fmin if hasattr(hps.data, 'mel_fmin') else 0
self.mel_fmax = hps.data.mel_fmax if hasattr(hps.data, 'mel_fmax') else None
# 构建梅尔频谱变换
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sampling_rate,
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
f_min=self.mel_fmin,
f_max=self.mel_fmax,
n_mels=192, # self.n_mel_channels,
window_fn=torch.hann_window,
center=False,
power=1.0,
)
def forward(self, audio):
"""
输入audio [B, 1, T] [1, T]单声道音频
输出mel_spec [B, n_mel_channels, T']
"""
if len(audio.shape) == 2:
audio = audio.unsqueeze(0) # [1, T] → [1, 1, T]
# 提取梅尔频谱
mel_spec = self.mel_transform(audio.squeeze(1)) # [B, n_mel, T']
# 对数缩放TTS标准操作
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
return mel_spec
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
super(PositionalEncoding, self).__init__()
self.pe = torch.zeros(max_seq_length, d_model) # 初始化位置编码矩阵
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
self.pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置使用正弦函数
self.pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置使用余弦函数
self.register_buffer('pe', self.pe.unsqueeze(0)) # 注册为缓冲区
def forward(self, x):
# 将位置编码添加到输入中
return x + self.pe[:, :x.size(1)]
class Spliter(nn.Module):
'''output: z_p shape: torch.Size([1, 192, x]), y_mask shape: torch.Size([1, 1, x]), ge shape: torch.Size([1, 1024, 1])'''
def __init__(self,
hps,
ge,
device):
super().__init__()
self.hps = hps
self.ge = ge
self.device = device
#TODO: 将mel_spec与ge输入Transformer模型
self.mel_dim = 192
self.ge_dim = 1024
self.transformer_dim = 512
self.ge_proj = nn.Linear(self.ge_dim, self.transformer_dim).to(self.device)
self.mel_proj = nn.Linear(self.mel_dim, self.transformer_dim).to(self.device)
self.pos_encoder = PositionalEncoding(self.transformer_dim).to(self.device)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.transformer_dim,
nhead=hps.model.nhead,
dim_feedforward=hps.model.ffn_dim,
batch_first=False,
dropout=0.1
),
num_layers=hps.model.num_layers
).to(self.device)
self.out_proj = nn.Linear(self.transformer_dim, self.mel_dim).to(self.device)
@torch.no_grad()
def mel_(self,audio_path, hps, device, dtype):
sr_target = int(hps.data.sampling_rate)
audio, sr_origin = torchaudio.load(audio_path)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sr_origin != sr_target:
resampler = Resample(sr_origin, sr_target).to(device)
audio = resampler(audio.to(device))
else:
audio = audio.to(device)
max_audio = audio.abs().max()
if max_audio > 1.0:
audio = audio / max_audio
mel_extractor = MelSpectrogram(hps).to(device)
mel_spec = mel_extractor(audio).to(dtype)
return mel_spec
def forward(self, audio_path, ge,device,dtype):
# 输入audio_path, ge
# 输出z_p, y_mask, ge
ge_ = ge
mel = self.mel_(audio_path, self.hps, device, dtype)
mel = mel.permute(2, 0, 1)
# 梅尔谱投影到Transformer维度[T, 1, 512]
mel_feat = self.mel_proj(mel)
# 全局情感特征GE处理[1,1024,1] → [1,1024] → [1,1,512]
ge = ge.to(device, dtype=dtype)
ge_squeeze = ge.squeeze(-1) # [1, 1024]
ge_feat = self.ge_proj(ge_squeeze).unsqueeze(0) # [1, 1, 512]
# ===================== 3. 特征融合与Transformer输入 =====================
# 将GE特征拼接在梅尔谱序列开头[T+1, 1, 512]
self.transformer_input = torch.cat([ge_feat, mel_feat], dim=0)
# 添加位置编码
self.transformer_input = self.pos_encoder(self.transformer_input)
# ===================== 4. Transformer编码 =====================
transformer_out = self.transformer(self.transformer_input) # [T+1, 1, 512]
# ===================== 5. 输出特征重构 =====================
# 去除GE开头提取梅尔谱对应的输出[T, 1, 512]
mel_out = transformer_out[1:, :, :]
# 投影回原始梅尔维度:[T, 1, 192]
mel_out = self.out_proj(mel_out)
# 转换为目标格式:[1, 192, T] → z_p
z_p = mel_out.permute(1, 2, 0)
# ===================== 6. 生成掩码 =====================
T = z_p.shape[-1] # 梅尔谱时间步
y_mask = torch.ones(1, 1, T, device=device, dtype=dtype) # [1,1,T] 全1掩码
# ===================== 7. 输出(严格匹配注释格式) =====================
return z_p, y_mask, ge_
class SpliterDataset(torch.utils.data.Dataset):
def __init__(self, voice_file_paths):
self.voice_file_paths = voice_file_paths
self.datas = get_train_set(voice_file_paths)
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
return self.datas[idx]