GPT-SoVITS/GPT_SoVITS/module/VoiceChange.py
2026-04-06 12:59:31 +08:00

175 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchaudio
import math
from torchaudio.transforms import Resample
import VoiceSave
import uuid
def get_train_set(voice_file_path):
if type(voice_file_path) == str:
voice_file_path = [voice_file_path]
ret = []
for i in voice_file_path:
tensors_ = VoiceSave.load_tensor(i,
f"get_{uuid.uuid4()}",
find_func=VoiceSave.__find_func__,
MySet=set())
ret.append(tensors_)
return ret
class MelSpectrogram(nn.Module):
def __init__(self, hps):
super().__init__()
self.filter_length = hps.data.filter_length
self.hop_length = hps.data.hop_length
self.win_length = hps.data.win_length
self.sampling_rate = hps.data.sampling_rate
self.n_mel_channels = hps.data.n_mel_channels
self.mel_fmin = hps.data.mel_fmin if hasattr(hps.data, 'mel_fmin') else 0
self.mel_fmax = hps.data.mel_fmax if hasattr(hps.data, 'mel_fmax') else None
# 构建梅尔频谱变换
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sampling_rate,
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
f_min=self.mel_fmin,
f_max=self.mel_fmax,
n_mels=192, # self.n_mel_channels,
window_fn=torch.hann_window,
center=False,
power=1.0,
)
def forward(self, audio):
"""
输入audio [B, 1, T] 或 [1, T](单声道音频)
输出mel_spec [B, n_mel_channels, T']
"""
if len(audio.shape) == 2:
audio = audio.unsqueeze(0) # [1, T] → [1, 1, T]
# 提取梅尔频谱
mel_spec = self.mel_transform(audio.squeeze(1)) # [B, n_mel, T']
# 对数缩放TTS标准操作
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
return mel_spec
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
super(PositionalEncoding, self).__init__()
self.pe = torch.zeros(max_seq_length, d_model) # 初始化位置编码矩阵
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
self.pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置使用正弦函数
self.pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置使用余弦函数
self.register_buffer('pe', self.pe.unsqueeze(0)) # 注册为缓冲区
def forward(self, x):
# 将位置编码添加到输入中
return x + self.pe[:, :x.size(1)]
class Spliter(nn.Module):
'''output: z_p shape: torch.Size([1, 192, x]), y_mask shape: torch.Size([1, 1, x]), ge shape: torch.Size([1, 1024, 1])'''
def __init__(self,
hps,
ge,
device):
super().__init__()
self.hps = hps
self.ge = ge
self.device = device
#TODO: 将mel_spec与ge输入Transformer模型
self.mel_dim = 192
self.ge_dim = 1024
self.transformer_dim = 512
self.ge_proj = nn.Linear(self.ge_dim, self.transformer_dim).to(self.device)
self.mel_proj = nn.Linear(self.mel_dim, self.transformer_dim).to(self.device)
self.pos_encoder = PositionalEncoding(self.transformer_dim).to(self.device)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.transformer_dim,
nhead=hps.model.nhead,
dim_feedforward=hps.model.ffn_dim,
batch_first=False,
dropout=0.1
),
num_layers=hps.model.num_layers
).to(self.device)
self.out_proj = nn.Linear(self.transformer_dim, self.mel_dim).to(self.device)
@torch.no_grad()
def mel_(self,audio_path, hps, device, dtype):
sr_target = int(hps.data.sampling_rate)
audio, sr_origin = torchaudio.load(audio_path)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sr_origin != sr_target:
resampler = Resample(sr_origin, sr_target).to(device)
audio = resampler(audio.to(device))
else:
audio = audio.to(device)
max_audio = audio.abs().max()
if max_audio > 1.0:
audio = audio / max_audio
mel_extractor = MelSpectrogram(hps).to(device)
mel_spec = mel_extractor(audio).to(dtype)
return mel_spec
def forward(self, audio_path, ge,device,dtype):
# 输入audio_path, ge
# 输出z_p, y_mask, ge
ge_ = ge
mel = self.mel_(audio_path, self.hps, device, dtype)
mel = mel.permute(2, 0, 1)
# 梅尔谱投影到Transformer维度[T, 1, 512]
mel_feat = self.mel_proj(mel)
# 全局情感特征GE处理[1,1024,1] → [1,1024] → [1,1,512]
ge = ge.to(device, dtype=dtype)
ge_squeeze = ge.squeeze(-1) # [1, 1024]
ge_feat = self.ge_proj(ge_squeeze).unsqueeze(0) # [1, 1, 512]
# ===================== 3. 特征融合与Transformer输入 =====================
# 将GE特征拼接在梅尔谱序列开头[T+1, 1, 512]
self.transformer_input = torch.cat([ge_feat, mel_feat], dim=0)
# 添加位置编码
self.transformer_input = self.pos_encoder(self.transformer_input)
# ===================== 4. Transformer编码 =====================
transformer_out = self.transformer(self.transformer_input) # [T+1, 1, 512]
# ===================== 5. 输出特征重构 =====================
# 去除GE开头提取梅尔谱对应的输出[T, 1, 512]
mel_out = transformer_out[1:, :, :]
# 投影回原始梅尔维度:[T, 1, 192]
mel_out = self.out_proj(mel_out)
# 转换为目标格式:[1, 192, T] → z_p
z_p = mel_out.permute(1, 2, 0)
# ===================== 6. 生成掩码 =====================
T = z_p.shape[-1] # 梅尔谱时间步
y_mask = torch.ones(1, 1, T, device=device, dtype=dtype) # [1,1,T] 全1掩码
# ===================== 7. 输出(严格匹配注释格式) =====================
return z_p, y_mask, ge_
class SpliterDataset(torch.utils.data.Dataset):
def __init__(self, voice_file_paths):
self.voice_file_paths = voice_file_paths
self.datas = get_train_set(voice_file_paths)
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
return self.datas[idx]