feat: 添加中间量导出功能

This commit is contained in:
Kaning123 2026-04-06 13:01:32 +08:00
parent 24d7290c11
commit e6a67650ff
2 changed files with 122 additions and 5 deletions

View File

@ -25,6 +25,53 @@ import contextlib
import random
import torchaudio
from torchaudio.transforms import Resample
import os
from pathlib import Path
def merge_dir_txt2(*TXT):
return Path(os.path.join(*TXT))
def get_my_dir():
return os.path.dirname(os.path.abspath(__file__))
def get_parent_dir(dir_path,depth=1):
parent_path = Path(dir_path)
for _ in range(depth):
parent_path = parent_path.parent
return parent_path
POOL:set = set()
def _get_unique_name(name,MySet:set=set()):
_id = 1
if name not in POOL and name not in MySet:
POOL.add(name)
return name
while name in POOL or name in MySet:
_id += 1
name = f'{name}_{_id}'
POOL.add(name)
return name
def find_func(zf,il):
f = zf.get_file_path("voice.json")
info = il.load_info(f)
if info is None:
return None
list_names = info["access_list"]
global POOL
POOL.update(list_names)
ret = []
for name in list_names:
try:
a = zf.get_file_path(name)
ret.append(a)
except FileNotFoundError:
continue
return ret
ROOT_DIR = str(get_parent_dir(get_my_dir()))
class StochasticDurationPredictor(nn.Module):
def __init__(
self,
@ -153,7 +200,7 @@ class DurationPredictor(nn.Module):
WINDOW = {}
class TextEncoder(nn.Module):
class TextEncoder(nn.Module):
def __init__(
self,
out_channels,
@ -990,7 +1037,7 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def ge_(self, refer, sv_emb, InjectGE=False, GE=None, LoadGE=True):
def ge_(self, refer, sv_emb=None, InjectGE=False, GE=None, LoadGE=True):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
@ -1004,6 +1051,7 @@ class SynthesizerTrn(nn.Module):
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
print(f"ge.shape : {ge.shape}")
return ge
if LoadGE:
@ -1021,11 +1069,17 @@ class SynthesizerTrn(nn.Module):
GE = torch.stack(GE, 0).mean(0)
ge = GE
else:
raise ValueError
raise ValueError("No GE stream provided!")
return ge
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None,
InjectGE=False,GE=None,LoadGE=True):
InjectGE=False,GE=None,LoadGE=True,
InjectZP=False,ZP=None,LoadZP=True,
OverWrite_Mask=False,Mask=None,
SaveGE=False,SaveZP=False,SaveMask=False,
GE_Name=None, ZP_Name=None, Mask_Name=None,
VoiceSave=None):
ge = self.ge_(refer, sv_emb, InjectGE, GE, LoadGE)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
@ -1042,14 +1096,75 @@ class SynthesizerTrn(nn.Module):
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
if InjectZP:
if type(ZP) == list:
ZP = torch.stack(ZP, 0).mean(0)
else:
ZP = ZP
z_p = ZP
else:
if LoadZP:
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
else:
raise ValueError("No z_p stream provided!")
if OverWrite_Mask:
if type(Mask) == list:
Mask = torch.stack(Mask, 0).mean(0)
if Mask is None:
raise ValueError("No mask stream provided!")
y_mask = Mask
print(f"z_p shape: {z_p.shape}, y_mask shape: {y_mask.shape}, ge shape: {ge.shape}")
z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o
@torch.no_grad()
def decode2(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None,
InjectGE=False,GE=None,LoadGE=True,
InjectZP=False,ZP=None,LoadZP=True,
OverWrite_Mask=False,Mask=None,):
ge = self.ge_(refer, sv_emb, InjectGE, GE, LoadGE)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask, _, _ = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
if InjectZP:
if type(ZP) == list:
ZP = torch.stack(ZP, 0).mean(0)
else:
ZP = ZP
z_p = ZP
else:
if LoadZP:
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
else:
raise ValueError("No z_p stream provided!")
if OverWrite_Mask:
if type(Mask) == list:
Mask = torch.stack(Mask, 0).mean(0)
if Mask is None:
raise ValueError("No mask stream provided!")
y_mask = Mask
print(f"z_p shape: {z_p.shape}, y_mask shape: {y_mask.shape}, ge shape: {ge.shape}")
return z_p, y_mask, ge
@torch.no_grad()
def decode_streaming(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None, result_length:int=None, overlap_frames:torch.Tensor=None, padding_length:int=None):
def get_ge(refer, sv_emb):

View File

@ -432,6 +432,8 @@ class ResidualCouplingLayer(nn.Module):
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
print(f"x.shape: {x.shape}, x_mask.shape: {x_mask.shape}")
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)