diff --git a/GPT_SoVITS/module/VoiceChange.py b/GPT_SoVITS/module/VoiceChange.py new file mode 100644 index 00000000..ca0a1dda --- /dev/null +++ b/GPT_SoVITS/module/VoiceChange.py @@ -0,0 +1,175 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import torchaudio +import math +from torchaudio.transforms import Resample +import VoiceSave +import uuid + +def get_train_set(voice_file_path): + if type(voice_file_path) == str: + voice_file_path = [voice_file_path] + ret = [] + for i in voice_file_path: + tensors_ = VoiceSave.load_tensor(i, + f"get_{uuid.uuid4()}", + find_func=VoiceSave.__find_func__, + MySet=set()) + ret.append(tensors_) + return ret + +class MelSpectrogram(nn.Module): + def __init__(self, hps): + super().__init__() + self.filter_length = hps.data.filter_length + self.hop_length = hps.data.hop_length + self.win_length = hps.data.win_length + self.sampling_rate = hps.data.sampling_rate + self.n_mel_channels = hps.data.n_mel_channels + self.mel_fmin = hps.data.mel_fmin if hasattr(hps.data, 'mel_fmin') else 0 + self.mel_fmax = hps.data.mel_fmax if hasattr(hps.data, 'mel_fmax') else None + + # 构建梅尔频谱变换 + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.sampling_rate, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + f_min=self.mel_fmin, + f_max=self.mel_fmax, + n_mels=192, # self.n_mel_channels, + window_fn=torch.hann_window, + center=False, + power=1.0, + ) + + def forward(self, audio): + """ + 输入:audio [B, 1, T] 或 [1, T](单声道音频) + 输出:mel_spec [B, n_mel_channels, T'] + """ + if len(audio.shape) == 2: + audio = audio.unsqueeze(0) # [1, T] → [1, 1, T] + + # 提取梅尔频谱 + mel_spec = self.mel_transform(audio.squeeze(1)) # [B, n_mel, T'] + + # 对数缩放(TTS标准操作) + mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) + + return mel_spec + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_seq_length=5000): + super(PositionalEncoding, self).__init__() + self.pe = torch.zeros(max_seq_length, d_model) # 初始化位置编码矩阵 + position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) + self.pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置使用正弦函数 + self.pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置使用余弦函数 + self.register_buffer('pe', self.pe.unsqueeze(0)) # 注册为缓冲区 + + def forward(self, x): + # 将位置编码添加到输入中 + return x + self.pe[:, :x.size(1)] + +class Spliter(nn.Module): + '''output: z_p shape: torch.Size([1, 192, x]), y_mask shape: torch.Size([1, 1, x]), ge shape: torch.Size([1, 1024, 1])''' + def __init__(self, + hps, + ge, + device): + super().__init__() + self.hps = hps + + self.ge = ge + self.device = device + #TODO: 将mel_spec与ge输入Transformer模型 + self.mel_dim = 192 + self.ge_dim = 1024 + self.transformer_dim = 512 + self.ge_proj = nn.Linear(self.ge_dim, self.transformer_dim).to(self.device) + self.mel_proj = nn.Linear(self.mel_dim, self.transformer_dim).to(self.device) + self.pos_encoder = PositionalEncoding(self.transformer_dim).to(self.device) + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=self.transformer_dim, + nhead=hps.model.nhead, + dim_feedforward=hps.model.ffn_dim, + batch_first=False, + dropout=0.1 + ), + num_layers=hps.model.num_layers + ).to(self.device) + + self.out_proj = nn.Linear(self.transformer_dim, self.mel_dim).to(self.device) + + @torch.no_grad() + def mel_(self,audio_path, hps, device, dtype): + sr_target = int(hps.data.sampling_rate) + audio, sr_origin = torchaudio.load(audio_path) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + if sr_origin != sr_target: + resampler = Resample(sr_origin, sr_target).to(device) + audio = resampler(audio.to(device)) + else: + audio = audio.to(device) + max_audio = audio.abs().max() + if max_audio > 1.0: + audio = audio / max_audio + mel_extractor = MelSpectrogram(hps).to(device) + mel_spec = mel_extractor(audio).to(dtype) + return mel_spec + + def forward(self, audio_path, ge,device,dtype): + # 输入:audio_path, ge + # 输出:z_p, y_mask, ge + ge_ = ge + mel = self.mel_(audio_path, self.hps, device, dtype) + + mel = mel.permute(2, 0, 1) + # 梅尔谱投影到Transformer维度:[T, 1, 512] + mel_feat = self.mel_proj(mel) + + # 全局情感特征GE处理:[1,1024,1] → [1,1024] → [1,1,512] + ge = ge.to(device, dtype=dtype) + ge_squeeze = ge.squeeze(-1) # [1, 1024] + ge_feat = self.ge_proj(ge_squeeze).unsqueeze(0) # [1, 1, 512] + + # ===================== 3. 特征融合与Transformer输入 ===================== + # 将GE特征拼接在梅尔谱序列开头:[T+1, 1, 512] + self.transformer_input = torch.cat([ge_feat, mel_feat], dim=0) + # 添加位置编码 + self.transformer_input = self.pos_encoder(self.transformer_input) + + # ===================== 4. Transformer编码 ===================== + transformer_out = self.transformer(self.transformer_input) # [T+1, 1, 512] + + # ===================== 5. 输出特征重构 ===================== + # 去除GE开头,提取梅尔谱对应的输出:[T, 1, 512] + mel_out = transformer_out[1:, :, :] + # 投影回原始梅尔维度:[T, 1, 192] + mel_out = self.out_proj(mel_out) + # 转换为目标格式:[1, 192, T] → z_p + z_p = mel_out.permute(1, 2, 0) + + # ===================== 6. 生成掩码 ===================== + T = z_p.shape[-1] # 梅尔谱时间步 + y_mask = torch.ones(1, 1, T, device=device, dtype=dtype) # [1,1,T] 全1掩码 + + # ===================== 7. 输出(严格匹配注释格式) ===================== + return z_p, y_mask, ge_ + +class SpliterDataset(torch.utils.data.Dataset): + def __init__(self, voice_file_paths): + self.voice_file_paths = voice_file_paths + self.datas = get_train_set(voice_file_paths) + + def __len__(self): + return len(self.datas) + + def __getitem__(self, idx): + return self.datas[idx] \ No newline at end of file