mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-04-29 21:00:42 +08:00
feat: Added VoiceChange.py
This commit is contained in:
parent
fb50fc090f
commit
24d7290c11
175
GPT_SoVITS/module/VoiceChange.py
Normal file
175
GPT_SoVITS/module/VoiceChange.py
Normal file
@ -0,0 +1,175 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import math
|
||||
from torchaudio.transforms import Resample
|
||||
import VoiceSave
|
||||
import uuid
|
||||
|
||||
def get_train_set(voice_file_path):
|
||||
if type(voice_file_path) == str:
|
||||
voice_file_path = [voice_file_path]
|
||||
ret = []
|
||||
for i in voice_file_path:
|
||||
tensors_ = VoiceSave.load_tensor(i,
|
||||
f"get_{uuid.uuid4()}",
|
||||
find_func=VoiceSave.__find_func__,
|
||||
MySet=set())
|
||||
ret.append(tensors_)
|
||||
return ret
|
||||
|
||||
class MelSpectrogram(nn.Module):
|
||||
def __init__(self, hps):
|
||||
super().__init__()
|
||||
self.filter_length = hps.data.filter_length
|
||||
self.hop_length = hps.data.hop_length
|
||||
self.win_length = hps.data.win_length
|
||||
self.sampling_rate = hps.data.sampling_rate
|
||||
self.n_mel_channels = hps.data.n_mel_channels
|
||||
self.mel_fmin = hps.data.mel_fmin if hasattr(hps.data, 'mel_fmin') else 0
|
||||
self.mel_fmax = hps.data.mel_fmax if hasattr(hps.data, 'mel_fmax') else None
|
||||
|
||||
# 构建梅尔频谱变换
|
||||
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.sampling_rate,
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
f_min=self.mel_fmin,
|
||||
f_max=self.mel_fmax,
|
||||
n_mels=192, # self.n_mel_channels,
|
||||
window_fn=torch.hann_window,
|
||||
center=False,
|
||||
power=1.0,
|
||||
)
|
||||
|
||||
def forward(self, audio):
|
||||
"""
|
||||
输入:audio [B, 1, T] 或 [1, T](单声道音频)
|
||||
输出:mel_spec [B, n_mel_channels, T']
|
||||
"""
|
||||
if len(audio.shape) == 2:
|
||||
audio = audio.unsqueeze(0) # [1, T] → [1, 1, T]
|
||||
|
||||
# 提取梅尔频谱
|
||||
mel_spec = self.mel_transform(audio.squeeze(1)) # [B, n_mel, T']
|
||||
|
||||
# 对数缩放(TTS标准操作)
|
||||
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
||||
|
||||
return mel_spec
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, max_seq_length=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.pe = torch.zeros(max_seq_length, d_model) # 初始化位置编码矩阵
|
||||
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
||||
self.pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置使用正弦函数
|
||||
self.pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置使用余弦函数
|
||||
self.register_buffer('pe', self.pe.unsqueeze(0)) # 注册为缓冲区
|
||||
|
||||
def forward(self, x):
|
||||
# 将位置编码添加到输入中
|
||||
return x + self.pe[:, :x.size(1)]
|
||||
|
||||
class Spliter(nn.Module):
|
||||
'''output: z_p shape: torch.Size([1, 192, x]), y_mask shape: torch.Size([1, 1, x]), ge shape: torch.Size([1, 1024, 1])'''
|
||||
def __init__(self,
|
||||
hps,
|
||||
ge,
|
||||
device):
|
||||
super().__init__()
|
||||
self.hps = hps
|
||||
|
||||
self.ge = ge
|
||||
self.device = device
|
||||
#TODO: 将mel_spec与ge输入Transformer模型
|
||||
self.mel_dim = 192
|
||||
self.ge_dim = 1024
|
||||
self.transformer_dim = 512
|
||||
self.ge_proj = nn.Linear(self.ge_dim, self.transformer_dim).to(self.device)
|
||||
self.mel_proj = nn.Linear(self.mel_dim, self.transformer_dim).to(self.device)
|
||||
self.pos_encoder = PositionalEncoding(self.transformer_dim).to(self.device)
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=self.transformer_dim,
|
||||
nhead=hps.model.nhead,
|
||||
dim_feedforward=hps.model.ffn_dim,
|
||||
batch_first=False,
|
||||
dropout=0.1
|
||||
),
|
||||
num_layers=hps.model.num_layers
|
||||
).to(self.device)
|
||||
|
||||
self.out_proj = nn.Linear(self.transformer_dim, self.mel_dim).to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def mel_(self,audio_path, hps, device, dtype):
|
||||
sr_target = int(hps.data.sampling_rate)
|
||||
audio, sr_origin = torchaudio.load(audio_path)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0, keepdim=True)
|
||||
if sr_origin != sr_target:
|
||||
resampler = Resample(sr_origin, sr_target).to(device)
|
||||
audio = resampler(audio.to(device))
|
||||
else:
|
||||
audio = audio.to(device)
|
||||
max_audio = audio.abs().max()
|
||||
if max_audio > 1.0:
|
||||
audio = audio / max_audio
|
||||
mel_extractor = MelSpectrogram(hps).to(device)
|
||||
mel_spec = mel_extractor(audio).to(dtype)
|
||||
return mel_spec
|
||||
|
||||
def forward(self, audio_path, ge,device,dtype):
|
||||
# 输入:audio_path, ge
|
||||
# 输出:z_p, y_mask, ge
|
||||
ge_ = ge
|
||||
mel = self.mel_(audio_path, self.hps, device, dtype)
|
||||
|
||||
mel = mel.permute(2, 0, 1)
|
||||
# 梅尔谱投影到Transformer维度:[T, 1, 512]
|
||||
mel_feat = self.mel_proj(mel)
|
||||
|
||||
# 全局情感特征GE处理:[1,1024,1] → [1,1024] → [1,1,512]
|
||||
ge = ge.to(device, dtype=dtype)
|
||||
ge_squeeze = ge.squeeze(-1) # [1, 1024]
|
||||
ge_feat = self.ge_proj(ge_squeeze).unsqueeze(0) # [1, 1, 512]
|
||||
|
||||
# ===================== 3. 特征融合与Transformer输入 =====================
|
||||
# 将GE特征拼接在梅尔谱序列开头:[T+1, 1, 512]
|
||||
self.transformer_input = torch.cat([ge_feat, mel_feat], dim=0)
|
||||
# 添加位置编码
|
||||
self.transformer_input = self.pos_encoder(self.transformer_input)
|
||||
|
||||
# ===================== 4. Transformer编码 =====================
|
||||
transformer_out = self.transformer(self.transformer_input) # [T+1, 1, 512]
|
||||
|
||||
# ===================== 5. 输出特征重构 =====================
|
||||
# 去除GE开头,提取梅尔谱对应的输出:[T, 1, 512]
|
||||
mel_out = transformer_out[1:, :, :]
|
||||
# 投影回原始梅尔维度:[T, 1, 192]
|
||||
mel_out = self.out_proj(mel_out)
|
||||
# 转换为目标格式:[1, 192, T] → z_p
|
||||
z_p = mel_out.permute(1, 2, 0)
|
||||
|
||||
# ===================== 6. 生成掩码 =====================
|
||||
T = z_p.shape[-1] # 梅尔谱时间步
|
||||
y_mask = torch.ones(1, 1, T, device=device, dtype=dtype) # [1,1,T] 全1掩码
|
||||
|
||||
# ===================== 7. 输出(严格匹配注释格式) =====================
|
||||
return z_p, y_mask, ge_
|
||||
|
||||
class SpliterDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, voice_file_paths):
|
||||
self.voice_file_paths = voice_file_paths
|
||||
self.datas = get_train_set(voice_file_paths)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.datas)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.datas[idx]
|
||||
Loading…
x
Reference in New Issue
Block a user