ruff format --line-length 120 --target-version py39

This commit is contained in:
XXXXRT666 2025-04-01 11:14:56 +01:00
parent a893a4e283
commit dec3df3282
130 changed files with 7986 additions and 6424 deletions

View File

@ -1,5 +1,8 @@
# Download moda ASR related models # Download moda ASR related models
from modelscope import snapshot_download from modelscope import snapshot_download
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4") model_dir = snapshot_download(
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4") "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4"
)
model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4")
model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4")

View File

@ -4,14 +4,11 @@ import itertools
import math import math
import random import random
from random import shuffle from random import shuffle
from typing import Iterator from typing import Iterator, Optional, TypeVar
from typing import Optional
from typing import TypeVar
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import Dataset from torch.utils.data import Dataset, Sampler
from torch.utils.data import Sampler
__all__ = [ __all__ = [
"DistributedBucketSampler", "DistributedBucketSampler",
@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0: if rank >= num_replicas or rank < 0:
raise ValueError( raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset self.dataset = dataset
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.rank = rank self.rank = rank
@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there # If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally. # is no need to drop any data, since the dataset will be split equally.
if ( if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible. # Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when # This is to ensure each rank receives the same amount of data when
# using this Sampler. # using this Sampler.
self.num_samples = math.ceil( self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) (len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
/ self.num_replicas # type: ignore[arg-type]
) )
else: else:
self.num_samples = math.ceil( self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas len(self.dataset) / self.num_replicas,
) # type: ignore[arg-type] ) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle self.shuffle = shuffle
@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
grouped_batch_size = self.batch_size * self.num_replicas grouped_batch_size = self.batch_size * self.num_replicas
shuffled_bucket = list(itertools.chain(*shuffled_bucket)) shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [ batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
shuffle(batches) shuffle(batches)
indices = list(itertools.chain(*batches)) indices = list(itertools.chain(*batches))
else: else:
@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices): if padding_size <= len(indices):
indices += indices[:padding_size] indices += indices[:padding_size]
else: else:
indices += (indices * math.ceil(padding_size / len(indices)))[ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
:padding_size
]
else: else:
# remove tail of data to make it evenly divisible. # remove tail of data to make it evenly divisible.
indices = indices[: self.total_size] indices = indices[: self.total_size]

View File

@ -1,9 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
from pytorch_lightning import LightningDataModule from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from AR.data.bucket_sampler import DistributedBucketSampler from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset from AR.data.dataset import Text2SemanticDataset
from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule): class Text2SemanticDataModule(LightningDataModule):
@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val']) # pad_val=self.config['data']['pad_val'])
def train_dataloader(self): def train_dataloader(self):
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"] batch_size = (
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存 self.config["train"]["batch_size"] // 2
if self.config["train"].get("if_dpo", False) is True
else self.config["train"]["batch_size"]
)
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader( return DataLoader(
self._train_dataset, self._train_dataset,

View File

@ -2,18 +2,16 @@
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert") # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback
import os import os
from typing import Dict import traceback
from typing import List from typing import Dict, List
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset
version = os.environ.get('version',None) version = os.environ.get("version", None)
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
@ -32,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = [] padded_sequences = []
for seq, length in zip(sequences, seq_lengths): for seq, length in zip(sequences, seq_lengths):
padding = ( padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value) padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq) padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences) batch = np.stack(padded_sequences)
@ -59,12 +55,16 @@ class Text2SemanticDataset(Dataset):
super().__init__() super().__init__()
self.semantic_data = pd.read_csv( self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8" semantic_path,
delimiter="\t",
encoding="utf-8",
) )
# get dict # get dict
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % ( self.path3 = "%s/3-bert" % (
os.path.dirname(phoneme_path) os.path.dirname(
phoneme_path,
)
) # "%s/3-bert"%exp_dir#bert_dir ) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2) assert os.path.exists(self.path2)
@ -125,7 +125,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len): for i in range(semantic_data_len):
# 先依次遍历 # 先依次遍历
# get str # get str
item_name = self.semantic_data.iloc[i,0] item_name = self.semantic_data.iloc[i, 0]
# print(self.phoneme_data) # print(self.phoneme_data)
try: try:
phoneme, word2ph, text = self.phoneme_data[item_name] phoneme, word2ph, text = self.phoneme_data[item_name]
@ -135,7 +135,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1 num_not_in += 1
continue continue
semantic_str = self.semantic_data.iloc[i,1] semantic_str = self.semantic_data.iloc[i, 1]
# get token list # get token list
semantic_ids = [int(idx) for idx in semantic_str.split(" ")] semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
@ -156,9 +156,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1 num_not_in += 1
continue continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行 # if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if ( if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2改为恒定限制为semantic/2.5就行
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1 num_deleted_ps += 1
continue continue
# if len(semantic_ids) > 1000:###########3 # if len(semantic_ids) > 1000:###########3
@ -167,9 +165,7 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ( if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1 num_deleted_ps += 1
# print(item_name) # print(item_name)
continue continue
@ -192,12 +188,12 @@ class Text2SemanticDataset(Dataset):
print(f"there are {num_not_in} semantic datas not in phoneme datas") print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0: if num_deleted_bigger > 0:
print( print(
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds" f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
) )
if num_deleted_ps > 0: if num_deleted_ps > 0:
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值 # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
print( print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}" f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
) )
""" """
there are 31 semantic datas not in phoneme datas there are 31 semantic datas not in phoneme datas
@ -304,7 +300,10 @@ if __name__ == "__main__":
batch_size = 12 batch_size = 12
dataloader = DataLoader( dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False,
) )
for i, batch in enumerate(dataloader): for i, batch in enumerate(dataloader):
if i % 1000 == 0: if i % 1000 == 0:

View File

@ -9,10 +9,12 @@ from typing import Dict
import torch import torch
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule): class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True): def __init__(self, config, output_dir, is_train=True):
super().__init__() super().__init__()
@ -24,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print( print(
self.load_state_dict( self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"] torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
) )
) )
if is_train: if is_train:
@ -36,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int): def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers() opt = self.optimizers()
scheduler = self.lr_schedulers() scheduler = self.lr_schedulers()
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
loss, acc = forward( loss, acc = forward(
batch["phoneme_ids"], batch["phoneme_ids"],
batch["phoneme_ids_len"], batch["phoneme_ids_len"],
@ -114,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
model_parameters = self.model.parameters() model_parameters = self.model.parameters()
parameters_names = [] parameters_names = []
parameters_names.append( parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
lm_opt = ScaledAdam( lm_opt = ScaledAdam(
model_parameters, model_parameters,
lr=0.01, lr=0.01,

View File

@ -9,6 +9,7 @@ from typing import Dict
import torch import torch
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from AR.models.t2s_model_onnx import Text2SemanticDecoder from AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam from AR.modules.optim import ScaledAdam
@ -25,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print( print(
self.load_state_dict( self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"] torch.load(
) pretrained_s1,
map_location="cpu",
)["weight"],
),
) )
if is_train: if is_train:
self.automatic_optimization = False self.automatic_optimization = False
@ -80,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
model_parameters = self.model.parameters() model_parameters = self.model.parameters()
parameters_names = [] parameters_names = []
parameters_names.append( parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
lm_opt = ScaledAdam( lm_opt = ScaledAdam(
model_parameters, model_parameters,
lr=0.01, lr=0.01,

View File

@ -2,25 +2,24 @@
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
import math import math
from typing import List, Optional from typing import List, Optional
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask, make_pad_mask_left import torch
from AR.models.utils import (
topk_sampling,
sample,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
make_pad_mask_left,
make_reject_y,
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = { default_config = {
"embedding_dim": 512, "embedding_dim": 512,
@ -34,10 +33,17 @@ default_config = {
"EOS": 1024, "EOS": 1024,
} }
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定 # @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following: # Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor: def scaled_dot_product_attention(
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2) query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None: if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1))) scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else: else:
@ -57,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
if attn_mask.dtype == torch.bool: if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0) attn_weight.masked_fill_(attn_mask, 0)
else: else:
attn_mask[attn_mask!=float("-inf")] =0 attn_mask[attn_mask != float("-inf")] = 0
attn_mask[attn_mask==float("-inf")] =1 attn_mask[attn_mask == float("-inf")] = 1
attn_weight.masked_fill_(attn_mask, 0) attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value return attn_weight @ value
@torch.jit.script @torch.jit.script
class T2SMLP: class T2SMLP:
def __init__(self, w1, b1, w2, b2): def __init__(self, w1, b1, w2, b2):
@ -80,20 +87,20 @@ class T2SMLP:
@torch.jit.script @torch.jit.script
class T2SBlock: class T2SBlock:
def __init__( def __init__(
self, self,
num_heads, num_heads,
hidden_dim: int, hidden_dim: int,
mlp: T2SMLP, mlp: T2SMLP,
qkv_w, qkv_w,
qkv_b, qkv_b,
out_w, out_w,
out_b, out_b,
norm_w1, norm_w1,
norm_b1, norm_b1,
norm_eps1, norm_eps1,
norm_w2, norm_w2,
norm_b2, norm_b2,
norm_eps2, norm_eps2,
): ):
self.num_heads = num_heads self.num_heads = num_heads
self.mlp = mlp self.mlp = mlp
@ -112,24 +119,32 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool) self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore @torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]): def to_mask(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor],
):
if padding_mask is None: if padding_mask is None:
return x return x
if padding_mask.dtype == torch.bool: if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0) return x.masked_fill(padding_mask, 0)
else: else:
return x * padding_mask return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0] batch_size = q.shape[0]
q_len = q.shape[1] q_len = q.shape[1]
kv_len = k.shape[1] kv_len = k.shape[1]
q = self.to_mask(q, padding_mask) q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask) k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask) v_cache = self.to_mask(v, padding_mask)
@ -147,9 +162,7 @@ class T2SBlock:
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn x = x + attn
x = F.layer_norm( x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x) x = x + self.mlp.forward(x)
x = F.layer_norm( x = F.layer_norm(
x, x,
@ -159,13 +172,20 @@ class T2SBlock:
self.norm_eps2, self.norm_eps2,
) )
return x, k_cache, v_cache return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True): def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1) k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1) v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0] batch_size = q.shape[0]
q_len = q.shape[1] q_len = q.shape[1]
kv_len = k_cache.shape[1] kv_len = k_cache.shape[1]
@ -174,7 +194,6 @@ class T2SBlock:
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa: if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None) attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else: else:
@ -185,7 +204,11 @@ class T2SBlock:
x = x + attn x = x + attn
x = F.layer_norm( x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1 x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
) )
x = x + self.mlp.forward(x) x = x + self.mlp.forward(x)
x = F.layer_norm( x = F.layer_norm(
@ -200,17 +223,19 @@ class T2SBlock:
@torch.jit.script @torch.jit.script
class T2STransformer: class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]): def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks self.num_blocks: int = num_blocks
self.blocks = blocks self.blocks = blocks
def process_prompt( def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor, self,
padding_mask : Optional[torch.Tensor]=None, x: torch.Tensor,
torch_sdpa:bool=True attn_mask: torch.Tensor,
): padding_mask: Optional[torch.Tensor] = None,
k_cache : List[torch.Tensor] = [] torch_sdpa: bool = True,
v_cache : List[torch.Tensor] = [] ):
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks): for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa) x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_) k_cache.append(k_cache_)
@ -218,14 +243,17 @@ class T2STransformer:
return x, k_cache, v_cache return x, k_cache, v_cache
def decode_next_token( def decode_next_token(
self, x:torch.Tensor, self,
k_cache: List[torch.Tensor], x: torch.Tensor,
v_cache: List[torch.Tensor], k_cache: List[torch.Tensor],
attn_mask : torch.Tensor=None, v_cache: List[torch.Tensor],
torch_sdpa:bool=True attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
): ):
for i in range(self.num_blocks): for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa) x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
)
return x, k_cache, v_cache return x, k_cache, v_cache
@ -247,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
# assert self.EOS == 1024 # assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim) self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding( self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout self.embedding_dim,
self.phoneme_vocab_size,
self.p_dropout,
) )
self.ar_text_position = SinePositionalEmbedding( self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
) )
self.ar_audio_embedding = TokenEmbedding( self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout self.embedding_dim,
self.vocab_size,
self.p_dropout,
) )
self.ar_audio_position = SinePositionalEmbedding( self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
) )
self.h = TransformerEncoder( self.h = TransformerEncoder(
@ -291,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
layer.linear1.weight, layer.linear1.weight,
layer.linear1.bias, layer.linear1.bias,
layer.linear2.weight, layer.linear2.weight,
layer.linear2.bias layer.linear2.bias,
) )
block = T2SBlock( block = T2SBlock(
@ -307,11 +345,11 @@ class Text2SemanticDecoder(nn.Module):
layer.norm1.eps, layer.norm1.eps,
layer.norm2.weight, layer.norm2.weight,
layer.norm2.bias, layer.norm2.bias,
layer.norm2.eps layer.norm2.eps,
) )
blocks.append(block) blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks) self.t2s_transformer = T2STransformer(self.num_layers, blocks)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature): def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
@ -385,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
logits = self.ar_predict_layer(xy_dec[:, x_len:]) logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO ############# ###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature) reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
x, x_lens, reject_y, reject_y_lens, bert_feature
)
reject_xy_dec, _ = self.h( reject_xy_dec, _ = self.h(
(reject_xy_pos, None), (reject_xy_pos, None),
@ -402,7 +442,7 @@ class Text2SemanticDecoder(nn.Module):
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
loss = loss_1 + loss_2 loss = loss_1 + loss_2
return loss, acc return loss, acc
@ -471,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer( def infer(
self, self,
x, x,
x_lens, x_lens,
prompts, prompts,
bert_feature, bert_feature,
top_k: int = -100, top_k: int = -100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
): ):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -506,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
y.device
)
xy_dec, _ = self.h( xy_dec, _ = self.h(
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, mask=xy_attn_mask,
) )
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling( samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num) print("use early stop num:", early_stop_num)
@ -540,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
return y return y
def pad_y_eos(self, y, y_mask_int, eos_id): def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
y_mask_int, (0, 1), value=1
)
# 错位 # 错位
return targets[:, :-1], targets[:, 1:] return targets[:, :-1], targets[:, 1:]
def infer_panel_batch_infer( def infer_panel_batch_infer(
self, self,
x:List[torch.LongTensor], #####全部文本token x: List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor, x_lens: torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token prompts: torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor], bert_feature: List[torch.LongTensor],
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
@ -561,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
): ):
if prompts is None: if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer") print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs) return self.infer_panel_naive_batched(
x,
x_lens,
prompts,
bert_feature,
top_k=top_k,
top_p=top_p,
early_stop_num=early_stop_num,
temperature=temperature,
**kwargs,
)
max_len = kwargs.get("max_len", x_lens.max())
max_len = kwargs.get("max_len",x_lens.max())
x_list = [] x_list = []
for x_item, bert_item in zip(x, bert_feature): for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) # max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@ -572,14 +615,15 @@ class Text2SemanticDecoder(nn.Module):
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0)) x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0) x_item = self.ar_text_position(x_item).squeeze(0)
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right # x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left x_item = (
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
) ### padding left
x_list.append(x_item) x_list.append(x_item)
x:torch.Tensor = torch.stack(x_list, dim=0) x: torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder # AR Decoder
y = prompts y = prompts
x_len = x.shape[1] x_len = x.shape[1]
stop = False stop = False
@ -592,34 +636,32 @@ class Text2SemanticDecoder(nn.Module):
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1] y_len = y_emb.shape[1]
prefix_len = y.shape[1] prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device) y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask ##### ##### create mask #####
bsz = x.shape[0] bsz = x.shape[0]
src_len = x_len + y_len src_len = x_len + y_len
y_paddind_mask = make_pad_mask_left(y_lens, y_len) y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask_left(x_lens, max_len) x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len) # (bsz, x_len + y_len)
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1) padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad( x_mask = F.pad(
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device), torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len), (0, y_len),
value=True, value=True,
) )
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device) causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y] # padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见" ### 上面是错误的会导致padding的token被"看见"
@ -637,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1) padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask) attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool() attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的 # 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len | # | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], # [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
@ -653,74 +694,69 @@ class Text2SemanticDecoder(nn.Module):
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS], # [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]] # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode ##### ###### decode #####
y_list = [None]*y.shape[0] y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0])) batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0] idx_list = [None] * y.shape[0]
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if idx == 0: if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else: else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask) xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer( logits = self.ar_predict_layer(xy_dec[:, -1])
xy_dec[:, -1]
)
if idx == 0: if idx == 0:
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False) attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1] logits = logits[:, :-1]
else: else:
attn_mask = F.pad(attn_mask,(0,1),value=False) attn_mask = F.pad(attn_mask, (0, 1), value=False)
samples = sample( samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0] )[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量 ####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1) tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \ if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS则停止
(self.EOS in tokens): ###如果生成到EOS则停止 l1 = samples[:, 0] == self.EOS
l1 = samples[:, 0]==self.EOS l2 = tokens == self.EOS
l2 = tokens==self.EOS l = l1.logical_or(l2)
l = l1.logical_or(l2) removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist() reserved_idx_of_batch_for_y = torch.where(l == False)[0]
reserved_idx_of_batch_for_y = torch.where(l==False)[0] # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y] for i in removed_idx_of_batch_for_y:
for i in removed_idx_of_batch_for_y: batch_index = batch_idx_map[i]
batch_index = batch_idx_map[i] idx_list[batch_index] = idx
idx_list[batch_index] = idx y_list[batch_index] = y[i, :-1]
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None: if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device) # index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y) y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y) attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None : if k_cache is not None:
for i in range(len(k_cache)): for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y) k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y) v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
print("use early stop num:", early_stop_num) print("use early stop num:", early_stop_num)
stop = True stop = True
for i, batch_index in enumerate(batch_idx_map): for i, batch_index in enumerate(batch_idx_map):
batch_index = batch_idx_map[i] batch_index = batch_idx_map[i]
idx_list[batch_index] = idx idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1] y_list[batch_index] = y[i, :-1]
if None not in idx_list: if None not in idx_list:
stop = True stop = True
if stop: if stop:
if y.shape[1]==0: if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1) y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction") print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
@ -728,60 +764,65 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ################################### ####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:]) y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if (None in idx_list): if None in idx_list:
for i in range(x.shape[0]): for i in range(x.shape[0]):
if idx_list[i] is None: if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替 idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
if ref_free: if ref_free:
return y_list, [0]*x.shape[0] return y_list, [0] * x.shape[0]
# print(idx_list) # print(idx_list)
return y_list, idx_list return y_list, idx_list
def infer_panel_naive_batched(self, def infer_panel_naive_batched(
x:List[torch.LongTensor], #####全部文本token self,
x_lens:torch.LongTensor, x: List[torch.LongTensor], #####全部文本token
prompts:torch.LongTensor, ####参考音频token x_lens: torch.LongTensor,
bert_feature:List[torch.LongTensor], prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
repetition_penalty: float = 1.35, repetition_penalty: float = 1.35,
**kwargs **kwargs,
): ):
y_list = [] y_list = []
idx_list = [] idx_list = []
for i in range(len(x)): for i in range(len(x)):
y, idx = self.infer_panel_naive(x[i].unsqueeze(0), y, idx = self.infer_panel_naive(
x_lens[i], x[i].unsqueeze(0),
prompts[i].unsqueeze(0) if prompts is not None else None, x_lens[i],
bert_feature[i].unsqueeze(0), prompts[i].unsqueeze(0) if prompts is not None else None,
top_k, bert_feature[i].unsqueeze(0),
top_p, top_k,
early_stop_num, top_p,
temperature, early_stop_num,
repetition_penalty, temperature,
**kwargs) repetition_penalty,
**kwargs,
)
y_list.append(y[0]) y_list.append(y[0])
idx_list.append(idx) idx_list.append(idx)
return y_list, idx_list return y_list, idx_list
def infer_panel_naive( def infer_panel_naive(
self, self,
x:torch.LongTensor, #####全部文本token x: torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor, x_lens: torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token prompts: torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor, bert_feature: torch.LongTensor,
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
repetition_penalty: float = 1.35, repetition_penalty: float = 1.35,
**kwargs **kwargs,
): ):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -826,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\ xy_attn_mask = (
.unsqueeze(0)\ torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.expand(bsz*self.num_head, -1, -1)\ .unsqueeze(0)
.view(bsz, self.num_head, src_len, src_len)\ .expand(bsz * self.num_head, -1, -1)
.to(device=x.device, dtype=torch.bool) .view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if xy_attn_mask is not None: if xy_attn_mask is not None:
@ -838,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
else: else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer( logits = self.ar_predict_layer(xy_dec[:, -1])
xy_dec[:, -1]
)
if idx == 0: if idx == 0:
xy_attn_mask = None xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1] logits = logits[:, :-1]
samples = sample( samples = sample(
@ -868,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ################################### ####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:]) y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free: if ref_free:
return y[:, :-1], 0 return y[:, :-1], 0
return y[:, :-1], idx return y[:, :-1], idx
def infer_panel( def infer_panel(
self, self,
x:torch.LongTensor, #####全部文本token x: torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor, x_lens: torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token prompts: torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor, bert_feature: torch.LongTensor,
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
repetition_penalty: float = 1.35, repetition_penalty: float = 1.35,
**kwargs **kwargs,
): ):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs) return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)

View File

@ -1,16 +1,13 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
import torch import torch
from AR.modules.embedding_onnx import SinePositionalEmbedding
from AR.modules.embedding_onnx import TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm
from AR.modules.transformer_onnx import TransformerEncoder
from AR.modules.transformer_onnx import TransformerEncoderLayer
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = { default_config = {
"embedding_dim": 512, "embedding_dim": 512,
"hidden_dim": 512, "hidden_dim": 512,
@ -25,12 +22,13 @@ default_config = {
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float() inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
def logits_to_probs( def logits_to_probs(
logits, logits,
previous_tokens = None, previous_tokens=None,
temperature: float = 1.0, temperature: float = 1.0,
top_k = None, top_k=None,
top_p = None, top_p=None,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
): ):
previous_tokens = previous_tokens.squeeze() previous_tokens = previous_tokens.squeeze()
@ -38,19 +36,27 @@ def logits_to_probs(
previous_tokens = previous_tokens.long() previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens) score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where( score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty score < 0,
score * repetition_penalty,
score / repetition_penalty,
) )
logits.scatter_(dim=0, index=previous_tokens, src=score) logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum( cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
dim=-1,
) )
sorted_indices_to_remove = cum_probs > top_p sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
) )
logits = logits.masked_fill(indices_to_remove, -float("Inf")) logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@ -66,7 +72,7 @@ def logits_to_probs(
def multinomial_sample_one_no_sync( def multinomial_sample_one_no_sync(
probs_sort probs_sort,
): # Does multinomial sampling without a cuda synchronization ): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort) q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@ -78,7 +84,9 @@ def sample(
**sampling_kwargs, **sampling_kwargs,
): ):
probs = logits_to_probs( probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs logits=logits,
previous_tokens=previous_tokens,
**sampling_kwargs,
) )
idx_next = multinomial_sample_one_no_sync(probs) idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs return idx_next, probs
@ -90,7 +98,7 @@ class OnnxEncoder(nn.Module):
self.ar_text_embedding = ar_text_embedding self.ar_text_embedding = ar_text_embedding
self.bert_proj = bert_proj self.bert_proj = bert_proj
self.ar_text_position = ar_text_position self.ar_text_position = ar_text_position
def forward(self, x, bert_feature): def forward(self, x, bert_feature):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -98,8 +106,18 @@ class OnnxEncoder(nn.Module):
class T2SFirstStageDecoder(nn.Module): class T2SFirstStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, def __init__(
top_k, early_stop_num, num_layers): self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__() super().__init__()
self.ar_audio_embedding = ar_audio_embedding self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position self.ar_audio_position = ar_audio_position
@ -110,11 +128,11 @@ class T2SFirstStageDecoder(nn.Module):
self.top_k = top_k self.top_k = top_k
self.early_stop_num = early_stop_num self.early_stop_num = early_stop_num
self.num_layers = num_layers self.num_layers = num_layers
def forward(self, x, prompt): def forward(self, x, prompt):
y = prompt y = prompt
x_example = x[:,:,0] * 0.0 x_example = x[:, :, 0] * 0.0
#N, 1, 512 # N, 1, 512
cache = { cache = {
"all_stage": self.num_layers, "all_stage": self.num_layers,
"k": None, "k": None,
@ -131,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:,:,0] * 0.0 y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool() x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0 torch.ones_like(
y_example.transpose(0, 1),
dtype=torch.int64,
),
dim=0,
) )
y_attn_mask = y_attn_mask > 0 y_attn_mask = y_attn_mask > 0
@ -144,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1) x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ cache["k"] = (
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1) torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\ .unsqueeze(1)
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1) .repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
@ -159,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
class T2SStageDecoder(nn.Module): class T2SStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric, def __init__(
top_k, early_stop_num, num_layers): self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__() super().__init__()
self.ar_audio_embedding = ar_audio_embedding self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position self.ar_audio_position = ar_audio_position
@ -183,14 +221,18 @@ class T2SStageDecoder(nn.Module):
} }
y_emb = torch.cat( y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 [
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
],
1,
) )
cache["y_emb"] = y_emb cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:] xy_pos = y_pos[:, -1:]
y_example = y_pos[:,:,0] * 0.0 y_example = y_pos[:, :, 0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1) xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool) xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
@ -249,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
def init_onnx(self): def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position) self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h, self.first_stage_decoder = T2SFirstStageDecoder(
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num, self.ar_audio_embedding,
self.num_layers) self.ar_audio_position,
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num, self.ar_predict_layer,
self.num_layers) self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature): def forward(self, x, prompts, bert_feature):
early_stop_num = self.early_stop_num early_stop_num = self.early_stop_num
@ -285,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
y = prompts y = prompts
prefix_len = y.shape[1] prefix_len = y.shape[1]
x_len = x.shape[1] x_len = x.shape[1]
x_example = x[:,:,0] * 0.0 x_example = x[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example) x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool) x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
@ -302,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
if cache["first_infer"] == 1: if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y)
else: else:
y_emb = torch.cat( y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
cache["y_emb"] = y_emb cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1: if cache["first_infer"] == 1:
@ -316,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True) x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), value=False (x_len, 0),
value=False,
) )
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
else: else:
@ -334,4 +391,4 @@ class Text2SemanticDecoder(nn.Module):
break break
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0 cache["first_infer"] = 0
return y, idx return y, idx

View File

@ -1,8 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
from typing import Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Tuple
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
if max_length is None: if max_length is None:
@ -67,14 +69,18 @@ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
n = lengths.size(0) n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device) seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1) expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len-lengths).unsqueeze(-1) expaned_lengths -= (max_len - lengths).unsqueeze(-1)
return expaned_lengths<0 return expaned_lengths < 0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering( def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1,
): ):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
@ -105,9 +111,7 @@ def top_k_top_p_filtering(
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
return logits return logits
@ -156,19 +160,21 @@ def logits_to_probs(
previous_tokens = previous_tokens.long() previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens) score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where( score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty score < 0,
score * repetition_penalty,
score / repetition_penalty,
) )
logits.scatter_(dim=1, index=previous_tokens, src=score) logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum( cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove dim=1,
index=sorted_indices,
src=sorted_indices_to_remove,
) )
logits = logits.masked_fill(indices_to_remove, -float("Inf")) logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@ -176,7 +182,7 @@ def logits_to_probs(
if top_k is not None: if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1) pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits) logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
@ -188,18 +194,19 @@ def sample(
previous_tokens: Optional[torch.Tensor] = None, previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs, **sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs( probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
)
idx_next = multinomial_sample_one_no_sync(probs) idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs return idx_next, probs
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor, def dpo_loss(
reference_chosen_logps: torch.FloatTensor, policy_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor,
beta: float, reference_chosen_logps: torch.FloatTensor,
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps
@ -214,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
return losses.mean(), chosen_rewards, rejected_rewards return losses.mean(), chosen_rewards, rejected_rewards
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
def get_batch_logps(
logits_target: torch.FloatTensor,
logits_reject: torch.FloatTensor,
labels_target: torch.LongTensor,
labels_reject: torch.LongTensor,
average_log_prob: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# dummy token; we'll ignore the losses on these tokens later # dummy token; we'll ignore the losses on these tokens later
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2) per_token_logps_target = torch.gather(
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2) logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
).squeeze(2)
per_token_logps_reject = torch.gather(
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
).squeeze(2)
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1) return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
def make_reject_y(y_o, y_lens): def make_reject_y(y_o, y_lens):
def repeat_P(y): def repeat_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]] pre = y[: range_idx[0]]
shf = y[range_idx[1]:] shf = y[range_idx[1] :]
range_text = y[range_idx[0]:range_idx[1]] range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, range_text, range_text, shf]) new_y = torch.cat([pre, range_text, range_text, shf])
return new_y return new_y
def lost_P(y): def lost_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]] pre = y[: range_idx[0]]
shf = y[range_idx[1]:] shf = y[range_idx[1] :]
range_text = y[range_idx[0]:range_idx[1]] range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, shf]) new_y = torch.cat([pre, shf])
return new_y return new_y
bs = len(y_lens) bs = len(y_lens)
reject_y = [] reject_y = []
reject_y_lens = [] reject_y_lens = []
for b in range(bs): for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1, ))[0] process_item_idx = torch.randint(0, 1, size=(1,))[0]
if process_item_idx == 0: if process_item_idx == 0:
new_y = repeat_P(y_o[b]) new_y = repeat_P(y_o[b])
reject_y.append(new_y) reject_y.append(new_y)
reject_y_lens.append(len(new_y)) reject_y_lens.append(len(new_y))
elif process_item_idx==1: elif process_item_idx == 1:
new_y = lost_P(y_o[b]) new_y = lost_P(y_o[b])
reject_y.append(new_y) reject_y.append(new_y)
reject_y_lens.append(len(new_y)) reject_y_lens.append(len(new_y))
@ -256,7 +276,7 @@ def make_reject_y(y_o, y_lens):
pad_length = max_length - reject_y_lens[b] pad_length = max_length - reject_y_lens[b]
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0) reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
reject_y = torch.stack(reject_y, dim = 0) reject_y = torch.stack(reject_y, dim=0)
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device) reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
return reject_y, reject_y_lens return reject_y, reject_y_lens

View File

@ -1,17 +1,14 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional from typing import Optional, Tuple
from typing import Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Linear from torch.nn import Linear, Module
from torch.nn import Module from torch.nn import functional as F
from torch.nn.init import constant_ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched F.multi_head_attention_forward = multi_head_attention_forward_patched
@ -73,6 +70,7 @@ class MultiheadAttention(Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
""" """
__constants__ = ["batch_first"] __constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor] bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor] bias_v: Optional[torch.Tensor]
@ -104,9 +102,7 @@ class MultiheadAttention(Module):
self.dropout = dropout self.dropout = dropout
self.batch_first = batch_first self.batch_first = batch_first
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert ( assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv: if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@ -117,31 +113,32 @@ class MultiheadAttention(Module):
if linear1_cls == Linear: if linear1_cls == Linear:
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter( self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs) torch.empty((embed_dim, embed_dim), **factory_kwargs),
) )
self.k_proj_weight = Parameter( self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs) torch.empty((embed_dim, self.kdim), **factory_kwargs),
) )
self.v_proj_weight = Parameter( self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs) torch.empty((embed_dim, self.vdim), **factory_kwargs),
) )
self.register_parameter("in_proj_weight", None) self.register_parameter("in_proj_weight", None)
else: else:
self.in_proj_weight = Parameter( self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
) )
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None) self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None) self.register_parameter("v_proj_weight", None)
if bias: if bias:
self.in_proj_bias = Parameter( self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
torch.empty(3 * embed_dim, **factory_kwargs)
)
else: else:
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear( self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
) )
self._reset_parameters() self._reset_parameters()
@ -150,7 +147,10 @@ class MultiheadAttention(Module):
raise NotImplementedError raise NotImplementedError
else: else:
self.in_proj_linear = linear1_cls( self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
) )
self.in_proj_weight = self.in_proj_linear.weight self.in_proj_weight = self.in_proj_linear.weight
@ -164,7 +164,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls( self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
) )
if self.bias_k is not None: if self.bias_k is not None:
@ -261,28 +264,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None: if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype _kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point( if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask key_padding_mask,
): ):
raise AssertionError( raise AssertionError("only bool and floating types of key_padding_mask are supported")
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = "" why_not_fast_path = ""
if not is_batched: if not is_batched:
why_not_fast_path = ( why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value: elif query is not key or key is not value:
# When lifting this restriction, don't forget to either # When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where # enforce that the dtypes all match or test cases where
# they don't! # they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" why_not_fast_path = (
elif ( f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype )
): elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message. # this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
)
elif self.training: elif self.training:
why_not_fast_path = "training is enabled" why_not_fast_path = "training is enabled"
elif not self.batch_first: elif not self.batch_first:
@ -300,9 +301,7 @@ class MultiheadAttention(Module):
elif attn_mask is not None: elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None" why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None: elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = ( why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
"key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1: elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd" why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled(): elif torch.is_autocast_enabled():
@ -322,20 +321,10 @@ class MultiheadAttention(Module):
# generator expressions. # generator expressions.
if torch.overrides.has_torch_function(tensor_args): if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function" why_not_fast_path = "some Tensor argument has_torch_function"
elif not all( elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any( elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
[x is not None and x.requires_grad for x in tensor_args] why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path: if not why_not_fast_path:
return torch._native_multi_head_attention( return torch._native_multi_head_attention(
query, query,
@ -350,11 +339,7 @@ class MultiheadAttention(Module):
key_padding_mask if key_padding_mask is not None else attn_mask, key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights, need_weights,
average_attn_weights, average_attn_weights,
1 1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
) )
any_nested = query.is_nested or key.is_nested or value.is_nested any_nested = query.is_nested or key.is_nested or value.is_nested

View File

@ -1,13 +1,10 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional from typing import Optional, Tuple
from typing import Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Linear from torch.nn import Linear, Module
from torch.nn import Module from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -46,9 +43,7 @@ class MultiheadAttention(Module):
self.dropout = dropout self.dropout = dropout
self.batch_first = batch_first self.batch_first = batch_first
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert ( assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv: if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@ -59,18 +54,30 @@ class MultiheadAttention(Module):
if linear1_cls == Linear: if linear1_cls == Linear:
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter( self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs) torch.empty(
(embed_dim, embed_dim),
**factory_kwargs,
)
) )
self.k_proj_weight = Parameter( self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs) torch.empty(
(embed_dim, self.kdim),
**factory_kwargs,
)
) )
self.v_proj_weight = Parameter( self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs) torch.empty(
(embed_dim, self.vdim),
**factory_kwargs,
)
) )
self.register_parameter("in_proj_weight", None) self.register_parameter("in_proj_weight", None)
else: else:
self.in_proj_weight = Parameter( self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) torch.empty(
(3 * embed_dim, embed_dim),
**factory_kwargs,
)
) )
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None) self.register_parameter("k_proj_weight", None)
@ -78,13 +85,11 @@ class MultiheadAttention(Module):
if bias: if bias:
self.in_proj_bias = Parameter( self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs) torch.empty(3 * embed_dim, **factory_kwargs),
) )
else: else:
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear( self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters() self._reset_parameters()
else: else:
@ -92,7 +97,10 @@ class MultiheadAttention(Module):
raise NotImplementedError raise NotImplementedError
else: else:
self.in_proj_linear = linear1_cls( self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
) )
self.in_proj_weight = self.in_proj_linear.weight self.in_proj_weight = self.in_proj_linear.weight
@ -106,7 +114,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls( self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
) )
if self.bias_k is not None: if self.bias_k is not None:

View File

@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
return return
pe = torch.zeros(x.size(1), self.embedding_dim) pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse: if self.reverse:
position = torch.arange( position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else: else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
* -(math.log(10000.0) / self.embedding_dim)
) )
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)

View File

@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
def extend_pe(self, x): def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1) position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
scpe = (position * self.div_term).unsqueeze(0) scpe = (position * self.div_term).unsqueeze(0)
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
pe = pe.contiguous().view(1, -1, self.embedding_dim) pe = pe.contiguous().view(1, -1, self.embedding_dim)

View File

@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
lr = self.end_lr lr = self.end_lr
else: else:
decay_ratio = (self._current_step - self.warmup_steps) / ( decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
self.total_steps - self.warmup_steps
)
if decay_ratio < 0.0 or decay_ratio > 1.0: if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError( raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
@ -70,7 +66,13 @@ if __name__ == "__main__":
m = nn.Linear(10, 10) m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4) opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule( s = WarmupCosineLRSchedule(
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0 opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0,
) )
lrs = [] lrs = []
for i in range(25000): for i in range(25000):

View File

@ -16,8 +16,7 @@
import contextlib import contextlib
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import List from typing import List, Tuple
from typing import Tuple
import torch import torch
from torch import Tensor from torch import Tensor
@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
group_params_names: name for each parameter in group, group_params_names: name for each parameter in group,
which is List[str]. which is List[str].
""" """
batches = defaultdict( batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
list batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
assert len(param_group) == len(group_params_names) assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names): for p, named_p in zip(param_group, group_params_names):
@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
batches_names[key].append(named_p) batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys()) batches_names_keys = list(batches_names.keys())
sorted_idx = sorted( sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
range(len(batches_names)), key=lambda i: batches_names_keys[i]) batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches_names = [
batches_names[batches_names_keys[idx]] for idx in sorted_idx
]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict() stacked_params_dict = dict()
@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
# group. class Optimizer will take care of saving/loading state. # group. class Optimizer will take care of saving/loading state.
state = self.state[p] state = self.state[p]
p_stacked = torch.stack(batch) p_stacked = torch.stack(batch)
grad = torch.stack([ grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
])
p_stacked.grad = grad p_stacked.grad = grad
stacked_params_dict[key] = p_stacked stacked_params_dict[key] = p_stacked
tuples.append((p_stacked, state, batch_names)) tuples.append((p_stacked, state, batch_names))
yield tuples # <-- calling code will do the actual optimization here! yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i]) p.copy_(stacked_params[i])
@ -164,25 +154,24 @@ class ScaledAdam(BatchedOptimizer):
""" """
def __init__( def __init__(
self, self,
params, params,
lr=3e-02, lr=3e-02,
clipping_scale=None, clipping_scale=None,
betas=(0.9, 0.98), betas=(0.9, 0.98),
scalar_lr_scale=0.1, scalar_lr_scale=0.1,
eps=1.0e-08, eps=1.0e-08,
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=3.0, param_max_rms=3.0,
scalar_max=10.0, scalar_max=10.0,
size_update_period=4, size_update_period=4,
clipping_update_period=100, clipping_update_period=100,
parameters_names=None, parameters_names=None,
show_dominant_parameters=True, ): show_dominant_parameters=True,
):
assert parameters_names is not None, ( assert parameters_names is not None, (
"Please prepare parameters_names," "Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
"which is a List[List[str]]. Each List[str] is for a group" )
"and each str is for a parameter")
defaults = dict( defaults = dict(
lr=lr, lr=lr,
clipping_scale=clipping_scale, clipping_scale=clipping_scale,
@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
param_max_rms=param_max_rms, param_max_rms=param_max_rms,
scalar_max=scalar_max, scalar_max=scalar_max,
size_update_period=size_update_period, size_update_period=size_update_period,
clipping_update_period=clipping_update_period, ) clipping_update_period=clipping_update_period,
)
super(ScaledAdam, self).__init__(params, defaults) super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names) assert len(self.param_groups) == len(parameters_names)
@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
batch = True batch = True
for group, group_params_names in zip(self.param_groups, for group, group_params_names in zip(self.param_groups, self.parameters_names):
self.parameters_names): with self.batched_params(group["params"], group_params_names) as batches:
with self.batched_params(group["params"],
group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like # batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to # a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim. # a stacking dim, it is not a real dim.
if (len(batches[0][1]) == if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
0): # if len(first state) == 0: not yet initialized
clipping_scale = 1 clipping_scale = 1
else: else:
clipping_scale = self._get_clipping_scale(group, batches) clipping_scale = self._get_clipping_scale(group, batches)
@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
# grad is not going to be None, we handled that when creating the batches. # grad is not going to be None, we handled that when creating the batches.
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
"ScaledAdam optimizer does not support sparse gradients"
)
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
self._init_state(group, p, state) self._init_state(group, p, state)
@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
# parameter-change "delta", which combines all forms of # parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam, # update. this is equivalent to how it's done in Adam,
# except for the first few steps. # except for the first few steps.
state["delta"] = torch.zeros_like( state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
p, memory_format=torch.preserve_format)
batch_size = p.shape[0] batch_size = p.shape[0]
numel = p.numel() // batch_size numel = p.numel() // batch_size
@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
# "param_rms" just periodically records the scalar root-mean-square value of # "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor. # the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1) # it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = ( param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
state["param_rms"] = param_rms state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
*param_rms.shape, **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam. # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like( state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
p, memory_format=torch.preserve_format)
def _get_clipping_scale(self, def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
group: dict,
tuples: List[Tuple[Tensor, dict, List[str]]]
) -> float:
""" """
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update. by this amount before applying the rest of the update.
@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"] clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device) tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples: for p, state, param_names in tuples:
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
"ScaledAdam optimizer does not support sparse gradients")
if p.numel() == p.shape[0]: # a batch of scalars if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else: else:
tot_sumsq += ((grad * state["param_rms"])**2).sum() tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
tot_norm = tot_sumsq.sqrt() tot_norm = tot_sumsq.sqrt()
if "model_norms" not in first_state: if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros( first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
clipping_update_period, device=p.device)
first_state["model_norms"][step % clipping_update_period] = tot_norm first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0: if step % clipping_update_period == 0:
@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
for n in range(0, 5): for n in range(0, 5):
index = min( index = min(
clipping_update_period - 1, clipping_update_period - 1,
(clipping_update_period // 4) * n, ) (clipping_update_period // 4) * n,
)
quartiles.append(sorted_norms[index].item()) quartiles.append(sorted_norms[index].item())
median = quartiles[2] median = quartiles[2]
threshold = clipping_scale * median threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 / percent_clipped = (
clipping_update_period first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
if "num_clipped" in first_state else 0.0) )
first_state["num_clipped"] = 0 first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles]) quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info( logging.info(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
) )
if step < clipping_update_period: if step < clipping_update_period:
@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
model_norm_threshold = first_state["model_norm_threshold"] model_norm_threshold = first_state["model_norm_threshold"]
except KeyError: except KeyError:
logging.info( logging.info(
"Warning: model_norm_threshold not in state: possibly " "Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
"you changed config when restarting, adding clipping_scale option?"
) )
return 1.0 return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0: if ans < 1.0:
first_state["num_clipped"] += 1 first_state["num_clipped"] += 1
if ans < 0.1: if ans < 0.1:
logging.warn( logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters: if self.show_dominant_parameters:
assert p.shape[0] == len(param_names) assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq) self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans return ans
def _show_gradient_dominating_parameter( def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
self, tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor):
""" """
Show information of parameter wihch dominanting tot_sumsq. Show information of parameter wihch dominanting tot_sumsq.
@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time. from tuples, we still pass it to save some time.
""" """
all_sumsq_orig = {} all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples: for p, state, batch_param_names in tuples:
# p is a stacked batch parameters. # p is a stacked batch parameters.
batch_grad = p.grad batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars if p.numel() == p.shape[0]: # a batch of scalars
@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
batch_rms_orig = torch.ones(p.shape[0]) batch_rms_orig = torch.ones(p.shape[0])
else: else:
batch_rms_orig = state["param_rms"] batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum( batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(batch_param_names,
batch_sumsq_orig,
batch_rms_orig, batch_grad):
for name, sumsq_orig, rms, grad in zip(
batch_param_names,
batch_sumsq_orig,
batch_rms_orig,
batch_grad,
):
proportion_orig = sumsq_orig / tot_sumsq proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose( assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(), sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0), ) torch.tensor(1.0),
)
sorted_by_proportion = { sorted_by_proportion = {
k: v k: v
for k, v in sorted( for k, v in sorted(
all_sumsq_orig.items(), all_sumsq_orig.items(),
key=lambda item: item[1][0], key=lambda item: item[1][0],
reverse=True, ) reverse=True,
)
} }
dominant_param_name = next(iter(sorted_by_proportion)) dominant_param_name = next(iter(sorted_by_proportion))
(dominant_proportion, dominant_sumsq, dominant_rms, (
dominant_grad, ) = sorted_by_proportion[dominant_param_name] dominant_proportion,
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}" dominant_sumsq,
f" with proportion {dominant_proportion:.2f}," dominant_rms,
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" dominant_grad,
f"={dominant_sumsq:.3e}," ) = sorted_by_proportion[dominant_param_name]
f" grad_sumsq = {(dominant_grad**2).sum():.3e}," logging.info(
f" orig_rms_sq={(dominant_rms**2).item():.3e}") f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(self, def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
group: dict,
p: Tensor,
state: dict,
clipping_scale: float):
""" """
Do the step for one parameter, which is actually going to be a batch of Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim. `real` parameters, with dim 0 as the batch dim.
@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
if numel > 1: if numel > 1:
# Update the size/scale of p, and set param_rms # Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"] scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum( scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1: if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2) param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
.mean(dim=list(range(1, p.ndim)), keepdim=True)
.sqrt())
if step > 0: if step > 0:
# self._size_update() learns the overall scale on the # self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it. # parameter, by shrinking or expanding it.
@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = step + 1 state["step"] = step + 1
def _size_update(self, def _size_update(
group: dict, self,
scale_grads: Tensor, group: dict,
p: Tensor, scale_grads: Tensor,
state: dict) -> None: p: Tensor,
state: dict,
) -> None:
""" """
Called only where p.numel() > 1, this updates the scale of the parameter. Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing If we imagine: p = underlying_param * scale.exp(), and we are doing
@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
# faster decay at this level. # faster decay at this level.
beta2_corr = beta2**size_update_period beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state[ scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_( scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period` (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...) alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1. # The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period size_step = (step + 1) // size_update_period
@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
denom = scale_exp_avg_sq.sqrt() + eps denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (-size_lr * (bias_correction2**0.5) * scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
scale_grads.sum(dim=0) / denom)
is_too_small = param_rms < param_min_rms is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms is_too_large = param_rms > param_max_rms
@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"] this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1)
bias_correction2 = 1 - beta2**(this_step + 1)
if bias_correction2 < 0.99: if bias_correction2 < 0.99:
# note: not in-place. # note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
# bias_correction2 is like in Adam. Don't bother with bias_correction1; # bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway. # slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2**(state["step"] + 1) bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"] delta = state["delta"]

View File

@ -24,18 +24,18 @@ def multi_head_attention_forward_patched(
dropout_p: float, dropout_p: float,
out_proj_weight, out_proj_weight,
out_proj_bias, out_proj_bias,
training = True, training=True,
key_padding_mask = None, key_padding_mask=None,
need_weights = True, need_weights=True,
attn_mask = None, attn_mask=None,
use_separate_proj_weight = False, use_separate_proj_weight=False,
q_proj_weight = None, q_proj_weight=None,
k_proj_weight = None, k_proj_weight=None,
v_proj_weight = None, v_proj_weight=None,
static_k = None, static_k=None,
static_v = None, static_v=None,
average_attn_weights = True, average_attn_weights=True,
is_causal = False, is_causal=False,
cache=None, cache=None,
): ):
r""" r"""
@ -155,9 +155,7 @@ def multi_head_attention_forward_patched(
cache=cache, cache=cache,
) )
is_batched = _mha_shape_check( is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
query, key, value, key_padding_mask, attn_mask, num_heads
)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the # is batched, run the computation and before returning squeeze the
@ -210,45 +208,33 @@ def multi_head_attention_forward_patched(
# longer causal. # longer causal.
is_causal = False is_causal = False
assert ( assert embed_dim == embed_dim_to_check, (
embed_dim == embed_dim_to_check f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" )
if isinstance(embed_dim, torch.Tensor): if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing # embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode="trunc") head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else: else:
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
assert ( assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight: if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used # allow MHA to have different embedding dimensions when separate projection weights are used
assert ( assert key.shape[:2] == value.shape[:2], (
key.shape[:2] == value.shape[:2] f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" )
else: else:
assert ( assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
# #
# compute in-projection # compute in-projection
# #
if not use_separate_proj_weight: if not use_separate_proj_weight:
assert ( assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else: else:
assert ( assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
q_proj_weight is not None assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
), "use_separate_proj_weight is True but q_proj_weight is None" assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None: if in_proj_bias is None:
b_q = b_k = b_v = None b_q = b_k = b_v = None
else: else:
@ -311,9 +297,7 @@ def multi_head_attention_forward_patched(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
) )
else: else:
raise RuntimeError( raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
# add bias along batch dimension (currently second) # add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None: if bias_k is not None and bias_v is not None:
@ -337,34 +321,26 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert ( assert static_k.size(0) == bsz * num_heads, (
static_k.size(0) == bsz * num_heads f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" )
assert ( assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k k = static_k
if static_v is None: if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert ( assert static_v.size(0) == bsz * num_heads, (
static_v.size(0) == bsz * num_heads f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" )
assert ( assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v v = static_v
# add zero attention along batch dimension (now first) # add zero attention along batch dimension (now first)
if add_zero_attn: if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim) zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat( k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
if attn_mask is not None: if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1)) attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None: if key_padding_mask is not None:
@ -380,9 +356,7 @@ def multi_head_attention_forward_patched(
src_len, src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = ( key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len) key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
) )
if attn_mask is None: if attn_mask is None:
attn_mask = key_padding_mask attn_mask = key_padding_mask
@ -401,14 +375,10 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape B, Nt, E = q.shape
q_scaled = q / math.sqrt(E) q_scaled = q / math.sqrt(E)
assert not ( assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None: if attn_mask is not None:
attn_output_weights = torch.baddbmm( attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
attn_mask, q_scaled, k.transpose(-2, -1)
)
else: else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = softmax(attn_output_weights, dim=-1)
@ -417,9 +387,7 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
attn_output = ( attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@ -448,13 +416,9 @@ def multi_head_attention_forward_patched(
v = v.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = ( attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@ -3,6 +3,7 @@ from torch.nn.functional import (
_canonical_mask, _canonical_mask,
) )
def multi_head_attention_forward_patched( def multi_head_attention_forward_patched(
query, query,
key, key,
@ -31,7 +32,6 @@ def multi_head_attention_forward_patched(
is_causal: bool = False, is_causal: bool = False,
cache=None, cache=None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars # set up shape vars
_, _, embed_dim = query.shape _, _, embed_dim = query.shape
attn_mask = _canonical_mask( attn_mask = _canonical_mask(
@ -77,12 +77,8 @@ def multi_head_attention_forward_patched(
q = q.view(num_heads, -1, head_dim).unsqueeze(0) q = q.view(num_heads, -1, head_dim).unsqueeze(0)
k = k.view(num_heads, -1, head_dim).unsqueeze(0) k = k.view(num_heads, -1, head_dim).unsqueeze(0)
v = v.view(num_heads, -1, head_dim).unsqueeze(0) v = v.view(num_heads, -1, head_dim).unsqueeze(0)
attn_output = scaled_dot_product_attention( attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
q, k, v, attn_mask, dropout_p, is_causal attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(-1, 1, attn_output.size(1)) attn_output = attn_output.view(-1, 1, attn_output.size(1))

View File

@ -58,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving. # floors), should be expectation-preserving.
floor = -0.043637 floor = -0.043637
ceil = 1.2 ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
deriv
)
if __name__ == "__main__": if __name__ == "__main__":
# for self-testing only. # for self-testing only.
assert d_scaled.min() >= 0.0 assert d_scaled.min() >= 0.0
@ -150,13 +148,9 @@ def _compute_scale_factor(
else: else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs. # x_abs)_mean , min_abs.
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
min=0, max=max_factor
)
return below_threshold - above_threshold return below_threshold - above_threshold
@ -178,18 +172,16 @@ def _compute_sign_factor(
else: else:
# 0 if proportion_positive >= min_positive, else can be # 0 if proportion_positive >= min_positive, else can be
# as large as max_factor. # as large as max_factor.
factor1 = ( factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
if max_positive == 1.0: if max_positive == 1.0:
factor2 = 0.0 factor2 = 0.0
else: else:
# 0 if self.proportion_positive <= max_positive, else can be # 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor. # as large as -max_factor.
factor2 = ( factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) min=0, max=max_factor
).clamp_(min=0, max=max_factor) )
sign_factor = factor1 - factor2 sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1: # require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float) assert not isinstance(sign_factor, float)
@ -317,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x) return _no_op(x)
def BalancedDoubleSwish( def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
""" """
ActivationBalancer -> DoubleSwish ActivationBalancer -> DoubleSwish
""" """
balancer = ActivationBalancer( balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
return nn.Sequential( return nn.Sequential(
balancer, balancer,
DoubleSwish(), DoubleSwish(),

View File

@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.elementwise_affine = elementwise_affine
if self.elementwise_affine: if self.elementwise_affine:
self.weight = nn.Parameter( self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs) self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else: else:
self.register_parameter("weight", None) self.register_parameter("weight", None)
self.register_parameter("bias", None) self.register_parameter("bias", None)
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
) )
assert embedding is None assert embedding is None
return F.layer_norm( return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return ( return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module): class IdentityNorm(nn.Module):
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512) >>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src) >>> out = transformer_encoder(src)
""" """
__constants__ = ["norm"] __constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None): def __init__(self, encoder_layer, num_layers, norm=None):
@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
) )
# Implementation of Feedforward model # Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls( self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls( self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout)
@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None: if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype _skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point( if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
src_key_padding_mask raise AssertionError("only bool and floating types of key_padding_mask are supported")
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if self.norm_first: if self.norm_first:
x = x + self._sa_block( x = x + self._sa_block(

View File

@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.elementwise_affine = elementwise_affine
if self.elementwise_affine: if self.elementwise_affine:
self.weight = nn.Parameter( self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs) self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else: else:
self.register_parameter("weight", None) self.register_parameter("weight", None)
self.register_parameter("bias", None) self.register_parameter("bias", None)
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
) )
assert embedding is None assert embedding is None
return F.layer_norm( return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return ( return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module): class IdentityNorm(nn.Module):
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512) >>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src) >>> out = transformer_encoder(src)
""" """
__constants__ = ["norm"] __constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None): def __init__(self, encoder_layer, num_layers, norm=None):
@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"] __constants__ = ["batch_first", "norm_first"]
def __init__( def __init__(
self, self,
d_model: int, d_model: int,
@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
linear2_cls=linear2_self_attention_cls, linear2_cls=linear2_self_attention_cls,
**factory_kwargs, **factory_kwargs,
) )
self.linear1 = linear1_feedforward_cls( self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls( self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout)

View File

@ -30,9 +30,7 @@ class GruutPhonemizer:
"«": "«", "«": "«",
"»": "»", "»": "»",
} }
self._punctuation_regexp: str = ( self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
rf"([{''.join(self._special_cases_dict.keys())}])"
)
def _normalize_punctuation(self, text: str) -> str: def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text) text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
@ -53,13 +51,8 @@ class GruutPhonemizer:
def phonemize(self, text: str, espeak: bool = False) -> str: def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text) text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [ sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
sent words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
return " ".join(words) return " ".join(words)
def transform(self, phonemes): def transform(self, phonemes):

View File

@ -3,7 +3,9 @@
PAD = "_" PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” ' PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'" IPA_LETTERS = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS) SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ") SPACE_ID = SYMBOLS.index(" ")
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)} SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}

View File

@ -2,12 +2,12 @@ import re
def str2bool(str): def str2bool(str):
return True if str.lower() == 'true' else False return True if str.lower() == "true" else False
def get_newest_ckpt(string_list): def get_newest_ckpt(string_list):
# 定义一个正则表达式模式,用于匹配字符串中的数字 # 定义一个正则表达式模式,用于匹配字符串中的数字
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt' pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表 # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
extracted_info = [] extracted_info = []
@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
step = int(match.group(2)) step = int(match.group(2))
extracted_info.append((epoch, step, string)) extracted_info.append((epoch, step, string))
# 按照 epoch 后面的数字和 step 后面的数字进行排序 # 按照 epoch 后面的数字和 step 后面的数字进行排序
sorted_info = sorted( sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
# 获取最新的 ckpt 文件名 # 获取最新的 ckpt 文件名
newest_ckpt = sorted_info[0][2] newest_ckpt = sorted_info[0][2]
return newest_ckpt return newest_ckpt
@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
# 文本存在且不为空时 return True # 文本存在且不为空时 return True
def check_txt_file(file_path): def check_txt_file(file_path):
try: try:
with open(file_path, 'r') as file: with open(file_path, "r") as file:
text = file.readline().strip() text = file.readline().strip()
assert text.strip() != '' assert text.strip() != ""
return text return text
except Exception: except Exception:
return False return False

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Initialize modules for espnet2 neural networks.""" """Initialize modules for espnet2 neural networks."""
import torch import torch
from typeguard import check_argument_types from typeguard import check_argument_types

View File

@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
def write_args(args, path): def write_args(args, path):
args_dict = dict( args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
with open(path, "a") as args_file: with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__)) args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write( args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> Cmd:\n") args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv)) args_file.write(str(sys.argv))
args_file.write("\n==> args:\n") args_file.write("\n==> args:\n")

View File

@ -23,9 +23,7 @@ class Snake(nn.Module):
>>> x = a1(x) >>> x = a1(x)
""" """
def __init__( def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
""" """
Initialization. Initialization.
INPUT: INPUT:
@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = a1(x) >>> x = a1(x)
""" """
def __init__( def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
""" """
Initialization. Initialization.
INPUT: INPUT:

View File

@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta): def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward( activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
inputs, up_ftr, down_ftr, alpha, beta
)
return activation_results return activation_results
@ -61,17 +59,11 @@ class Activation1d(nn.Module):
if self.act.__class__.__name__ == "Snake": if self.act.__class__.__name__ == "Snake":
beta = self.act.alpha.data # Snake uses same params for alpha and beta beta = self.act.alpha.data # Snake uses same params for alpha and beta
else: else:
beta = ( beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
self.act.beta.data
) # Snakebeta uses different params for alpha and beta
alpha = self.act.alpha.data alpha = self.act.alpha.data
if ( if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
not self.act.alpha_logscale
): # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha) alpha = torch.log(alpha)
beta = torch.log(beta) beta = torch.log(beta)
x = FusedAntiAliasActivation.apply( x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
)
return x return x

View File

@ -58,17 +58,13 @@ def load():
srcpath / "anti_alias_activation.cpp", srcpath / "anti_alias_activation.cpp",
srcpath / "anti_alias_activation_cuda.cu", srcpath / "anti_alias_activation_cuda.cu",
] ]
anti_alias_activation_cuda = _cpp_extention_load_helper( anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
"anti_alias_activation_cuda", sources, extra_cuda_flags
)
return anti_alias_activation_cuda return anti_alias_activation_cuda
def _get_cuda_bare_metal_version(cuda_dir): def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split() output = raw_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
release = output[release_idx].split(".") release = output[release_idx].split(".")

View File

@ -27,9 +27,7 @@ else:
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html # https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d( def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0 even = kernel_size % 2 == 0
half_size = kernel_size // 2 half_size = kernel_size // 2

View File

@ -11,18 +11,12 @@ class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None): def __init__(self, ratio=2, kernel_size=None):
super().__init__() super().__init__()
self.ratio = ratio self.ratio = ratio
self.kernel_size = ( self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.stride = ratio self.stride = ratio
self.pad = self.kernel_size // ratio - 1 self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = ( self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter) self.register_buffer("filter", filter)
# x: [B, C, T] # x: [B, C, T]
@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
_, C, _ = x.shape _, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate") x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d( x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right] x = x[..., self.pad_left : -self.pad_right]
return x return x
@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None): def __init__(self, ratio=2, kernel_size=None):
super().__init__() super().__init__()
self.ratio = ratio self.ratio = ratio
self.kernel_size = ( self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d( self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio, cutoff=0.5 / ratio,
half_width=0.6 / ratio, half_width=0.6 / ratio,

View File

@ -50,7 +50,7 @@ class AMPBlock1(torch.nn.Module):
activation: str = None, activation: str = None,
): ):
super().__init__() super().__init__()
self.h = h self.h = h
self.convs1 = nn.ModuleList( self.convs1 = nn.ModuleList(
@ -87,9 +87,7 @@ class AMPBlock1(torch.nn.Module):
) )
self.convs2.apply(init_weights) self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len( self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
self.convs2
) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False): if self.h.get("use_cuda_kernel", False):
@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
if activation == "snake": if activation == "snake":
self.activations = nn.ModuleList( self.activations = nn.ModuleList(
[ [
Activation1d( Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers) for _ in range(self.num_layers)
] ]
) )
elif activation == "snakebeta": elif activation == "snakebeta":
self.activations = nn.ModuleList( self.activations = nn.ModuleList(
[ [
Activation1d( Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers) for _ in range(self.num_layers)
] ]
) )
@ -169,7 +159,7 @@ class AMPBlock2(torch.nn.Module):
activation: str = None, activation: str = None,
): ):
super().__init__() super().__init__()
self.h = h self.h = h
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
if activation == "snake": if activation == "snake":
self.activations = nn.ModuleList( self.activations = nn.ModuleList(
[ [
Activation1d( Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers) for _ in range(self.num_layers)
] ]
) )
elif activation == "snakebeta": elif activation == "snakebeta":
self.activations = nn.ModuleList( self.activations = nn.ModuleList(
[ [
Activation1d( Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
for _ in range(self.num_layers) for _ in range(self.num_layers)
] ]
) )
@ -283,9 +265,7 @@ class BigVGAN(
self.num_upsamples = len(h.upsample_rates) self.num_upsamples = len(h.upsample_rates)
# Pre-conv # Pre-conv
self.conv_pre = weight_norm( self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
)
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
if h.resblock == "1": if h.resblock == "1":
@ -293,9 +273,7 @@ class BigVGAN(
elif h.resblock == "2": elif h.resblock == "2":
resblock_class = AMPBlock2 resblock_class = AMPBlock2
else: else:
raise ValueError( raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
)
# Transposed conv-based upsamplers. does not apply anti-aliasing # Transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
@ -320,22 +298,14 @@ class BigVGAN(
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1)) ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate( for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
):
self.resblocks.append(
resblock_class(h, ch, k, d, activation=h.activation)
)
# Post-conv # Post-conv
activation_post = ( activation_post = (
activations.Snake(ch, alpha_logscale=h.snake_logscale) activations.Snake(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snake" if h.activation == "snake"
else ( else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snakebeta"
else None
)
) )
if activation_post is None: if activation_post is None:
raise NotImplementedError( raise NotImplementedError(
@ -346,9 +316,7 @@ class BigVGAN(
# Whether to use bias for the final conv_post. Default to True for backward compatibility # Whether to use bias for the final conv_post. Default to True for backward compatibility
self.use_bias_at_final = h.get("use_bias_at_final", True) self.use_bias_at_final = h.get("use_bias_at_final", True)
self.conv_post = weight_norm( self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
)
# Weight initialization # Weight initialization
for i in range(len(self.ups)): for i in range(len(self.ups)):

View File

@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
), ),
] ]
) )
self.conv_post = norm_f( self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = [] fmap = []
@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
self.mpd_reshapes = h.mpd_reshapes self.mpd_reshapes = h.mpd_reshapes
print(f"mpd_reshapes: {self.mpd_reshapes}") print(f"mpd_reshapes: {self.mpd_reshapes}")
self.discriminators = nn.ModuleList( self.discriminators = nn.ModuleList(
[ [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
for rs in self.mpd_reshapes
]
) )
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[ def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor], List[torch.Tensor],
List[List[torch.Tensor]], List[List[torch.Tensor]],
@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
super().__init__() super().__init__()
self.resolution = resolution self.resolution = resolution
assert ( assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
len(self.resolution) == 3
), f"MRD layer requires list with len=3, got {self.resolution}"
self.lrelu_slope = 0.1 self.lrelu_slope = 0.1
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"): if hasattr(cfg, "mrd_use_spectral_norm"):
print( print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}" norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
)
norm_f = (
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
)
self.d_mult = cfg.discriminator_channel_mult self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"): if hasattr(cfg, "mrd_channel_mult"):
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}") print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
), ),
] ]
) )
self.conv_post = norm_f( self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = [] fmap = []
@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False): def __init__(self, cfg, debug=False):
super().__init__() super().__init__()
self.resolutions = cfg.resolutions self.resolutions = cfg.resolutions
assert ( assert len(self.resolutions) == 3, (
len(self.resolutions) == 3 f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
) )
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[ def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor], List[torch.Tensor],
List[List[torch.Tensor]], List[List[torch.Tensor]],
@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
convs = lambda: nn.ModuleList( convs = lambda: nn.ModuleList(
[ [
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm( weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
), weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm( weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
),
] ]
) )
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm( self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
)
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]: def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
# Remove DC offset # Remove DC offset
@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
super().__init__() super().__init__()
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h. # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512]) self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
self.discriminators = nn.ModuleList( self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[ def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor], List[torch.Tensor],
List[List[torch.Tensor]], List[List[torch.Tensor]],
List[List[torch.Tensor]], List[List[torch.Tensor]],
]: ]:
y_d_rs = [] y_d_rs = []
y_d_gs = [] y_d_gs = []
fmap_rs = [] fmap_rs = []
@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license. # Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
# LICENSE is in incl_licenses directory. # LICENSE is in incl_licenses directory.
class DiscriminatorCQT(nn.Module): class DiscriminatorCQT(nn.Module):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int): def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
in_chs = min(self.filters_scale * self.filters, self.max_filters) in_chs = min(self.filters_scale * self.filters, self.max_filters)
for i, dilation in enumerate(self.dilations): for i, dilation in enumerate(self.dilations):
out_chs = min( out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
)
self.convs.append( self.convs.append(
weight_norm( weight_norm(
nn.Conv2d( nn.Conv2d(
@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
in_chs, in_chs,
out_chs, out_chs,
kernel_size=(self.kernel_size[0], self.kernel_size[0]), kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding( padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
(self.kernel_size[0], self.kernel_size[0])
),
) )
) )
) )
@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
# Multi-scale params to loop over # Multi-scale params to loop over
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256]) self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9]) self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get( self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
"cqtd_bins_per_octaves", [24, 36, 48]
)
self.discriminators = nn.ModuleList( self.discriminators = nn.ModuleList(
[ [
@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
] ]
) )
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[ def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor], List[torch.Tensor],
List[List[torch.Tensor]], List[List[torch.Tensor]],
List[List[torch.Tensor]], List[List[torch.Tensor]],
]: ]:
y_d_rs = [] y_d_rs = []
y_d_gs = [] y_d_gs = []
fmap_rs = [] fmap_rs = []
@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
super().__init__() super().__init__()
self.discrimiantor = nn.ModuleList(list_discriminator) self.discrimiantor = nn.ModuleList(list_discriminator)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[ def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor], List[torch.Tensor],
List[torch.Tensor], List[torch.Tensor],
List[List[torch.Tensor]], List[List[torch.Tensor]],
List[List[torch.Tensor]], List[List[torch.Tensor]],
]: ]:
y_d_rs = [] y_d_rs = []
y_d_gs = [] y_d_gs = []
fmap_rs = [] fmap_rs = []

View File

@ -35,9 +35,7 @@ def inference(a, h):
with torch.no_grad(): with torch.no_grad():
for i, filname in enumerate(filelist): for i, filname in enumerate(filelist):
# Load the ground truth audio and resample if necessary # Load the ground truth audio and resample if necessary
wav, sr = librosa.load( wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
)
wav = torch.FloatTensor(wav).to(device) wav = torch.FloatTensor(wav).to(device)
# Compute mel spectrogram from the ground truth audio # Compute mel spectrogram from the ground truth audio
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h) x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
@ -48,9 +46,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16") audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join( output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
)
write(output_file, h.sampling_rate, audio) write(output_file, h.sampling_rate, audio)
print(output_file) print(output_file)

View File

@ -61,9 +61,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16") audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join( output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
)
write(output_file, h.sampling_rate, audio) write(output_file, h.sampling_rate, audio)
print(output_file) print(output_file)

View File

@ -116,15 +116,13 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
window_type, window_type,
): ):
""" """
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from: Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
""" """
B, C, T = wav.shape B, C, T = wav.shape
if match_stride: if match_stride:
assert ( assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
hop_length == window_length // 4
), "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2 pad = (window_length - hop_length) // 2
else: else:
@ -154,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
magnitude = torch.abs(stft) magnitude = torch.abs(stft)
nf = magnitude.shape[2] nf = magnitude.shape[2]
mel_basis = self.get_mel_filters( mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
)
mel_basis = torch.from_numpy(mel_basis).to(wav.device) mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2) mel_spectrogram = mel_spectrogram.transpose(-1, 2)
@ -181,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
""" """
loss = 0.0 loss = 0.0
for n_mels, fmin, fmax, s in zip( for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
kwargs = { kwargs = {
"n_mels": n_mels, "n_mels": n_mels,
"fmin": fmin, "fmin": fmin,
@ -196,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
x_mels = self.mel_spectrogram(x, **kwargs) x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs) y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log( x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
x_mels.clamp(min=self.clamp_eps).pow(self.pow) y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels) loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels) loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
@ -210,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
# Loss functions # Loss functions
def feature_loss( def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
loss = 0 loss = 0
for dr, dg in zip(fmap_r, fmap_g): for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg): for rl, gl in zip(dr, dg):
@ -225,7 +212,6 @@ def feature_loss(
def discriminator_loss( def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0 loss = 0
r_losses = [] r_losses = []
g_losses = [] g_losses = []
@ -242,7 +228,6 @@ def discriminator_loss(
def generator_loss( def generator_loss(
disc_outputs: List[torch.Tensor], disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]: ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0 loss = 0
gen_losses = [] gen_losses = []
for dg in disc_outputs: for dg in disc_outputs:

View File

@ -86,9 +86,7 @@ def mel_spectrogram(
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache: if key not in mel_basis_cache:
mel = librosa_mel_fn( mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device) hann_window_cache[key] = torch.hann_window(win_size).to(device)
@ -96,9 +94,7 @@ def mel_spectrogram(
hann_window = hann_window_cache[key] hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2 padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad( y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
y.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
spec = torch.stft( spec = torch.stft(
y, y,
@ -150,17 +146,13 @@ def get_dataset_filelist(a):
with open(a.input_training_file, "r", encoding="utf-8") as fi: with open(a.input_training_file, "r", encoding="utf-8") as fi:
training_files = [ training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
for x in fi.read().split("\n")
if len(x) > 0
] ]
print(f"first training file: {training_files[0]}") print(f"first training file: {training_files[0]}")
with open(a.input_validation_file, "r", encoding="utf-8") as fi: with open(a.input_validation_file, "r", encoding="utf-8") as fi:
validation_files = [ validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
for x in fi.read().split("\n")
if len(x) > 0
] ]
print(f"first validation file: {validation_files[0]}") print(f"first validation file: {validation_files[0]}")
@ -171,9 +163,7 @@ def get_dataset_filelist(a):
for x in fi.read().split("\n") for x in fi.read().split("\n")
if len(x) > 0 if len(x) > 0
] ]
print( print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
)
list_unseen_validation_files.append(unseen_validation_files) list_unseen_validation_files.append(unseen_validation_files)
return training_files, validation_files, list_unseen_validation_files return training_files, validation_files, list_unseen_validation_files
@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
print("[INFO] checking dataset integrity...") print("[INFO] checking dataset integrity...")
for i in tqdm(range(len(self.audio_files))): for i in tqdm(range(len(self.audio_files))):
assert os.path.exists( assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
self.audio_files[i]
), f"{self.audio_files[i]} not found"
def __getitem__( def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
self, index: int
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
try: try:
filename = self.audio_files[index] filename = self.audio_files[index]
@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
# Obtain randomized audio chunk # Obtain randomized audio chunk
if source_sampling_rate != self.sampling_rate: if source_sampling_rate != self.sampling_rate:
# Adjust segment size to crop if the source sr is different # Adjust segment size to crop if the source sr is different
target_segment_size = math.ceil( target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
self.segment_size
* (source_sampling_rate / self.sampling_rate)
)
else: else:
target_segment_size = self.segment_size target_segment_size = self.segment_size
# Compute upper bound index for the random chunk # Compute upper bound index for the random chunk
random_chunk_upper_bound = max( random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
0, audio.shape[0] - target_segment_size
)
# Crop or pad audio to obtain random chunk with target_segment_size # Crop or pad audio to obtain random chunk with target_segment_size
if audio.shape[0] >= target_segment_size: if audio.shape[0] >= target_segment_size:
@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
else: else:
# For fine-tuning, assert that the waveform is in the defined sampling_rate # For fine-tuning, assert that the waveform is in the defined sampling_rate
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly) # Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
assert ( assert source_sampling_rate == self.sampling_rate, (
source_sampling_rate == self.sampling_rate f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}" )
# Cast ndarray to torch tensor # Cast ndarray to torch tensor
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
mel = mel[:, :, mel_start : mel_start + frames_per_seg] mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[ audio = audio[
:, :,
mel_start mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
* self.hop_size : (mel_start + frames_per_seg)
* self.hop_size,
] ]
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error. # Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio> # NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size # To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
mel = torch.nn.functional.pad( mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
mel, (0, frames_per_seg - mel.size(2)), "constant" audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
)
audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None) # Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
mel_loss = mel_spectrogram( mel_loss = mel_spectrogram(
@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
# Shape sanity checks # Shape sanity checks
assert ( assert (
audio.shape[1] == mel.shape[2] * self.hop_size audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
and audio.shape[1] == mel_loss.shape[2] * self.hop_size ), (
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}" f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
if self.fine_tuning: if self.fine_tuning:
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly. raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
else: else:
print( print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
)
return self[random.randrange(len(self))] return self[random.randrange(len(self))]
def __len__(self): def __len__(self):

View File

@ -3,6 +3,7 @@
import os import os
import sys import sys
# to import modules from parent_dir # to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir) sys.path.append(parent_dir)
@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda") data = torch.rand((10, 10, 200), device="cuda")
# Check activations.Snake cuda vs. torch # Check activations.Snake cuda vs. torch
fused_anti_alias_activation = activation1d.Activation1d( fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
activation=Snake(10), fused=True
).cuda()
fused_activation_output = fused_anti_alias_activation(data) fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d( torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
activation=Snake(10), fused=False
).cuda()
torch_activation_output = torch_anti_alias_activation(data) torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs() test_result = (fused_activation_output - torch_activation_output).abs()

View File

@ -3,6 +3,7 @@
import os import os
import sys import sys
# to import modules from parent_dir # to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir) sys.path.append(parent_dir)
@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda") data = torch.rand((10, 10, 200), device="cuda")
# Check activations, Snake CUDA vs. Torch # Check activations, Snake CUDA vs. Torch
fused_anti_alias_activation = activation1d.Activation1d( fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
activation=SnakeBeta(10), fused=True
).cuda()
fused_activation_output = fused_anti_alias_activation(data) fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d( torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
activation=SnakeBeta(10), fused=False
).cuda()
torch_activation_output = torch_anti_alias_activation(data) torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs() test_result = (fused_activation_output - torch_activation_output).abs()
@ -57,7 +54,6 @@ def test_anti_alias_activation():
) )
if __name__ == "__main__": if __name__ == "__main__":
from alias_free_activation.cuda import load from alias_free_activation.cuda import load

View File

@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
def get_mel(x, h): def get_mel(x, h):
return mel_spectrogram( return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
)
def load_checkpoint(filepath, device): def load_checkpoint(filepath, device):
@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
description="Test script to check CUDA kernel correctness."
)
parser.add_argument( parser.add_argument(
"--checkpoint_file", "--checkpoint_file",
type=str, type=str,
@ -91,27 +87,25 @@ if __name__ == "__main__":
# define number of samples and length of mel frame to benchmark # define number of samples and length of mel frame to benchmark
num_sample = 10 num_sample = 10
num_mel_frame = 16384 num_mel_frame = 16384
# CUDA kernel correctness check # CUDA kernel correctness check
diff = 0.0 diff = 0.0
for i in tqdm(range(num_sample)): for i in tqdm(range(num_sample)):
# Random mel # Random mel
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
with torch.inference_mode(): with torch.inference_mode():
audio_original = generator_original(data) audio_original = generator_original(data)
with torch.inference_mode(): with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data) audio_cuda_kernel = generator_cuda_kernel(data)
# Both outputs should be (almost) the same # Both outputs should be (almost) the same
test_result = (audio_original - audio_cuda_kernel).abs() test_result = (audio_original - audio_cuda_kernel).abs()
diff += test_result.mean(dim=-1).item() diff += test_result.mean(dim=-1).item()
diff /= num_sample diff /= num_sample
if ( if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
diff <= 2e-3
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
print( print(
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference" f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}" f"\n > mean_difference={diff}"
@ -125,9 +119,9 @@ if __name__ == "__main__":
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, " f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
) )
del data, audio_original, audio_cuda_kernel del data, audio_original, audio_cuda_kernel
# Variables for tracking total time and VRAM usage # Variables for tracking total time and VRAM usage
toc_total_original = 0 toc_total_original = 0
toc_total_cuda_kernel = 0 toc_total_cuda_kernel = 0
@ -145,10 +139,10 @@ if __name__ == "__main__":
audio_original = generator_original(data) audio_original = generator_original(data)
torch.cuda.synchronize() torch.cuda.synchronize()
toc = time() - tic toc = time() - tic
toc_total_original += toc toc_total_original += toc
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda") vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_original del data, audio_original
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -163,11 +157,11 @@ if __name__ == "__main__":
torch.cuda.synchronize() torch.cuda.synchronize()
toc = time() - tic toc = time() - tic
toc_total_cuda_kernel += toc toc_total_cuda_kernel += toc
audio_length_total += audio_cuda_kernel.shape[-1] audio_length_total += audio_cuda_kernel.shape[-1]
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda") vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_cuda_kernel del data, audio_cuda_kernel
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -175,8 +169,8 @@ if __name__ == "__main__":
audio_second = audio_length_total / h.sampling_rate audio_second = audio_length_total / h.sampling_rate
khz_original = audio_length_total / toc_total_original / 1000 khz_original = audio_length_total / toc_total_original / 1000
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000 khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3) vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3) vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
# Print results # Print results
print( print(

View File

@ -77,24 +77,18 @@ def train(rank, a, h):
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default # Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator # New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
print( print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
# Variable name is kept as "mrd" for backward compatibility & minimal code change # Variable name is kept as "mrd" for backward compatibility & minimal code change
mrd = MultiBandDiscriminator(h).to(device) mrd = MultiBandDiscriminator(h).to(device)
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
print( print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device) mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
else: # Fallback to original MRD in BigVGAN-v1 else: # Fallback to original MRD in BigVGAN-v1
mrd = MultiResolutionDiscriminator(h).to(device) mrd = MultiResolutionDiscriminator(h).to(device)
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss # New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
if h.get("use_multiscale_melloss", False): if h.get("use_multiscale_melloss", False):
print( print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
)
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss( fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
sampling_rate=h.sampling_rate sampling_rate=h.sampling_rate
) # NOTE: accepts waveform as input ) # NOTE: accepts waveform as input
@ -114,9 +108,7 @@ def train(rank, a, h):
if os.path.isdir(a.checkpoint_path): if os.path.isdir(a.checkpoint_path):
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training # New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
cp_g = scan_checkpoint( cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
)
cp_do = scan_checkpoint( cp_do = scan_checkpoint(
a.checkpoint_path, a.checkpoint_path,
prefix="do_", prefix="do_",
@ -143,9 +135,7 @@ def train(rank, a, h):
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device) mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW( optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
)
optim_d = torch.optim.AdamW( optim_d = torch.optim.AdamW(
itertools.chain(mrd.parameters(), mpd.parameters()), itertools.chain(mrd.parameters(), mpd.parameters()),
h.learning_rate, h.learning_rate,
@ -156,12 +146,8 @@ def train(rank, a, h):
optim_g.load_state_dict(state_dict_do["optim_g"]) optim_g.load_state_dict(state_dict_do["optim_g"])
optim_d.load_state_dict(state_dict_do["optim_d"]) optim_d.load_state_dict(state_dict_do["optim_d"])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR( scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
optim_g, gamma=h.lr_decay, last_epoch=last_epoch scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
)
# Define training and validation datasets # Define training and validation datasets
@ -169,9 +155,7 @@ def train(rank, a, h):
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
Example: trained on LibriTTS, validate on VCTK Example: trained on LibriTTS, validate on VCTK
""" """
training_filelist, validation_filelist, list_unseen_validation_filelist = ( training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
get_dataset_filelist(a)
)
trainset = MelDataset( trainset = MelDataset(
training_filelist, training_filelist,
@ -324,33 +308,26 @@ def train(rank, a, h):
h.fmax_for_loss, h.fmax_for_loss,
) )
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1)) min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item() val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out) # PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
if ( if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
"nonspeech" not in mode
): # Skips if the name of dataset (in mode string) contains "nonspeech"
# Resample to 16000 for pesq # Resample to 16000 for pesq
y_16k = pesq_resampler(y) y_16k = pesq_resampler(y)
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1)) y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
y_g_hat_int_16k = ( y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
)
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb") val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
# MRSTFT calculation # MRSTFT calculation
min_t = min(y.size(-1), y_g_hat.size(-1)) min_t = min(y.size(-1), y_g_hat.size(-1))
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item() val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
# Log audio and figures to Tensorboard # Log audio and figures to Tensorboard
if j % a.eval_subsample == 0: # Subsample every nth from validation set if j % a.eval_subsample == 0: # Subsample every nth from validation set
if steps >= 0: if steps >= 0:
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate) sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
if ( if a.save_audio: # Also save audio to disk if --save_audio is set to True
a.save_audio
): # Also save audio to disk if --save_audio is set to True
save_audio( save_audio(
y[0], y[0],
os.path.join( os.path.join(
@ -373,9 +350,7 @@ def train(rank, a, h):
steps, steps,
h.sampling_rate, h.sampling_rate,
) )
if ( if a.save_audio: # Also save audio to disk if --save_audio is set to True
a.save_audio
): # Also save audio to disk if --save_audio is set to True
save_audio( save_audio(
y_g_hat[0, 0], y_g_hat[0, 0],
os.path.join( os.path.join(
@ -487,15 +462,11 @@ def train(rank, a, h):
# MPD # MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
y_df_hat_r, y_df_hat_g
)
# MRD # MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach()) y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
y_ds_hat_r, y_ds_hat_g
)
loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all = loss_disc_s + loss_disc_f
@ -505,17 +476,11 @@ def train(rank, a, h):
# Whether to freeze D for initial training steps # Whether to freeze D for initial training steps
if steps >= a.freeze_step: if steps >= a.freeze_step:
loss_disc_all.backward() loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_( grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
mpd.parameters(), clip_grad_norm grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
mrd.parameters(), clip_grad_norm
)
optim_d.step() optim_d.step()
else: else:
print( print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
)
grad_norm_mpd = 0.0 grad_norm_mpd = 0.0
grad_norm_mrd = 0.0 grad_norm_mrd = 0.0
@ -523,9 +488,7 @@ def train(rank, a, h):
optim_g.zero_grad() optim_g.zero_grad()
# L1 Mel-Spectrogram Loss # L1 Mel-Spectrogram Loss
lambda_melloss = h.get( lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
"lambda_melloss", 45.0
) # Defaults to 45 in BigVGAN-v1 if not set
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
else: # Uses mel <y_mel, y_g_hat_mel> for loss else: # Uses mel <y_mel, y_g_hat_mel> for loss
@ -542,27 +505,19 @@ def train(rank, a, h):
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
if steps >= a.freeze_step: if steps >= a.freeze_step:
loss_gen_all = ( loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
)
else: else:
print( print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
)
loss_gen_all = loss_mel loss_gen_all = loss_mel
loss_gen_all.backward() loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_( grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
generator.parameters(), clip_grad_norm
)
optim_g.step() optim_g.step()
if rank == 0: if rank == 0:
# STDOUT logging # STDOUT logging
if steps % a.stdout_interval == 0: if steps % a.stdout_interval == 0:
mel_error = ( mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to stdout
print( print(
f"Steps: {steps:d}, " f"Steps: {steps:d}, "
f"Gen Loss Total: {loss_gen_all:4.3f}, " f"Gen Loss Total: {loss_gen_all:4.3f}, "
@ -577,11 +532,7 @@ def train(rank, a, h):
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}" checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
save_checkpoint( save_checkpoint(
checkpoint_path, checkpoint_path,
{ {"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
"generator": (
generator.module if h.num_gpus > 1 else generator
).state_dict()
},
) )
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}" checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
save_checkpoint( save_checkpoint(
@ -598,9 +549,7 @@ def train(rank, a, h):
# Tensorboard summary logging # Tensorboard summary logging
if steps % a.summary_interval == 0: if steps % a.summary_interval == 0:
mel_error = ( mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to tensorboard
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps) sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps) sw.add_scalar("training/mel_spec_error", mel_error, steps)
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps) sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
@ -612,12 +561,8 @@ def train(rank, a, h):
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps) sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps) sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps) sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
sw.add_scalar( sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
)
sw.add_scalar(
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
)
sw.add_scalar("training/epoch", epoch + 1, steps) sw.add_scalar("training/epoch", epoch + 1, steps)
# Validation # Validation
@ -660,9 +605,7 @@ def train(rank, a, h):
scheduler_d.step() scheduler_d.step()
if rank == 0: if rank == 0:
print( print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
)
def main(): def main():
@ -674,12 +617,8 @@ def main():
parser.add_argument("--input_wavs_dir", default="LibriTTS") parser.add_argument("--input_wavs_dir", default="LibriTTS")
parser.add_argument("--input_mels_dir", default="ft_dataset") parser.add_argument("--input_mels_dir", default="ft_dataset")
parser.add_argument( parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
"--input_training_file", default="tests/LibriTTS/train-full.txt" parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
)
parser.add_argument(
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
)
parser.add_argument( parser.add_argument(
"--list_input_unseen_wavs_dir", "--list_input_unseen_wavs_dir",

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,9 @@
import os import os
import sys import sys
import threading import threading
from tqdm import tqdm from tqdm import tqdm
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
@ -19,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto") language = os.environ.get("language", "Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
punctuation = set(['!', '?', '', ',', '.', '-']) punctuation = set(["!", "?", "", ",", ".", "-"])
def get_first(text:str) -> str:
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip() text = re.split(pattern, text)[0].strip()
return text return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2: if (len(texts)) < 2:
return texts return texts
result = [] result = []
@ -39,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
if len(text) >= threshold: if len(text) >= threshold:
result.append(text) result.append(text)
text = "" text = ""
if (len(text) > 0): if len(text) > 0:
if len(result) == 0: if len(result) == 0:
result.append(text) result.append(text)
else: else:
@ -47,28 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
return result return result
class TextPreprocessor: class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM, def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
tokenizer:AutoTokenizer, device:torch.device):
self.bert_model = bert_model self.bert_model = bert_model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = device self.device = device
self.bert_lock = threading.RLock() self.bert_lock = threading.RLock()
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]: def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f'############ {i18n("切分文本")} ############') print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text) text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method) texts = self.pre_seg_text(text, lang, text_split_method)
result = [] result = []
print(f'############ {i18n("提取文本Bert特征")} ############') print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts): for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text=="": if phones is None or norm_text == "":
continue continue
res={ res = {
"phones": phones, "phones": phones,
"bert_features": bert_features, "bert_features": bert_features,
"norm_text": norm_text, "norm_text": norm_text,
@ -76,11 +74,11 @@ class TextPreprocessor:
result.append(res) result.append(res)
return result return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str): def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n") text = text.strip("\n")
if len(text) == 0: if len(text) == 0:
return [] return []
if (text[0] not in splits and len(get_first(text)) < 4): if text[0] not in splits and len(get_first(text)) < 4:
text = "" + text if lang != "en" else "." + text text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:")) print(i18n("实际输入的目标文本:"))
print(text) print(text)
@ -96,18 +94,18 @@ class TextPreprocessor:
_texts = merge_short_text_in_array(_texts, 5) _texts = merge_short_text_in_array(_texts, 5)
texts = [] texts = []
for text in _texts: for text in _texts:
# 解决输入目标文本的空行导致报错的问题 # 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0): if len(text.strip()) == 0:
continue continue
if not re.sub("\W+", "", text): if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。 # 检测一下,如果是纯符号,就跳过。
continue continue
if (text[-1] not in splits): text += "" if lang != "en" else "." if text[-1] not in splits:
text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题 # 解决句子过长导致Bert报错的问题
if (len(text) > 510): if len(text) > 510:
texts.extend(split_big_text(text)) texts.extend(split_big_text(text))
else: else:
texts.append(text) texts.append(text)
@ -116,78 +114,79 @@ class TextPreprocessor:
print(texts) print(texts)
return texts return texts
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]: def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version) return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False): def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock: with self.bert_lock:
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
# language = language.replace("all_","") # language = language.replace("all_","")
formattext = text formattext = text
while " " in formattext: while " " in formattext:
formattext = formattext.replace(" ", " ") formattext = formattext.replace(" ", " ")
if language == "all_zh": if language == "all_zh":
if re.search(r'[A-Za-z]', formattext): if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"zh",version) return self.get_phones_and_bert(formattext, "zh", version)
else: else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device) bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext) formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"yue",version) return self.get_phones_and_bert(formattext, "yue", version)
else: else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = torch.zeros( bert = torch.zeros(
(1024, len(phones)), (1024, len(phones)),
dtype=torch.float32, dtype=torch.float32,
).to(self.device) ).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[] textlist = []
langlist=[] langlist = []
if language == "auto": if language == "auto":
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
elif language == "auto_yue": elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh": if tmp["lang"] == "zh":
tmp["lang"] = "yue" tmp["lang"] = "yue"
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
else: else:
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "en": if tmp["lang"] == "en":
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
else: else:
# 因无法区别中日韩文汉字,以用户输入为准 # 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language) langlist.append(language)
textlist.append(tmp["text"]) textlist.append(tmp["text"])
# print(textlist) # print(textlist)
# print(langlist) # print(langlist)
phones_list = [] phones_list = []
bert_list = [] bert_list = []
norm_text_list = [] norm_text_list = []
for i in range(len(textlist)): for i in range(len(textlist)):
lang = langlist[i] lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang) bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones) phones_list.append(phones)
norm_text_list.append(norm_text) norm_text_list.append(norm_text)
bert_list.append(bert) bert_list.append(bert)
bert = torch.cat(bert_list, dim=1) bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, []) phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list) norm_text = "".join(norm_text_list)
if not final and len(phones) < 6: if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text,language,version,final=True) return self.get_phones_and_bert("." + text, language, version, final=True)
return phones, bert, norm_text return phones, bert, norm_text
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
with torch.no_grad(): with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt") inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs: for i in inputs:
@ -202,14 +201,14 @@ class TextPreprocessor:
phone_level_feature = torch.cat(phone_level_feature, dim=0) phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"): def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_","") language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version) phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version) phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str): def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
language=language.replace("all_","") language = language.replace("all_", "")
if language == "zh": if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device) feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else: else:
@ -220,21 +219,19 @@ class TextPreprocessor:
return feature return feature
def filter_text(self, texts):
def filter_text(self,texts): _text = []
_text=[] if all(text in [None, " ", "\n", ""] for text in texts):
if all(text in [None, " ", "\n",""] for text in texts):
raise ValueError(i18n("请输入有效文本")) raise ValueError(i18n("请输入有效文本"))
for text in texts: for text in texts:
if text in [None, " ", ""]: if text in [None, " ", ""]:
pass pass
else: else:
_text.append(text) _text.append(text)
return _text return _text
def replace_consecutive_punctuation(self, text):
def replace_consecutive_punctuation(self,text): punctuations = "".join(re.escape(p) for p in punctuation)
punctuations = ''.join(re.escape(p) for p in punctuation) pattern = f"([{punctuations}])([{punctuations}])+"
pattern = f'([{punctuations}])([{punctuations}])+' result = re.sub(pattern, r"\1", text)
result = re.sub(pattern, r'\1', text)
return result return result

View File

@ -1 +1 @@
from . import TTS, text_segmentation_method from . import TTS, text_segmentation_method

View File

@ -1,41 +1,57 @@
import re import re
from typing import Callable from typing import Callable
punctuation = set(['!', '?', '', ',', '.', '-'," "]) punctuation = set(["!", "?", "", ",", ".", "-", " "])
METHODS = dict() METHODS = dict()
def get_method(name:str)->Callable:
def get_method(name: str) -> Callable:
method = METHODS.get(name, None) method = METHODS.get(name, None)
if method is None: if method is None:
raise ValueError(f"Method {name} not found") raise ValueError(f"Method {name} not found")
return method return method
def get_method_names()->list:
def get_method_names() -> list:
return list(METHODS.keys()) return list(METHODS.keys())
def register_method(name): def register_method(name):
def decorator(func): def decorator(func):
METHODS[name] = func METHODS[name] = func
return func return func
return decorator return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
}
def split_big_text(text, max_len=510): def split_big_text(text, max_len=510):
# 定义全角和半角标点符号 # 定义全角和半角标点符号
punctuation = "".join(splits) punctuation = "".join(splits)
# 切割文本 # 切割文本
segments = re.split('([' + punctuation + '])', text) segments = re.split("([" + punctuation + "])", text)
# 初始化结果列表和当前片段 # 初始化结果列表和当前片段
result = [] result = []
current_segment = '' current_segment = ""
for segment in segments: for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段 # 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
if len(current_segment + segment) > max_len: if len(current_segment + segment) > max_len:
@ -43,13 +59,12 @@ def split_big_text(text, max_len=510):
current_segment = segment current_segment = segment
else: else:
current_segment += segment current_segment += segment
# 将最后一个片段加入结果列表 # 将最后一个片段加入结果列表
if current_segment: if current_segment:
result.append(current_segment) result.append(current_segment)
return result
return result
def split(todo_text): def split(todo_text):
@ -90,7 +105,7 @@ def cut1(inp):
if len(split_idx) > 1: if len(split_idx) > 1:
opts = [] opts = []
for idx in range(len(split_idx) - 1): for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else: else:
opts = [inp] opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)] opts = [item for item in opts if not set(item).issubset(punctuation)]
@ -123,6 +138,7 @@ def cut2(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)] opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts) return "\n".join(opts)
# 按中文句号。切 # 按中文句号。切
@register_method("cut3") @register_method("cut3")
def cut3(inp): def cut3(inp):
@ -131,26 +147,28 @@ def cut3(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)] opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts) return "\n".join(opts)
#按英文句号.切
# 按英文句号.切
@register_method("cut4") @register_method("cut4")
def cut4(inp): def cut4(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip(".")) opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
opts = [item for item in opts if not set(item).issubset(punctuation)] opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts) return "\n".join(opts)
# 按标点符号切 # 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5") @register_method("cut5")
def cut5(inp): def cut5(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
punds = {',', '.', ';', '?', '!', '', '', '', '', '', ';', '', ''} punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mergeitems = [] mergeitems = []
items = [] items = []
for i, char in enumerate(inp): for i, char in enumerate(inp):
if char in punds: if char in punds:
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char) items.append(char)
else: else:
items.append(char) items.append(char)
@ -166,8 +184,6 @@ def cut5(inp):
return "\n".join(opt) return "\n".join(opt)
if __name__ == "__main__":
if __name__ == '__main__':
method = get_method("cut5") method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。")) print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

View File

@ -1,6 +1,13 @@
import os import os
import sys import sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.insert(0, now_dir) sys.path.insert(0, now_dir)
from text.g2pw import G2PWPinyin from text.g2pw import G2PWPinyin
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
v_to_u=False,
neutral_tone_with_five=True,
)

View File

@ -32,7 +32,8 @@ default_config = {
"EOS": 1024, "EOS": 1024,
} }
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"] config = dict_s1["config"]
config["model"]["dropout"] = float(config["model"]["dropout"]) config["model"]["dropout"] = float(config["model"]["dropout"])
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
@ -40,6 +41,7 @@ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
t2s_model = t2s_model.eval() t2s_model = t2s_model.eval()
return t2s_model return t2s_model
@torch.jit.script @torch.jit.script
def logits_to_probs( def logits_to_probs(
logits, logits,
@ -56,39 +58,35 @@ def logits_to_probs(
if previous_tokens is not None and repetition_penalty != 1.0: if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long() previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens) score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where( score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=1, index=previous_tokens, src=score) logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum( cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf")) logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5) logits = logits / max(temperature, 1e-5)
if top_k is not None: if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1) pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits) logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
return probs return probs
@torch.jit.script @torch.jit.script
def multinomial_sample_one_no_sync(probs_sort): def multinomial_sample_one_no_sync(probs_sort):
# Does multinomial sampling without a cuda synchronization # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort) q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@torch.jit.script @torch.jit.script
def sample( def sample(
logits, logits,
@ -99,15 +97,20 @@ def sample(
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
): ):
probs = logits_to_probs( probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty logits=logits,
previous_tokens=previous_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
) )
idx_next = multinomial_sample_one_no_sync(probs) idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs return idx_next, probs
@torch.jit.script @torch.jit.script
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False): def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype) hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@ -157,6 +160,7 @@ class DictToAttrRecursive(dict):
except KeyError: except KeyError:
raise AttributeError(f"Attribute {item} not found") raise AttributeError(f"Attribute {item} not found")
@torch.jit.script @torch.jit.script
class T2SMLP: class T2SMLP:
def __init__(self, w1, b1, w2, b2): def __init__(self, w1, b1, w2, b2):
@ -170,23 +174,24 @@ class T2SMLP:
x = F.linear(x, self.w2, self.b2) x = F.linear(x, self.w2, self.b2)
return x return x
@torch.jit.script @torch.jit.script
class T2SBlock: class T2SBlock:
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
hidden_dim: int, hidden_dim: int,
mlp: T2SMLP, mlp: T2SMLP,
qkv_w, qkv_w,
qkv_b, qkv_b,
out_w, out_w,
out_b, out_b,
norm_w1, norm_w1,
norm_b1, norm_b1,
norm_eps1: float, norm_eps1: float,
norm_w2, norm_w2,
norm_b2, norm_b2,
norm_eps2: float, norm_eps2: float,
): ):
self.num_heads = num_heads self.num_heads = num_heads
self.mlp = mlp self.mlp = mlp
@ -205,22 +210,22 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool) self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore @torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]): def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
if padding_mask is None: if padding_mask is None:
return x return x
if padding_mask.dtype == torch.bool: if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0) return x.masked_fill(padding_mask, 0)
else: else:
return x * padding_mask return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None): def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0] batch_size = q.shape[0]
q_len = q.shape[1] q_len = q.shape[1]
kv_len = k.shape[1] kv_len = k.shape[1]
q = self.to_mask(q, padding_mask) q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask) k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask) v_cache = self.to_mask(v, padding_mask)
@ -231,22 +236,20 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
if padding_mask is not None: if padding_mask is not None:
for i in range(batch_size): for i in range(batch_size):
# mask = padding_mask[i,:,0] # mask = padding_mask[i,:,0]
if self.false.device!= padding_mask.device: if self.false.device != padding_mask.device:
self.false = self.false.to(padding_mask.device) self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i,:,0]==self.false)[0] idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
x_item = x[i,idx,:].unsqueeze(0) x_item = x[i, idx, :].unsqueeze(0)
attn_item = attn[i,idx,:].unsqueeze(0) attn_item = attn[i, idx, :].unsqueeze(0)
x_item = x_item + attn_item x_item = x_item + attn_item
x_item = F.layer_norm( x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x_item = x_item + self.mlp.forward(x_item) x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm( x_item = F.layer_norm(
x_item, x_item,
@ -255,13 +258,11 @@ class T2SBlock:
self.norm_b2, self.norm_b2,
self.norm_eps2, self.norm_eps2,
) )
x[i,idx,:] = x_item.squeeze(0) x[i, idx, :] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask) x = self.to_mask(x, padding_mask)
else: else:
x = x + attn x = x + attn
x = F.layer_norm( x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x) x = x + self.mlp.forward(x)
x = F.layer_norm( x = F.layer_norm(
x, x,
@ -271,13 +272,13 @@ class T2SBlock:
self.norm_eps2, self.norm_eps2,
) )
return x, k_cache, v_cache return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor): def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1) k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1) v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0] batch_size = q.shape[0]
q_len = q.shape[1] q_len = q.shape[1]
kv_len = k_cache.shape[1] kv_len = k_cache.shape[1]
@ -288,14 +289,12 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v) attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim) attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b) attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn x = x + attn
x = F.layer_norm( x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x) x = x + self.mlp.forward(x)
x = F.layer_norm( x = F.layer_norm(
x, x,
@ -306,48 +305,46 @@ class T2SBlock:
) )
return x, k_cache, v_cache return x, k_cache, v_cache
@torch.jit.script @torch.jit.script
class T2STransformer: class T2STransformer:
def __init__(self, num_blocks : int, blocks: list[T2SBlock]): def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
self.num_blocks : int = num_blocks self.num_blocks: int = num_blocks
self.blocks = blocks self.blocks = blocks
def process_prompt( def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None): k_cache: list[torch.Tensor] = []
k_cache : list[torch.Tensor] = [] v_cache: list[torch.Tensor] = []
v_cache : list[torch.Tensor] = []
for i in range(self.num_blocks): for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask) x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
k_cache.append(k_cache_) k_cache.append(k_cache_)
v_cache.append(v_cache_) v_cache.append(v_cache_)
return x, k_cache, v_cache return x, k_cache, v_cache
def decode_next_token( def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
self, x:torch.Tensor,
k_cache: list[torch.Tensor],
v_cache: list[torch.Tensor]):
for i in range(self.num_blocks): for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache return x, k_cache, v_cache
class VitsModel(nn.Module): class VitsModel(nn.Module):
def __init__(self, vits_path): def __init__(self, vits_path):
super().__init__() super().__init__()
# dict_s2 = torch.load(vits_path,map_location="cpu") # dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path) dict_s2 = torch.load(vits_path)
self.hps = dict_s2["config"] self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1" self.hps["model"]["version"] = "v1"
else: else:
self.hps["model"]["version"] = "v2" self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps) self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz" self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn( self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1, self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length, self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers, n_speakers=self.hps.data.n_speakers,
**self.hps.model **self.hps.model,
) )
self.vq_model.eval() self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False) self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
@ -359,12 +356,13 @@ class VitsModel(nn.Module):
self.hps.data.sampling_rate, self.hps.data.sampling_rate,
self.hps.data.hop_length, self.hps.data.hop_length,
self.hps.data.win_length, self.hps.data.win_length,
center=False center=False,
) )
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0] return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
class T2SModel(nn.Module): class T2SModel(nn.Module):
def __init__(self,raw_t2s:Text2SemanticLightningModule): def __init__(self, raw_t2s: Text2SemanticLightningModule):
super(T2SModel, self).__init__() super(T2SModel, self).__init__()
self.model_dim = raw_t2s.model.model_dim self.model_dim = raw_t2s.model.model_dim
self.embedding_dim = raw_t2s.model.embedding_dim self.embedding_dim = raw_t2s.model.embedding_dim
@ -373,7 +371,7 @@ class T2SModel(nn.Module):
self.vocab_size = raw_t2s.model.vocab_size self.vocab_size = raw_t2s.model.vocab_size
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
# self.p_dropout = float(raw_t2s.model.p_dropout) # self.p_dropout = float(raw_t2s.model.p_dropout)
self.EOS:int = int(raw_t2s.model.EOS) self.EOS: int = int(raw_t2s.model.EOS)
self.norm_first = raw_t2s.model.norm_first self.norm_first = raw_t2s.model.norm_first
assert self.EOS == self.vocab_size - 1 assert self.EOS == self.vocab_size - 1
self.hz = 50 self.hz = 50
@ -383,7 +381,7 @@ class T2SModel(nn.Module):
self.ar_text_position = raw_t2s.model.ar_text_position self.ar_text_position = raw_t2s.model.ar_text_position
self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
self.ar_audio_position = raw_t2s.model.ar_audio_position self.ar_audio_position = raw_t2s.model.ar_audio_position
# self.t2s_transformer = T2STransformer(self.num_layers, blocks) # self.t2s_transformer = T2STransformer(self.num_layers, blocks)
# self.t2s_transformer = raw_t2s.model.t2s_transformer # self.t2s_transformer = raw_t2s.model.t2s_transformer
@ -392,12 +390,7 @@ class T2SModel(nn.Module):
for i in range(self.num_layers): for i in range(self.num_layers):
layer = h.layers[i] layer = h.layers[i]
t2smlp = T2SMLP( t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
block = T2SBlock( block = T2SBlock(
self.num_head, self.num_head,
@ -412,11 +405,11 @@ class T2SModel(nn.Module):
layer.norm1.eps, layer.norm1.eps,
layer.norm2.weight, layer.norm2.weight,
layer.norm2.bias, layer.norm2.bias,
layer.norm2.eps layer.norm2.eps,
) )
blocks.append(block) blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks) self.t2s_transformer = T2STransformer(self.num_layers, blocks)
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
@ -425,20 +418,27 @@ class T2SModel(nn.Module):
self.max_sec = raw_t2s.config["data"]["max_sec"] self.max_sec = raw_t2s.config["data"]["max_sec"]
self.top_k = int(raw_t2s.config["inference"]["top_k"]) self.top_k = int(raw_t2s.config["inference"]["top_k"])
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor,top_k:LongTensor): def forward(
self,
prompts: LongTensor,
ref_seq: LongTensor,
text_seq: LongTensor,
ref_bert: torch.Tensor,
text_bert: torch.Tensor,
top_k: LongTensor,
):
bert = torch.cat([ref_bert.T, text_bert.T], 1) bert = torch.cat([ref_bert.T, text_bert.T], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0) bert = bert.unsqueeze(0)
x = self.ar_text_embedding(all_phoneme_ids) x = self.ar_text_embedding(all_phoneme_ids)
x = x + self.bert_proj(bert.transpose(1, 2)) x = x + self.bert_proj(bert.transpose(1, 2))
x:torch.Tensor = self.ar_text_position(x) x: torch.Tensor = self.ar_text_position(x)
early_stop_num = self.early_stop_num early_stop_num = self.early_stop_num
# [1,N,512] [1,N]
#[1,N,512] [1,N]
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
y = prompts y = prompts
# x_example = x[:,:,0] * 0.0 # x_example = x[:,:,0] * 0.0
@ -464,15 +464,17 @@ class T2SModel(nn.Module):
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\ xy_attn_mask = (
.unsqueeze(0)\ torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.expand(bsz*self.num_head, -1, -1)\ .unsqueeze(0)
.view(bsz, self.num_head, src_len, src_len)\ .expand(bsz * self.num_head, -1, -1)
.to(device=x.device, dtype=torch.bool) .view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
idx = 0 idx = 0
top_k = int(top_k) top_k = int(top_k)
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
@ -480,23 +482,25 @@ class T2SModel(nn.Module):
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
y_emb = self.ar_audio_embedding(y[:, -1:]) y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
stop = False stop = False
# for idx in range(1, 50): # for idx in range(1, 50):
for idx in range(1, 1500): for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
if(idx<11):###至少预测出10个token不然不给停止0.4s if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1] logits = logits[:, :-1]
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
@ -507,20 +511,22 @@ class T2SModel(nn.Module):
break break
y_emb = self.ar_audio_embedding(y[:, -1:]) y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
y[0,-1] = 0 ].to(dtype=y_emb.dtype, device=y_emb.device)
y[0, -1] = 0
return y[:, -idx:].unsqueeze(0) return y[:, -idx:].unsqueeze(0)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
)
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path
@torch.jit.script @torch.jit.script
def build_phone_level_feature(res:Tensor, word2ph:IntTensor): def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
phone_level_feature = [] phone_level_feature = []
for i in range(word2ph.shape[0]): for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1) repeat_feature = res[i].repeat(word2ph[i].item(), 1)
@ -529,103 +535,111 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
# [sum(word2ph), 1024] # [sum(word2ph), 1024]
return phone_level_feature return phone_level_feature
class MyBertModel(torch.nn.Module): class MyBertModel(torch.nn.Module):
def __init__(self, bert_model): def __init__(self, bert_model):
super(MyBertModel, self).__init__() super(MyBertModel, self).__init__()
self.bert = bert_model self.bert = bert_model
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor): def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1] res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
return build_phone_level_feature(res, word2ph) return build_phone_level_feature(res, word2ph)
class SSLModel(torch.nn.Module): class SSLModel(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.ssl = cnhubert.get_model().model self.ssl = cnhubert.get_model().model
def forward(self, ref_audio_16k)-> torch.Tensor: def forward(self, ref_audio_16k) -> torch.Tensor:
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2) ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
return ssl_content return ssl_content
class ExportSSLModel(torch.nn.Module): class ExportSSLModel(torch.nn.Module):
def __init__(self,ssl:SSLModel): def __init__(self, ssl: SSLModel):
super().__init__() super().__init__()
self.ssl = ssl self.ssl = ssl
def forward(self, ref_audio:torch.Tensor): def forward(self, ref_audio: torch.Tensor):
return self.ssl(ref_audio) return self.ssl(ref_audio)
@torch.jit.export @torch.jit.export
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor: def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
audio = resamplex(ref_audio,src_sr,dst_sr).float() audio = resamplex(ref_audio, src_sr, dst_sr).float()
return audio return audio
def export_bert(output_path): def export_bert(output_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么." text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
ref_bert_inputs = tokenizer(text, return_tensors="pt") ref_bert_inputs = tokenizer(text, return_tensors="pt")
word2ph = [] word2ph = []
for c in text: for c in text:
if c in ['','','','',",",".","?"]: if c in ["", "", "", "", ",", ".", "?"]:
word2ph.append(1) word2ph.append(1)
else: else:
word2ph.append(2) word2ph.append(2)
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int() ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
my_bert_model = MyBertModel(bert_model) my_bert_model = MyBertModel(bert_model)
ref_bert_inputs = { ref_bert_inputs = {
'input_ids': ref_bert_inputs['input_ids'], "input_ids": ref_bert_inputs["input_ids"],
'attention_mask': ref_bert_inputs['attention_mask'], "attention_mask": ref_bert_inputs["attention_mask"],
'token_type_ids': ref_bert_inputs['token_type_ids'], "token_type_ids": ref_bert_inputs["token_type_ids"],
'word2ph': ref_bert_inputs['word2ph'] "word2ph": ref_bert_inputs["word2ph"],
} }
torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1) torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1) torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1) torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0) torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
output_path = os.path.join(output_path, "bert_model.pt") output_path = os.path.join(output_path, "bert_model.pt")
my_bert_model.save(output_path) my_bert_model.save(output_path)
print('#### exported bert ####') print("#### exported bert ####")
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
if not os.path.exists(output_path): if not os.path.exists(output_path):
os.makedirs(output_path) os.makedirs(output_path)
print(f"目录已创建: {output_path}") print(f"目录已创建: {output_path}")
else: else:
print(f"目录已存在: {output_path}") print(f"目录已存在: {output_path}")
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel() ssl = SSLModel()
if export_bert_and_ssl: if export_bert_and_ssl:
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio))) s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
ssl_path = os.path.join(output_path, "ssl_model.pt") ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path) torch.jit.script(s).save(ssl_path)
print('#### exported ssl ####') print("#### exported ssl ####")
export_bert(output_path) export_bert(output_path)
else: else:
s = ExportSSLModel(ssl) s = ExportSSLModel(ssl)
print(f"device: {device}") print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
ref_seq = torch.LongTensor([ref_seq_id]).to(device) ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T.to(ref_seq.device) ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2') text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device) text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T.to(text_seq.device) text_bert = text_bert_T.T.to(text_seq.device)
ssl_content = ssl(ref_audio).to(device) ssl_content = ssl(ref_audio).to(device)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path).to(device) vits = VitsModel(vits_path).to(device)
vits.eval() vits.eval()
@ -633,18 +647,18 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
# dict_s1 = torch.load(gpt_path, map_location=device) # dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path) dict_s1 = torch.load(gpt_path)
raw_t2s = get_raw_t2s_model(dict_s1).to(device) raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print('#### get_raw_t2s_model ####') print("#### get_raw_t2s_model ####")
print(raw_t2s.config) print(raw_t2s.config)
t2s_m = T2SModel(raw_t2s) t2s_m = T2SModel(raw_t2s)
t2s_m.eval() t2s_m.eval()
t2s = torch.jit.script(t2s_m).to(device) t2s = torch.jit.script(t2s_m).to(device)
print('#### script t2s_m ####') print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate) print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS(t2s,vits).to(device) gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
gpt_sovits.eval() gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio,16000,32000).to(device) ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
torch._dynamo.mark_dynamic(ssl_content, 2) torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1) torch._dynamo.mark_dynamic(ref_audio_sr, 1)
@ -657,32 +671,28 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
with torch.no_grad(): with torch.no_grad():
gpt_sovits_export = torch.jit.trace( gpt_sovits_export = torch.jit.trace(
gpt_sovits, gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
example_inputs=( )
ssl_content,
ref_audio_sr,
ref_seq,
text_seq,
ref_bert,
text_bert,
top_k))
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path) gpt_sovits_export.save(gpt_sovits_path)
print('#### exported gpt_sovits ####') print("#### exported gpt_sovits ####")
@torch.jit.script @torch.jit.script
def parse_audio(ref_audio): def parse_audio(ref_audio):
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()#.to(ref_audio.device) ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device) ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
return ref_audio_16k,ref_audio_sr return ref_audio_16k, ref_audio_sr
@torch.jit.script @torch.jit.script
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor: def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float() return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
class GPT_SoVITS(nn.Module): class GPT_SoVITS(nn.Module):
def __init__(self, t2s:T2SModel,vits:VitsModel): def __init__(self, t2s: T2SModel, vits: VitsModel):
super().__init__() super().__init__()
self.t2s = t2s self.t2s = t2s
self.vits = vits self.vits = vits
@ -709,12 +719,11 @@ class GPT_SoVITS(nn.Module):
def test(): def test():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file") parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument('--output_path', required=True, help="Path to the output directory") parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args() args = parser.parse_args()
gpt_path = args.gpt_model gpt_path = args.gpt_model
@ -725,7 +734,7 @@ def test():
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
# bert = MyBertModel(bert_model) # bert = MyBertModel(bert_model)
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
# dict_s1 = torch.load(gpt_path, map_location="cuda") # dict_s1 = torch.load(gpt_path, map_location="cuda")
# raw_t2s = get_raw_t2s_model(dict_s1) # raw_t2s = get_raw_t2s_model(dict_s1)
@ -739,78 +748,79 @@ def test():
# ssl = ExportSSLModel(SSLModel()).to('cuda') # ssl = ExportSSLModel(SSLModel()).to('cuda')
# ssl.eval() # ssl.eval()
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda') ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
# gpt_sovits = GPT_SoVITS(t2s,vits) # gpt_sovits = GPT_SoVITS(t2s,vits)
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda') gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2') ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
ref_seq = torch.LongTensor([ref_seq_id]) ref_seq = torch.LongTensor([ref_seq_id])
ref_bert = ref_bert_T.T.to(ref_seq.device) ref_bert = ref_bert_T.T.to(ref_seq.device)
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2') # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字." text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2') text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
test_bert = tokenizer(text, return_tensors="pt") test_bert = tokenizer(text, return_tensors="pt")
word2ph = [] word2ph = []
for c in text: for c in text:
if c in ['','','','',"?",",","."]: if c in ["", "", "", "", "?", ",", "."]:
word2ph.append(1) word2ph.append(1)
else: else:
word2ph.append(2) word2ph.append(2)
test_bert['word2ph'] = torch.Tensor(word2ph).int() test_bert["word2ph"] = torch.Tensor(word2ph).int()
test_bert = my_bert( test_bert = my_bert(
test_bert['input_ids'].to('cuda'), test_bert["input_ids"].to("cuda"),
test_bert['attention_mask'].to('cuda'), test_bert["attention_mask"].to("cuda"),
test_bert['token_type_ids'].to('cuda'), test_bert["token_type_ids"].to("cuda"),
test_bert['word2ph'].to('cuda') test_bert["word2ph"].to("cuda"),
) )
text_seq = torch.LongTensor([text_seq_id]) text_seq = torch.LongTensor([text_seq_id])
text_bert = text_bert_T.T.to(text_seq.device) text_bert = text_bert_T.T.to(text_seq.device)
print('text_bert:',text_bert.shape,text_bert) print("text_bert:", text_bert.shape, text_bert)
print('test_bert:',test_bert.shape,test_bert) print("test_bert:", test_bert.shape, test_bert)
print(torch.allclose(text_bert.to('cuda'),test_bert)) print(torch.allclose(text_bert.to("cuda"), test_bert))
print('text_seq:',text_seq.shape) print("text_seq:", text_seq.shape)
print('text_bert:',text_bert.shape,text_bert.type()) print("text_bert:", text_bert.shape, text_bert.type())
#[1,N] # [1,N]
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda') ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
print('ref_audio:',ref_audio.shape) print("ref_audio:", ref_audio.shape)
ref_audio_sr = ssl.resample(ref_audio,16000,32000) ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
print('start ssl') print("start ssl")
ssl_content = ssl(ref_audio) ssl_content = ssl(ref_audio)
print('start gpt_sovits:') print("start gpt_sovits:")
print('ssl_content:',ssl_content.shape) print("ssl_content:", ssl_content.shape)
print('ref_audio_sr:',ref_audio_sr.shape) print("ref_audio_sr:", ref_audio_sr.shape)
print('ref_seq:',ref_seq.shape) print("ref_seq:", ref_seq.shape)
ref_seq=ref_seq.to('cuda') ref_seq = ref_seq.to("cuda")
print('text_seq:',text_seq.shape) print("text_seq:", text_seq.shape)
text_seq=text_seq.to('cuda') text_seq = text_seq.to("cuda")
print('ref_bert:',ref_bert.shape) print("ref_bert:", ref_bert.shape)
ref_bert=ref_bert.to('cuda') ref_bert = ref_bert.to("cuda")
print('text_bert:',text_bert.shape) print("text_bert:", text_bert.shape)
text_bert=text_bert.to('cuda') text_bert = text_bert.to("cuda")
top_k = torch.LongTensor([5]).to('cuda') top_k = torch.LongTensor([5]).to("cuda")
with torch.no_grad(): with torch.no_grad():
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k) audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
print('start write wav') print("start write wav")
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000) soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
import text import text
import json import json
def export_symbel(version='v2'):
if version=='v1': def export_symbel(version="v2"):
if version == "v1":
symbols = text._symbol_to_id_v1 symbols = text._symbol_to_id_v1
with open("onnx/symbols_v1.json", "w") as file: with open("onnx/symbols_v1.json", "w") as file:
json.dump(symbols, file, indent=4) json.dump(symbols, file, indent=4)
@ -819,15 +829,16 @@ def export_symbel(version='v2'):
with open("onnx/symbols_v2.json", "w") as file: with open("onnx/symbols_v2.json", "w") as file:
json.dump(symbols, file, indent=4) json.dump(symbols, file, indent=4)
def main(): def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file") parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument('--output_path', required=True, help="Path to the output directory") parser.add_argument("--output_path", required=True, help="Path to the output directory")
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model") parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
parser.add_argument('--device', help="Device to use") parser.add_argument("--device", help="Device to use")
args = parser.parse_args() args = parser.parse_args()
export( export(
@ -840,9 +851,11 @@ def main():
export_bert_and_ssl=args.export_common_model, export_bert_and_ssl=args.export_common_model,
) )
import inference_webui import inference_webui
if __name__ == "__main__": if __name__ == "__main__":
inference_webui.is_half=False inference_webui.is_half = False
inference_webui.dtype=torch.float32 inference_webui.dtype = torch.float32
main() main()
# test() # test()

View File

@ -32,7 +32,6 @@ now_dir = os.getcwd()
class MelSpectrgram(torch.nn.Module): class MelSpectrgram(torch.nn.Module):
def __init__( def __init__(
self, self,
dtype, dtype,
@ -48,14 +47,12 @@ class MelSpectrgram(torch.nn.Module):
): ):
super().__init__() super().__init__()
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype) self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
mel = librosa_mel_fn( mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device) self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
self.n_fft:int = n_fft self.n_fft: int = n_fft
self.hop_size:int = hop_size self.hop_size: int = hop_size
self.win_size:int = win_size self.win_size: int = win_size
self.center:bool = center self.center: bool = center
def forward(self, y): def forward(self, y):
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
@ -172,9 +169,7 @@ class ExportCFM(torch.nn.Module):
): ):
T_min = fea_ref.size(2) T_min = fea_ref.size(2)
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_res = self.cfm( cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps
)
cfm_res = cfm_res[:, :, mel2.shape[2] :] cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:] mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:] fea_ref = fea_todo_chunk[:, :, -T_min:]
@ -198,6 +193,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
spec_min = -12 spec_min = -12
spec_max = 2 spec_max = 2
@torch.jit.script @torch.jit.script
def norm_spec(x): def norm_spec(x):
spec_min = -12 spec_min = -12
@ -212,7 +208,6 @@ def denorm_spec(x):
class ExportGPTSovitsHalf(torch.nn.Module): class ExportGPTSovitsHalf(torch.nn.Module):
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3): def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
super().__init__() super().__init__()
self.hps = hps self.hps = hps
@ -231,15 +226,15 @@ class ExportGPTSovitsHalf(torch.nn.Module):
center=False, center=False,
) )
# self.dtype = dtype # self.dtype = dtype
self.filter_length:int = hps.data.filter_length self.filter_length: int = hps.data.filter_length
self.sampling_rate:int = hps.data.sampling_rate self.sampling_rate: int = hps.data.sampling_rate
self.hop_length:int = hps.data.hop_length self.hop_length: int = hps.data.hop_length
self.win_length:int = hps.data.win_length self.win_length: int = hps.data.win_length
def forward( def forward(
self, self,
ssl_content, ssl_content,
ref_audio_32k:torch.FloatTensor, ref_audio_32k: torch.FloatTensor,
phoneme_ids0, phoneme_ids0,
phoneme_ids1, phoneme_ids1,
bert1, bert1,
@ -255,21 +250,17 @@ class ExportGPTSovitsHalf(torch.nn.Module):
center=False, center=False,
).to(ssl_content.dtype) ).to(ssl_content.dtype)
codes = self.vq_model.extract_latent(ssl_content) codes = self.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0)
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
pred_semantic = self.t2s_m( pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
ge = self.vq_model.create_ge(refer) ge = self.vq_model.create_ge(refer)
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
prompt_ = prompt.unsqueeze(0) prompt_ = prompt.unsqueeze(0)
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge) fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@ -293,6 +284,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
return fea_ref, fea_todo, mel2 return fea_ref, fea_todo, mel2
class GPTSoVITSV3(torch.nn.Module): class GPTSoVITSV3(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, bigvgan): def __init__(self, gpt_sovits_half, cfm, bigvgan):
super().__init__() super().__init__()
@ -303,9 +295,9 @@ class GPTSoVITSV3(torch.nn.Module):
def forward( def forward(
self, self,
ssl_content, ssl_content,
ref_audio_32k:torch.FloatTensor, ref_audio_32k: torch.FloatTensor,
phoneme_ids0:torch.LongTensor, phoneme_ids0: torch.LongTensor,
phoneme_ids1:torch.LongTensor, phoneme_ids1: torch.LongTensor,
bert1, bert1,
bert2, bert2,
top_k: torch.LongTensor, top_k: torch.LongTensor,
@ -313,7 +305,9 @@ class GPTSoVITSV3(torch.nn.Module):
): ):
# current_time = datetime.now() # current_time = datetime.now()
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S")) # print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 934 - fea_ref.shape[2] chunk_len = 934 - fea_ref.shape[2]
wav_gen_list = [] wav_gen_list = []
idx = 0 idx = 0
@ -331,7 +325,13 @@ class GPTSoVITSV3(torch.nn.Module):
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256 # 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
complete_len = chunk_len - fea_todo_chunk.shape[-1] complete_len = chunk_len - fea_todo_chunk.shape[-1]
if complete_len != 0: if complete_len != 0:
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype)], 2) fea_todo_chunk = torch.cat(
[
fea_todo_chunk,
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
],
2,
)
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps) cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
idx += chunk_len idx += chunk_len
@ -339,17 +339,17 @@ class GPTSoVITSV3(torch.nn.Module):
cfm_res = denorm_spec(cfm_res) cfm_res = denorm_spec(cfm_res)
bigvgan_res = self.bigvgan(cfm_res) bigvgan_res = self.bigvgan(cfm_res)
wav_gen_list.append(bigvgan_res) wav_gen_list.append(bigvgan_res)
wav_gen = torch.cat(wav_gen_list, 2) wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length] return wav_gen[0][0][:wav_gen_length]
def init_bigvgan(): def init_bigvgan():
global bigvgan_model global bigvgan_model
from BigVGAN import bigvgan from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained( bigvgan_model = bigvgan.BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
% (now_dir,),
use_cuda_kernel=False, use_cuda_kernel=False,
) # if True, RuntimeError: Ninja is required to load C++ extensions ) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
@ -467,10 +467,7 @@ def export_cfm(
cfm = e_cfm.cfm cfm = e_cfm.cfm
B, T = mu.size(0), mu.size(1) B, T = mu.size(0), mu.size(1)
x = ( x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype)
* temperature
)
print("x:", x.shape, x.dtype) print("x:", x.shape, x.dtype)
prompt_len = prompt.size(-1) prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x, dtype=mu.dtype) prompt_x = torch.zeros_like(x, dtype=mu.dtype)
@ -565,11 +562,7 @@ def export():
wav16k = wav16k.to(device) wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device) zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch]) wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = sovits.vq_model.extract_latent(ssl_content) codes = sovits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
@ -626,10 +619,7 @@ def export():
"create_ge": refer, "create_ge": refer,
} }
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
trace_vq_model = torch.jit.trace_module(
sovits.vq_model, inputs, optimize=True
)
trace_vq_model.save("onnx/ad/vq_model.pt") trace_vq_model.save("onnx/ad/vq_model.pt")
print(fea_ref.shape, fea_ref.dtype, ge.shape) print(fea_ref.shape, fea_ref.dtype, ge.shape)
@ -714,9 +704,7 @@ def export():
idx += chunk_len idx += chunk_len
cfm_res, fea_ref, mel2 = export_cfm_( cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
fea_ref, fea_todo_chunk, mel2, sample_steps
)
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
continue continue
@ -726,9 +714,7 @@ def export():
with torch.inference_mode(): with torch.inference_mode():
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype) cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
torch._dynamo.mark_dynamic(cmf_res_rand, 2) torch._dynamo.mark_dynamic(cmf_res_rand, 2)
bigvgan_model_ = torch.jit.trace( bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)
)
bigvgan_model_.save("onnx/ad/bigvgan_model.pt") bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
wav_gen = bigvgan_model(cmf_res) wav_gen = bigvgan_model(cmf_res)
print("wav_gen:", wav_gen.shape, wav_gen.dtype) print("wav_gen:", wav_gen.shape, wav_gen.dtype)
@ -748,7 +734,6 @@ def test_export(
bigvgan, bigvgan,
output, output,
): ):
# hps = sovits.hps # hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav" ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0 speed = 1.0
@ -773,13 +758,9 @@ def test_export(
wav16k = wav16k.to(device) wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device) zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch]) wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
"last_hidden_state"
].transpose(
1, 2
) # .float()
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000) ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float() ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
phones1, bert1, norm_text1 = get_phones_and_bert( phones1, bert1, norm_text1 = get_phones_and_bert(
@ -799,8 +780,18 @@ def test_export(
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("start inference %s", current_time) logger.info("start inference %s", current_time)
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape) print(
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) ssl_content.shape,
ref_audio_32k.shape,
phoneme_ids0.shape,
phoneme_ids1.shape,
bert1.shape,
bert2.shape,
top_k.shape,
)
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 934 - fea_ref.shape[2] chunk_len = 934 - fea_ref.shape[2]
print(fea_ref.shape, fea_todo.shape, mel2.shape) print(fea_ref.shape, fea_todo.shape, mel2.shape)
@ -812,7 +803,6 @@ def test_export(
wav_gen_length = fea_todo.shape[2] * 256 wav_gen_length = fea_todo.shape[2] * 256
while 1: while 1:
current_time = datetime.now() current_time = datetime.now()
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S")) print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
@ -861,7 +851,6 @@ def test_export1(
gpt_sovits_v3, gpt_sovits_v3,
output, output,
): ):
# hps = sovits.hps # hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav" ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0 speed = 1.0
@ -886,14 +875,10 @@ def test_export1(
wav16k = wav16k.to(device) wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device) zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch]) wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
"last_hidden_state"
].transpose(
1, 2
) # .float()
print("ssl_content:", ssl_content.shape, ssl_content.dtype) print("ssl_content:", ssl_content.shape, ssl_content.dtype)
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000) ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float() ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
phones1, bert1, norm_text1 = get_phones_and_bert( phones1, bert1, norm_text1 = get_phones_and_bert(
@ -913,11 +898,19 @@ def test_export1(
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("start inference %s", current_time) logger.info("start inference %s", current_time)
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape) print(
ssl_content.shape,
ref_audio_32k.shape,
phoneme_ids0.shape,
phoneme_ids1.shape,
bert1.shape,
bert2.shape,
top_k.shape,
)
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps) wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
print("wav_gen:", wav_gen.shape, wav_gen.dtype) print("wav_gen:", wav_gen.shape, wav_gen.dtype)
wav_gen = torch.cat([wav_gen,zero_wav_torch],0) wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
audio = wav_gen.cpu().detach().numpy() audio = wav_gen.cpu().detach().numpy()
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@ -929,17 +922,16 @@ import time
def test_(): def test_():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
# cfm = ExportCFM(sovits.cfm) # cfm = ExportCFM(sovits.cfm)
# cfm.cfm.estimator = dit # cfm.cfm.estimator = dit
sovits.cfm = None sovits.cfm = None
cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device) cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device)
# cfm = torch.jit.optimize_for_inference(cfm) # cfm = torch.jit.optimize_for_inference(cfm)
cfm = cfm.half().to(device) cfm = cfm.half().to(device)
cfm.eval() cfm.eval()
logger.info("cfm ok") logger.info("cfm ok")
@ -959,10 +951,7 @@ def test_():
# t2s_m.top_k = 15 # t2s_m.top_k = 15
logger.info("t2s_m ok") logger.info("t2s_m ok")
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
vq_model: torch.jit.ScriptModule = torch.jit.load(
"onnx/ad/vq_model.pt", map_location=device
)
# vq_model = torch.jit.optimize_for_inference(vq_model) # vq_model = torch.jit.optimize_for_inference(vq_model)
# vq_model = vq_model.half().to(device) # vq_model = vq_model.half().to(device)
vq_model.eval() vq_model.eval()
@ -1020,8 +1009,9 @@ def test_():
# "out2.wav", # "out2.wav",
# ) # )
def test_export_gpt_sovits_v3(): def test_export_gpt_sovits_v3():
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt",map_location=device) gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
# test_export1( # test_export1(
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
# gpt_sovits_v3, # gpt_sovits_v3,

View File

@ -27,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
from module.commons import sequence_mask from module.commons import sequence_mask
class TextEmbedding(nn.Module): class TextEmbedding(nn.Module):
def __init__(self, text_dim, conv_layers=0, conv_mult=2): def __init__(self, text_dim, conv_layers=0, conv_mult=2):
super().__init__() super().__init__()
@ -129,26 +130,24 @@ class DiT(nn.Module):
return ckpt_forward return ckpt_forward
def forward(#x, prompt_x, x_lens, t, style,cond def forward( # x, prompt_x, x_lens, t, style,cond
self,#d is channel,n is T self, # d is channel,n is T
x0: float["b n d"], # nosied input audio # noqa: F722 x0: float["b n d"], # nosied input audio # noqa: F722
cond0: float["b n d"], # masked cond audio # noqa: F722 cond0: float["b n d"], # masked cond audio # noqa: F722
x_lens, x_lens,
time: float["b"] | float[""], # time step # noqa: F821 F722 time: float["b"] | float[""], # time step # noqa: F821 F722
dt_base_bootstrap, dt_base_bootstrap,
text0, # : int["b nt"] # noqa: F722#####condition feature text0, # : int["b nt"] # noqa: F722#####condition feature
use_grad_ckpt=False, # bool use_grad_ckpt=False, # bool
###no-use ###no-use
drop_audio_cond=False, # cfg for cond audio drop_audio_cond=False, # cfg for cond audio
drop_text=False, # cfg for text drop_text=False, # cfg for text
# mask: bool["b n"] | None = None, # noqa: F722 # mask: bool["b n"] | None = None, # noqa: F722
): ):
x = x0.transpose(2, 1)
x=x0.transpose(2,1) cond = cond0.transpose(2, 1)
cond=cond0.transpose(2,1) text = text0.transpose(2, 1)
text=text0.transpose(2,1) mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
batch, seq_len = x.shape[0], x.shape[1] batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0: if time.ndim == 0:
@ -157,8 +156,8 @@ class DiT(nn.Module):
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time) t = self.time_embed(time)
dt = self.d_embed(dt_base_bootstrap) dt = self.d_embed(dt_base_bootstrap)
t+=dt t += dt
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len) rope = self.rotary_embed.forward_from_seq_len(seq_len)
@ -178,4 +177,4 @@ class DiT(nn.Module):
x = self.norm_out(x, t) x = self.norm_out(x, t)
output = self.proj_out(x) output = self.proj_out(x)
return output return output

View File

@ -391,6 +391,7 @@ class Attention(nn.Module):
# Attention processor # Attention processor
# from torch.nn.attention import SDPBackend # from torch.nn.attention import SDPBackend
# torch.backends.cuda.enable_flash_sdp(True) # torch.backends.cuda.enable_flash_sdp(True)
class AttnProcessor: class AttnProcessor:
@ -545,6 +546,7 @@ class JointAttnProcessor:
# DiT Block # DiT Block
class DiTBlock(nn.Module): class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__() super().__init__()

View File

@ -1,6 +1,3 @@
from . import cnhubert, whisper_enc from . import cnhubert, whisper_enc
content_module_map = { content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}
'cnhubert': cnhubert,
'whisper': whisper_enc
}

View File

@ -1,10 +1,11 @@
import torch import torch
import os import os
from transformers import logging as tf_logging from transformers import logging as tf_logging
tf_logging.set_verbosity_error() tf_logging.set_verbosity_error()
import logging import logging
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import ( from transformers import (
@ -19,21 +20,19 @@ cnhubert_base_path = None
class CNHubert(nn.Module): class CNHubert(nn.Module):
def __init__(self, base_path:str=None): def __init__(self, base_path: str = None):
super().__init__() super().__init__()
if base_path is None: if base_path is None:
base_path = cnhubert_base_path base_path = cnhubert_base_path
if os.path.exists(base_path):... if os.path.exists(base_path):
else:raise FileNotFoundError(base_path) ...
else:
raise FileNotFoundError(base_path)
self.model = HubertModel.from_pretrained(base_path, local_files_only=True) self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
base_path, local_files_only=True
)
def forward(self, x): def forward(self, x):
input_values = self.feature_extractor( input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"] feats = self.model(input_values)["last_hidden_state"]
return feats return feats

View File

@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
feature_len = mel.shape[-1] // 2 feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频" assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad(): with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[ feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
:1, :feature_len, :
].transpose(1, 2)
return feature return feature

View File

@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
i18n = I18nAuto() i18n = I18nAuto()
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
def synthesize(
GPT_model_path,
SoVITS_model_path,
ref_audio_path,
ref_text_path,
ref_language,
target_text_path,
target_language,
output_path,
):
# Read reference text # Read reference text
with open(ref_text_path, 'r', encoding='utf-8') as file: with open(ref_text_path, "r", encoding="utf-8") as file:
ref_text = file.read() ref_text = file.read()
# Read target text # Read target text
with open(target_text_path, 'r', encoding='utf-8') as file: with open(target_text_path, "r", encoding="utf-8") as file:
target_text = file.read() target_text = file.read()
# Change model weights # Change model weights
@ -21,12 +31,16 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
change_sovits_weights(sovits_path=SoVITS_model_path) change_sovits_weights(sovits_path=SoVITS_model_path)
# Synthesize audio # Synthesize audio
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path, synthesis_result = get_tts_wav(
prompt_text=ref_text, ref_wav_path=ref_audio_path,
prompt_language=i18n(ref_language), prompt_text=ref_text,
text=target_text, prompt_language=i18n(ref_language),
text_language=i18n(target_language), top_p=1, temperature=1) text=target_text,
text_language=i18n(target_language),
top_p=1,
temperature=1,
)
result_list = list(synthesis_result) result_list = list(synthesis_result)
if result_list: if result_list:
@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
sf.write(output_wav_path, last_audio_data, last_sampling_rate) sf.write(output_wav_path, last_audio_data, last_sampling_rate)
print(f"Audio saved to {output_wav_path}") print(f"Audio saved to {output_wav_path}")
def main(): def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file") parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio") parser.add_argument(
parser.add_argument('--target_text', required=True, help="Path to the target text file") "--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text") )
parser.add_argument('--output_path', required=True, help="Path to the output directory") parser.add_argument("--target_text", required=True, help="Path to the target text file")
parser.add_argument(
"--target_language",
required=True,
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
help="Language of the target text",
)
parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args() args = parser.parse_args()
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path) synthesize(
args.gpt_model,
args.sovits_model,
args.ref_audio,
args.ref_text,
args.ref_language,
args.target_text,
args.target_language,
args.output_path,
)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
import soundfile as sf import soundfile as sf
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.setWindowTitle('GPT-SoVITS GUI') self.setWindowTitle("GPT-SoVITS GUI")
self.setGeometry(800, 450, 950, 850) self.setGeometry(800, 450, 950, 850)
self.setStyleSheet(""" self.setStyleSheet("""
@ -61,11 +62,12 @@ class GPTSoVITSGUI(QMainWindow):
border: 1px solid #45a049; border: 1px solid #45a049;
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1); box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
} }
""") """)
license_text = ( license_text = (
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. " "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
)
license_label = QLabel(license_text) license_label = QLabel(license_text)
license_label.setWordWrap(True) license_label.setWordWrap(True)
@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
self.output_text = QTextEdit() self.output_text = QTextEdit()
self.output_text.setReadOnly(True) self.output_text.setReadOnly(True)
self.add_drag_drop_events([ self.add_drag_drop_events(
self.GPT_model_input, [
self.SoVITS_model_input, self.GPT_model_input,
self.ref_audio_input, self.SoVITS_model_input,
self.ref_text_input, self.ref_audio_input,
self.target_text_input, self.ref_text_input,
self.output_input, self.target_text_input,
]) self.output_input,
]
)
self.synthesize_button = QPushButton("合成") self.synthesize_button = QPushButton("合成")
self.synthesize_button.clicked.connect(self.synthesize) self.synthesize_button.clicked.connect(self.synthesize)
@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
def upload_ref_text(self): def upload_ref_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path: if file_path:
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, "r", encoding="utf-8") as file:
content = file.read() content = file.read()
self.ref_text_input.setText(content) self.ref_text_input.setText(content)
def upload_target_text(self): def upload_target_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path: if file_path:
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, "r", encoding="utf-8") as file:
content = file.read() content = file.read()
self.target_text_input.setText(content) self.target_text_input.setText(content)
@ -284,17 +288,19 @@ class GPTSoVITSGUI(QMainWindow):
change_sovits_weights(sovits_path=SoVITS_model_path) change_sovits_weights(sovits_path=SoVITS_model_path)
self.SoVITS_Path = SoVITS_model_path self.SoVITS_Path = SoVITS_model_path
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path, synthesis_result = get_tts_wav(
prompt_text=ref_text, ref_wav_path=ref_audio_path,
prompt_language=language_combobox, prompt_text=ref_text,
text=target_text, prompt_language=language_combobox,
text_language=target_language_combobox) text=target_text,
text_language=target_language_combobox,
)
result_list = list(synthesis_result) result_list = list(synthesis_result)
if result_list: if result_list:
last_sampling_rate, last_audio_data = result_list[-1] last_sampling_rate, last_audio_data = result_list[-1]
output_wav_path = os.path.join(output_path, "output.wav") output_wav_path = os.path.join(output_path, "output.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate) sf.write(output_wav_path, last_audio_data, last_sampling_rate)
result = "Audio saved to " + output_wav_path result = "Audio saved to " + output_wav_path
@ -303,8 +309,8 @@ class GPTSoVITSGUI(QMainWindow):
self.output_text.append("处理结果:\n" + result) self.output_text.append("处理结果:\n" + result)
if __name__ == '__main__': if __name__ == "__main__":
app = QApplication(sys.argv) app = QApplication(sys.argv)
mainWin = GPTSoVITSGUI() mainWin = GPTSoVITSGUI()
mainWin.show() mainWin.show()
sys.exit(app.exec_()) sys.exit(app.exec_())

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,19 @@
''' """
按中英混合识别 按中英混合识别
按日英混合识别 按日英混合识别
多语种启动切分识别语种 多语种启动切分识别语种
全部按中文识别 全部按中文识别
全部按英文识别 全部按英文识别
全部按日文识别 全部按日文识别
''' """
import random import random
import os import os
import re import re
import logging import logging
import json import json
import sys import sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir)) sys.path.append("%s/GPT_SoVITS" % (now_dir))
@ -27,8 +29,10 @@ import torch
try: try:
import gradio.analytics as analytics import gradio.analytics as analytics
analytics.version_check = lambda:None
except:... analytics.version_check = lambda: None
except:
...
infer_ttswebui = os.environ.get("infer_ttswebui", 9872) infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
@ -43,15 +47,15 @@ gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None) sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None) cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None) bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2") version = os.environ.get("version", "v2")
import gradio as gr import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
from TTS_infer_pack.text_segmentation_method import get_method from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto") language = os.environ.get("language", "Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
@ -68,30 +72,30 @@ else:
# device = "cpu" # device = "cpu"
dict_language_v1 = { dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别 i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变 i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别 i18n("日文"): "all_ja", # 全部按日文识别
i18n("中英混合"): "zh",#按中英混合识别####不变 i18n("中英混合"): "zh", # 按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变 i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种 i18n("多语种混合"): "auto", # 多语种启动切分识别语种
} }
dict_language_v2 = { dict_language_v2 = {
i18n("中文"): "all_zh",#全部按中文识别 i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变 i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别 i18n("日文"): "all_ja", # 全部按日文识别
i18n("粤语"): "all_yue",#全部按中文识别 i18n("粤语"): "all_yue", # 全部按中文识别
i18n("韩文"): "all_ko",#全部按韩文识别 i18n("韩文"): "all_ko", # 全部按韩文识别
i18n("中英混合"): "zh",#按中英混合识别####不变 i18n("中英混合"): "zh", # 按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变 i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("粤英混合"): "yue",#按粤英混合识别####不变 i18n("粤英混合"): "yue", # 按粤英混合识别####不变
i18n("韩英混合"): "ko",#按韩英混合识别####不变 i18n("韩英混合"): "ko", # 按韩英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种 i18n("多语种混合"): "auto", # 多语种启动切分识别语种
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种 i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
} }
dict_language = dict_language_v1 if version =='v1' else dict_language_v2 dict_language = dict_language_v1 if version == "v1" else dict_language_v2
cut_method = { cut_method = {
i18n("不切"):"cut0", i18n("不切"): "cut0",
i18n("凑四句一切"): "cut1", i18n("凑四句一切"): "cut1",
i18n("凑50字一切"): "cut2", i18n("凑50字一切"): "cut2",
i18n("按中文句号。切"): "cut3", i18n("按中文句号。切"): "cut3",
@ -118,22 +122,33 @@ gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path sovits_path = tts_config.vits_weights_path
version = tts_config.version version = tts_config.version
def inference(text, text_lang,
ref_audio_path,
aux_ref_audio_paths,
prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty, sample_steps, super_sampling,
):
def inference(
text,
text_lang,
ref_audio_path,
aux_ref_audio_paths,
prompt_text,
prompt_lang,
top_k,
top_p,
temperature,
text_split_method,
batch_size,
speed_factor,
ref_text_free,
split_bucket,
fragment_interval,
seed,
keep_random,
parallel_infer,
repetition_penalty,
sample_steps,
super_sampling,
):
seed = -1 if keep_random else seed seed = -1 if keep_random else seed
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1) actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
inputs={ inputs = {
"text": text, "text": text,
"text_lang": dict_language[text_lang], "text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path, "ref_audio_path": ref_audio_path,
@ -144,12 +159,12 @@ def inference(text, text_lang,
"top_p": top_p, "top_p": top_p,
"temperature": temperature, "temperature": temperature,
"text_split_method": cut_method[text_split_method], "text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size), "batch_size": int(batch_size),
"speed_factor":float(speed_factor), "speed_factor": float(speed_factor),
"split_bucket":split_bucket, "split_bucket": split_bucket,
"return_fragment":False, "return_fragment": False,
"fragment_interval":fragment_interval, "fragment_interval": fragment_interval,
"seed":actual_seed, "seed": actual_seed,
"parallel_infer": parallel_infer, "parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty, "repetition_penalty": repetition_penalty,
"sample_steps": int(sample_steps), "sample_steps": int(sample_steps),
@ -159,11 +174,12 @@ def inference(text, text_lang,
for item in tts_pipeline.run(inputs): for item in tts_pipeline.run(inputs):
yield item, actual_seed yield item, actual_seed
except NO_PROMPT_ERROR: except NO_PROMPT_ERROR:
gr.Warning(i18n('V3不支持无参考文本模式请填写参考文本')) gr.Warning(i18n("V3不支持无参考文本模式请填写参考文本"))
def custom_sort_key(s): def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分 # 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s) parts = re.split("(\d+)", s)
# 将数字部分转换为整数,非数字部分保持不变 # 将数字部分转换为整数,非数字部分保持不变
parts = [int(part) if part.isdigit() else part for part in parts] parts = [int(part) if part.isdigit() else part for part in parts]
return parts return parts
@ -171,52 +187,67 @@ def custom_sort_key(s):
def change_choices(): def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
"choices": sorted(GPT_names, key=custom_sort_key),
"__type__": "update",
}
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
_ =[[],[]] path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name = [
"GPT_SoVITS/pretrained_models/s2G488k.pth",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
path_sovits_v3,
]
pretrained_gpt_name = [
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
]
_ = [[], []]
for i in range(3): for i in range(3):
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i]) if os.path.exists(pretrained_gpt_name[i]):
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i]) _[0].append(pretrained_gpt_name[i])
pretrained_gpt_name,pretrained_sovits_name = _ if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name, pretrained_sovits_name = _
if os.path.exists("./weight.json"): if os.path.exists("./weight.json"):
pass pass
else: else:
with open("./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file) with open("./weight.json", "w", encoding="utf-8") as file:
json.dump({"GPT": {}, "SoVITS": {}}, file)
with open("./weight.json", 'r', encoding="utf-8") as file: with open("./weight.json", "r", encoding="utf-8") as file:
weight_data = file.read() weight_data = file.read()
weight_data=json.loads(weight_data) weight_data = json.loads(weight_data)
gpt_path = os.environ.get( gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name)) sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
sovits_path = os.environ.get( if isinstance(gpt_path, list):
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
if isinstance(gpt_path,list):
gpt_path = gpt_path[0] gpt_path = gpt_path[0]
if isinstance(sovits_path,list): if isinstance(sovits_path, list):
sovits_path = sovits_path[0] sovits_path = sovits_path[0]
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
for path in SoVITS_weight_root + GPT_weight_root:
os.makedirs(path, exist_ok=True)
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True)
def get_weights_names(GPT_weight_root, SoVITS_weight_root): def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names = [i for i in pretrained_sovits_name] SoVITS_names = [i for i in pretrained_sovits_name]
for path in SoVITS_weight_root: for path in SoVITS_weight_root:
for name in os.listdir(path): for name in os.listdir(path):
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name)) if name.endswith(".pth"):
SoVITS_names.append("%s/%s" % (path, name))
GPT_names = [i for i in pretrained_gpt_name] GPT_names = [i for i in pretrained_gpt_name]
for path in GPT_weight_root: for path in GPT_weight_root:
for name in os.listdir(path): for name in os.listdir(path):
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name)) if name.endswith(".ckpt"):
GPT_names.append("%s/%s" % (path, name))
return SoVITS_names, GPT_names return SoVITS_names, GPT_names
@ -224,72 +255,110 @@ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast from process_ckpt import get_sovits_version_from_path_fast
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global version, dict_language global version, dict_language
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 and not os.path.exists(path_sovits_v3): if if_lora_v3 and not os.path.exists(path_sovits_v3):
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info) gr.Warning(info)
raise FileExistsError(info) raise FileExistsError(info)
tts_pipeline.init_vits_weights(sovits_path) tts_pipeline.init_vits_weights(sovits_path)
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2 dict_language = dict_language_v1 if tts_pipeline.configs.version == "v1" else dict_language_v2
if prompt_language is not None and text_language is not None: if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()): if prompt_language in list(dict_language.keys()):
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language} prompt_text_update, prompt_language_update = (
{"__type__": "update"},
{"__type__": "update", "value": prompt_language},
)
else: else:
prompt_text_update = {'__type__':'update', 'value':''} prompt_text_update = {"__type__": "update", "value": ""}
prompt_language_update = {'__type__':'update', 'value':i18n("中文")} prompt_language_update = {"__type__": "update", "value": i18n("中文")}
if text_language in list(dict_language.keys()): if text_language in list(dict_language.keys()):
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language} text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
else: else:
text_update = {'__type__':'update', 'value':''} text_update = {"__type__": "update", "value": ""}
text_language_update = {'__type__':'update', 'value':i18n("中文")} text_language_update = {"__type__": "update", "value": i18n("中文")}
if model_version=="v3": if model_version == "v3":
visible_sample_steps=True visible_sample_steps = True
visible_inp_refs=False visible_inp_refs = False
else: else:
visible_sample_steps=False visible_sample_steps = False
visible_inp_refs=True visible_inp_refs = True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False} yield (
{"__type__": "update", "choices": list(dict_language.keys())},
{"__type__": "update", "choices": list(dict_language.keys())},
prompt_text_update,
prompt_language_update,
text_update,
text_language_update,
{"__type__": "update", "visible": visible_sample_steps},
{"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False},
{"__type__": "update", "visible": True if model_version == "v3" else False},
)
with open("./weight.json") as f:
data = f.read()
data = json.loads(data)
data["SoVITS"][version] = sovits_path
with open("./weight.json", "w") as f:
f.write(json.dumps(data))
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
data["SoVITS"][version]=sovits_path
with open("./weight.json","w")as f:f.write(json.dumps(data))
with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown( gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
+ "<br>"
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
) )
with gr.Column(): with gr.Column():
# with gr.Group(): # with gr.Group():
gr.Markdown(value=i18n("模型切换")) gr.Markdown(value=i18n("模型切换"))
with gr.Row(): with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True) GPT_dropdown = gr.Dropdown(
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True) label=i18n("GPT模型列表"),
choices=sorted(GPT_names, key=custom_sort_key),
value=gpt_path,
interactive=True,
)
SoVITS_dropdown = gr.Dropdown(
label=i18n("SoVITS模型列表"),
choices=sorted(SoVITS_names, key=custom_sort_key),
value=sovits_path,
interactive=True,
)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary") refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息")) gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row(): with gr.Row():
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"), type="filepath") inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"), type="filepath")
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple") inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"), file_count="multiple")
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2) prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
with gr.Row(): with gr.Row():
prompt_language = gr.Dropdown( prompt_language = gr.Dropdown(
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文") label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
) )
with gr.Column(): with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True) ref_text_free = gr.Checkbox(
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"<br>"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")) label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
value=False,
interactive=True,
show_label=True,
)
gr.Markdown(
i18n("使用无参考文本模式时建议使用微调的GPT")
+ "<br>"
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
)
with gr.Column(): with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式")) gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
@ -298,42 +367,66 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文") label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
) )
with gr.Group(): with gr.Group():
gr.Markdown(value=i18n("推理设置")) gr.Markdown(value=i18n("推理设置"))
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True) batch_size = gr.Slider(
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True) minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
)
sample_steps = gr.Radio(
label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True
)
with gr.Row(): with gr.Row():
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True) fragment_interval = gr.Slider(
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True) minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
)
speed_factor = gr.Slider(
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
)
with gr.Row(): with gr.Row():
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
with gr.Row(): with gr.Row():
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) temperature = gr.Slider(
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True) minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
)
repetition_penalty = gr.Slider(
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
)
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
how_to_cut = gr.Dropdown( how_to_cut = gr.Dropdown(
label=i18n("怎么切"), label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ], choices=[
value=i18n("凑四句一切"), i18n("不切"),
interactive=True, scale=1 i18n("凑四句一切"),
) i18n("凑50字一切"),
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True) i18n("按中文句号。切"),
i18n("按英文句号.切"),
i18n("按标点符号切"),
],
value=i18n("凑四句一切"),
interactive=True,
scale=1,
)
super_sampling = gr.Checkbox(
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
)
with gr.Row(): with gr.Row():
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True) parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True) split_bucket = gr.Checkbox(
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
value=True,
interactive=True,
show_label=True,
)
with gr.Row(): with gr.Row():
seed = gr.Number(label=i18n("随机种子"), value=-1)
seed = gr.Number(label=i18n("随机种子"),value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
output = gr.Audio(label=i18n("输出的语音")) output = gr.Audio(label=i18n("输出的语音"))
@ -341,40 +434,67 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button = gr.Button(i18n("合成语音"), variant="primary") inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary") stop_infer = gr.Button(i18n("终止合成"), variant="primary")
inference_button.click( inference_button.click(
inference, inference,
[ [
text,text_language, inp_ref, inp_refs, text,
prompt_text, prompt_language, text_language,
top_k, top_p, temperature, inp_ref,
how_to_cut, batch_size, inp_refs,
speed_factor, ref_text_free, prompt_text,
split_bucket,fragment_interval, prompt_language,
seed, keep_random, parallel_infer, top_k,
repetition_penalty, sample_steps, super_sampling, top_p,
], temperature,
how_to_cut,
batch_size,
speed_factor,
ref_text_free,
split_bucket,
fragment_interval,
seed,
keep_random,
parallel_infer,
repetition_penalty,
sample_steps,
super_sampling,
],
[output, seed], [output, seed],
) )
stop_infer.click(tts_pipeline.stop, [], []) stop_infer.click(tts_pipeline.stop, [], [])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language]) SoVITS_dropdown.change(
change_sovits_weights,
[SoVITS_dropdown, prompt_language, text_language],
[prompt_language, text_language, prompt_text, prompt_language, text, text_language],
)
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], []) GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
with gr.Group(): with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")) gr.Markdown(
value=i18n(
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
)
)
with gr.Row(): with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4) text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column(): with gr.Column():
_how_to_cut = gr.Radio( _how_to_cut = gr.Radio(
label=i18n("怎么切"), label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ], choices=[
value=i18n("凑四句一切"), i18n("不切"),
interactive=True, i18n("凑四句一切"),
) i18n("凑50字一切"),
cut_text= gr.Button(i18n("切分"), variant="primary") i18n("按中文句号。切"),
i18n("按英文句号.切"),
i18n("按标点符号切"),
],
value=i18n("凑四句一切"),
interactive=True,
)
cut_text = gr.Button(i18n("切分"), variant="primary")
def to_cut(text_inp, how_to_cut): def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]: if len(text_inp.strip()) == 0 or text_inp == []:
return "" return ""
method = get_method(cut_method[how_to_cut]) method = get_method(cut_method[how_to_cut])
return method(text_inp) return method(text_inp)
@ -383,8 +503,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt]) cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")) gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
if __name__ == '__main__': if __name__ == "__main__":
app.queue().launch(#concurrency_count=511, max_size=1022 app.queue().launch( # concurrency_count=511, max_size=1022
server_name="0.0.0.0", server_name="0.0.0.0",
inbrowser=True, inbrowser=True,
share=is_share, share=is_share,

View File

@ -18,7 +18,7 @@ class Encoder(nn.Module):
p_dropout=0.0, p_dropout=0.0,
window_size=4, window_size=4,
isflow=False, isflow=False,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -56,9 +56,7 @@ class Encoder(nn.Module):
) )
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
if isflow: if isflow:
cond_layer = torch.nn.Conv1d( cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight") self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"] self.gin_channels = kwargs["gin_channels"]
@ -74,9 +72,7 @@ class Encoder(nn.Module):
x = self.cond_pre(x) x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply( x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
x, g_l, torch.IntTensor([self.hidden_channels])
)
y = self.attn_layers[i](x, x, attn_mask) y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_1[i](x + y) x = self.norm_layers_1[i](x + y)
@ -99,7 +95,7 @@ class Decoder(nn.Module):
p_dropout=0.0, p_dropout=0.0,
proximal_bias=False, proximal_bias=False,
proximal_init=True, proximal_init=True,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -131,9 +127,7 @@ class Decoder(nn.Module):
) )
self.norm_layers_0.append(LayerNorm(hidden_channels)) self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append( self.encdec_attn_layers.append(
MultiHeadAttention( MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
) )
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append( self.ffn_layers.append(
@ -153,9 +147,7 @@ class Decoder(nn.Module):
x: decoder input x: decoder input
h: encoder output h: encoder output
""" """
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for i in range(self.n_layers):
@ -211,14 +203,8 @@ class MultiHeadAttention(nn.Module):
if window_size is not None: if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5 rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter( self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight) nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight) nn.init.xavier_uniform_(self.conv_k.weight)
@ -247,46 +233,28 @@ class MultiHeadAttention(nn.Module):
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None: if self.window_size is not None:
assert ( assert t_s == t_t, "Relative attention is only available for self-attention."
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys( rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits) scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local scores = scores + scores_local
if self.proximal_bias: if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention." assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to( scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
device=scores.device, dtype=scores.dtype
)
if mask is not None: if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None: if self.block_length is not None:
assert ( assert t_s == t_t, "Local attention is only available for self-attention."
t_s == t_t block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4) scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn) p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value) output = torch.matmul(p_attn, value)
if self.window_size is not None: if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn) relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings( value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
self.emb_rel_v, t_s output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
) output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn return output, p_attn
def _matmul_with_relative_values(self, x, y): def _matmul_with_relative_values(self, x, y):
@ -320,9 +288,7 @@ class MultiHeadAttention(nn.Module):
) )
else: else:
padded_relative_embeddings = relative_embeddings padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
:, slice_start_position:slice_end_position
]
return used_relative_embeddings return used_relative_embeddings
def _relative_position_to_absolute_position(self, x): def _relative_position_to_absolute_position(self, x):
@ -336,14 +302,10 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1). # Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length]) x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad( x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements. # Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
:, :, :length, length - 1 :
]
return x_final return x_final
def _absolute_position_to_relative_position(self, x): def _absolute_position_to_relative_position(self, x):
@ -353,9 +315,7 @@ class MultiHeadAttention(nn.Module):
""" """
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# padd along column # padd along column
x = F.pad( x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape # add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
@ -537,9 +497,7 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
def weight_norm_modules(module, name="weight", dim=0): def weight_norm_modules(module, name="weight", dim=0):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module, Depthwise_Separable_TransposeConv1D
):
module.weight_norm() module.weight_norm()
return module return module
else: else:
@ -547,9 +505,7 @@ def weight_norm_modules(module, name="weight", dim=0):
def remove_weight_norm_modules(module, name="weight"): def remove_weight_norm_modules(module, name="weight"):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module, Depthwise_Separable_TransposeConv1D
):
module.remove_weight_norm() module.remove_weight_norm()
else: else:
remove_weight_norm(module, name) remove_weight_norm(module, name)
@ -567,7 +523,7 @@ class FFT(nn.Module):
proximal_bias=False, proximal_bias=False,
proximal_init=True, proximal_init=True,
isflow=False, isflow=False,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -579,9 +535,7 @@ class FFT(nn.Module):
self.proximal_bias = proximal_bias self.proximal_bias = proximal_bias
self.proximal_init = proximal_init self.proximal_init = proximal_init
if isflow: if isflow:
cond_layer = torch.nn.Conv1d( cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight") self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"] self.gin_channels = kwargs["gin_channels"]
@ -622,18 +576,14 @@ class FFT(nn.Module):
if g is not None: if g is not None:
g = self.cond_layer(g) g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
device=x.device, dtype=x.dtype
)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for i in range(self.n_layers):
if g is not None: if g is not None:
x = self.cond_pre(x) x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply( x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
x, g_l, torch.IntTensor([self.hidden_channels])
)
y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_0[i](x + y) x = self.norm_layers_0[i](x + y)

View File

@ -7,6 +7,7 @@ from module import commons
from typing import Optional from typing import Optional
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5): def __init__(self, channels, eps=1e-5):
super().__init__() super().__init__()
@ -43,7 +44,7 @@ class Encoder(nn.Module):
p_dropout=0.0, p_dropout=0.0,
window_size=4, window_size=4,
isflow=True, isflow=True,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -65,13 +66,9 @@ class Encoder(nn.Module):
if self.gin_channels != 0: if self.gin_channels != 0:
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default # vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = ( self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
)
logging.debug(self.gin_channels, self.cond_layer_idx) logging.debug(self.gin_channels, self.cond_layer_idx)
assert ( assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
self.cond_layer_idx < self.n_layers
), "cond_layer_idx should be less than n_layers"
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList() self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList() self.norm_layers_1 = nn.ModuleList()
@ -117,11 +114,13 @@ class Encoder(nn.Module):
# x = self.norm_layers_2[i](x + y) # x = self.norm_layers_2[i](x + y)
# x = x * x_mask # x = x * x_mask
# return x # return x
def forward(self, x, x_mask): def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask x = x * x_mask
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2): for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
):
y = attn_layers(x, x, attn_mask) y = attn_layers(x, x, attn_mask)
y = self.drop(y) y = self.drop(y)
x = norm_layers_1(x + y) x = norm_layers_1(x + y)
@ -170,14 +169,8 @@ class MultiHeadAttention(nn.Module):
if window_size is not None: if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5 rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter( self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight) nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight) nn.init.xavier_uniform_(self.conv_k.weight)
@ -187,7 +180,7 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.copy_(self.conv_q.weight) self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias) self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None): def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
q = self.conv_q(x) q = self.conv_q(x)
k = self.conv_k(c) k = self.conv_k(c)
v = self.conv_v(c) v = self.conv_v(c)
@ -198,7 +191,7 @@ class MultiHeadAttention(nn.Module):
x = self.conv_o(x) x = self.conv_o(x)
return x return x
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None): def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
# reshape [b, d, t] -> [b, n_h, t, d_k] # reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, _ = (*key.size(), query.size(2)) b, d, t_s, _ = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
@ -223,8 +216,8 @@ class MultiHeadAttention(nn.Module):
relative_weights = self._absolute_position_to_relative_position(p_attn) relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = (output.transpose(2, 3).contiguous().view(b, d, -1)) output = output.transpose(2, 3).contiguous().view(b, d, -1)
return output, p_attn return output, p_attn
def _matmul_with_relative_values(self, x, y): def _matmul_with_relative_values(self, x, y):
@ -248,19 +241,17 @@ class MultiHeadAttention(nn.Module):
def _get_relative_embeddings(self, relative_embeddings, length): def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1 max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops. # Pad first before slice to avoid using cond ops.
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1) pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64)) pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64)) slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
slice_end_position = slice_start_position + 2 * length - 1 slice_end_position = slice_start_position + 2 * length - 1
padded_relative_embeddings = F.pad( padded_relative_embeddings = F.pad(
relative_embeddings, relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
) )
used_relative_embeddings = padded_relative_embeddings[ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
:, slice_start_position:slice_end_position
]
return used_relative_embeddings return used_relative_embeddings
def _relative_position_to_absolute_position(self, x): def _relative_position_to_absolute_position(self, x):
@ -274,14 +265,10 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1). # Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length]) x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad( x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements. # Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
:, :, :length, length - 1 :
]
return x_final return x_final
def _absolute_position_to_relative_position(self, x): def _absolute_position_to_relative_position(self, x):
@ -291,9 +278,7 @@ class MultiHeadAttention(nn.Module):
""" """
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# padd along column # padd along column
x = F.pad( x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape # add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
@ -351,7 +336,7 @@ class FFN(nn.Module):
x = self.drop(x) x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask)) x = self.conv_2(self.padding(x * x_mask))
return x * x_mask return x * x_mask
def padding(self, x): def padding(self, x):
return self._same_padding(x) return self._same_padding(x)
@ -395,12 +380,6 @@ class MRTE(nn.Module):
ssl_enc = self.c_pre(ssl_enc * ssl_mask) ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask) text_enc = self.text_pre(text * text_mask)
x = ( x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask) x = self.c_post(x * ssl_mask)
return x return x

View File

@ -28,9 +28,7 @@ def intersperse(lst, item):
def kl_divergence(m_p, logs_p, m_q, logs_q): def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)""" """KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5 kl = (logs_q - logs_p) - 0.5
kl += ( kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl return kl
@ -67,9 +65,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float) position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2 num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp( inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
) )

View File

@ -30,6 +30,7 @@
# SOFTWARE. # SOFTWARE.
"""Core vector quantization implementation.""" """Core vector quantization implementation."""
import typing as tp import typing as tp
from einops import rearrange, repeat from einops import rearrange, repeat
@ -121,9 +122,7 @@ class EuclideanCodebook(nn.Module):
): ):
super().__init__() super().__init__()
self.decay = decay self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
uniform_init if not kmeans_init else torch.zeros
)
embed = init_fn(codebook_size, dim) embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size self.codebook_size = codebook_size
@ -151,9 +150,7 @@ class EuclideanCodebook(nn.Module):
# broadcast_tensors(self.buffers()) # broadcast_tensors(self.buffers())
def replace_(self, samples, mask): def replace_(self, samples, mask):
modified_codebook = torch.where( modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook) self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples): def expire_codes_(self, batch_samples):
@ -174,11 +171,7 @@ class EuclideanCodebook(nn.Module):
def quantize(self, x): def quantize(self, x):
embed = self.embed.t() embed = self.embed.t()
dist = -( dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices embed_ind = dist.max(dim=-1).indices
return embed_ind return embed_ind
@ -222,8 +215,7 @@ class EuclideanCodebook(nn.Module):
embed_sum = x.t() @ embed_onehot embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = ( cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
* self.cluster_size.sum()
) )
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized) self.embed.data.copy_(embed_normalized)
@ -264,12 +256,8 @@ class VectorQuantization(nn.Module):
_codebook_dim: int = default(codebook_dim, dim) _codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim requires_projection = _codebook_dim != dim
self.project_in = ( self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.epsilon = epsilon self.epsilon = epsilon
self.commitment_weight = commitment_weight self.commitment_weight = commitment_weight
@ -330,13 +318,9 @@ class ResidualVectorQuantization(nn.Module):
def __init__(self, *, num_quantizers, **kwargs): def __init__(self, *, num_quantizers, **kwargs):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward( def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
quantized_out = 0.0 quantized_out = 0.0
residual = x residual = x
@ -359,9 +343,7 @@ class ResidualVectorQuantization(nn.Module):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized return quantized_out, out_indices, out_losses, out_quantized
def encode( def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
residual = x residual = x
all_indices = [] all_indices = []
n_q = n_q or len(self.layers) n_q = n_q or len(self.layers)

View File

@ -5,11 +5,14 @@ import torch
import torch.utils.data import torch.utils.data
from tqdm import tqdm from tqdm import tqdm
from module.mel_processing import spectrogram_torch,spec_to_mel_torch from module.mel_processing import spectrogram_torch, spec_to_mel_torch
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
import torch.nn.functional as F import torch.nn.functional as F
from tools.my_utils import load_audio from tools.my_utils import load_audio
version = os.environ.get('version',None)
version = os.environ.get("version", None)
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79) # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
class TextAudioSpeakerLoader(torch.utils.data.Dataset): class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
@ -34,7 +37,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for line in lines: for line in lines:
tmp = line.split("\t") tmp = line.split("\t")
if (len(tmp) != 4): if len(tmp) != 4:
continue continue
self.phoneme_data[tmp[0]] = [tmp[1]] self.phoneme_data[tmp[0]] = [tmp[1]]
@ -42,7 +45,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text tmp = self.audiopaths_sid_text
leng = len(tmp) leng = len(tmp)
min_num = 100 min_num = 100
if (leng < min_num): if leng < min_num:
self.audiopaths_sid_text = [] self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))): for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp self.audiopaths_sid_text += tmp
@ -67,7 +70,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text): for audiopath in tqdm(self.audiopaths_sid_text):
try: try:
phoneme = self.phoneme_data[audiopath][0] phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ') phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version) phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception: except Exception:
print(f"{audiopath} not in self.phoneme_data !") print(f"{audiopath} not in self.phoneme_data !")
@ -102,7 +105,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath)) spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad(): with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]): if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False ssl.requires_grad = False
@ -120,8 +123,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audio = torch.FloatTensor(audio_array) # /32768 audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, spec = spectrogram_torch(
center=False) audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
return spec, audio_norm return spec, audio_norm
@ -137,12 +141,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
return len(self.audiopaths_sid_text) return len(self.audiopaths_sid_text)
def random_slice(self, ssl, wav, mel): def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
"first", ssl.shape, wav.shape)
len_mel = mel.shape[1] len_mel = mel.shape[1]
if self.val: if self.val:
reference_mel = mel[:, :len_mel // 3] reference_mel = mel[:, : len_mel // 3]
return reference_mel, ssl, wav, mel return reference_mel, ssl, wav, mel
dir = random.randint(0, 1) dir = random.randint(0, 1)
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2)) sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
@ -150,20 +153,29 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
if dir == 0: if dir == 0:
reference_mel = mel[:, :sep_point] reference_mel = mel[:, :sep_point]
ssl = ssl[:, :, sep_point:] ssl = ssl[:, :, sep_point:]
wav2 = wav[:, sep_point * self.hop_length:] wav2 = wav[:, sep_point * self.hop_length :]
mel = mel[:, sep_point:] mel = mel[:, sep_point:]
else: else:
reference_mel = mel[:, sep_point:] reference_mel = mel[:, sep_point:]
ssl = ssl[:, :, :sep_point] ssl = ssl[:, :, :sep_point]
wav2 = wav[:, :sep_point * self.hop_length] wav2 = wav[:, : sep_point * self.hop_length]
mel = mel[:, :sep_point] mel = mel[:, :sep_point]
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir) ssl.shape,
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
""" class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False): def __init__(self, return_ids=False):
self.return_ids = return_ids self.return_ids = return_ids
@ -175,9 +187,7 @@ class TextAudioSpeakerCollate():
batch: [text_normalized, spec_normalized, wav_normalized, sid] batch: [text_normalized, spec_normalized, wav_normalized, sid]
""" """
# Right zero-pad all one-hot text sequences to max input length # Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort( _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
max_ssl_len = max([x[0].size(2) for x in batch]) max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@ -205,22 +215,24 @@ class TextAudioSpeakerCollate():
row = batch[ids_sorted_decreasing[i]] row = batch[ids_sorted_decreasing[i]]
ssl = row[0] ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :] ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2) ssl_lengths[i] = ssl.size(2)
spec = row[1] spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1) spec_lengths[i] = spec.size(1)
wav = row[2] wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1) wav_lengths[i] = wav.size(1)
text = row[3] text = row[3]
text_padded[i, :text.size(0)] = text text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset): class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
""" """
1) loads audio, speaker_id, text pairs 1) loads audio, speaker_id, text pairs
@ -244,7 +256,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
for line in lines: for line in lines:
tmp = line.split("\t") tmp = line.split("\t")
if (len(tmp) != 4): if len(tmp) != 4:
continue continue
self.phoneme_data[tmp[0]] = [tmp[1]] self.phoneme_data[tmp[0]] = [tmp[1]]
@ -252,7 +264,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text tmp = self.audiopaths_sid_text
leng = len(tmp) leng = len(tmp)
min_num = 100 min_num = 100
if (leng < min_num): if leng < min_num:
self.audiopaths_sid_text = [] self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))): for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp self.audiopaths_sid_text += tmp
@ -277,7 +289,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text): for audiopath in tqdm(self.audiopaths_sid_text):
try: try:
phoneme = self.phoneme_data[audiopath][0] phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ') phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version) phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception: except Exception:
print(f"{audiopath} not in self.phoneme_data !") print(f"{audiopath} not in self.phoneme_data !")
@ -304,15 +316,16 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths self.lengths = lengths
self.spec_min=-12 self.spec_min = -12
self.spec_max=2 self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1024
self.hop_length_mel = 256
self.n_mel_channels = 100
self.sampling_rate_mel = 24000
self.mel_fmin = 0
self.mel_fmax = None
self.filter_length_mel=self.win_length_mel=1024
self.hop_length_mel=256
self.n_mel_channels=100
self.sampling_rate_mel=24000
self.mel_fmin=0
self.mel_fmax=None
def norm_spec(self, x): def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
@ -323,7 +336,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath)) spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad(): with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]): if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False ssl.requires_grad = False
@ -338,25 +351,35 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
return (ssl, spec, mel, text) return (ssl, spec, mel, text)
def get_audio(self, filename): def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768 audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio=torch.FloatTensor(audio_array)#/32768 audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速 audio_array24 = load_audio(
audio24=torch.FloatTensor(audio_array24)#/32768 filename, 24000
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24 = torch.FloatTensor(audio_array24) # /32768
audio_norm24 = audio24 audio_norm24 = audio24
audio_norm24 = audio_norm24.unsqueeze(0) audio_norm24 = audio_norm24.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length, spec = spectrogram_torch(
self.sampling_rate, self.hop_length, self.win_length, audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
center=False) )
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False) audio_norm24,
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax) self.filter_length_mel,
self.sampling_rate_mel,
self.hop_length_mel,
self.win_length_mel,
center=False,
)
mel = spec_to_mel_torch(
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
)
mel = torch.squeeze(mel, 0) mel = torch.squeeze(mel, 0)
mel=self.norm_spec(mel) mel = self.norm_spec(mel)
# print(1111111,spec.shape,mel.shape) # print(1111111,spec.shape,mel.shape)
return spec, mel return spec, mel
@ -370,9 +393,10 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self.audiopaths_sid_text) return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV3():
""" Zero-pads model inputs and targets
""" class TextAudioSpeakerCollateV3:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False): def __init__(self, return_ids=False):
self.return_ids = return_ids self.return_ids = return_ids
@ -383,12 +407,10 @@ class TextAudioSpeakerCollateV3():
------ ------
batch: [text_normalized, spec_normalized, wav_normalized, sid] batch: [text_normalized, spec_normalized, wav_normalized, sid]
""" """
#ssl, spec, wav,mel, text # ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length # Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort( _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
torch.LongTensor([x[1].size(1) for x in batch]), # (ssl, spec,mel, text)
dim=0, descending=True)
#(ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch]) max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1)) max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
@ -402,7 +424,7 @@ class TextAudioSpeakerCollateV3():
# max_wav_len = max([x[2].size(1) for x in batch]) # max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[3].size(0) for x in batch]) max_text_len = max([x[3].size(0) for x in batch])
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320 max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
ssl_lengths = torch.LongTensor(len(batch)) ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch)) spec_lengths = torch.LongTensor(len(batch))
@ -413,7 +435,7 @@ class TextAudioSpeakerCollateV3():
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len) mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len) text_padded = torch.LongTensor(len(batch), max_text_len)
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) # wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spec_padded.zero_() spec_padded.zero_()
@ -426,11 +448,11 @@ class TextAudioSpeakerCollateV3():
row = batch[ids_sorted_decreasing[i]] row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text # ssl, spec, wav,mel, text
ssl = row[0] ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :] ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2) ssl_lengths[i] = ssl.size(2)
spec = row[1] spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1) spec_lengths[i] = spec.size(1)
# wav = row[2] # wav = row[2]
@ -438,15 +460,17 @@ class TextAudioSpeakerCollateV3():
# wav_lengths[i] = wav.size(1) # wav_lengths[i] = wav.size(1)
mel = row[2] mel = row[2]
mel_padded[i, :, :mel.size(1)] = mel mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1) mel_lengths[i] = mel.size(1)
text = row[3] text = row[3]
text_padded[i, :text.size(0)] = text text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset): class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
""" """
1) loads audio, speaker_id, text pairs 1) loads audio, speaker_id, text pairs
@ -470,7 +494,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
for line in lines: for line in lines:
tmp = line.split("\t") tmp = line.split("\t")
if (len(tmp) != 4): if len(tmp) != 4:
continue continue
self.phoneme_data[tmp[0]] = [tmp[1]] self.phoneme_data[tmp[0]] = [tmp[1]]
@ -478,7 +502,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text tmp = self.audiopaths_sid_text
leng = len(tmp) leng = len(tmp)
min_num = 100 min_num = 100
if (leng < min_num): if leng < min_num:
self.audiopaths_sid_text = [] self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))): for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp self.audiopaths_sid_text += tmp
@ -503,7 +527,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text): for audiopath in tqdm(self.audiopaths_sid_text):
try: try:
phoneme = self.phoneme_data[audiopath][0] phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ') phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version) phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception: except Exception:
print(f"{audiopath} not in self.phoneme_data !") print(f"{audiopath} not in self.phoneme_data !")
@ -530,15 +554,16 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths self.lengths = lengths
self.spec_min=-12 self.spec_min = -12
self.spec_max=2 self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1024
self.hop_length_mel = 256
self.n_mel_channels = 100
self.sampling_rate_mel = 24000
self.mel_fmin = 0
self.mel_fmax = None
self.filter_length_mel=self.win_length_mel=1024
self.hop_length_mel=256
self.n_mel_channels=100
self.sampling_rate_mel=24000
self.mel_fmin=0
self.mel_fmax=None
def norm_spec(self, x): def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
@ -546,10 +571,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
audiopath, phoneme_ids = audiopath_sid_text audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids) text = torch.FloatTensor(phoneme_ids)
try: try:
spec, mel,wav = self.get_audio("%s/%s" % (self.path5, audiopath)) spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad(): with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu") ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]): if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False ssl.requires_grad = False
@ -564,27 +589,37 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
return (ssl, spec, wav, mel, text) return (ssl, spec, wav, mel, text)
def get_audio(self, filename): def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768 audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio=torch.FloatTensor(audio_array)#/32768 audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速 audio_array24 = load_audio(
audio24=torch.FloatTensor(audio_array24)#/32768 filename, 24000
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24 = torch.FloatTensor(audio_array24) # /32768
audio_norm24 = audio24 audio_norm24 = audio24
audio_norm24 = audio_norm24.unsqueeze(0) audio_norm24 = audio_norm24.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length, spec = spectrogram_torch(
self.sampling_rate, self.hop_length, self.win_length, audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
center=False) )
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False) audio_norm24,
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax) self.filter_length_mel,
self.sampling_rate_mel,
self.hop_length_mel,
self.win_length_mel,
center=False,
)
mel = spec_to_mel_torch(
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
)
mel = torch.squeeze(mel, 0) mel = torch.squeeze(mel, 0)
mel=self.norm_spec(mel) mel = self.norm_spec(mel)
# print(1111111,spec.shape,mel.shape) # print(1111111,spec.shape,mel.shape)
return spec, mel,audio_norm return spec, mel, audio_norm
def get_sid(self, sid): def get_sid(self, sid):
sid = torch.LongTensor([int(sid)]) sid = torch.LongTensor([int(sid)])
@ -596,9 +631,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self.audiopaths_sid_text) return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV3b():
""" Zero-pads model inputs and targets
""" class TextAudioSpeakerCollateV3b:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False): def __init__(self, return_ids=False):
self.return_ids = return_ids self.return_ids = return_ids
@ -609,12 +645,10 @@ class TextAudioSpeakerCollateV3b():
------ ------
batch: [text_normalized, spec_normalized, wav_normalized, sid] batch: [text_normalized, spec_normalized, wav_normalized, sid]
""" """
#ssl, spec, wav,mel, text # ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length # Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort( _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
torch.LongTensor([x[1].size(1) for x in batch]), # (ssl, spec,mel, text)
dim=0, descending=True)
#(ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch]) max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1)) max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
@ -627,7 +661,7 @@ class TextAudioSpeakerCollateV3b():
max_spec_len = int(2 * ((max_spec_len // 2) + 1)) max_spec_len = int(2 * ((max_spec_len // 2) + 1))
max_wav_len = max([x[2].size(1) for x in batch]) max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[4].size(0) for x in batch]) max_text_len = max([x[4].size(0) for x in batch])
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320 max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
ssl_lengths = torch.LongTensor(len(batch)) ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch)) spec_lengths = torch.LongTensor(len(batch))
@ -638,7 +672,7 @@ class TextAudioSpeakerCollateV3b():
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len) mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len) text_padded = torch.LongTensor(len(batch), max_text_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spec_padded.zero_() spec_padded.zero_()
@ -651,28 +685,40 @@ class TextAudioSpeakerCollateV3b():
row = batch[ids_sorted_decreasing[i]] row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text # ssl, spec, wav,mel, text
ssl = row[0] ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :] ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2) ssl_lengths[i] = ssl.size(2)
spec = row[1] spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1) spec_lengths[i] = spec.size(1)
wav = row[2] wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1) wav_lengths[i] = wav.size(1)
mel = row[3] mel = row[3]
mel_padded[i, :, :mel.size(1)] = mel mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1) mel_lengths[i] = mel.size(1)
text = row[4] text = row[4]
text_padded[i, :text.size(0)] = text text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths return (
ssl_padded,
spec_padded,
mel_padded,
ssl_lengths,
spec_lengths,
text_padded,
text_lengths,
wav_padded,
wav_lengths,
mel_lengths,
)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
""" """
Maintain similar input lengths in a batch. Maintain similar input lengths in a batch.
@ -736,12 +782,12 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
num_samples_bucket = self.num_samples_per_bucket[i] num_samples_bucket = self.num_samples_per_bucket[i]
rem = num_samples_bucket - len_bucket rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
ids_bucket = ids_bucket[self.rank::self.num_replicas] ids_bucket = ids_bucket[self.rank :: self.num_replicas]
for j in range(len(ids_bucket) // self.batch_size): for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]] batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
batches.append(batch) batches.append(batch)
if self.shuffle: if self.shuffle:
@ -768,4 +814,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
return -1 return -1
def __len__(self): def __len__(self):
return self.num_samples // self.batch_size return self.num_samples // self.batch_size

View File

@ -65,8 +65,6 @@ def mle_loss(z, m, logs, logdet, mask):
torch.exp(-2 * logs) * ((z - m) ** 2) torch.exp(-2 * logs) * ((z - m) ** 2)
) # neg normal likelihood w/o the constant term ) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum( l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
torch.ones_like(z) * mask
) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l return l

View File

@ -47,9 +47,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
@ -79,20 +77,14 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
dtype_device = str(spec.dtype) + "_" + str(spec.device) dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn( mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec) spec = spectral_normalize_torch(spec)
return spec return spec
def mel_spectrogram_torch( def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
@ -103,16 +95,10 @@ def mel_spectrogram_torch(
fmax_dtype_device = str(fmax) + "_" + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn( mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),

View File

@ -1,4 +1,5 @@
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import math import math
@ -15,6 +16,7 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from module.commons import init_weights, get_padding
from module.mrte_model import MRTE from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from text import symbols2 as symbols_v2
@ -46,29 +48,21 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2)) self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows): for i in range(n_flows):
self.flows.append( self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv( self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList() self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2)) self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4): for i in range(4):
self.post_flows.append( self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip()) self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv( self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1) self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -89,10 +83,7 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w) h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask) h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask h_w = self.post_proj(h_w) * x_mask
e_q = ( e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q z_q = e_q
for flow in self.post_flows: for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -100,13 +91,8 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1) z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask z0 = (w - u) * x_mask
logdet_tot_q += torch.sum( logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0 logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask) z0, logdet = self.log_flow(z0, x_mask)
@ -115,18 +101,12 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows: for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse) z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet logdet_tot = logdet_tot + logdet
nll = ( nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # [b] return nll + logq # [b]
else: else:
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = ( z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows: for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse) z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1) z0, z1 = torch.split(z, [1, 1], 1)
@ -135,9 +115,7 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
def __init__( def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -147,13 +125,9 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d( self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels) self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d( self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels) self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1) self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -188,7 +162,7 @@ class TextEncoder(nn.Module):
kernel_size, kernel_size,
p_dropout, p_dropout,
latent_channels=192, latent_channels=192,
version = "v2", version="v2",
): ):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
@ -235,26 +209,22 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None): def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y.dtype
)
y = self.ssl_proj(y * y_mask) * y_mask y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask) y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze( text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
commons.sequence_mask(text_lengths, text.size(1)), 1
).to(y.dtype)
if test == 1: if test == 1:
text[:, :] = 0 text[:, :] = 0
text = self.text_embedding(text).transpose(1, 2) text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask) text = self.encoder_text(text * text_mask, text_mask)
y = self.mrte(y, y_mask, text, text_mask, ge) y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask) y = self.encoder2(y * y_mask, y_mask)
if(speed!=1): if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear") y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
@ -358,9 +328,7 @@ class PosteriorEncoder(nn.Module):
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
if g != None: if g != None:
g = g.detach() g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
@ -370,14 +338,9 @@ class PosteriorEncoder(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, def __init__(
in_channels, self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
out_channels, ):
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -392,7 +355,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
if(g!=None): if g != None:
g = g.detach() g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
@ -400,6 +363,7 @@ class Encoder(nn.Module):
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
return stats, x_mask return stats, x_mask
class WNEncoder(nn.Module): class WNEncoder(nn.Module):
def __init__( def __init__(
self, self,
@ -432,9 +396,7 @@ class WNEncoder(nn.Module):
self.norm = modules.LayerNorm(out_channels) self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask out = self.proj(x) * x_mask
@ -457,9 +419,7 @@ class Generator(torch.nn.Module):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d( self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
@ -479,9 +439,7 @@ class Generator(torch.nn.Module):
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate( for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -634,9 +592,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11] periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs) self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat): def forward(self, y, y_hat):
@ -736,10 +692,7 @@ class Quantizer(torch.nn.Module):
super(Quantizer, self).__init__() super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0 assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList( self.quantizer_modules = nn.ModuleList(
[ [Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
) )
self.n_code_groups = n_code_groups self.n_code_groups = n_code_groups
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -757,9 +710,7 @@ class Quantizer(torch.nn.Module):
z_q.append(_z_q) z_q.append(_z_q)
min_indicies.append(_min_indicies) # B * T, min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape) z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean( loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
(z_q - xin.detach()) ** 2
)
z_q = xin + (z_q - xin).detach() z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2) z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
@ -799,13 +750,9 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder( self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
ssl_dim, style_vector_dim=hidden_channels
)
self.encoder = attentions.Encoder( self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1) self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q self.n_q = n_q
@ -818,9 +765,7 @@ class CodePredictor(nn.Module):
x = x + g x = x + g
x = self.encoder(x * x_mask, x_mask) x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose( logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
2, 3
)
target = codes[1:].transpose(0, 1) target = codes[1:].transpose(0, 1)
if not infer: if not infer:
logits = logits.reshape(-1, self.dims) logits = logits.reshape(-1, self.dims)
@ -868,8 +813,8 @@ class SynthesizerTrn(nn.Module):
use_sdp=True, use_sdp=True,
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
version = "v2", version="v2",
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
@ -900,7 +845,7 @@ class SynthesizerTrn(nn.Module):
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
version = version, version=version,
) )
self.dec = Generator( self.dec = Generator(
inter_channels, inter_channels,
@ -921,12 +866,10 @@ class SynthesizerTrn(nn.Module):
16, 16,
gin_channels=gin_channels, gin_channels=gin_channels,
) )
self.flow = ResidualCouplingBlock( self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
# self.version=os.environ.get("version","v1") # self.version=os.environ.get("version","v1")
if(self.version=="v1"): if self.version == "v1":
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else: else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
@ -943,13 +886,11 @@ class SynthesizerTrn(nn.Module):
self.freeze_quantizer = freeze_quantizer self.freeze_quantizer = freeze_quantizer
def forward(self, ssl, y, y_lengths, text, text_lengths): def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y.dtype if self.version == "v1":
)
if(self.version=="v1"):
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
else: else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask) ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
with autocast(enabled=False): with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad: with maybe_no_grad:
@ -957,24 +898,16 @@ class SynthesizerTrn(nn.Module):
self.ssl_proj.eval() self.ssl_proj.eval()
self.quantizer.eval() self.quantizer.eval()
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer( quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
ssl, layers=[0]
)
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":
quantized = F.interpolate( quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p( x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
quantized, y_lengths, text, text_lengths, ge
)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge) z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments( z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=ge) o = self.dec(z_slice, g=ge)
return ( return (
o, o,
@ -987,24 +920,18 @@ class SynthesizerTrn(nn.Module):
) )
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5): def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y.dtype if self.version == "v1":
)
if(self.version=="v1"):
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
else: else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask) ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0]) quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":
quantized = F.interpolate( quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p( x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
quantized, y_lengths, text, text_lengths, ge, test=test
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -1013,39 +940,34 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p) return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad() @torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5,speed=1): def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
def get_ge(refer): def get_ge(refer):
ge = None ge = None
if refer is not None: if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze( refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
commons.sequence_mask(refer_lengths, refer.size(2)), 1 if self.version == "v1":
).to(refer.dtype)
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
else: else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge return ge
if(type(refer)==list):
ges=[] if type(refer) == list:
ges = []
for _refer in refer: for _refer in refer:
ge=get_ge(_refer) ge = get_ge(_refer)
ges.append(ge) ges.append(ge)
ge=torch.stack(ges,0).mean(0) ge = torch.stack(ges, 0).mean(0)
else: else:
ge=get_ge(refer) ge = get_ge(refer)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":
quantized = F.interpolate( quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
quantized, size=int(quantized.shape[-1] * 2), mode="nearest" x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge,speed
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -1057,11 +979,10 @@ class SynthesizerTrn(nn.Module):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1) return codes.transpose(0, 1)
class CFM(torch.nn.Module): class CFM(torch.nn.Module):
def __init__( def __init__(self, in_channels, dit):
self,
in_channels,dit
):
super().__init__() super().__init__()
self.sigma_min = 1e-6 self.sigma_min = 1e-6
@ -1075,41 +996,54 @@ class CFM(torch.nn.Module):
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0): def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
"""Forward diffusion""" """Forward diffusion"""
B, T = mu.size(0), mu.size(1) B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
prompt_len = prompt.size(-1) prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype) prompt_x = torch.zeros_like(x, dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len] prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0 x[..., :prompt_len] = 0
mu=mu.transpose(2,1) mu = mu.transpose(2, 1)
t = 0 t = 0
d = 1 / n_timesteps d = 1 / n_timesteps
for j in range(n_timesteps): for j in range(n_timesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args) # v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1) v_pred = self.estimator(
if inference_cfg_rate>1e-5: x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) ).transpose(2, 1)
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate if inference_cfg_rate > 1e-5:
neg = self.estimator(
x,
prompt_x,
x_lens,
t_tensor,
d_tensor,
mu,
use_grad_ckpt=False,
drop_audio_cond=True,
drop_text=True,
).transpose(2, 1)
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
x = x + d * v_pred x = x + d * v_pred
t = t + d t = t + d
x[:, :, :prompt_len] = 0 x[:, :, :prompt_len] = 0
return x return x
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt): def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
b, _, t = x1.shape b, _, t = x1.shape
t = torch.rand([b], device=mu.device, dtype=x1.dtype) t = torch.rand([b], device=mu.device, dtype=x1.dtype)
x0 = torch.randn_like(x1,device=mu.device) x0 = torch.randn_like(x1, device=mu.device)
vt = x1 - x0 vt = x1 - x0
xt = x0 + t[:, None, None] * vt xt = x0 + t[:, None, None] * vt
dt = torch.zeros_like(t,device=mu.device) dt = torch.zeros_like(t, device=mu.device)
prompt = torch.zeros_like(x1) prompt = torch.zeros_like(x1)
for i in range(b): for i in range(b):
prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]] prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]]
xt[i, :, :prompt_lens[i]] = 0 xt[i, :, : prompt_lens[i]] = 0
gailv=0.3# if ttime()>1736250488 else 0.1 gailv = 0.3 # if ttime()>1736250488 else 0.1
if random.random() < gailv: if random.random() < gailv:
base = torch.randint(2, 8, (t.shape[0],), device=mu.device) base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
d = 1/torch.pow(2, base) d = 1 / torch.pow(2, base)
d_input = d.clone() d_input = d.clone()
d_input[d_input < 1e-2] = 0 d_input[d_input < 1e-2] = 0
# with torch.no_grad(): # with torch.no_grad():
@ -1117,52 +1051,55 @@ class CFM(torch.nn.Module):
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach() # v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
x_mid = xt + d[:, None, None] * v_pred_1 x_mid = xt + d[:, None, None] * v_pred_1
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach() # v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach() v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
vt = (v_pred_1 + v_pred_2) / 2 vt = (v_pred_1 + v_pred_2) / 2
vt = vt.detach() vt = vt.detach()
dt = 2*d dt = 2 * d
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1) vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1)
loss = 0 loss = 0
for i in range(b): for i in range(b):
loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]]) loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]])
loss /= b loss /= b
return loss return loss
def set_no_grad(net_g): def set_no_grad(net_g):
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
param.requires_grad=False param.requires_grad = False
class SynthesizerTrnV3(nn.Module): class SynthesizerTrnV3(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
""" """
def __init__(self, def __init__(
spec_channels, self,
segment_size, spec_channels,
inter_channels, segment_size,
hidden_channels, inter_channels,
filter_channels, hidden_channels,
n_heads, filter_channels,
n_layers, n_heads,
kernel_size, n_layers,
p_dropout, kernel_size,
resblock, p_dropout,
resblock_kernel_sizes, resblock,
resblock_dilation_sizes, resblock_kernel_sizes,
upsample_rates, resblock_dilation_sizes,
upsample_initial_channel, upsample_rates,
upsample_kernel_sizes, upsample_initial_channel,
n_speakers=0, upsample_kernel_sizes,
gin_channels=0, n_speakers=0,
use_sdp=True, gin_channels=0,
semantic_frame_rate=None, use_sdp=True,
freeze_quantizer=None, semantic_frame_rate=None,
version="v3", freeze_quantizer=None,
**kwargs): version="v3",
**kwargs,
):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
self.inter_channels = inter_channels self.inter_channels = inter_channels
@ -1183,132 +1120,133 @@ class SynthesizerTrnV3(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.version = version self.version = version
self.model_dim=512 self.model_dim = 512
self.use_sdp = use_sdp self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout) self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, # self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) # upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, # self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
# gin_channels=gin_channels) # gin_channels=gin_channels)
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) # self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768 ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"] assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz': if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else: else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer( self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
dimension=ssl_dim, self.freeze_quantizer = freeze_quantizer
n_q=1, inter_channels2 = 512
bins=1024 self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
) self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
self.freeze_quantizer=freeze_quantizer self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
inter_channels2=512 self.cfm = CFM(
self.bridge=nn.Sequential( 100,
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
nn.LeakyReLU() ) # text_dim is condition feature dim
) if self.freeze_quantizer == True:
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
if self.freeze_quantizer==True:
set_no_grad(self.ssl_proj) set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer) set_no_grad(self.quantizer)
set_no_grad(self.enc_p) set_no_grad(self.enc_p)
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now def forward(
self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
): # ssl_lengths no need now
with autocast(enabled=False): with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask) ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad: with maybe_no_grad:
if self.freeze_quantizer: if self.freeze_quantizer:
self.ssl_proj.eval()# self.ssl_proj.eval() #
self.quantizer.eval() self.quantizer.eval()
self.enc_p.eval() self.enc_p.eval()
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer( quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
ssl, layers=[0] quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate. fea, y_mask_ = self.wns1(
B=ssl.shape[0] fea, mel_lengths, ge
prompt_len_max = mel_lengths*2/3 ) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
B = ssl.shape[0]
prompt_len_max = mel_lengths * 2 / 3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
minn=min(mel.shape[-1],fea.shape[-1]) minn = min(mel.shape[-1], fea.shape[-1])
mel=mel[:,:,:minn] mel = mel[:, :, :minn]
fea=fea[:,:,:minn] fea = fea[:, :, :minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt) cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
return cfm_loss return cfm_loss
@torch.no_grad() @torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None,speed=1): def decode_encp(self, codes, text, refer, ge=None, speed=1):
# print(2333333,refer.shape) # print(2333333,refer.shape)
# ge=None # ge=None
if(ge==None): if ge == None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
if speed==1: if speed == 1:
sizee=int(codes.size(2)*2.5*1.5) sizee = int(codes.size(2) * 2.5 * 1.5)
else: else:
sizee=int(codes.size(2)*2.5*1.5/speed)+1 sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device) y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea=self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel ####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge) fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge return fea, ge
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1) return codes.transpose(0, 1)
class SynthesizerTrnV3b(nn.Module): class SynthesizerTrnV3b(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
""" """
def __init__(self, def __init__(
spec_channels, self,
segment_size, spec_channels,
inter_channels, segment_size,
hidden_channels, inter_channels,
filter_channels, hidden_channels,
n_heads, filter_channels,
n_layers, n_heads,
kernel_size, n_layers,
p_dropout, kernel_size,
resblock, p_dropout,
resblock_kernel_sizes, resblock,
resblock_dilation_sizes, resblock_kernel_sizes,
upsample_rates, resblock_dilation_sizes,
upsample_initial_channel, upsample_rates,
upsample_kernel_sizes, upsample_initial_channel,
n_speakers=0, upsample_kernel_sizes,
gin_channels=0, n_speakers=0,
use_sdp=True, gin_channels=0,
semantic_frame_rate=None, use_sdp=True,
freeze_quantizer=None, semantic_frame_rate=None,
**kwargs): freeze_quantizer=None,
**kwargs,
):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
self.inter_channels = inter_channels self.inter_channels = inter_channels
@ -1328,47 +1266,52 @@ class SynthesizerTrnV3b(nn.Module):
self.n_speakers = n_speakers self.n_speakers = n_speakers
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.model_dim=512 self.model_dim = 512
self.use_sdp = use_sdp self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout) self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, self.dec = Generator(
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) inter_channels,
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, resblock,
gin_channels=gin_channels) resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768 ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"] assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz': if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else: else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer( self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
dimension=ssl_dim, self.freeze_quantizer = freeze_quantizer
n_q=1,
bins=1024
)
self.freeze_quantizer=freeze_quantizer
inter_channels2=512 inter_channels2 = 512
self.bridge=nn.Sequential( self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
nn.LeakyReLU() self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
) self.cfm = CFM(
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels) 100,
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1) DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim ) # text_dim is condition feature dim
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
with autocast(enabled=False): with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask) ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k # ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
# ge=None # ge=None
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
@ -1377,51 +1320,59 @@ class SynthesizerTrnV3b(nn.Module):
self.ssl_proj.eval() self.ssl_proj.eval()
self.quantizer.eval() self.quantizer.eval()
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer( quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
ssl, layers=[0] quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge) z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=ge) o = self.dec(z_slice, g=ge)
fea=self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge) fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
learned_mel = self.linear_mel(fea) learned_mel = self.linear_mel(fea)
B=ssl.shape[0] B = ssl.shape[0]
prompt_len_max = mel_lengths*2/3 prompt_len_max = mel_lengths * 2 / 3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)# prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) #
minn=min(mel.shape[-1],fea.shape[-1]) minn = min(mel.shape[-1], fea.shape[-1])
mel=mel[:,:,:minn] mel = mel[:, :, :minn]
fea=fea[:,:,:minn] fea = fea[:, :, :minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized return (
commit_loss,
cfm_loss,
F.mse_loss(learned_mel, mel),
o,
ids_slice,
y_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
quantized,
)
@torch.no_grad() @torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None): def decode_encp(self, codes, text, refer, ge=None):
# print(2333333,refer.shape) # print(2333333,refer.shape)
# ge=None # ge=None
if(ge==None): if ge == None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device) y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel ####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge) fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge return fea, ge
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1) return codes.transpose(0, 1)

View File

@ -14,6 +14,7 @@ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from module.commons import init_weights, get_padding
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from text import symbols2 as symbols_v2
@ -42,29 +43,21 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2)) self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows): for i in range(n_flows):
self.flows.append( self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv( self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList() self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2)) self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4): for i in range(4):
self.post_flows.append( self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip()) self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv( self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1) self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -85,10 +78,7 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w) h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask) h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask h_w = self.post_proj(h_w) * x_mask
e_q = ( e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q z_q = e_q
for flow in self.post_flows: for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -96,13 +86,8 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1) z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask z0 = (w - u) * x_mask
logdet_tot_q += torch.sum( logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0 logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask) z0, logdet = self.log_flow(z0, x_mask)
@ -111,18 +96,12 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows: for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse) z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet logdet_tot = logdet_tot + logdet
nll = ( nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # [b] return nll + logq # [b]
else: else:
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = ( z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows: for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse) z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1) z0, z1 = torch.split(z, [1, 1], 1)
@ -131,9 +110,7 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
def __init__( def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -143,13 +120,9 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d( self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels) self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d( self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels) self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1) self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -232,7 +205,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge, speed=1): def forward(self, y, text, ge, speed=1):
y_mask = torch.ones_like(y[:1,:1,:]) y_mask = torch.ones_like(y[:1, :1, :])
y = self.ssl_proj(y * y_mask) * y_mask y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask) y = self.encoder_ssl(y * y_mask, y_mask)
@ -244,8 +217,8 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge) y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask) y = self.encoder2(y * y_mask, y_mask)
if(speed!=1): if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear") y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask stats = self.proj(y) * y_mask
@ -331,9 +304,7 @@ class PosteriorEncoder(nn.Module):
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
if g != None: if g != None:
g = g.detach() g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
@ -343,14 +314,9 @@ class PosteriorEncoder(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, def __init__(
in_channels, self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
out_channels, ):
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -365,7 +331,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
if(g!=None): if g != None:
g = g.detach() g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
@ -373,6 +339,7 @@ class Encoder(nn.Module):
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
return stats, x_mask return stats, x_mask
class WNEncoder(nn.Module): class WNEncoder(nn.Module):
def __init__( def __init__(
self, self,
@ -405,9 +372,7 @@ class WNEncoder(nn.Module):
self.norm = modules.LayerNorm(out_channels) self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask out = self.proj(x) * x_mask
@ -430,9 +395,7 @@ class Generator(torch.nn.Module):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d( self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
@ -452,9 +415,7 @@ class Generator(torch.nn.Module):
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate( for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -463,7 +424,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g:Optional[torch.Tensor]=None): def forward(self, x, g: Optional[torch.Tensor] = None):
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: if g is not None:
x = x + self.cond(g) x = x + self.cond(g)
@ -607,9 +568,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11] periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs) self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat): def forward(self, y, y_hat):
@ -709,10 +668,7 @@ class Quantizer(torch.nn.Module):
super(Quantizer, self).__init__() super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0 assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList( self.quantizer_modules = nn.ModuleList(
[ [Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
) )
self.n_code_groups = n_code_groups self.n_code_groups = n_code_groups
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -730,9 +686,7 @@ class Quantizer(torch.nn.Module):
z_q.append(_z_q) z_q.append(_z_q)
min_indicies.append(_min_indicies) # B * T, min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape) z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean( loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
(z_q - xin.detach()) ** 2
)
z_q = xin + (z_q - xin).detach() z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2) z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
@ -772,13 +726,9 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder( self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
ssl_dim, style_vector_dim=hidden_channels
)
self.encoder = attentions.Encoder( self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1) self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q self.n_q = n_q
@ -791,9 +741,7 @@ class CodePredictor(nn.Module):
x = x + g x = x + g
x = self.encoder(x * x_mask, x_mask) x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose( logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
2, 3
)
target = codes[1:].transpose(0, 1) target = codes[1:].transpose(0, 1)
if not infer: if not infer:
logits = logits.reshape(-1, self.dims) logits = logits.reshape(-1, self.dims)
@ -842,7 +790,7 @@ class SynthesizerTrn(nn.Module):
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
version="v2", version="v2",
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
@ -894,9 +842,7 @@ class SynthesizerTrn(nn.Module):
# 16, # 16,
# gin_channels=gin_channels, # gin_channels=gin_channels,
# ) # )
self.flow = ResidualCouplingBlock( self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
# self.version=os.environ.get("version","v1") # self.version=os.environ.get("version","v1")
if self.version == "v1": if self.version == "v1":
@ -921,9 +867,9 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer,noise_scale=0.5, speed=1): def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
refer_mask = torch.ones_like(refer[:1,:1,:]) refer_mask = torch.ones_like(refer[:1, :1, :])
if (self.version == "v1"): if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
else: else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
@ -933,10 +879,8 @@ class SynthesizerTrn(nn.Module):
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p( x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
quantized, text, ge, speed
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -949,11 +893,9 @@ class SynthesizerTrn(nn.Module):
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1) return codes.transpose(0, 1)
class CFM(torch.nn.Module): class CFM(torch.nn.Module):
def __init__( def __init__(self, in_channels, dit):
self,
in_channels,dit
):
super().__init__() super().__init__()
# self.sigma_min = 1e-6 # self.sigma_min = 1e-6
@ -963,27 +905,34 @@ class CFM(torch.nn.Module):
# self.criterion = torch.nn.MSELoss() # self.criterion = torch.nn.MSELoss()
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0): def forward(
self,
mu: torch.Tensor,
x_lens: torch.LongTensor,
prompt: torch.Tensor,
n_timesteps: torch.LongTensor,
temperature: float = 1.0,
):
"""Forward diffusion""" """Forward diffusion"""
B, T = mu.size(0), mu.size(1) B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
ntimesteps = int(n_timesteps) ntimesteps = int(n_timesteps)
prompt_len = prompt.size(-1) prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype) prompt_x = torch.zeros_like(x, dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len] prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0.0 x[..., :prompt_len] = 0.0
mu=mu.transpose(2,1) mu = mu.transpose(2, 1)
t = torch.tensor(0.0,dtype=x.dtype,device=x.device) t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device) d = torch.tensor(1.0 / ntimesteps, dtype=x.dtype, device=x.device)
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
for j in range(ntimesteps): for j in range(ntimesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d # d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args) # v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1) v_pred = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu).transpose(2, 1)
# if inference_cfg_rate>1e-5: # if inference_cfg_rate>1e-5:
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) # neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate # v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
@ -995,47 +944,51 @@ class CFM(torch.nn.Module):
def set_no_grad(net_g): def set_no_grad(net_g):
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
param.requires_grad=False param.requires_grad = False
@torch.jit.script_if_tracing @torch.jit.script_if_tracing
def compile_codes_length(codes): def compile_codes_length(codes):
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device) y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
return y_lengths1 * 2.5 * 1.5 return y_lengths1 * 2.5 * 1.5
@torch.jit.script_if_tracing @torch.jit.script_if_tracing
def compile_ref_length(refer): def compile_ref_length(refer):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
return refer_lengths return refer_lengths
class SynthesizerTrnV3(nn.Module): class SynthesizerTrnV3(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
""" """
def __init__(self, def __init__(
spec_channels, self,
segment_size, spec_channels,
inter_channels, segment_size,
hidden_channels, inter_channels,
filter_channels, hidden_channels,
n_heads, filter_channels,
n_layers, n_heads,
kernel_size, n_layers,
p_dropout, kernel_size,
resblock, p_dropout,
resblock_kernel_sizes, resblock,
resblock_dilation_sizes, resblock_kernel_sizes,
upsample_rates, resblock_dilation_sizes,
upsample_initial_channel, upsample_rates,
upsample_kernel_sizes, upsample_initial_channel,
n_speakers=0, upsample_kernel_sizes,
gin_channels=0, n_speakers=0,
use_sdp=True, gin_channels=0,
semantic_frame_rate=None, use_sdp=True,
freeze_quantizer=None, semantic_frame_rate=None,
version="v3", freeze_quantizer=None,
**kwargs): version="v3",
**kwargs,
):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
self.inter_channels = inter_channels self.inter_channels = inter_channels
@ -1056,41 +1009,38 @@ class SynthesizerTrnV3(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.version = version self.version = version
self.model_dim=512 self.model_dim = 512
self.use_sdp = use_sdp self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout) self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback # self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, # self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) # upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, # self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
# gin_channels=gin_channels) # gin_channels=gin_channels)
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) # self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768 ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"] assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz': if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else: else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer( self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
dimension=ssl_dim,
n_q=1,
bins=1024
)
freeze_quantizer freeze_quantizer
inter_channels2=512 inter_channels2 = 512
self.bridge=nn.Sequential( self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
nn.LeakyReLU() self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
) self.cfm = CFM(
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels) 100,
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1) DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim ) # text_dim is condition feature dim
if freeze_quantizer==True: if freeze_quantizer == True:
set_no_grad(self.ssl_proj) set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer) set_no_grad(self.quantizer)
set_no_grad(self.enc_p) set_no_grad(self.enc_p)
@ -1098,24 +1048,23 @@ class SynthesizerTrnV3(nn.Module):
def create_ge(self, refer): def create_ge(self, refer):
refer_lengths = compile_ref_length(refer) refer_lengths = compile_ref_length(refer)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge return ge
def forward(self, codes, text,ge,speed=1): def forward(self, codes, text, ge, speed=1):
y_lengths1 = compile_codes_length(codes)
y_lengths1=compile_codes_length(codes)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
fea=self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel ####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge) fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea return fea
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1) return codes.transpose(0, 1)

View File

@ -52,11 +52,7 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList() self.norm_layers = nn.ModuleList()
self.conv_layers.append( self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
@ -156,9 +152,7 @@ class WN(torch.nn.Module):
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
if gin_channels != 0: if gin_channels != 0:
cond_layer = torch.nn.Conv1d( cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers): for i in range(n_layers):
@ -479,9 +473,7 @@ class ConvFlow(nn.Module):
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d( self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj.weight.data.zero_() self.proj.weight.data.zero_()
self.proj.bias.data.zero_() self.proj.bias.data.zero_()
@ -495,9 +487,7 @@ class ConvFlow(nn.Module):
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
self.filter_channels
)
unnormalized_derivatives = h[..., 2 * self.num_bins :] unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform( x1, logabsdet = piecewise_rational_quadratic_transform(
@ -616,9 +606,7 @@ class MultiHeadAttention(nn.Module):
self.w_ks = nn.Linear(d_model, n_head * d_k) self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v) self.w_vs = nn.Linear(d_model, n_head * d_v)
self.attention = ScaledDotProductAttention( self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
temperature=np.power(d_model, 0.5), dropout=dropout
)
self.fc = nn.Linear(n_head * d_v, d_model) self.fc = nn.Linear(n_head * d_v, d_model)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -649,9 +637,7 @@ class MultiHeadAttention(nn.Module):
output, attn = self.attention(q, k, v, mask=slf_mask) output, attn = self.attention(q, k, v, mask=slf_mask)
output = output.view(n_head, sz_b, len_x, d_v) output = output.view(n_head, sz_b, len_x, d_v)
output = ( output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
) # b x lq x (n*dv)
output = self.fc(output) output = self.fc(output)
@ -741,9 +727,7 @@ class MelStyleEncoder(nn.Module):
if mask is not None: if mask is not None:
mask = (mask.int() == 0).squeeze(1) mask = (mask.int() == 0).squeeze(1)
max_len = x.shape[1] max_len = x.shape[1]
slf_attn_mask = ( slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
)
# spectral # spectral
x = self.spectral(x) x = self.spectral(x)
@ -785,9 +769,7 @@ class MelStyleEncoderVAE(nn.Module):
mu = self.fc1(enc_out) mu = self.fc1(enc_out)
logvar = self.fc2(enc_out) logvar = self.fc2(enc_out)
posterior = D.Normal(mu, torch.exp(logvar)) posterior = D.Normal(mu, torch.exp(logvar))
kl_divergence = D.kl_divergence( kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
)
loss_kl = kl_divergence.mean() loss_kl = kl_divergence.mean()
z = posterior.rsample() z = posterior.rsample()
@ -825,9 +807,7 @@ class ActNorm(nn.Module):
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs): def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
if x_mask is None: if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to( x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
device=x.device, dtype=x.dtype
)
x_len = torch.sum(x_mask, [1, 2]) x_len = torch.sum(x_mask, [1, 2])
if not self.initialized: if not self.initialized:
self.initialize(x, x_mask) self.initialize(x, x_mask)
@ -856,9 +836,7 @@ class ActNorm(nn.Module):
v = m_sq - (m**2) v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = ( bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
)
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
self.bias.data.copy_(bias_init) self.bias.data.copy_(bias_init)
@ -873,9 +851,7 @@ class InvConvNear(nn.Module):
self.n_split = n_split self.n_split = n_split
self.no_jacobian = no_jacobian self.no_jacobian = no_jacobian
w_init = torch.linalg.qr( w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
torch.FloatTensor(self.n_split, self.n_split).normal_()
)[0]
if torch.det(w_init) < 0: if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0] w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init) self.weight = nn.Parameter(w_init)
@ -890,11 +866,7 @@ class InvConvNear(nn.Module):
x_len = torch.sum(x_mask, [1, 2]) x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
x = ( x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
x.permute(0, 1, 3, 2, 4)
.contiguous()
.view(b, self.n_split, c // self.n_split, t)
)
if reverse: if reverse:
if hasattr(self, "weight_inv"): if hasattr(self, "weight_inv"):

View File

@ -31,32 +31,15 @@ class MRTE(nn.Module):
text_enc = self.text_pre(text * text_mask) text_enc = self.text_pre(text * text_mask)
if test != None: if test != None:
if test == 0: if test == 0:
x = ( x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
elif test == 1: elif test == 1:
x = ssl_enc + ge x = ssl_enc + ge
elif test == 2: elif test == 2:
x = ( x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
else: else:
raise ValueError("test should be 0,1,2") raise ValueError("test should be 0,1,2")
else: else:
x = ( x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask) x = self.c_post(x * ssl_mask)
return x return x
@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
model_embedding_size=256, model_embedding_size=256,
): ):
super(SpeakerEncoder, self).__init__() super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM( self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.linear = nn.Linear(model_hidden_size, model_embedding_size) self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU() self.relu = nn.ReLU()

View File

@ -87,14 +87,10 @@ class ResidualVectorQuantizer(nn.Module):
raise ValueError( raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
) )
quantized, codes, commit_loss, quantized_list = self.vq( quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
x, n_q=n_q, layers=layers
)
return quantized, codes, torch.mean(commit_loss), quantized_list return quantized, codes, torch.mean(commit_loss), quantized_list
def encode( def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth. """Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer. and returns indices for each quantizer.

View File

@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
min_bin_width=min_bin_width, min_bin_width=min_bin_width,
min_bin_height=min_bin_height, min_bin_height=min_bin_height,
min_derivative=min_derivative, min_derivative=min_derivative,
**spline_kwargs **spline_kwargs,
) )
return outputs, logabsdet return outputs, logabsdet
@ -175,8 +175,7 @@ def rational_quadratic_spline(
theta_one_minus_theta = root * (1 - root) theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ( denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
* theta_one_minus_theta
) )
derivative_numerator = input_delta.pow(2) * ( derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2) input_derivatives_plus_one * root.pow(2)
@ -190,12 +189,9 @@ def rational_quadratic_spline(
theta = (inputs - input_cumwidths) / input_bin_widths theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta) theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * ( numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + ( denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
* theta_one_minus_theta
) )
outputs = input_cumheights + numerator / denominator outputs = input_cumheights + numerator / denominator

View File

@ -1,22 +1,22 @@
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch import torch
import torchaudio import torchaudio
from torch import nn from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model() ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
import soundfile
import os
import json import json
import os
import soundfile
from text import cleaned_text_to_sequence
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to( hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@ -72,7 +72,7 @@ class T2SEncoder(nn.Module):
super().__init__() super().__init__()
self.encoder = t2s.onnx_encoder self.encoder = t2s.onnx_encoder
self.vits = vits self.vits = vits
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
codes = self.vits.extract_latent(ssl_content) codes = self.vits.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
@ -101,22 +101,22 @@ class T2SModel(nn.Module):
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model) self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
self.first_stage_decoder = self.t2s_model.first_stage_decoder self.first_stage_decoder = self.t2s_model.first_stage_decoder
self.stage_decoder = self.t2s_model.stage_decoder self.stage_decoder = self.t2s_model.stage_decoder
#self.t2s_model = torch.jit.script(self.t2s_model) # self.t2s_model = torch.jit.script(self.t2s_model)
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
early_stop_num = self.t2s_model.early_stop_num early_stop_num = self.t2s_model.early_stop_num
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
prefix_len = prompts.shape[1] prefix_len = prompts.shape[1]
#[1,N,512] [1,N] # [1,N,512] [1,N]
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False stop = False
for idx in range(1, 1500): for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example) enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
@ -130,13 +130,11 @@ class T2SModel(nn.Module):
return y[:, -idx:].unsqueeze(0) return y[:, -idx:].unsqueeze(0)
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False): def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
#self.onnx_encoder = torch.jit.script(self.onnx_encoder) # self.onnx_encoder = torch.jit.script(self.onnx_encoder)
if dynamo: if dynamo:
export_options = torch.onnx.ExportOptions(dynamic_shapes=True) export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_encoder_export_output = torch.onnx.dynamo_export( onnx_encoder_export_output = torch.onnx.dynamo_export(
self.onnx_encoder, self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
export_options=export_options
) )
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
return return
@ -148,13 +146,13 @@ class T2SModel(nn.Module):
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
output_names=["x", "prompts"], output_names=["x", "prompts"],
dynamic_axes={ dynamic_axes={
"ref_seq": {1 : "ref_length"}, "ref_seq": {1: "ref_length"},
"text_seq": {1 : "text_length"}, "text_seq": {1: "text_length"},
"ref_bert": {0 : "ref_length"}, "ref_bert": {0: "ref_length"},
"text_bert": {0 : "text_length"}, "text_bert": {0: "text_length"},
"ssl_content": {2 : "ssl_length"}, "ssl_content": {2: "ssl_length"},
}, },
opset_version=16 opset_version=16,
) )
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
@ -165,11 +163,11 @@ class T2SModel(nn.Module):
input_names=["x", "prompts"], input_names=["x", "prompts"],
output_names=["y", "k", "v", "y_emb", "x_example"], output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={ dynamic_axes={
"x": {1 : "x_length"}, "x": {1: "x_length"},
"prompts": {1 : "prompts_length"}, "prompts": {1: "prompts_length"},
}, },
verbose=False, verbose=False,
opset_version=16 opset_version=16,
) )
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
@ -180,38 +178,38 @@ class T2SModel(nn.Module):
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
output_names=["y", "k", "v", "y_emb", "logits", "samples"], output_names=["y", "k", "v", "y_emb", "logits", "samples"],
dynamic_axes={ dynamic_axes={
"iy": {1 : "iy_length"}, "iy": {1: "iy_length"},
"ik": {1 : "ik_length"}, "ik": {1: "ik_length"},
"iv": {1 : "iv_length"}, "iv": {1: "iv_length"},
"iy_emb": {1 : "iy_emb_length"}, "iy_emb": {1: "iy_emb_length"},
"ix_example": {1 : "ix_example_length"}, "ix_example": {1: "ix_example_length"},
}, },
verbose=False, verbose=False,
opset_version=16 opset_version=16,
) )
class VitsModel(nn.Module): class VitsModel(nn.Module):
def __init__(self, vits_path): def __init__(self, vits_path):
super().__init__() super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu") dict_s2 = torch.load(vits_path, map_location="cpu")
self.hps = dict_s2["config"] self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1" self.hps["model"]["version"] = "v1"
else: else:
self.hps["model"]["version"] = "v2" self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps) self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz" self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn( self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1, self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length, self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers, n_speakers=self.hps.data.n_speakers,
**self.hps.model **self.hps.model,
) )
self.vq_model.eval() self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False) self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio): def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch( refer = spectrogram_torch(
ref_audio, ref_audio,
@ -219,7 +217,7 @@ class VitsModel(nn.Module):
self.hps.data.sampling_rate, self.hps.data.sampling_rate,
self.hps.data.hop_length, self.hps.data.hop_length,
self.hps.data.win_length, self.hps.data.win_length,
center=False center=False,
) )
return self.vq_model(pred_semantic, text_seq, refer)[0, 0] return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
@ -229,18 +227,22 @@ class GptSoVits(nn.Module):
super().__init__() super().__init__()
self.vits = vits self.vits = vits
self.t2s = t2s self.t2s = t2s
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
audio = self.vits(text_seq, pred_semantic, ref_audio) audio = self.vits(text_seq, pred_semantic, ref_audio)
if debug: if debug:
import onnxruntime import onnxruntime
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
audio1 = sess.run(None, { audio1 = sess.run(
"text_seq" : text_seq.detach().cpu().numpy(), None,
"pred_semantic" : pred_semantic.detach().cpu().numpy(), {
"ref_audio" : ref_audio.detach().cpu().numpy() "text_seq": text_seq.detach().cpu().numpy(),
}) "pred_semantic": pred_semantic.detach().cpu().numpy(),
"ref_audio": ref_audio.detach().cpu().numpy(),
},
)
return audio, audio1 return audio, audio1
return audio return audio
@ -254,12 +256,12 @@ class GptSoVits(nn.Module):
input_names=["text_seq", "pred_semantic", "ref_audio"], input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["audio"], output_names=["audio"],
dynamic_axes={ dynamic_axes={
"text_seq": {1 : "text_length"}, "text_seq": {1: "text_length"},
"pred_semantic": {2 : "pred_length"}, "pred_semantic": {2: "pred_length"},
"ref_audio": {1 : "audio_length"}, "ref_audio": {1: "audio_length"},
}, },
opset_version=17, opset_version=17,
verbose=False verbose=False,
) )
@ -277,14 +279,67 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
gpt = T2SModel(gpt_path, vits) gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt) gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel() ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)]) ref_seq = torch.LongTensor(
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)]) [
cleaned_text_to_sequence(
[
"n",
"i2",
"h",
"ao3",
",",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
text_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float() ref_audio = torch.randn((1, 48000 * 5)).float()
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float() ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float() ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
try: try:
os.mkdir(f"onnx/{project_name}") os.mkdir(f"onnx/{project_name}")
@ -325,8 +380,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
} }
MoeVSConfJson = json.dumps(MoeVSConf) MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile: with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4) json.dump(MoeVSConf, MoeVsConfFile, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
@ -340,4 +395,4 @@ if __name__ == "__main__":
exp_path = "nahida" exp_path = "nahida"
export(vits_path, gpt_path, exp_path) export(vits_path, gpt_path, exp_path)
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate) # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)

View File

@ -8,12 +8,13 @@ exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part") i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir") bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
import torch import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
version = os.environ.get('version', None) version = os.environ.get("version", None)
import traceback import traceback
import os.path import os.path
from text.cleaner import clean_text from text.cleaner import clean_text
@ -33,13 +34,13 @@ from time import time as ttime
import shutil import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path) dir = os.path.dirname(path)
name=os.path.basename(path) name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part) tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea,tmp_path) torch.save(fea, tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name)) shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
@ -53,8 +54,10 @@ if os.path.exists(txt_path) == False:
# device = "mps" # device = "mps"
else: else:
device = "cpu" device = "cpu"
if os.path.exists(bert_pretrained_dir):... if os.path.exists(bert_pretrained_dir):
else:raise FileNotFoundError(bert_pretrained_dir) ...
else:
raise FileNotFoundError(bert_pretrained_dir)
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir) tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True: if is_half == True:
@ -83,12 +86,10 @@ if os.path.exists(txt_path) == False:
def process(data, res): def process(data, res):
for name, text, lan in data: for name, text, lan in data:
try: try:
name=clean_path(name) name = clean_path(name)
name = os.path.basename(name) name = os.path.basename(name)
print(name) print(name)
phones, word2ph, norm_text = clean_text( phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("", ","), lan, version)
text.replace("%", "-").replace("", ","), lan, version
)
path_bert = "%s/%s.pt" % (bert_dir, name) path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh": if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph) bert_feature = get_bert_feature(norm_text, word2ph)
@ -128,9 +129,7 @@ if os.path.exists(txt_path) == False:
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"]) # todo.append([name,text,"zh"])
if language in language_v1_to_language_v2.keys(): if language in language_v1_to_language_v2.keys():
todo.append( todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
else: else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m") print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except: except:

View File

@ -2,26 +2,30 @@
import sys import sys
import os import os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir") inp_text = os.environ.get("inp_text")
exp_name= os.environ.get("exp_name") inp_wav_dir = os.environ.get("inp_wav_dir")
i_part= os.environ.get("i_part") exp_name = os.environ.get("exp_name")
all_parts= os.environ.get("all_parts") i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert from feature_extractor import cnhubert
opt_dir= os.environ.get("opt_dir")
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir") opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
import torch import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback import traceback
import numpy as np import numpy as np
from scipy.io import wavfile from scipy.io import wavfile
import librosa import librosa
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from tools.my_utils import load_audio,clean_path from tools.my_utils import load_audio, clean_path
# from config import cnhubert_base_path # from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path # cnhubert.cnhubert_base_path=cnhubert_base_path
@ -36,90 +40,95 @@ from tools.my_utils import load_audio,clean_path
from time import time as ttime from time import time as ttime
import shutil import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path) def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part) tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea,tmp_path) torch.save(fea, tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name)) shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir="%s/4-cnhubert"%(opt_dir)
wav32dir="%s/5-wav32k"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(hubert_dir,exist_ok=True)
os.makedirs(wav32dir,exist_ok=True)
maxx=0.95 hubert_dir = "%s/4-cnhubert" % (opt_dir)
alpha=0.5 wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:
device = "cpu" device = "cpu"
model=cnhubert.get_model() model = cnhubert.get_model()
# is_half=False # is_half=False
if(is_half==True): if is_half == True:
model=model.half().to(device) model = model.half().to(device)
else: else:
model = model.to(device) model = model.to(device)
nan_fails=[] nan_fails = []
def name2go(wav_name,wav_path):
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
if(os.path.exists(hubert_path)):return def name2go(wav_name, wav_path):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if os.path.exists(hubert_path):
return
tmp_audio = load_audio(wav_path, 32000) tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max() tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2: if tmp_max > 2.2:
print("%s-filtered,%s" % (wav_name, tmp_max)) print("%s-filtered,%s" % (wav_name, tmp_max))
return return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
tmp_audio = librosa.resample( tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
tmp_audio32b, orig_sr=32000, target_sr=16000
)#不是重采样问题
tensor_wav16 = torch.from_numpy(tmp_audio) tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True): if is_half == True:
tensor_wav16=tensor_wav16.half().to(device) tensor_wav16 = tensor_wav16.half().to(device)
else: else:
tensor_wav16 = tensor_wav16.to(device) tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0: if np.isnan(ssl.detach().numpy()).sum() != 0:
nan_fails.append((wav_name,wav_path)) nan_fails.append((wav_name, wav_path))
print("nan filtered:%s"%wav_name) print("nan filtered:%s" % wav_name)
return return
wavfile.write( wavfile.write(
"%s/%s"%(wav32dir,wav_name), "%s/%s" % (wav32dir, wav_name),
32000, 32000,
tmp_audio32.astype("int16"), tmp_audio32.astype("int16"),
) )
my_save(ssl,hubert_path) my_save(ssl, hubert_path)
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]: with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try: try:
# wav_name,text=line.split("\t") # wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
wav_name=clean_path(wav_name) wav_name = clean_path(wav_name)
if (inp_wav_dir != "" and inp_wav_dir != None): if inp_wav_dir != "" and inp_wav_dir != None:
wav_name = os.path.basename(wav_name) wav_name = os.path.basename(wav_name)
wav_path = "%s/%s"%(inp_wav_dir, wav_name) wav_path = "%s/%s" % (inp_wav_dir, wav_name)
else: else:
wav_path=wav_name wav_path = wav_name
wav_name = os.path.basename(wav_name) wav_name = os.path.basename(wav_name)
name2go(wav_name,wav_path) name2go(wav_name, wav_path)
except: except:
print(line,traceback.format_exc()) print(line, traceback.format_exc())
if(len(nan_fails)>0 and is_half==True): if len(nan_fails) > 0 and is_half == True:
is_half=False is_half = False
model=model.float() model = model.float()
for wav in nan_fails: for wav in nan_fails:
try: try:
name2go(wav[0],wav[1]) name2go(wav[0], wav[1])
except: except:
print(wav_name,traceback.format_exc()) print(wav_name, traceback.format_exc())

View File

@ -5,13 +5,15 @@ exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part") i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G") pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path") s2config_path = os.environ.get("s2config_path")
if os.path.exists(pretrained_s2G):... if os.path.exists(pretrained_s2G):
else:raise FileNotFoundError(pretrained_s2G) ...
else:
raise FileNotFoundError(pretrained_s2G)
# version=os.environ.get("version","v2") # version=os.environ.get("version","v2")
size = os.path.getsize(pretrained_s2G) size = os.path.getsize(pretrained_s2G)
if size < 82978 * 1024: if size < 82978 * 1024:
@ -25,6 +27,7 @@ elif size < 700 * 1024 * 1024:
else: else:
version = "v3" version = "v3"
import torch import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback import traceback
import sys import sys
@ -33,11 +36,13 @@ now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
import logging import logging
import utils import utils
if version!="v3":
if version != "v3":
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
else: else:
from module.models import SynthesizerTrnV3 as SynthesizerTrn from module.models import SynthesizerTrnV3 as SynthesizerTrn
from tools.my_utils import clean_path from tools.my_utils import clean_path
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G # from config import pretrained_s2G
@ -66,7 +71,7 @@ if os.path.exists(semantic_path) == False:
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
version=version, version=version,
**hps.model **hps.model,
) )
if is_half == True: if is_half == True:
vq_model = vq_model.half().to(device) vq_model = vq_model.half().to(device)
@ -103,7 +108,7 @@ if os.path.exists(semantic_path) == False:
try: try:
# wav_name,text=line.split("\t") # wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
wav_name=clean_path(wav_name) wav_name = clean_path(wav_name)
wav_name = os.path.basename(wav_name) wav_name = os.path.basename(wav_name)
# name2go(name,lines1) # name2go(name,lines1)
name2go(wav_name, lines1) name2go(wav_name, lines1)

View File

@ -8,31 +8,37 @@ from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
''' def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
"""
00:v1 00:v1
01:v2 01:v2
02:v3 02:v3
03:v3lora 03:v3lora
''' """
from io import BytesIO from io import BytesIO
def my_save2(fea,path):
def my_save2(fea, path):
bio = BytesIO() bio = BytesIO()
torch.save(fea, bio) torch.save(fea, bio)
bio.seek(0) bio.seek(0)
data = bio.getvalue() data = bio.getvalue()
data = b'03' + data[2:]###temp for v3lora only, todo data = b"03" + data[2:] ###temp for v3lora only, todo
with open(path, "wb") as f: f.write(data) with open(path, "wb") as f:
f.write(data)
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
try: try:
opt = OrderedDict() opt = OrderedDict()
opt["weight"] = {} opt["weight"] = {}
@ -43,7 +49,7 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
opt["config"] = hps opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank: if lora_rank:
opt["lora_rank"]=lora_rank opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
else: else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
@ -51,41 +57,48 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
except: except:
return traceback.format_exc() return traceback.format_exc()
head2version={
b'00':["v1","v1",False], head2version = {
b'01':["v2","v2",False], b"00": ["v1", "v1", False],
b'02':["v2","v3",False], b"01": ["v2", "v2", False],
b'03':["v2","v3",True], b"02": ["v2", "v3", False],
b"03": ["v2", "v3", True],
} }
hash_pretrained_dict={ hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
} }
import hashlib import hashlib
def get_hash_from_file(sovits_path): def get_hash_from_file(sovits_path):
with open(sovits_path,"rb")as f:data=f.read(8192) with open(sovits_path, "rb") as f:
data = f.read(8192)
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
hash_md5.update(data) hash_md5.update(data)
return hash_md5.hexdigest() return hash_md5.hexdigest()
def get_sovits_version_from_path_fast(sovits_path): def get_sovits_version_from_path_fast(sovits_path):
###1-if it is pretrained sovits models, by hash ###1-if it is pretrained sovits models, by hash
hash=get_hash_from_file(sovits_path) hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict: if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash] return hash_pretrained_dict[hash]
###2-new weights or old weights, by head ###2-new weights or old weights, by head
with open(sovits_path,"rb")as f:version=f.read(2) with open(sovits_path, "rb") as f:
if version!=b"PK": version = f.read(2)
if version != b"PK":
return head2version[version] return head2version[version]
###3-old weights, by file size ###3-old weights, by file size
if_lora_v3=False if_lora_v3 = False
size=os.path.getsize(sovits_path) size = os.path.getsize(sovits_path)
''' """
v1weights:about 82942KB v1weights:about 82942KB
half thr:82978KB half thr:82978KB
v2weights:about 83014KB v2weights:about 83014KB
v3weights:about 750MB v3weights:about 750MB
''' """
if size < 82978 * 1024: if size < 82978 * 1024:
model_version = version = "v1" model_version = version = "v1"
elif size < 700 * 1024 * 1024: elif size < 700 * 1024 * 1024:
@ -93,15 +106,16 @@ def get_sovits_version_from_path_fast(sovits_path):
else: else:
version = "v2" version = "v2"
model_version = "v3" model_version = "v3"
return version,model_version,if_lora_v3 return version, model_version, if_lora_v3
def load_sovits_new(sovits_path): def load_sovits_new(sovits_path):
f=open(sovits_path,"rb") f = open(sovits_path, "rb")
meta=f.read(2) meta = f.read(2)
if meta!="PK": if meta != "PK":
data = b'PK' + f.read() data = b"PK" + f.read()
bio = BytesIO() bio = BytesIO()
bio.write(data) bio.write(data)
bio.seek(0) bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False) return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path,map_location="cpu", weights_only=False) return torch.load(sovits_path, map_location="cpu", weights_only=False)

View File

@ -5,25 +5,24 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse import argparse
import logging import logging
import platform
from pathlib import Path from pathlib import Path
import torch import torch
import platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config from AR.utils.io import load_yaml_config
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict from collections import OrderedDict
from AR.utils import get_newest_ckpt
from process_ckpt import my_save from process_ckpt import my_save
@ -35,7 +34,7 @@ class my_model_ckpt(ModelCheckpoint):
if_save_every_weights, if_save_every_weights,
half_weights_save_dir, half_weights_save_dir,
exp_name, exp_name,
**kwargs **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.if_save_latest = if_save_latest self.if_save_latest = if_save_latest
@ -48,10 +47,7 @@ class my_model_ckpt(ModelCheckpoint):
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
if self._should_save_on_train_epoch_end(trainer): if self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer) monitor_candidates = self._monitor_candidates(trainer)
if ( if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._every_n_epochs >= 1
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if ( if (
self.if_save_latest == True self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt ): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
@ -73,7 +69,7 @@ class my_model_ckpt(ModelCheckpoint):
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
# torch.save( # torch.save(
# print(os.environ) # print(os.environ)
if(os.environ.get("LOCAL_RANK","0")=="0"): if os.environ.get("LOCAL_RANK", "0") == "0":
my_save( my_save(
to_save_od, to_save_od,
"%s/%s-e%s.ckpt" "%s/%s-e%s.ckpt"
@ -110,7 +106,7 @@ def main(args):
dirpath=ckpt_dir, dirpath=ckpt_dir,
) )
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
os.environ["MASTER_ADDR"]="localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["USE_LIBUV"] = "0" os.environ["USE_LIBUV"] = "0"
trainer: Trainer = Trainer( trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"], max_epochs=config["train"]["epochs"],
@ -121,9 +117,9 @@ def main(args):
devices=-1 if torch.cuda.is_available() else 1, devices=-1 if torch.cuda.is_available() else 1,
benchmark=False, benchmark=False,
fast_dev_run=False, fast_dev_run=False,
strategy = DDPStrategy( strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
process_group_backend="nccl" if platform.system() != "Windows" else "gloo" if torch.cuda.is_available()
) if torch.cuda.is_available() else "auto", else "auto",
precision=config["train"]["precision"], precision=config["train"]["precision"],
logger=logger, logger=logger,
num_sanity_val_steps=0, num_sanity_val_steps=0,
@ -131,9 +127,7 @@ def main(args):
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题! use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
) )
model: Text2SemanticLightningModule = Text2SemanticLightningModule( model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
config, output_dir
)
data_module: Text2SemanticDataModule = Text2SemanticDataModule( data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config, config,

View File

@ -1,37 +1,41 @@
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import utils
import os import os
import utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
import logging
logging.getLogger("matplotlib").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO)
from random import randint from random import randint
from module import commons
from module import commons
from module.data_utils import ( from module.data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler, DistributedBucketSampler,
TextAudioSpeakerCollate,
TextAudioSpeakerLoader,
) )
from module.models import ( from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from module.models import (
MultiPeriodDiscriminator,
SynthesizerTrn,
)
from process_ckpt import savee from process_ckpt import savee
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
@ -47,7 +51,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main(): def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
else: else:
@ -75,7 +78,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
world_size=n_gpus, world_size=n_gpus,
rank=rank, rank=rank,
@ -129,19 +132,27 @@ def run(rank, n_gpus, hps):
# batch_size=1, pin_memory=True, # batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn) # drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn( net_g = (
hps.data.filter_length // 2 + 1, SynthesizerTrn(
hps.train.segment_size // hps.data.hop_length, hps.data.filter_length // 2 + 1,
n_speakers=hps.data.n_speakers, hps.train.segment_size // hps.data.hop_length,
**hps.model, n_speakers=hps.data.n_speakers,
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn( **hps.model,
hps.data.filter_length // 2 + 1, ).cuda(rank)
hps.train.segment_size // hps.data.hop_length, if torch.cuda.is_available()
n_speakers=hps.data.n_speakers, else SynthesizerTrn(
**hps.model, hps.data.filter_length // 2 + 1,
).to(device) hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device) net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
)
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
if not param.requires_grad: if not param.requires_grad:
print(name, "not requires_grad") print(name, "not requires_grad")
@ -194,7 +205,7 @@ def run(rank, n_gpus, hps):
try: # 如果能加载自动resume try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"), utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
net_d, net_d,
optim_d, optim_d,
) # D多半加载没事 ) # D多半加载没事
@ -202,11 +213,11 @@ def run(rank, n_gpus, hps):
logger.info("loaded D") logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"), utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
net_g, net_g,
optim_g, optim_g,
) )
epoch_str+=1 epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader) global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1 # epoch_str = 1
# global_step = 0 # global_step = 0
@ -214,37 +225,55 @@ def run(rank, n_gpus, hps):
# traceback.print_exc() # traceback.print_exc()
epoch_str = 1 epoch_str = 1
global_step = 0 global_step = 0
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G): if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G, print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict( net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False, strict=False,
) if torch.cuda.is_available() else net_g.load_state_dict( )
if torch.cuda.is_available()
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False, strict=False,
) ),
) ##测试不加载优化器 ) ##测试不加载优化器
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D): if (
hps.train.pretrained_s2D != ""
and hps.train.pretrained_s2D != None
and os.path.exists(hps.train.pretrained_s2D)
):
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D) logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
print("loaded pretrained %s" % hps.train.pretrained_s2D, print(
"loaded pretrained %s" % hps.train.pretrained_s2D,
net_d.module.load_state_dict( net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"] torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
) if torch.cuda.is_available() else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
) )
if torch.cuda.is_available()
else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
),
) )
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR( scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1 optim_g,
gamma=hps.train.lr_decay,
last_epoch=-1,
) )
scheduler_d = torch.optim.lr_scheduler.ExponentialLR( scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=-1 optim_d,
gamma=hps.train.lr_decay,
last_epoch=-1,
) )
for _ in range(epoch_str): for _ in range(epoch_str):
scheduler_g.step() scheduler_g.step()
@ -286,9 +315,7 @@ def run(rank, n_gpus, hps):
print("training done") print("training done")
def train_and_evaluate( def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets net_g, net_d = nets
optim_g, optim_d = optims optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers # scheduler_g, scheduler_d = schedulers
@ -312,17 +339,38 @@ def train_and_evaluate(
text_lengths, text_lengths,
) in enumerate(tqdm(train_loader)): ) in enumerate(tqdm(train_loader)):
if torch.cuda.is_available(): if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( spec, spec_lengths = (
rank, non_blocking=True spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
) )
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( y, y_lengths = (
rank, non_blocking=True y.cuda(
rank,
non_blocking=True,
),
y_lengths.cuda(
rank,
non_blocking=True,
),
) )
ssl = ssl.cuda(rank, non_blocking=True) ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda( text, text_lengths = (
rank, non_blocking=True text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
) )
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@ -351,9 +399,7 @@ def train_and_evaluate(
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax, hps.data.mel_fmax,
) )
y_mel = commons.slice_segments( y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_hat_mel = mel_spectrogram_torch( y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1), y_hat.squeeze(1),
hps.data.filter_length, hps.data.filter_length,
@ -365,15 +411,14 @@ def train_and_evaluate(
hps.data.mel_fmax, hps.data.mel_fmax,
) )
y = commons.slice_segments( y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator # Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False): with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g y_d_hat_r,
y_d_hat_g,
) )
loss_disc_all = loss_disc loss_disc_all = loss_disc
optim_d.zero_grad() optim_d.zero_grad()
@ -406,7 +451,8 @@ def train_and_evaluate(
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info( logger.info(
"Train Epoch: {} [{:.0f}%]".format( "Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader) epoch,
100.0 * batch_idx / len(train_loader),
) )
) )
logger.info([x.item() for x in losses] + [global_step, lr]) logger.info([x.item() for x in losses] + [global_step, lr])
@ -430,25 +476,37 @@ def train_and_evaluate(
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict=None image_dict = None
try:###Some people installed the wrong version of matplotlib. try: ###Some people installed the wrong version of matplotlib.
image_dict = { image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy( "slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy() y_mel[0].data.cpu().numpy(),
), ),
"slice/mel_gen": utils.plot_spectrogram_to_numpy( "slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy() y_hat_mel[0].data.cpu().numpy(),
), ),
"all/mel": utils.plot_spectrogram_to_numpy( "all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy() mel[0].data.cpu().numpy(),
), ),
"all/stats_ssl": utils.plot_spectrogram_to_numpy( "all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy() stats_ssl[0].data.cpu().numpy(),
), ),
} }
except:pass except:
if image_dict:utils.summarize(writer=writer,global_step=global_step,images=image_dict,scalars=scalar_dict,) pass
else:utils.summarize(writer=writer,global_step=global_step,scalars=scalar_dict,) if image_dict:
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
else:
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict,
)
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0: if hps.train.if_save_latest == 0:
@ -458,7 +516,8 @@ def train_and_evaluate(
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step) "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(global_step),
), ),
) )
utils.save_checkpoint( utils.save_checkpoint(
@ -467,7 +526,8 @@ def train_and_evaluate(
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step) "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(global_step),
), ),
) )
else: else:
@ -477,7 +537,8 @@ def train_and_evaluate(
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333) "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(233333333333),
), ),
) )
utils.save_checkpoint( utils.save_checkpoint(
@ -486,7 +547,8 @@ def train_and_evaluate(
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333) "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(233333333333),
), ),
) )
if rank == 0 and hps.train.if_save_every_weights == True: if rank == 0 and hps.train.if_save_every_weights == True:
@ -541,10 +603,24 @@ def evaluate(hps, generator, eval_loader, writer_eval):
ssl = ssl.to(device) ssl = ssl.to(device)
text, text_lengths = text.to(device), text_lengths.to(device) text, text_lengths = text.to(device), text_lengths.to(device)
for test in [0, 1]: for test in [0, 1]:
y_hat, mask, *_ = generator.module.infer( y_hat, mask, *_ = (
ssl, spec, spec_lengths, text, text_lengths, test=test generator.module.infer(
) if torch.cuda.is_available() else generator.infer( ssl,
ssl, spec, spec_lengths, text, text_lengths, test=test spec,
spec_lengths,
text,
text_lengths,
test=test,
)
if torch.cuda.is_available()
else generator.infer(
ssl,
spec,
spec_lengths,
text,
text_lengths,
test=test,
)
) )
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
@ -569,19 +645,19 @@ def evaluate(hps, generator, eval_loader, writer_eval):
image_dict.update( image_dict.update(
{ {
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy( f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy() y_hat_mel[0].cpu().numpy(),
) ),
} }
) )
audio_dict.update( audio_dict.update(
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]} {
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
},
) )
image_dict.update( image_dict.update(
{ {
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
mel[0].cpu().numpy() },
)
}
) )
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})

View File

@ -1,29 +1,37 @@
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import utils
import os import os
import utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
import logging
logging.getLogger("matplotlib").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO)
from random import randint from random import randint
from module import commons
from module import commons
from module.data_utils import (
DistributedBucketSampler,
)
from module.data_utils import (
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
)
from module.data_utils import ( from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader, TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
) )
from module.models import ( from module.models import (
SynthesizerTrnV3 as SynthesizerTrn, SynthesizerTrnV3 as SynthesizerTrn,
@ -43,7 +51,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main(): def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
else: else:
@ -71,7 +78,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
world_size=n_gpus, world_size=n_gpus,
rank=rank, rank=rank,
@ -125,17 +132,21 @@ def run(rank, n_gpus, hps):
# batch_size=1, pin_memory=True, # batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn) # drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn( net_g = (
hps.data.filter_length // 2 + 1, SynthesizerTrn(
hps.train.segment_size // hps.data.hop_length, hps.data.filter_length // 2 + 1,
n_speakers=hps.data.n_speakers, hps.train.segment_size // hps.data.hop_length,
**hps.model, n_speakers=hps.data.n_speakers,
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn( **hps.model,
hps.data.filter_length // 2 + 1, ).cuda(rank)
hps.train.segment_size // hps.data.hop_length, if torch.cuda.is_available()
n_speakers=hps.data.n_speakers, else SynthesizerTrn(
**hps.model, hps.data.filter_length // 2 + 1,
).to(device) hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
)
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device) # net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
# for name, param in net_g.named_parameters(): # for name, param in net_g.named_parameters():
@ -143,7 +154,7 @@ def run(rank, n_gpus, hps):
# print(name, "not requires_grad") # print(name, "not requires_grad")
optim_g = torch.optim.AdamW( optim_g = torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致 filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
hps.train.learning_rate, hps.train.learning_rate,
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps, eps=hps.train.eps,
@ -171,11 +182,11 @@ def run(rank, n_gpus, hps):
# logger.info("loaded D") # logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"), utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
net_g, net_g,
optim_g, optim_g,
) )
epoch_str+=1 epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader) global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1 # epoch_str = 1
# global_step = 0 # global_step = 0
@ -183,17 +194,24 @@ def run(rank, n_gpus, hps):
# traceback.print_exc() # traceback.print_exc()
epoch_str = 1 epoch_str = 1
global_step = 0 global_step = 0
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G): if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G, print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict( net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False, strict=False,
) if torch.cuda.is_available() else net_g.load_state_dict( )
if torch.cuda.is_available()
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False, strict=False,
) ),
) ##测试不加载优化器 ) ##测试不加载优化器
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D): # if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
# if rank == 0: # if rank == 0:
@ -209,9 +227,7 @@ def run(rank, n_gpus, hps):
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR( scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR( # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1 # optim_d, gamma=hps.train.lr_decay, last_epoch=-1
# ) # )
@ -221,7 +237,7 @@ def run(rank, n_gpus, hps):
scaler = GradScaler(enabled=hps.train.fp16_run) scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str) print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1): for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0: if rank == 0:
@ -257,7 +273,16 @@ def run(rank, n_gpus, hps):
def train_and_evaluate( def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers rank,
epoch,
hps,
nets,
optims,
schedulers,
scaler,
loaders,
logger,
writers,
): ):
net_g, net_d = nets net_g, net_d = nets
optim_g, optim_d = optims optim_g, optim_d = optims
@ -281,19 +306,33 @@ def train_and_evaluate(
# text, # text,
# text_lengths, # text_lengths,
# ) in enumerate(tqdm(train_loader)): # ) in enumerate(tqdm(train_loader)):
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)): for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
tqdm(train_loader)
):
if torch.cuda.is_available(): if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( spec, spec_lengths = (
rank, non_blocking=True spec.cuda(
) rank,
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda( non_blocking=True,
rank, non_blocking=True ),
spec_lengths.cuda(
rank,
non_blocking=True,
),
) )
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
ssl = ssl.cuda(rank, non_blocking=True) ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda( text, text_lengths = (
rank, non_blocking=True text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
) )
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@ -304,8 +343,18 @@ def train_and_evaluate(
text, text_lengths = text.to(device), text_lengths.to(device) text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt) cfm_loss = net_g(
loss_gen_all=cfm_loss ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad() optim_g.zero_grad()
scaler.scale(loss_gen_all).backward() scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g) scaler.unscale_(optim_g)
@ -315,12 +364,15 @@ def train_and_evaluate(
if rank == 0: if rank == 0:
if global_step % hps.train.log_interval == 0: if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr'] lr = optim_g.param_groups[0]["lr"]
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
losses = [cfm_loss] losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format( logger.info(
epoch, "Train Epoch: {} [{:.0f}%]".format(
100. * batch_idx / len(train_loader))) epoch,
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr]) logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
@ -334,7 +386,8 @@ def train_and_evaluate(
writer=writer, writer=writer,
global_step=global_step, global_step=global_step,
# images=image_dict, # images=image_dict,
scalars=scalar_dict) scalars=scalar_dict,
)
# if global_step % hps.train.eval_interval == 0: # if global_step % hps.train.eval_interval == 0:
# # evaluate(hps, net_g, eval_loader, writer_eval) # # evaluate(hps, net_g, eval_loader, writer_eval)
@ -344,7 +397,6 @@ def train_and_evaluate(
# # if keep_ckpts > 0: # # if keep_ckpts > 0:
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) # # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0: if hps.train.if_save_latest == 0:
@ -354,7 +406,8 @@ def train_and_evaluate(
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step) "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(global_step),
), ),
) )
# utils.save_checkpoint( # utils.save_checkpoint(
@ -373,7 +426,8 @@ def train_and_evaluate(
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333) "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(233333333333),
), ),
) )
# utils.save_checkpoint( # utils.save_checkpoint(

View File

@ -1,35 +1,45 @@
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import utils
import os import os
import utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
import logging
logging.getLogger("matplotlib").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO)
from collections import OrderedDict as od
from random import randint from random import randint
from module import commons from module import commons
from peft import LoraConfig, get_peft_model from module.data_utils import (
DistributedBucketSampler,
)
from module.data_utils import (
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
)
from module.data_utils import ( from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader, TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
) )
from module.models import ( from module.models import (
SynthesizerTrnV3 as SynthesizerTrn, SynthesizerTrnV3 as SynthesizerTrn,
) )
from peft import LoraConfig, get_peft_model
from process_ckpt import savee from process_ckpt import savee
from collections import OrderedDict as od
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
@ -43,7 +53,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main(): def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
else: else:
@ -62,7 +71,7 @@ def main():
def run(rank, n_gpus, hps): def run(rank, n_gpus, hps):
global global_step,no_grad_names,save_root,lora_rank global global_step, no_grad_names, save_root, lora_rank
if rank == 0: if rank == 0:
logger = utils.get_logger(hps.data.exp_dir) logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps) logger.info(hps)
@ -71,7 +80,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
world_size=n_gpus, world_size=n_gpus,
rank=rank, rank=rank,
@ -119,21 +128,24 @@ def run(rank, n_gpus, hps):
persistent_workers=True, persistent_workers=True,
prefetch_factor=4, prefetch_factor=4,
) )
save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank) save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
os.makedirs(save_root,exist_ok=True) os.makedirs(save_root, exist_ok=True)
lora_rank=int(hps.train.lora_rank) lora_rank = int(hps.train.lora_rank)
lora_config = LoraConfig( lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank, r=lora_rank,
lora_alpha=lora_rank, lora_alpha=lora_rank,
init_lora_weights=True, init_lora_weights=True,
) )
def get_model(hps):return SynthesizerTrn(
hps.data.filter_length // 2 + 1, def get_model(hps):
hps.train.segment_size // hps.data.hop_length, return SynthesizerTrn(
n_speakers=hps.data.n_speakers, hps.data.filter_length // 2 + 1,
**hps.model, hps.train.segment_size // hps.data.hop_length,
) n_speakers=hps.data.n_speakers,
**hps.model,
)
def get_optim(net_g): def get_optim(net_g):
return torch.optim.AdamW( return torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致 filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
@ -141,61 +153,66 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps, eps=hps.train.eps,
) )
def model2cuda(net_g,rank):
def model2cuda(net_g, rank):
if torch.cuda.is_available(): if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True) net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
else: else:
net_g = net_g.to(device) net_g = net_g.to(device)
return net_g return net_g
try:# 如果能加载自动resume
try: # 如果能加载自动resume
net_g = get_model(hps) net_g = get_model(hps)
net_g.cfm = get_peft_model(net_g.cfm, lora_config) net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank) net_g = model2cuda(net_g, rank)
optim_g=get_optim(net_g) optim_g = get_optim(net_g)
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(save_root, "G_*.pth"), utils.latest_checkpoint_path(save_root, "G_*.pth"),
net_g, net_g,
optim_g, optim_g,
) )
epoch_str+=1 epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader) global_step = (epoch_str - 1) * len(train_loader)
except: # 如果首次不能加载加载pretrain except: # 如果首次不能加载加载pretrain
# traceback.print_exc() # traceback.print_exc()
epoch_str = 1 epoch_str = 1
global_step = 0 global_step = 0
net_g = get_model(hps) net_g = get_model(hps)
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G): if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0: if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G, print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.load_state_dict( net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False, strict=False,
) ),
) )
net_g.cfm = get_peft_model(net_g.cfm, lora_config) net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank) net_g = model2cuda(net_g, rank)
optim_g = get_optim(net_g) optim_g = get_optim(net_g)
no_grad_names=set() no_grad_names = set()
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
if not param.requires_grad: if not param.requires_grad:
no_grad_names.add(name.replace("module.","")) no_grad_names.add(name.replace("module.", ""))
# print(name, "not requires_grad") # print(name, "not requires_grad")
# print(no_grad_names) # print(no_grad_names)
# os._exit(233333) # os._exit(233333)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR( scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
for _ in range(epoch_str): for _ in range(epoch_str):
scheduler_g.step() scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run) scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None net_d = optim_d = scheduler_d = None
print("start training from epoch %s"%epoch_str) print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1): for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0: if rank == 0:
train_and_evaluate( train_and_evaluate(
@ -227,9 +244,8 @@ def run(rank, n_gpus, hps):
scheduler_g.step() scheduler_g.step()
print("training done") print("training done")
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
):
net_g, net_d = nets net_g, net_d = nets
optim_g, optim_d = optims optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers # scheduler_g, scheduler_d = schedulers
@ -241,18 +257,32 @@ def train_and_evaluate(
global global_step global global_step
net_g.train() net_g.train()
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)): for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
tqdm(train_loader)
):
if torch.cuda.is_available(): if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( spec, spec_lengths = (
rank, non_blocking=True spec.cuda(
) rank,
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda( non_blocking=True,
rank, non_blocking=True ),
spec_lengths.cuda(
rank,
non_blocking=True,
),
) )
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
ssl = ssl.cuda(rank, non_blocking=True) ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False ssl.requires_grad = False
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda( text, text_lengths = (
rank, non_blocking=True text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
) )
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@ -262,8 +292,18 @@ def train_and_evaluate(
text, text_lengths = text.to(device), text_lengths.to(device) text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt) cfm_loss = net_g(
loss_gen_all=cfm_loss ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad() optim_g.zero_grad()
scaler.scale(loss_gen_all).backward() scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g) scaler.unscale_(optim_g)
@ -273,18 +313,17 @@ def train_and_evaluate(
if rank == 0: if rank == 0:
if global_step % hps.train.log_interval == 0: if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr'] lr = optim_g.param_groups[0]["lr"]
losses = [cfm_loss] losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format( logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
epoch,
100. * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr]) logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize( utils.summarize(
writer=writer, writer=writer,
global_step=global_step, global_step=global_step,
scalars=scalar_dict) scalars=scalar_dict,
)
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
@ -294,9 +333,7 @@ def train_and_evaluate(
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(save_root, "G_{}.pth".format(global_step)),
save_root, "G_{}.pth".format(global_step)
),
) )
else: else:
utils.save_checkpoint( utils.save_checkpoint(
@ -304,21 +341,19 @@ def train_and_evaluate(
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join( os.path.join(save_root, "G_{}.pth".format(233333333333)),
save_root, "G_{}.pth".format(233333333333)
),
) )
if rank == 0 and hps.train.if_save_every_weights == True: if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"): if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict() ckpt = net_g.module.state_dict()
else: else:
ckpt = net_g.state_dict() ckpt = net_g.state_dict()
sim_ckpt=od() sim_ckpt = od()
for key in ckpt: for key in ckpt:
# if "cfm"not in key: # if "cfm"not in key:
# print(key) # print(key)
if key not in no_grad_names: if key not in no_grad_names:
sim_ckpt[key]=ckpt[key].half().cpu() sim_ckpt[key] = ckpt[key].half().cpu()
logger.info( logger.info(
"saving ckpt %s_e%s:%s" "saving ckpt %s_e%s:%s"
% ( % (
@ -326,10 +361,11 @@ def train_and_evaluate(
epoch, epoch,
savee( savee(
sim_ckpt, sim_ckpt,
hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank), hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
epoch, epoch,
global_step, global_step,
hps,lora_rank=lora_rank hps,
lora_rank=lora_rank,
), ),
) )
) )

View File

@ -1 +1 @@
from .langsegmenter import LangSegmenter from .langsegmenter import LangSegmenter

View File

@ -3,38 +3,44 @@ import re
# jieba静音 # jieba静音
import jieba import jieba
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置 # 更改fast_langdetect大模型位置
from pathlib import Path from pathlib import Path
import fast_langdetect import fast_langdetect
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
fast_langdetect.infer.LangDetectConfig(
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
)
)
from split_lang import LangSplitter from split_lang import LangSplitter
def full_en(text): def full_en(text):
pattern = r'^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$' pattern = r"^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
return bool(re.match(pattern, text)) return bool(re.match(pattern, text))
def full_cjk(text): def full_cjk(text):
# 来自wiki # 来自wiki
cjk_ranges = [ cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs (0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DB5), # CJK Extension A (0x3400, 0x4DB5), # CJK Extension A
(0x20000, 0x2A6DD), # CJK Extension B (0x20000, 0x2A6DD), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C (0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D (0x2B740, 0x2B81F), # CJK Extension D
(0x2B820, 0x2CEAF), # CJK Extension E (0x2B820, 0x2CEAF), # CJK Extension E
(0x2CEB0, 0x2EBEF), # CJK Extension F (0x2CEB0, 0x2EBEF), # CJK Extension F
(0x30000, 0x3134A), # CJK Extension G (0x30000, 0x3134A), # CJK Extension G
(0x31350, 0x323AF), # CJK Extension H (0x31350, 0x323AF), # CJK Extension H
(0x2EBF0, 0x2EE5D), # CJK Extension H (0x2EBF0, 0x2EE5D), # CJK Extension H
] ]
pattern = r'[0-9、-〜。!?.!?… ]+$' pattern = r"[0-9、-〜。!?.!?… ]+$"
cjk_text = "" cjk_text = ""
for char in text: for char in text:
@ -45,7 +51,7 @@ def full_cjk(text):
return cjk_text return cjk_text
def split_jako(tag_lang,item): def split_jako(tag_lang, item):
if tag_lang == "ja": if tag_lang == "ja":
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)" pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
else: else:
@ -53,41 +59,40 @@ def split_jako(tag_lang,item):
lang_list: list[dict] = [] lang_list: list[dict] = []
tag = 0 tag = 0
for match in re.finditer(pattern, item['text']): for match in re.finditer(pattern, item["text"]):
if match.start() > tag: if match.start() > tag:
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]}) lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
tag = match.end() tag = match.end()
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]}) lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
if tag < len(item['text']): if tag < len(item["text"]):
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]}) lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
return lang_list return lang_list
def merge_lang(lang_list, item): def merge_lang(lang_list, item):
if lang_list and item['lang'] == lang_list[-1]['lang']: if lang_list and item["lang"] == lang_list[-1]["lang"]:
lang_list[-1]['text'] += item['text'] lang_list[-1]["text"] += item["text"]
else: else:
lang_list.append(item) lang_list.append(item)
return lang_list return lang_list
class LangSegmenter(): class LangSegmenter:
# 默认过滤器, 基于gsv目前四种语言 # 默认过滤器, 基于gsv目前四种语言
DEFAULT_LANG_MAP = { DEFAULT_LANG_MAP = {
"zh": "zh", "zh": "zh",
"yue": "zh", # 粤语 "yue": "zh", # 粤语
"wuu": "zh", # 吴语 "wuu": "zh", # 吴语
"zh-cn": "zh", "zh-cn": "zh",
"zh-tw": "x", # 繁体设置为x "zh-tw": "x", # 繁体设置为x
"ko": "ko", "ko": "ko",
"ja": "ja", "ja": "ja",
"en": "en", "en": "en",
} }
def getTexts(text): def getTexts(text):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP) lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
substr = lang_splitter.split_by_lang(text=text) substr = lang_splitter.split_by_lang(text=text)
@ -95,18 +100,18 @@ class LangSegmenter():
lang_list: list[dict] = [] lang_list: list[dict] = []
for _, item in enumerate(substr): for _, item in enumerate(substr):
dict_item = {'lang':item.lang,'text':item.text} dict_item = {"lang": item.lang, "text": item.text}
# 处理短英文被识别为其他语言的问题 # 处理短英文被识别为其他语言的问题
if full_en(dict_item['text']): if full_en(dict_item["text"]):
dict_item['lang'] = 'en' dict_item["lang"] = "en"
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list, dict_item)
continue continue
# 处理非日语夹日文的问题(不包含CJK) # 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = [] ja_list: list[dict] = []
if dict_item['lang'] != 'ja': if dict_item["lang"] != "ja":
ja_list = split_jako('ja',dict_item) ja_list = split_jako("ja", dict_item)
if not ja_list: if not ja_list:
ja_list.append(dict_item) ja_list.append(dict_item)
@ -115,8 +120,8 @@ class LangSegmenter():
ko_list: list[dict] = [] ko_list: list[dict] = []
temp_list: list[dict] = [] temp_list: list[dict] = []
for _, ko_item in enumerate(ja_list): for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != 'ko': if ko_item["lang"] != "ko":
ko_list = split_jako('ko',ko_item) ko_list = split_jako("ko", ko_item)
if ko_list: if ko_list:
temp_list.extend(ko_list) temp_list.extend(ko_list)
@ -126,28 +131,28 @@ class LangSegmenter():
# 未存在非日韩文夹日韩文 # 未存在非日韩文夹日韩文
if len(temp_list) == 1: if len(temp_list) == 1:
# 未知语言检查是否为CJK # 未知语言检查是否为CJK
if dict_item['lang'] == 'x': if dict_item["lang"] == "x":
cjk_text = full_cjk(dict_item['text']) cjk_text = full_cjk(dict_item["text"])
if cjk_text: if cjk_text:
dict_item = {'lang':'zh','text':cjk_text} dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list, dict_item)
continue continue
else: else:
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list, dict_item)
continue continue
# 存在非日韩文夹日韩文 # 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list): for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK # 未知语言检查是否为CJK
if temp_item['lang'] == 'x': if temp_item["lang"] == "x":
cjk_text = full_cjk(dict_item['text']) cjk_text = full_cjk(dict_item["text"])
if cjk_text: if cjk_text:
dict_item = {'lang':'zh','text':cjk_text} dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list, dict_item)
else: else:
lang_list = merge_lang(lang_list,temp_item) lang_list = merge_lang(lang_list, temp_item)
return lang_list return lang_list
if __name__ == "__main__": if __name__ == "__main__":
text = "MyGO?,你也喜欢まいご吗?" text = "MyGO?,你也喜欢まいご吗?"
@ -155,4 +160,3 @@ if __name__ == "__main__":
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。" text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
print(LangSegmenter.getTexts(text)) print(LangSegmenter.getTexts(text))

View File

@ -10,18 +10,19 @@ from text import symbols2 as symbols_v2
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)} _symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)} _symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
def cleaned_text_to_sequence(cleaned_text, version=None): def cleaned_text_to_sequence(cleaned_text, version=None):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
text: string to convert to a sequence text: string to convert to a sequence
Returns: Returns:
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' """
if version is None:version=os.environ.get('version', 'v2') if version is None:
if version == "v1": version = os.environ.get("version", "v2")
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text] if version == "v1":
else: phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text] else:
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
return phones
return phones

View File

@ -98,9 +98,7 @@ def replace_punctuation(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub( replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text return replaced_text
@ -114,7 +112,9 @@ def text_normalize(text):
return dest_text return dest_text
punctuation_set=set(punctuation) punctuation_set = set(punctuation)
def jyuping_to_initials_finals_tones(jyuping_syllables): def jyuping_to_initials_finals_tones(jyuping_syllables):
initials_finals = [] initials_finals = []
tones = [] tones = []
@ -159,12 +159,14 @@ def jyuping_to_initials_finals_tones(jyuping_syllables):
assert len(initials_finals) == len(tones) assert len(initials_finals) == len(tones)
###魔改为辅音+带音调的元音 ###魔改为辅音+带音调的元音
phones=[] phones = []
for a,b in zip(initials_finals,tones): for a, b in zip(initials_finals, tones):
if(b not in [-1,0]):###防止粤语和普通话重合开头加Y如果是标点不加。 if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y如果是标点不加。
todo="%s%s"%(a,b) todo = "%s%s" % (a, b)
else:todo=a else:
if(todo not in punctuation_set):todo="Y%s"%todo todo = a
if todo not in punctuation_set:
todo = "Y%s" % todo
phones.append(todo) phones.append(todo)
# return initials_finals, tones, word2ph # return initials_finals, tones, word2ph
@ -217,4 +219,4 @@ if __name__ == "__main__":
# phones, tones, word2ph = g2p(text) # phones, tones, word2ph = g2p(text)
phones, word2ph = g2p(text) phones, word2ph = g2p(text)
# print(phones, tones, word2ph) # print(phones, tones, word2ph)
print(phones, word2ph) print(phones, word2ph)

View File

@ -18,6 +18,7 @@ pinyin_to_symbol_map = {
import jieba_fast import jieba_fast
import logging import logging
jieba_fast.setLogLevel(logging.CRITICAL) jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg import jieba_fast.posseg as psg
@ -37,7 +38,7 @@ rep_map = {
"/": ",", "/": ",",
"": "-", "": "-",
"~": "", "~": "",
"":"", "": "",
} }
tone_modifier = ToneSandhi() tone_modifier = ToneSandhi()
@ -49,9 +50,7 @@ def replace_punctuation(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub( replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text return replaced_text
@ -62,17 +61,15 @@ def replace_punctuation_with_en(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub( replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text return replaced_text
def replace_consecutive_punctuation(text): def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation) punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+' pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r'\1', text) result = re.sub(pattern, r"\1", text)
return result return result
@ -87,9 +84,7 @@ def _get_initials_finals(word):
initials = [] initials = []
finals = [] finals = []
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin( orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals): for c, v in zip(orig_initials, orig_finals):
initials.append(c) initials.append(c)
finals.append(v) finals.append(v)

View File

@ -19,17 +19,24 @@ pinyin_to_symbol_map = {
import jieba_fast import jieba_fast
import logging import logging
jieba_fast.setLogLevel(logging.CRITICAL) jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg import jieba_fast.posseg as psg
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启 # is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False # is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False
if is_g2pw: if is_g2pw:
# print("当前使用g2pw进行拼音推理") # print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation from text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path) parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True) g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
model_source=os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
v_to_u=False,
neutral_tone_with_five=True,
)
rep_map = { rep_map = {
"": ",", "": ",",
@ -46,7 +53,7 @@ rep_map = {
"/": ",", "/": ",",
"": "-", "": "-",
"~": "", "~": "",
"":"", "": "",
} }
tone_modifier = ToneSandhi() tone_modifier = ToneSandhi()
@ -58,9 +65,7 @@ def replace_punctuation(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub( replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text return replaced_text
@ -77,9 +82,7 @@ def _get_initials_finals(word):
finals = [] finals = []
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin( orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals): for c, v in zip(orig_initials, orig_finals):
initials.append(c) initials.append(c)
@ -87,31 +90,66 @@ def _get_initials_finals(word):
return initials, finals return initials, finals
must_erhua = { must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"}
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
}
not_erhua = { not_erhua = {
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿", "虐儿",
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿", "为儿",
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿", "护儿",
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿", "瞒儿",
"狗儿", "少儿" "救儿",
"替儿",
"有儿",
"一儿",
"我儿",
"俺儿",
"妻儿",
"拐儿",
"聋儿",
"乞儿",
"患儿",
"幼儿",
"孤儿",
"婴儿",
"婴幼儿",
"连体儿",
"脑瘫儿",
"流浪儿",
"体弱儿",
"混血儿",
"蜜雪儿",
"舫儿",
"祖儿",
"美儿",
"应采儿",
"可儿",
"侄儿",
"孙儿",
"侄孙儿",
"女儿",
"男儿",
"红孩儿",
"花儿",
"虫儿",
"马儿",
"鸟儿",
"猪儿",
"猫儿",
"狗儿",
"少儿",
} }
def _merge_erhua(initials: list[str],
finals: list[str],
word: str, def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> list[list[str]]:
pos: str) -> list[list[str]]:
""" """
Do erhub. Do erhub.
""" """
# fix er1 # fix er1
for i, phn in enumerate(finals): for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn == 'er1': if i == len(finals) - 1 and word[i] == "" and phn == "er1":
finals[i] = 'er2' finals[i] = "er2"
# 发音 # 发音
if word not in must_erhua and (word in not_erhua or if word not in must_erhua and (word in not_erhua or pos in {"a", "j", "nr"}):
pos in {"a", "j", "nr"}):
return initials, finals return initials, finals
# "……" 等情况直接返回 # "……" 等情况直接返回
@ -124,9 +162,13 @@ def _merge_erhua(initials: list[str],
new_initials = [] new_initials = []
new_finals = [] new_finals = []
for i, phn in enumerate(finals): for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn in { if (
"er2", "er5" i == len(finals) - 1
} and word[-2:] not in not_erhua and new_finals: and word[i] == ""
and phn in {"er2", "er5"}
and word[-2:] not in not_erhua
and new_finals
):
phn = "er" + new_finals[-1][-1] phn = "er" + new_finals[-1][-1]
new_initials.append(initials[i]) new_initials.append(initials[i])
@ -160,7 +202,7 @@ def _g2p(segments):
# assert len(sub_initials) == len(sub_finals) == len(word) # assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, []) initials = sum(initials, [])
finals = sum(finals, []) finals = sum(finals, [])
print("pypinyin结果",initials,finals) print("pypinyin结果", initials, finals)
else: else:
# g2pw采用整句推理 # g2pw采用整句推理
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3) pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
@ -171,19 +213,19 @@ def _g2p(segments):
sub_finals = [] sub_finals = []
now_word_length = pre_word_length + len(word) now_word_length = pre_word_length + len(word)
if pos == 'eng': if pos == "eng":
pre_word_length = now_word_length pre_word_length = now_word_length
continue continue
word_pinyins = pinyins[pre_word_length:now_word_length] word_pinyins = pinyins[pre_word_length:now_word_length]
# 多音字消歧 # 多音字消歧
word_pinyins = correct_pronunciation(word,word_pinyins) word_pinyins = correct_pronunciation(word, word_pinyins)
for pinyin in word_pinyins: for pinyin in word_pinyins:
if pinyin[0].isalpha(): if pinyin[0].isalpha():
sub_initials.append(to_initials(pinyin)) sub_initials.append(to_initials(pinyin))
sub_finals.append(to_finals_tone3(pinyin,neutral_tone_with_five=True)) sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
else: else:
sub_initials.append(pinyin) sub_initials.append(pinyin)
sub_finals.append(pinyin) sub_finals.append(pinyin)
@ -259,18 +301,18 @@ def replace_punctuation_with_en(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub( replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text return replaced_text
def replace_consecutive_punctuation(text): def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation) punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+' pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r'\1', text) result = re.sub(pattern, r"\1", text)
return result return result
def text_normalize(text): def text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer() tx = TextNormalizer()
@ -283,6 +325,7 @@ def text_normalize(text):
dest_text = replace_consecutive_punctuation(dest_text) dest_text = replace_consecutive_punctuation(dest_text)
return dest_text return dest_text
# 不排除英文的文本格式化 # 不排除英文的文本格式化
def mix_text_normalize(text): def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization

View File

@ -19,55 +19,57 @@ special = [
def clean_text(text, language, version=None): def clean_text(text, language, version=None):
if version is None:version=os.environ.get('version', 'v2') if version is None:
version = os.environ.get("version", "v2")
if version == "v1": if version == "v1":
symbols = symbols_v1.symbols symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else: else:
symbols = symbols_v2.symbols symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"} language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
if(language not in language_module_map): if language not in language_module_map:
language="en" language = "en"
text=" " text = " "
for special_s, special_l, target_symbol in special: for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l: if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol, version) return clean_special(text, language, special_s, target_symbol, version)
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]]) language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
if hasattr(language_module,"text_normalize"): if hasattr(language_module, "text_normalize"):
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
else: else:
norm_text=text norm_text = text
if language == "zh" or language=="yue":########## if language == "zh" or language == "yue": ##########
phones, word2ph = language_module.g2p(norm_text) phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph) assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph) assert len(norm_text) == len(word2ph)
elif language == "en": elif language == "en":
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
if len(phones) < 4: if len(phones) < 4:
phones = [','] + phones phones = [","] + phones
word2ph = None word2ph = None
else: else:
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
word2ph = None word2ph = None
phones = ['UNK' if ph not in symbols else ph for ph in phones] phones = ["UNK" if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text return phones, word2ph, norm_text
def clean_special(text, language, special_s, target_symbol, version=None): def clean_special(text, language, special_s, target_symbol, version=None):
if version is None:version=os.environ.get('version', 'v2') if version is None:
version = os.environ.get("version", "v2")
if version == "v1": if version == "v1":
symbols = symbols_v1.symbols symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else: else:
symbols = symbols_v2.symbols symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"} language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
""" """
特殊静音段sp符号处理 特殊静音段sp符号处理
""" """
text = text.replace(special_s, ",") text = text.replace(special_s, ",")
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]]) language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
new_ph = [] new_ph = []
@ -81,8 +83,9 @@ def clean_special(text, language, special_s, target_symbol, version=None):
def text_to_sequence(text, language, version=None): def text_to_sequence(text, language, version=None):
version = os.environ.get('version',version) version = os.environ.get("version", version)
if version is None:version='v2' if version is None:
version = "v2"
phones = clean_text(text) phones = clean_text(text)
return cleaned_text_to_sequence(phones, version) return cleaned_text_to_sequence(phones, version)

View File

@ -9,17 +9,17 @@ import unicodedata
# 后缀计量单位替换表 # 后缀计量单位替换表
measurement_map = { measurement_map = {
"m": ["meter", "meters"], "m": ["meter", "meters"],
'km': ["kilometer", "kilometers"], "km": ["kilometer", "kilometers"],
"km/h": ["kilometer per hour", "kilometers per hour"], "km/h": ["kilometer per hour", "kilometers per hour"],
"ft": ["feet", "feet"], "ft": ["feet", "feet"],
"L": ["liter", "liters"], "L": ["liter", "liters"],
"tbsp": ["tablespoon", "tablespoons"], "tbsp": ["tablespoon", "tablespoons"],
'tsp': ["teaspoon", "teaspoons"], "tsp": ["teaspoon", "teaspoons"],
"h": ["hour", "hours"], "h": ["hour", "hours"],
"min": ["minute", "minutes"], "min": ["minute", "minutes"],
"s": ["second", "seconds"], "s": ["second", "seconds"],
"°C": ["degree celsius", "degrees celsius"], "°C": ["degree celsius", "degrees celsius"],
"°F": ["degree fahrenheit", "degrees fahrenheit"] "°F": ["degree fahrenheit", "degrees fahrenheit"],
} }
@ -27,41 +27,42 @@ measurement_map = {
_inflect = inflect.engine() _inflect = inflect.engine()
# 转化数字序数词 # 转化数字序数词
_ordinal_number_re = re.compile(r'\b([0-9]+)\. ') _ordinal_number_re = re.compile(r"\b([0-9]+)\. ")
# 我听说好像对于数字正则识别其实用 \d 会好一点 # 我听说好像对于数字正则识别其实用 \d 会好一点
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
# 时间识别 # 时间识别
_time_re = re.compile(r'\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b') _time_re = re.compile(r"\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b")
# 后缀计量单位识别 # 后缀计量单位识别
_measurement_re = re.compile(r'\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b') _measurement_re = re.compile(r"\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b")
# 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ ) # 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ )
_pounds_re_start = re.compile(r'£([0-9\.\,]*[0-9]+)') _pounds_re_start = re.compile(r"£([0-9\.\,]*[0-9]+)")
_pounds_re_end = re.compile(r'([0-9\.\,]*[0-9]+)£') _pounds_re_end = re.compile(r"([0-9\.\,]*[0-9]+)£")
# 前后 $ 识别 # 前后 $ 识别
_dollars_re_start = re.compile(r'\$([0-9\.\,]*[0-9]+)') _dollars_re_start = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_dollars_re_end = re.compile(r'([(0-9\.\,]*[0-9]+)\$') _dollars_re_end = re.compile(r"([(0-9\.\,]*[0-9]+)\$")
# 小数的识别 # 小数的识别
_decimal_number_re = re.compile(r'([0-9]+\.\s*[0-9]+)') _decimal_number_re = re.compile(r"([0-9]+\.\s*[0-9]+)")
# 分数识别 (形式 "3/4" ) # 分数识别 (形式 "3/4" )
_fraction_re = re.compile(r'([0-9]+/[0-9]+)') _fraction_re = re.compile(r"([0-9]+/[0-9]+)")
# 序数词识别 # 序数词识别
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
# 数字处理 # 数字处理
_number_re = re.compile(r'[0-9]+') _number_re = re.compile(r"[0-9]+")
def _convert_ordinal(m): def _convert_ordinal(m):
""" """
标准化序数词, 例如: 1. 2. 3. 4. 5. 6. 标准化序数词, 例如: 1. 2. 3. 4. 5. 6.
Examples: Examples:
input: "1. " input: "1. "
output: "1st" output: "1st"
@ -70,24 +71,26 @@ def _convert_ordinal(m):
ordinal = _inflect.ordinal(m.group(1)) ordinal = _inflect.ordinal(m.group(1))
return ordinal + ", " return ordinal + ", "
def _remove_commas(m): def _remove_commas(m):
return m.group(1).replace(',', '') return m.group(1).replace(",", "")
def _expand_time(m): def _expand_time(m):
""" """
24 小时制的时间转换为 12 小时制的时间表示方式 24 小时制的时间转换为 12 小时制的时间表示方式
Examples: Examples:
input: "13:00 / 4:00 / 13:30" input: "13:00 / 4:00 / 13:30"
output: "one o'clock p.m. / four o'clock am. / one thirty p.m." output: "one o'clock p.m. / four o'clock am. / one thirty p.m."
""" """
hours, minutes = map(int, m.group(1, 2)) hours, minutes = map(int, m.group(1, 2))
period = 'a.m.' if hours < 12 else 'p.m.' period = "a.m." if hours < 12 else "p.m."
if hours > 12: if hours > 12:
hours -= 12 hours -= 12
hour_word = _inflect.number_to_words(hours) hour_word = _inflect.number_to_words(hours)
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else '' minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ""
if minutes == 0: if minutes == 0:
return f"{hour_word} o'clock {period}" return f"{hour_word} o'clock {period}"
@ -103,7 +106,7 @@ def _expand_measurement(m):
sign = m.group(3) sign = m.group(3)
ptr = 1 ptr = 1
# 想不到怎么方便的取数字又懒得改正则1.2 反正也是复数读法,干脆直接去掉 "." # 想不到怎么方便的取数字又懒得改正则1.2 反正也是复数读法,干脆直接去掉 "."
num = int(m.group(1).replace(sign, '').replace(".",'')) num = int(m.group(1).replace(sign, "").replace(".", ""))
decimal_part = m.group(2) decimal_part = m.group(2)
# 上面判断的漏洞,比如 0.1 的情况,在这里排除了 # 上面判断的漏洞,比如 0.1 的情况,在这里排除了
if decimal_part == None and num == 1: if decimal_part == None and num == 1:
@ -116,23 +119,24 @@ def _expand_pounds(m):
没找到特别规范的说明和美元的处理一样其实可以把两个合并在一起 没找到特别规范的说明和美元的处理一样其实可以把两个合并在一起
""" """
match = m.group(1) match = m.group(1)
parts = match.split('.') parts = match.split(".")
if len(parts) > 2: if len(parts) > 2:
return match + ' pounds' # Unexpected format return match + " pounds" # Unexpected format
pounds = int(parts[0]) if parts[0] else 0 pounds = int(parts[0]) if parts[0] else 0
pence = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0 pence = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
if pounds and pence: if pounds and pence:
pound_unit = 'pound' if pounds == 1 else 'pounds' pound_unit = "pound" if pounds == 1 else "pounds"
penny_unit = 'penny' if pence == 1 else 'pence' penny_unit = "penny" if pence == 1 else "pence"
return '%s %s and %s %s' % (pounds, pound_unit, pence, penny_unit) return "%s %s and %s %s" % (pounds, pound_unit, pence, penny_unit)
elif pounds: elif pounds:
pound_unit = 'pound' if pounds == 1 else 'pounds' pound_unit = "pound" if pounds == 1 else "pounds"
return '%s %s' % (pounds, pound_unit) return "%s %s" % (pounds, pound_unit)
elif pence: elif pence:
penny_unit = 'penny' if pence == 1 else 'pence' penny_unit = "penny" if pence == 1 else "pence"
return '%s %s' % (pence, penny_unit) return "%s %s" % (pence, penny_unit)
else: else:
return 'zero pounds' return "zero pounds"
def _expand_dollars(m): def _expand_dollars(m):
""" """
@ -142,23 +146,24 @@ def _expand_dollars(m):
output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents" output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents"
""" """
match = m.group(1) match = m.group(1)
parts = match.split('.') parts = match.split(".")
if len(parts) > 2: if len(parts) > 2:
return match + ' dollars' # Unexpected format return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0 dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0 cents = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
if dollars and cents: if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = "cent" if cents == 1 else "cents"
return '%s %s and %s %s' % (dollars, dollar_unit, cents, cent_unit) return "%s %s and %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars: elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = "dollar" if dollars == 1 else "dollars"
return '%s %s' % (dollars, dollar_unit) return "%s %s" % (dollars, dollar_unit)
elif cents: elif cents:
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = "cent" if cents == 1 else "cents"
return '%s %s' % (cents, cent_unit) return "%s %s" % (cents, cent_unit)
else: else:
return 'zero dollars' return "zero dollars"
# 小数的处理 # 小数的处理
def _expand_decimal_number(m): def _expand_decimal_number(m):
@ -168,11 +173,11 @@ def _expand_decimal_number(m):
output: "thirteen point two three four" output: "thirteen point two three four"
""" """
match = m.group(1) match = m.group(1)
parts = match.split('.') parts = match.split(".")
words = [] words = []
# 遍历字符串中的每个字符 # 遍历字符串中的每个字符
for char in parts[1]: for char in parts[1]:
if char == '.': if char == ".":
words.append("point") words.append("point")
else: else:
words.append(char) words.append(char)
@ -186,7 +191,7 @@ def _expend_fraction(m):
规则2: 如果分子大于 1, 在读分母的时候使用序数词复数读法. 规则2: 如果分子大于 1, 在读分母的时候使用序数词复数读法.
规则3: 当分母为2的时候, 分母读做 half, 并且当分子大于 1 的时候, half 也要用复数读法, 读为 halves. 规则3: 当分母为2的时候, 分母读做 half, 并且当分子大于 1 的时候, half 也要用复数读法, 读为 halves.
Examples: Examples:
| Written | Said | | Written | Said |
|:---:|:---:| |:---:|:---:|
| 1/3 | one third | | 1/3 | one third |
@ -196,39 +201,41 @@ def _expend_fraction(m):
| 3/2 | three halves | | 3/2 | three halves |
""" """
match = m.group(0) match = m.group(0)
numerator, denominator = map(int, match.split('/')) numerator, denominator = map(int, match.split("/"))
numerator_part = _inflect.number_to_words(numerator) numerator_part = _inflect.number_to_words(numerator)
if denominator == 2: if denominator == 2:
if numerator == 1: if numerator == 1:
denominator_part = 'half' denominator_part = "half"
else: else:
denominator_part = 'halves' denominator_part = "halves"
elif denominator == 1: elif denominator == 1:
return f'{numerator_part}' return f"{numerator_part}"
else: else:
denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator)) denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator))
if numerator > 1: if numerator > 1:
denominator_part += 's' denominator_part += "s"
return f"{numerator_part} {denominator_part}"
return f'{numerator_part} {denominator_part}'
def _expand_ordinal(m): def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0)) return _inflect.number_to_words(m.group(0))
def _expand_number(m): def _expand_number(m):
num = int(m.group(0)) num = int(m.group(0))
if num > 1000 and num < 3000: if num > 1000 and num < 3000:
if num == 2000: if num == 2000:
return 'two thousand' return "two thousand"
elif num > 2000 and num < 2010: elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100) return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0: elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred' return _inflect.number_to_words(num // 100) + " hundred"
else: else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else: else:
return _inflect.number_to_words(num, andword='') return _inflect.number_to_words(num, andword="")
def normalize(text): def normalize(text):
@ -238,7 +245,7 @@ def normalize(text):
""" """
text = re.sub(_ordinal_number_re, _convert_ordinal, text) text = re.sub(_ordinal_number_re, _convert_ordinal, text)
text = re.sub(r'(?<!\d)-|-(?!\d)', ' minus ', text) text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
text = re.sub(_comma_number_re, _remove_commas, text) text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_time_re, _expand_time, text) text = re.sub(_time_re, _expand_time, text)
text = re.sub(_measurement_re, _expand_measurement, text) text = re.sub(_measurement_re, _expand_measurement, text)
@ -251,19 +258,20 @@ def normalize(text):
text = re.sub(_ordinal_re, _expand_ordinal, text) text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text) text = re.sub(_number_re, _expand_number, text)
text = ''.join(char for char in unicodedata.normalize('NFD', text) text = "".join(
if unicodedata.category(char) != 'Mn') # Strip accents char for char in unicodedata.normalize("NFD", text) if unicodedata.category(char) != "Mn"
) # Strip accents
text = re.sub("%", " percent", text) text = re.sub("%", " percent", text)
text = re.sub("[^ A-Za-z'.,?!\-]", "", text) text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
text = re.sub(r"(?i)i\.e\.", "that is", text) text = re.sub(r"(?i)i\.e\.", "that is", text)
text = re.sub(r"(?i)e\.g\.", "for example", text) text = re.sub(r"(?i)e\.g\.", "for example", text)
# 增加纯大写单词拆分 # 增加纯大写单词拆分
text = re.sub(r'(?<!^)(?<![\s])([A-Z])', r' \1', text) text = re.sub(r"(?<!^)(?<![\s])([A-Z])", r" \1", text)
return text return text
if __name__ == '__main__': if __name__ == "__main__":
# 我觉得其实可以把切分结果展示出来只读或者修改不影响传给TTS的实际text # 我觉得其实可以把切分结果展示出来只读或者修改不影响传给TTS的实际text
# 然后让用户确认后再输入给 TTS可以让用户检查自己有没有不标准的输入 # 然后让用户确认后再输入给 TTS可以让用户检查自己有没有不标准的输入
print(normalize("1. test ordinal number 1st")) print(normalize("1. test ordinal number 1st"))
@ -272,4 +280,4 @@ if __name__ == '__main__':
print(normalize("1st, 22nd")) print(normalize("1st, 22nd"))
print(normalize("a test 20h, 1.2s, 1L, 0.1km")) print(normalize("a test 20h, 1.2s, 1L, 0.1km"))
print(normalize("a test of time 4:00, 13:00, 13:30")) print(normalize("a test of time 4:00, 13:00, 13:30"))
print(normalize("a test of temperature 4°F, 23°C, -19°C")) print(normalize("a test of temperature 4°F, 23°C, -19°C"))

View File

@ -11,6 +11,7 @@ from text.symbols2 import symbols
from builtins import str as unicode from builtins import str as unicode
from text.en_normalization.expend import normalize from text.en_normalization.expend import normalize
from nltk.tokenize import TweetTokenizer from nltk.tokenize import TweetTokenizer
word_tokenize = TweetTokenizer().tokenize word_tokenize = TweetTokenizer().tokenize
from nltk import pos_tag from nltk import pos_tag
@ -121,9 +122,9 @@ def replace_phs(phs):
def replace_consecutive_punctuation(text): def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation) punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}\s])([{punctuations}])+' pattern = f"([{punctuations}\s])([{punctuations}])+"
result = re.sub(pattern, r'\1', text) result = re.sub(pattern, r"\1", text)
return result return result
@ -182,6 +183,7 @@ def read_dict_new():
return g2p_dict return g2p_dict
def hot_reload_hot(g2p_dict): def hot_reload_hot(g2p_dict):
with open(CMU_DICT_HOT_PATH) as f: with open(CMU_DICT_HOT_PATH) as f:
line = f.readline() line = f.readline()
@ -258,9 +260,12 @@ class en_G2p(G2p):
del self.cmu[word.lower()] del self.cmu[word.lower()]
# 修正多音字 # 修正多音字
self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP') self.homograph2features["read"] = (["R", "IY1", "D"], ["R", "EH1", "D"], "VBP")
self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ') self.homograph2features["complex"] = (
["K", "AH0", "M", "P", "L", "EH1", "K", "S"],
["K", "AA1", "M", "P", "L", "EH0", "K", "S"],
"JJ",
)
def __call__(self, text): def __call__(self, text):
# tokenization # tokenization
@ -279,7 +284,7 @@ class en_G2p(G2p):
elif len(word) == 1: elif len(word) == 1:
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写 # 单读 A 发音修正, 这里需要原格式 o_word 判断大写
if o_word == "A": if o_word == "A":
pron = ['EY1'] pron = ["EY1"]
else: else:
pron = self.cmu[word][0] pron = self.cmu[word][0]
# g2p_en 原版多音字处理 # g2p_en 原版多音字处理
@ -288,7 +293,7 @@ class en_G2p(G2p):
if pos.startswith(pos1): if pos.startswith(pos1):
pron = pron1 pron = pron1
# pos1比pos长仅出现在read # pos1比pos长仅出现在read
elif len(pos) < len(pos1) and pos == pos1[:len(pos)]: elif len(pos) < len(pos1) and pos == pos1[: len(pos)]:
pron = pron1 pron = pron1
else: else:
pron = pron2 pron = pron2
@ -301,7 +306,6 @@ class en_G2p(G2p):
return prons[:-1] return prons[:-1]
def qryword(self, o_word): def qryword(self, o_word):
word = o_word.lower() word = o_word.lower()
@ -319,7 +323,7 @@ class en_G2p(G2p):
for w in word: for w in word:
# 单读 A 发音修正, 此处不存在大写的情况 # 单读 A 发音修正, 此处不存在大写的情况
if w == "a": if w == "a":
phones.extend(['EY1']) phones.extend(["EY1"])
elif not w.isalpha(): elif not w.isalpha():
phones.extend([w]) phones.extend([w])
else: else:
@ -330,23 +334,23 @@ class en_G2p(G2p):
if re.match(r"^([a-z]+)('s)$", word): if re.match(r"^([a-z]+)('s)$", word):
phones = self.qryword(word[:-2])[:] phones = self.qryword(word[:-2])[:]
# P T K F TH HH 无声辅音结尾 's 发 ['S'] # P T K F TH HH 无声辅音结尾 's 发 ['S']
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']: if phones[-1] in ["P", "T", "K", "F", "TH", "HH"]:
phones.extend(['S']) phones.extend(["S"])
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z'] # S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']: elif phones[-1] in ["S", "Z", "SH", "ZH", "CH", "JH"]:
phones.extend(['AH0', 'Z']) phones.extend(["AH0", "Z"])
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z'] # B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2 # AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z'] # ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
else: else:
phones.extend(['Z']) phones.extend(["Z"])
return phones return phones
# 尝试进行分词,应对复合词 # 尝试进行分词,应对复合词
comps = wordsegment.segment(word.lower()) comps = wordsegment.segment(word.lower())
# 无法分词的送回去预测 # 无法分词的送回去预测
if len(comps)==1: if len(comps) == 1:
return self.predict(word) return self.predict(word)
# 可以分词的递归处理 # 可以分词的递归处理

View File

@ -1 +1 @@
from text.g2pw.g2pw import * from text.g2pw.g2pw import *

View File

@ -15,6 +15,7 @@
Credits Credits
This code is modified from https://github.com/GitYCC/g2pW This code is modified from https://github.com/GitYCC/g2pW
""" """
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Tuple from typing import Tuple
@ -23,21 +24,24 @@ import numpy as np
from .utils import tokenize_and_map from .utils import tokenize_and_map
ANCHOR_CHAR = '' ANCHOR_CHAR = ""
def prepare_onnx_input(tokenizer, def prepare_onnx_input(
labels: List[str], tokenizer,
char2phonemes: Dict[str, List[int]], labels: List[str],
chars: List[str], char2phonemes: Dict[str, List[int]],
texts: List[str], chars: List[str],
query_ids: List[int], texts: List[str],
use_mask: bool=False, query_ids: List[int],
window_size: int=None, use_mask: bool = False,
max_len: int=512) -> Dict[str, np.array]: window_size: int = None,
max_len: int = 512,
) -> Dict[str, np.array]:
if window_size is not None: if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts( truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids) window_size=window_size, texts=texts, query_ids=query_ids
)
input_ids = [] input_ids = []
token_type_ids = [] token_type_ids = []
attention_masks = [] attention_masks = []
@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
query_id = (truncated_query_ids if window_size else query_ids)[idx] query_id = (truncated_query_ids if window_size else query_ids)[idx]
try: try:
tokens, text2token, token2text = tokenize_and_map( tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
tokenizer=tokenizer, text=text)
except Exception: except Exception:
print(f'warning: text "{text}" is invalid') print(f'warning: text "{text}" is invalid')
return {} return {}
text, query_id, tokens, text2token, token2text = _truncate( text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len, max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
text=text, )
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_id = list( input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int)) attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
query_char = text[query_id] query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \ phoneme_mask = (
if use_mask else [1] * len(labels) [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
)
char_id = chars.index(query_char) char_id = chars.index(query_char)
position_id = text2token[ position_id = text2token[query_id] + 1 # [CLS] token locate at first place
query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id) input_ids.append(input_id)
token_type_ids.append(token_type_id) token_type_ids.append(token_type_id)
@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
position_ids.append(position_id) position_ids.append(position_id)
outputs = { outputs = {
'input_ids': np.array(input_ids).astype(np.int64), "input_ids": np.array(input_ids).astype(np.int64),
'token_type_ids': np.array(token_type_ids).astype(np.int64), "token_type_ids": np.array(token_type_ids).astype(np.int64),
'attention_masks': np.array(attention_masks).astype(np.int64), "attention_masks": np.array(attention_masks).astype(np.int64),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32), "phoneme_masks": np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids).astype(np.int64), "char_ids": np.array(char_ids).astype(np.int64),
'position_ids': np.array(position_ids).astype(np.int64), "position_ids": np.array(position_ids).astype(np.int64),
} }
return outputs return outputs
def _truncate_texts(window_size: int, texts: List[str], def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = [] truncated_texts = []
truncated_query_ids = [] truncated_query_ids = []
for text, query_id in zip(texts, query_ids): for text, query_id in zip(texts, query_ids):
@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
return truncated_texts, truncated_query_ids return truncated_texts, truncated_query_ids
def _truncate(max_len: int, def _truncate(
text: str, max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
query_id: int, ):
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
truncate_len = max_len - 2 truncate_len = max_len - 2
if len(tokens) <= truncate_len: if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text) return (text, query_id, tokens, text2token, token2text)
@ -137,14 +131,16 @@ def _truncate(max_len: int,
start = token2text[token_start][0] start = token2text[token_start][0]
end = token2text[token_end - 1][1] end = token2text[token_end - 1][1]
return (text[start:end], query_id - start, tokens[token_start:token_end], [ return (
i - token_start if i is not None else None text[start:end],
for i in text2token[start:end] query_id - start,
], [(s - start, e - start) for s, e in token2text[token_start:token_end]]) tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
)
def get_phoneme_labels(polyphonic_chars: List[List[str]] def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {} char2phonemes = {}
for char, phoneme in polyphonic_chars: for char, phoneme in polyphonic_chars:
@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
return labels, char2phonemes return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars: List[List[str]] def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
) -> Tuple[List[str], Dict[str, List[int]]]: labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {} char2phonemes = {}
for char, phoneme in polyphonic_chars: for char, phoneme in polyphonic_chars:
if char not in char2phonemes: if char not in char2phonemes:
char2phonemes[char] = [] char2phonemes[char] = []
char2phonemes[char].append(labels.index(f'{char} {phoneme}')) char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
return labels, char2phonemes return labels, char2phonemes

View File

@ -17,17 +17,25 @@ PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
class G2PWPinyin(Pinyin): class G2PWPinyin(Pinyin):
def __init__(self, model_dir='G2PWModel/', model_source=None, def __init__(
enable_non_tradional_chinese=True, self,
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs): model_dir="G2PWModel/",
model_source=None,
enable_non_tradional_chinese=True,
v_to_u=False,
neutral_tone_with_five=False,
tone_sandhi=False,
**kwargs,
):
self._g2pw = G2PWOnnxConverter( self._g2pw = G2PWOnnxConverter(
model_dir=model_dir, model_dir=model_dir,
style='pinyin', style="pinyin",
model_source=model_source, model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese, enable_non_tradional_chinese=enable_non_tradional_chinese,
) )
self._converter = Converter( self._converter = Converter(
self._g2pw, v_to_u=v_to_u, self._g2pw,
v_to_u=v_to_u,
neutral_tone_with_five=neutral_tone_with_five, neutral_tone_with_five=neutral_tone_with_five,
tone_sandhi=tone_sandhi, tone_sandhi=tone_sandhi,
) )
@ -37,31 +45,25 @@ class G2PWPinyin(Pinyin):
class Converter(UltimateConverter): class Converter(UltimateConverter):
def __init__(self, g2pw_instance, v_to_u=False, def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
neutral_tone_with_five=False,
tone_sandhi=False, **kwargs):
super(Converter, self).__init__( super(Converter, self).__init__(
v_to_u=v_to_u, v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs
neutral_tone_with_five=neutral_tone_with_five, )
tone_sandhi=tone_sandhi, **kwargs)
self._g2pw = g2pw_instance self._g2pw = g2pw_instance
def convert(self, words, style, heteronym, errors, strict, **kwargs): def convert(self, words, style, heteronym, errors, strict, **kwargs):
pys = [] pys = []
if RE_HANS.match(words): if RE_HANS.match(words):
pys = self._to_pinyin(words, style=style, heteronym=heteronym, pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict)
errors=errors, strict=strict)
post_data = self.post_pinyin(words, heteronym, pys) post_data = self.post_pinyin(words, heteronym, pys)
if post_data is not None: if post_data is not None:
pys = post_data pys = post_data
pys = self.convert_styles( pys = self.convert_styles(pys, words, style, heteronym, errors, strict)
pys, words, style, heteronym, errors, strict)
else: else:
py = self.handle_nopinyin(words, style=style, errors=errors, py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict)
heteronym=heteronym, strict=strict)
if py: if py:
pys.extend(py) pys.extend(py)
@ -73,13 +75,11 @@ class Converter(UltimateConverter):
g2pw_pinyin = self._g2pw(han) g2pw_pinyin = self._g2pw(han)
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑 if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
return super(Converter, self).convert( return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs)
han, Style.TONE, heteronym, errors, strict, **kwargs)
for i, item in enumerate(g2pw_pinyin[0]): for i, item in enumerate(g2pw_pinyin[0]):
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑 if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
py = super(Converter, self).convert( py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs)
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
pinyins.extend(py) pinyins.extend(py)
else: else:
pinyins.append([to_tone(item)]) pinyins.append([to_tone(item)])
@ -104,7 +104,7 @@ def _remove_dup_and_empty(lst_list):
if lst: if lst:
new_lst_list.append(lst) new_lst_list.append(lst)
else: else:
new_lst_list.append(['']) new_lst_list.append([""])
return new_lst_list return new_lst_list
@ -127,17 +127,17 @@ def get_dict():
def read_dict(): def read_dict():
polyphonic_dict = {} polyphonic_dict = {}
with open(PP_DICT_PATH,encoding="utf-8") as f: with open(PP_DICT_PATH, encoding="utf-8") as f:
line = f.readline() line = f.readline()
while line: while line:
key, value_str = line.split(':') key, value_str = line.split(":")
value = eval(value_str.strip()) value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value polyphonic_dict[key.strip()] = value
line = f.readline() line = f.readline()
with open(PP_FIX_DICT_PATH,encoding="utf-8") as f: with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
line = f.readline() line = f.readline()
while line: while line:
key, value_str = line.split(':') key, value_str = line.split(":")
value = eval(value_str.strip()) value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value polyphonic_dict[key.strip()] = value
line = f.readline() line = f.readline()

View File

@ -2,6 +2,7 @@
# This code is modified from https://github.com/GitYCC/g2pW # This code is modified from https://github.com/GitYCC/g2pW
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import json import json
import os import os
@ -14,6 +15,7 @@ from typing import Tuple
import numpy as np import numpy as np
import onnxruntime import onnxruntime
onnxruntime.set_default_logger_severity(3) onnxruntime.set_default_logger_severity(3)
from opencc import OpenCC from opencc import OpenCC
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -26,21 +28,23 @@ from .dataset import prepare_onnx_input
from .utils import load_config from .utils import load_config
from ..zh_normalization.char_convert import tranditional_to_simplified from ..zh_normalization.char_convert import tranditional_to_simplified
model_version = '1.1' model_version = "1.1"
def predict(session, onnx_input: Dict[str, Any], def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]:
labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = [] all_preds = []
all_confidences = [] all_confidences = []
probs = session.run([], { probs = session.run(
"input_ids": onnx_input['input_ids'], [],
"token_type_ids": onnx_input['token_type_ids'], {
"attention_mask": onnx_input['attention_masks'], "input_ids": onnx_input["input_ids"],
"phoneme_mask": onnx_input['phoneme_masks'], "token_type_ids": onnx_input["token_type_ids"],
"char_ids": onnx_input['char_ids'], "attention_mask": onnx_input["attention_masks"],
"position_ids": onnx_input['position_ids'] "phoneme_mask": onnx_input["phoneme_masks"],
})[0] "char_ids": onnx_input["char_ids"],
"position_ids": onnx_input["position_ids"],
},
)[0]
preds = np.argmax(probs, axis=1).tolist() preds = np.argmax(probs, axis=1).tolist()
max_probs = [] max_probs = []
@ -52,17 +56,17 @@ def predict(session, onnx_input: Dict[str, Any],
return all_preds, all_confidences return all_preds, all_confidences
def download_and_decompress(model_dir: str='G2PWModel/'): def download_and_decompress(model_dir: str = "G2PWModel/"):
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir) parent_directory = os.path.dirname(model_dir)
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip") zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1") extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory,"G2PWModel") extract_dir_new = os.path.join(parent_directory, "G2PWModel")
print("Downloading g2pw model...") print("Downloading g2pw model...")
modelscope_url = "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip" modelscope_url = "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
with requests.get(modelscope_url, stream=True) as r: with requests.get(modelscope_url, stream=True) as r:
r.raise_for_status() r.raise_for_status()
with open(zip_dir, 'wb') as f: with open(zip_dir, "wb") as f:
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
if chunk: if chunk:
f.write(chunk) f.write(chunk)
@ -70,17 +74,20 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
print("Extracting g2pw model...") print("Extracting g2pw model...")
with zipfile.ZipFile(zip_dir, "r") as zip_ref: with zipfile.ZipFile(zip_dir, "r") as zip_ref:
zip_ref.extractall(parent_directory) zip_ref.extractall(parent_directory)
os.rename(extract_dir, extract_dir_new) os.rename(extract_dir, extract_dir_new)
return model_dir return model_dir
class G2PWOnnxConverter: class G2PWOnnxConverter:
def __init__(self, def __init__(
model_dir: str='G2PWModel/', self,
style: str='bopomofo', model_dir: str = "G2PWModel/",
model_source: str=None, style: str = "bopomofo",
enable_non_tradional_chinese: bool=False): model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
uncompress_path = download_and_decompress(model_dir) uncompress_path = download_and_decompress(model_dir)
sess_options = onnxruntime.SessionOptions() sess_options = onnxruntime.SessionOptions()
@ -88,41 +95,59 @@ class G2PWOnnxConverter:
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2 sess_options.intra_op_num_threads = 2
try: try:
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
except: except:
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider']) self.session_g2pW = onnxruntime.InferenceSession(
self.config = load_config( os.path.join(uncompress_path, "g2pW.onnx"),
config_path=os.path.join(uncompress_path, 'config.py'), sess_options=sess_options,
use_default=True) providers=["CPUExecutionProvider"],
)
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
self.model_source = model_source if model_source else self.config.model_source self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source) self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
polyphonic_chars_path = os.path.join(uncompress_path, polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
'POLYPHONIC_CHARS.txt') monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(uncompress_path,
'MONOPHONIC_CHARS.txt')
self.polyphonic_chars = [ self.polyphonic_chars = [
line.split('\t') line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
for line in open(polyphonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
] ]
self.non_polyphonic = { self.non_polyphonic = {
'', '', '', '', '', '', '', '', '', '', '', '', '', "",
'', '', '', '', '', '' "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
} }
self.non_monophonic = {'', ''} self.non_monophonic = {"", ""}
self.monophonic_chars = [ self.monophonic_chars = [
line.split('\t') line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n")
for line in open(monophonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
] ]
self.labels, self.char2phonemes = get_char_phoneme_labels( self.labels, self.char2phonemes = (
polyphonic_chars=self.polyphonic_chars get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
) if self.config.use_char_phoneme else get_phoneme_labels( if self.config.use_char_phoneme
polyphonic_chars=self.polyphonic_chars) else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
)
self.chars = sorted(list(self.char2phonemes.keys())) self.chars = sorted(list(self.char2phonemes.keys()))
@ -131,41 +156,29 @@ class G2PWOnnxConverter:
if char in self.polyphonic_chars_new: if char in self.polyphonic_chars_new:
self.polyphonic_chars_new.remove(char) self.polyphonic_chars_new.remove(char)
self.monophonic_chars_dict = { self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
char: phoneme
for char, phoneme in self.monophonic_chars
}
for char in self.non_monophonic: for char in self.non_monophonic:
if char in self.monophonic_chars_dict: if char in self.monophonic_chars_dict:
self.monophonic_chars_dict.pop(char) self.monophonic_chars_dict.pop(char)
self.pos_tags = [ self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
]
with open( with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
os.path.join(uncompress_path,
'bopomofo_to_pinyin_wo_tune_dict.json'),
'r',
encoding='utf-8') as fr:
self.bopomofo_convert_dict = json.load(fr) self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = { self.style_convert_func = {
'bopomofo': lambda x: x, "bopomofo": lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin, "pinyin": self._convert_bopomofo_to_pinyin,
}[style] }[style]
with open( with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
'r',
encoding='utf-8') as fr:
self.char_bopomofo_dict = json.load(fr) self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc: if self.enable_opencc:
self.cc = OpenCC('s2tw') self.cc = OpenCC("s2tw")
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1] tone = bopomofo[-1]
assert tone in '12345' assert tone in "12345"
component = self.bopomofo_convert_dict.get(bopomofo[:-1]) component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component: if component:
return component + tone return component + tone
@ -185,8 +198,7 @@ class G2PWOnnxConverter:
translated_sentences.append(translated_sent) translated_sentences.append(translated_sent)
sentences = translated_sentences sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data( texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
sentences=sentences)
if len(texts) == 0: if len(texts) == 0:
# sentences no polyphonic words # sentences no polyphonic words
return partial_results return partial_results
@ -199,14 +211,12 @@ class G2PWOnnxConverter:
texts=texts, texts=texts,
query_ids=query_ids, query_ids=query_ids,
use_mask=self.config.use_mask, use_mask=self.config.use_mask,
window_size=None) window_size=None,
)
preds, confidences = predict( preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
session=self.session_g2pW,
onnx_input=onnx_input,
labels=self.labels)
if self.config.use_char_phoneme: if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds] preds = [pred.split(" ")[1] for pred in preds]
results = partial_results results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds): for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
@ -214,15 +224,12 @@ class G2PWOnnxConverter:
return results return results
def _prepare_data( def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], [] texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences): for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese # pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent) sent_s = tranditional_to_simplified(sent)
pypinyin_result = pinyin( pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent) partial_result = [None] * len(sent)
for i, char in enumerate(sent): for i, char in enumerate(sent):
if char in self.polyphonic_chars_new: if char in self.polyphonic_chars_new:
@ -230,8 +237,7 @@ class G2PWOnnxConverter:
query_ids.append(i) query_ids.append(i)
sent_ids.append(sent_id) sent_ids.append(sent_id)
elif char in self.monophonic_chars_dict: elif char in self.monophonic_chars_dict:
partial_result[i] = self.style_convert_func( partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
self.monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict: elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0] partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])

View File

@ -15,6 +15,7 @@
Credits Credits
This code is modified from https://github.com/GitYCC/g2pW This code is modified from https://github.com/GitYCC/g2pW
""" """
import os import os
import re import re
@ -24,14 +25,14 @@ def wordize_and_map(text: str):
index_map_from_text_to_word = [] index_map_from_text_to_word = []
index_map_from_word_to_text = [] index_map_from_word_to_text = []
while len(text) > 0: while len(text) > 0:
match_space = re.match(r'^ +', text) match_space = re.match(r"^ +", text)
if match_space: if match_space:
space_str = match_space.group(0) space_str = match_space.group(0)
index_map_from_text_to_word += [None] * len(space_str) index_map_from_text_to_word += [None] * len(space_str)
text = text[len(space_str):] text = text[len(space_str) :]
continue continue
match_en = re.match(r'^[a-zA-Z0-9]+', text) match_en = re.match(r"^[a-zA-Z0-9]+", text)
if match_en: if match_en:
en_word = match_en.group(0) en_word = match_en.group(0)
@ -42,7 +43,7 @@ def wordize_and_map(text: str):
index_map_from_text_to_word += [len(words)] * len(en_word) index_map_from_text_to_word += [len(words)] * len(en_word)
words.append(en_word) words.append(en_word)
text = text[len(en_word):] text = text[len(en_word) :]
else: else:
word_start_pos = len(index_map_from_text_to_word) word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + 1 word_end_pos = word_start_pos + 1
@ -63,15 +64,14 @@ def tokenize_and_map(tokenizer, text: str):
for word, (word_start, word_end) in zip(words, word2text): for word, (word_start, word_end) in zip(words, word2text):
word_tokens = tokenizer.tokenize(word) word_tokens = tokenizer.tokenize(word)
if len(word_tokens) == 0 or word_tokens == ['[UNK]']: if len(word_tokens) == 0 or word_tokens == ["[UNK]"]:
index_map_from_token_to_text.append((word_start, word_end)) index_map_from_token_to_text.append((word_start, word_end))
tokens.append('[UNK]') tokens.append("[UNK]")
else: else:
current_word_start = word_start current_word_start = word_start
for word_token in word_tokens: for word_token in word_tokens:
word_token_len = len(re.sub(r'^##', '', word_token)) word_token_len = len(re.sub(r"^##", "", word_token))
index_map_from_token_to_text.append( index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len))
(current_word_start, current_word_start + word_token_len))
current_word_start = current_word_start + word_token_len current_word_start = current_word_start + word_token_len
tokens.append(word_token) tokens.append(word_token)
@ -85,53 +85,51 @@ def tokenize_and_map(tokenizer, text: str):
def _load_config(config_path: os.PathLike): def _load_config(config_path: os.PathLike):
import importlib.util import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
spec = importlib.util.spec_from_file_location("__init__", config_path)
config = importlib.util.module_from_spec(spec) config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config) spec.loader.exec_module(config)
return config return config
default_config_dict = { default_config_dict = {
'manual_seed': 1313, "manual_seed": 1313,
'model_source': 'bert-base-chinese', "model_source": "bert-base-chinese",
'window_size': 32, "window_size": 32,
'num_workers': 2, "num_workers": 2,
'use_mask': True, "use_mask": True,
'use_char_phoneme': False, "use_char_phoneme": False,
'use_conditional': True, "use_conditional": True,
'param_conditional': { "param_conditional": {
'affect_location': 'softmax', "affect_location": "softmax",
'bias': True, "bias": True,
'char-linear': True, "char-linear": True,
'pos-linear': False, "pos-linear": False,
'char+pos-second': True, "char+pos-second": True,
'char+pos-second_lowrank': False, "char+pos-second_lowrank": False,
'lowrank_size': 0, "lowrank_size": 0,
'char+pos-second_fm': False, "char+pos-second_fm": False,
'fm_size': 0, "fm_size": 0,
'fix_mode': None, "fix_mode": None,
'count_json': 'train.count.json' "count_json": "train.count.json",
}, },
'lr': 5e-5, "lr": 5e-5,
'val_interval': 200, "val_interval": 200,
'num_iter': 10000, "num_iter": 10000,
'use_focal': False, "use_focal": False,
'param_focal': { "param_focal": {"alpha": 0.0, "gamma": 0.7},
'alpha': 0.0, "use_pos": True,
'gamma': 0.7 "param_pos ": {
"weight": 0.1,
"pos_joint_training": True,
"train_pos_path": "train.pos",
"valid_pos_path": "dev.pos",
"test_pos_path": "test.pos",
}, },
'use_pos': True,
'param_pos ': {
'weight': 0.1,
'pos_joint_training': True,
'train_pos_path': 'train.pos',
'valid_pos_path': 'dev.pos',
'test_pos_path': 'test.pos'
}
} }
def load_config(config_path: os.PathLike, use_default: bool=False): def load_config(config_path: os.PathLike, use_default: bool = False):
config = _load_config(config_path) config = _load_config(config_path)
if use_default: if use_default:
for attr, val in default_config_dict.items(): for attr, val in default_config_dict.items():

View File

@ -2,43 +2,51 @@
import re import re
import os import os
import hashlib import hashlib
try: try:
import pyopenjtalk import pyopenjtalk
current_file_path = os.path.dirname(__file__) current_file_path = os.path.dirname(__file__)
# 防止win下无法读取模型 # 防止win下无法读取模型
if os.name == 'nt': if os.name == "nt":
python_dir = os.getcwd() python_dir = os.getcwd()
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8") OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)): if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)):
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()): if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper():
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir)) OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
else: else:
import shutil import shutil
if not os.path.exists('TEMP'):
os.mkdir('TEMP') if not os.path.exists("TEMP"):
os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ja")): if not os.path.exists(os.path.join("TEMP", "ja")):
os.mkdir(os.path.join("TEMP", "ja")) os.mkdir(os.path.join("TEMP", "ja"))
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")): if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic")) shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
shutil.copytree(pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), os.path.join("TEMP", "ja", "open_jtalk_dic"), ) shutil.copytree(
pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"),
os.path.join("TEMP", "ja", "open_jtalk_dic"),
)
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic") OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8") pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)): if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)):
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()): if current_file_path[: len(python_dir)].upper() == python_dir.upper():
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir)) current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
else: else:
if not os.path.exists('TEMP'): if not os.path.exists("TEMP"):
os.mkdir('TEMP') os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ja")): if not os.path.exists(os.path.join("TEMP", "ja")):
os.mkdir(os.path.join("TEMP", "ja")) os.mkdir(os.path.join("TEMP", "ja"))
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")): if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic")) os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
shutil.copyfile(os.path.join(current_file_path, "ja_userdic", "userdict.csv"),os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv")) shutil.copyfile(
os.path.join(current_file_path, "ja_userdic", "userdict.csv"),
os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"),
)
current_file_path = os.path.join("TEMP", "ja") current_file_path = os.path.join("TEMP", "ja")
def get_hash(fp: str) -> str: def get_hash(fp: str) -> str:
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
with open(fp, "rb") as f: with open(fp, "rb") as f:
@ -51,21 +59,26 @@ try:
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5") USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
# 如果没有用户词典就生成一个如果有就检查md5如果不一样就重新生成 # 如果没有用户词典就生成一个如果有就检查md5如果不一样就重新生成
if os.path.exists(USERDIC_CSV_PATH): if os.path.exists(USERDIC_CSV_PATH):
if not os.path.exists(USERDIC_BIN_PATH) or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r",encoding='utf-8').read(): if (
not os.path.exists(USERDIC_BIN_PATH)
or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read()
):
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH) pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
with open(USERDIC_HASH_PATH, "w", encoding='utf-8') as f: with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f:
f.write(get_hash(USERDIC_CSV_PATH)) f.write(get_hash(USERDIC_CSV_PATH))
if os.path.exists(USERDIC_BIN_PATH): if os.path.exists(USERDIC_BIN_PATH):
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH) pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
except Exception: except Exception:
# print(e) # print(e)
import pyopenjtalk import pyopenjtalk
# failed to load user dictionary, ignore. # failed to load user dictionary, ignore.
pass pass
from text.symbols import punctuation from text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks: # Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile( _japanese_characters = re.compile(
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
@ -123,9 +136,9 @@ def post_replace_ph(ph):
def replace_consecutive_punctuation(text): def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation) punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+' pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r'\1', text) result = re.sub(pattern, r"\1", text)
return result return result
@ -152,7 +165,7 @@ def preprocess_jap(text, with_prosody=False):
text += p.split(" ") text += p.split(" ")
if i < len(marks): if i < len(marks):
if marks[i] == " ":# 防止意外的UNK if marks[i] == " ": # 防止意外的UNK
continue continue
text += [marks[i].replace(" ", "")] text += [marks[i].replace(" ", "")]
return text return text
@ -165,6 +178,7 @@ def text_normalize(text):
text = replace_consecutive_punctuation(text) text = replace_consecutive_punctuation(text)
return text return text
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py # Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True): def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
"""Extract phoneme + prosoody symbol sequence from input full-context labels. """Extract phoneme + prosoody symbol sequence from input full-context labels.
@ -241,6 +255,7 @@ def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
return phones return phones
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py # Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
def _numeric_feature_by_regex(regex, s): def _numeric_feature_by_regex(regex, s):
match = re.search(regex, s) match = re.search(regex, s)
@ -248,6 +263,7 @@ def _numeric_feature_by_regex(regex, s):
return -50 return -50
return int(match.group(1)) return int(match.group(1))
def g2p(norm_text, with_prosody=True): def g2p(norm_text, with_prosody=True):
phones = preprocess_jap(norm_text, with_prosody) phones = preprocess_jap(norm_text, with_prosody)
phones = [post_replace_ph(i) for i in phones] phones = [post_replace_ph(i) for i in phones]

View File

@ -9,39 +9,43 @@ import importlib
import os import os
# 防止win下无法读取模型 # 防止win下无法读取模型
if os.name == 'nt': if os.name == "nt":
class win_G2p(G2p): class win_G2p(G2p):
def check_mecab(self): def check_mecab(self):
super().check_mecab() super().check_mecab()
spam_spec = importlib.util.find_spec("eunjeon") spam_spec = importlib.util.find_spec("eunjeon")
non_found = spam_spec is None non_found = spam_spec is None
if non_found: if non_found:
print('you have to install eunjeon. install it...') print("you have to install eunjeon. install it...")
else: else:
installpath = spam_spec.submodule_search_locations[0] installpath = spam_spec.submodule_search_locations[0]
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)): if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
import sys import sys
from eunjeon import Mecab as _Mecab from eunjeon import Mecab as _Mecab
class Mecab(_Mecab): class Mecab(_Mecab):
def get_dicpath(installpath): def get_dicpath(installpath):
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)): if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
import shutil import shutil
python_dir = os.getcwd()
if (installpath[:len(python_dir)].upper() == python_dir.upper()):
dicpath = os.path.join(os.path.relpath(installpath,python_dir),'data','mecabrc')
else:
if not os.path.exists('TEMP'):
os.mkdir('TEMP')
if not os.path.exists(os.path.join('TEMP', 'ko')):
os.mkdir(os.path.join('TEMP', 'ko'))
if os.path.exists(os.path.join('TEMP', 'ko', 'ko_dict')):
shutil.rmtree(os.path.join('TEMP', 'ko', 'ko_dict'))
shutil.copytree(os.path.join(installpath, 'data'), os.path.join('TEMP', 'ko', 'ko_dict')) python_dir = os.getcwd()
dicpath = os.path.join('TEMP', 'ko', 'ko_dict', 'mecabrc') if installpath[: len(python_dir)].upper() == python_dir.upper():
dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc")
else:
if not os.path.exists("TEMP"):
os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ko")):
os.mkdir(os.path.join("TEMP", "ko"))
if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")):
shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict"))
shutil.copytree(
os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict")
)
dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc")
else: else:
dicpath=os.path.abspath(os.path.join(installpath, 'data/mecabrc')) dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc"))
return dicpath return dicpath
def __init__(self, dicpath=get_dicpath(installpath)): def __init__(self, dicpath=get_dicpath(installpath)):
@ -52,97 +56,108 @@ if os.name == 'nt':
G2p = win_G2p G2p = win_G2p
from text.symbols2 import symbols from text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals. # This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' _korean_classifiers = (
"군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통"
)
# List of (hangul, hangul divided) pairs: # List of (hangul, hangul divided) pairs:
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ _hangul_divided = [
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule (re.compile("%s" % x[0]), x[1])
# ('ㄵ', 'ㄴㅈ'), for x in [
# ('ㄶ', 'ㄴㅎ'), # ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
# ('ㄺ', 'ㄹㄱ'), # ('ㄵ', 'ㄴㅈ'),
# ('ㄻ', 'ㄹㅁ'), # ('ㄶ', 'ㄴㅎ'),
# ('ㄼ', 'ㄹㅂ'), # ('ㄺ', 'ㄹㄱ'),
# ('ㄽ', 'ㄹㅅ'), # ('ㄻ', 'ㄹㅁ'),
# ('ㄾ', 'ㄹㅌ'), # ('ㄼ', 'ㄹㅂ'),
# ('ㄿ', 'ㄹㅍ'), # ('ㄽ', 'ㄹㅅ'),
# ('ㅀ', 'ㄹㅎ'), # ('ㄾ', 'ㄹㅌ'),
# ('ㅄ', 'ㅂㅅ'), # ('ㄿ', 'ㄹㅍ'),
('', 'ㅗㅏ'), # ('ㅀ', 'ㄹㅎ'),
('', 'ㅗㅐ'), # ('ㅄ', 'ㅂㅅ'),
('', 'ㅗㅣ'), ("", "ㅗㅏ"),
('', 'ㅜㅓ'), ("", "ㅗㅐ"),
('', 'ㅜㅔ'), ("", "ㅗㅣ"),
('', 'ㅜㅣ'), ("", "ㅜㅓ"),
('', 'ㅡㅣ'), ("", "ㅜㅔ"),
('', 'ㅣㅏ'), ("", "ㅜㅣ"),
('', 'ㅣㅐ'), ("", "ㅡㅣ"),
('', 'ㅣㅓ'), ("", "ㅣㅏ"),
('', 'ㅣㅔ'), ("", "ㅣㅐ"),
('', 'ㅣㅗ'), ("", "ㅣㅓ"),
('', 'ㅣㅜ') ("", "ㅣㅔ"),
]] ("", "ㅣㅗ"),
("", "ㅣㅜ"),
]
]
# List of (Latin alphabet, hangul) pairs: # List of (Latin alphabet, hangul) pairs:
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ _latin_to_hangul = [
('a', '에이'), (re.compile("%s" % x[0], re.IGNORECASE), x[1])
('b', ''), for x in [
('c', ''), ("a", "에이"),
('d', ''), ("b", ""),
('e', ''), ("c", ""),
('f', '에프'), ("d", ""),
('g', ''), ("e", ""),
('h', '에이치'), ("f", "에프"),
('i', '아이'), ("g", ""),
('j', '제이'), ("h", "에이치"),
('k', '케이'), ("i", "아이"),
('l', ''), ("j", "제이"),
('m', ''), ("k", "케이"),
('n', ''), ("l", ""),
('o', ''), ("m", ""),
('p', ''), ("n", ""),
('q', ''), ("o", ""),
('r', '아르'), ("p", ""),
('s', '에스'), ("q", ""),
('t', ''), ("r", "아르"),
('u', ''), ("s", "에스"),
('v', '브이'), ("t", ""),
('w', '더블유'), ("u", ""),
('x', '엑스'), ("v", "브이"),
('y', '와이'), ("w", "더블유"),
('z', '제트') ("x", "엑스"),
]] ("y", "와이"),
("z", "제트"),
]
]
# List of (ipa, lazy ipa) pairs: # List of (ipa, lazy ipa) pairs:
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ _ipa_to_lazy_ipa = [
('t͡ɕ','ʧ'), (re.compile("%s" % x[0], re.IGNORECASE), x[1])
('d͡ʑ','ʥ'), for x in [
('ɲ','n^'), ("t͡ɕ", "ʧ"),
('ɕ','ʃ'), ("d͡ʑ", "ʥ"),
('ʷ','w'), ("ɲ", "n^"),
('ɭ','l`'), ("ɕ", "ʃ"),
('ʎ','ɾ'), ("ʷ", "w"),
('ɣ','ŋ'), ("ɭ", "l`"),
('ɰ','ɯ'), ("ʎ", "ɾ"),
('ʝ','j'), ("ɣ", "ŋ"),
('ʌ','ə'), ("ɰ", "ɯ"),
('ɡ','g'), ("ʝ", "j"),
('\u031a','#'), ("ʌ", "ə"),
('\u0348','='), ("ɡ", "g"),
('\u031e',''), ("\u031a", "#"),
('\u0320',''), ("\u0348", "="),
('\u0339','') ("\u031e", ""),
]] ("\u0320", ""),
("\u0339", ""),
]
]
def fix_g2pk2_error(text): def fix_g2pk2_error(text):
new_text = "" new_text = ""
i = 0 i = 0
while i < len(text) - 4: while i < len(text) - 4:
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == '': if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "":
new_text += text[i:i+3] + ' ' + '' new_text += text[i : i + 3] + " " + ""
i += 5 i += 5
else: else:
new_text += text[i] new_text += text[i]
@ -166,20 +181,20 @@ def divide_hangul(text):
def hangul_number(num, sino=True): def hangul_number(num, sino=True):
'''Reference https://github.com/Kyubyong/g2pK''' """Reference https://github.com/Kyubyong/g2pK"""
num = re.sub(',', '', num) num = re.sub(",", "", num)
if num == '0': if num == "0":
return '' return ""
if not sino and num == '20': if not sino and num == "20":
return '스무' return "스무"
digits = '123456789' digits = "123456789"
names = '일이삼사오육칠팔구' names = "일이삼사오육칠팔구"
digit2name = {d: n for d, n in zip(digits, names)} digit2name = {d: n for d, n in zip(digits, names)}
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉' modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉"
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔' decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔"
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
@ -188,75 +203,75 @@ def hangul_number(num, sino=True):
i = len(num) - i - 1 i = len(num) - i - 1
if sino: if sino:
if i == 0: if i == 0:
name = digit2name.get(digit, '') name = digit2name.get(digit, "")
elif i == 1: elif i == 1:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
name = name.replace('일십', '') name = name.replace("일십", "")
else: else:
if i == 0: if i == 0:
name = digit2mod.get(digit, '') name = digit2mod.get(digit, "")
elif i == 1: elif i == 1:
name = digit2dec.get(digit, '') name = digit2dec.get(digit, "")
if digit == '0': if digit == "0":
if i % 4 == 0: if i % 4 == 0:
last_three = spelledout[-min(3, len(spelledout)):] last_three = spelledout[-min(3, len(spelledout)) :]
if ''.join(last_three) == '': if "".join(last_three) == "":
spelledout.append('') spelledout.append("")
continue continue
else: else:
spelledout.append('') spelledout.append("")
continue continue
if i == 2: if i == 2:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
name = name.replace('일백', '') name = name.replace("일백", "")
elif i == 3: elif i == 3:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
name = name.replace('일천', '') name = name.replace("일천", "")
elif i == 4: elif i == 4:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
name = name.replace('일만', '') name = name.replace("일만", "")
elif i == 5: elif i == 5:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
name = name.replace('일십', '') name = name.replace("일십", "")
elif i == 6: elif i == 6:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
name = name.replace('일백', '') name = name.replace("일백", "")
elif i == 7: elif i == 7:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
name = name.replace('일천', '') name = name.replace("일천", "")
elif i == 8: elif i == 8:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
elif i == 9: elif i == 9:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
elif i == 10: elif i == 10:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
elif i == 11: elif i == 11:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
elif i == 12: elif i == 12:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
elif i == 13: elif i == 13:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
elif i == 14: elif i == 14:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
elif i == 15: elif i == 15:
name = digit2name.get(digit, '') + '' name = digit2name.get(digit, "") + ""
spelledout.append(name) spelledout.append(name)
return ''.join(elem for elem in spelledout) return "".join(elem for elem in spelledout)
def number_to_hangul(text): def number_to_hangul(text):
'''Reference https://github.com/Kyubyong/g2pK''' """Reference https://github.com/Kyubyong/g2pK"""
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text)) tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text))
for token in tokens: for token in tokens:
num, classifier = token num, classifier = token
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
spelledout = hangul_number(num, sino=False) spelledout = hangul_number(num, sino=False)
else: else:
spelledout = hangul_number(num, sino=True) spelledout = hangul_number(num, sino=True)
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}') text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}")
# digit by digit for remaining digits # digit by digit for remaining digits
digits = '0123456789' digits = "0123456789"
names = '영일이삼사오육칠팔구' names = "영일이삼사오육칠팔구"
for d, n in zip(digits, names): for d, n in zip(digits, names):
text = text.replace(d, n) text = text.replace(d, n)
return text return text
@ -265,19 +280,23 @@ def number_to_hangul(text):
def korean_to_lazy_ipa(text): def korean_to_lazy_ipa(text):
text = latin_to_hangul(text) text = latin_to_hangul(text)
text = number_to_hangul(text) text = number_to_hangul(text)
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text) text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text)
for regex, replacement in _ipa_to_lazy_ipa: for regex, replacement in _ipa_to_lazy_ipa:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
_g2p=G2p()
_g2p = G2p()
def korean_to_ipa(text): def korean_to_ipa(text):
text = latin_to_hangul(text) text = latin_to_hangul(text)
text = number_to_hangul(text) text = number_to_hangul(text)
text = _g2p(text) text = _g2p(text)
text = fix_g2pk2_error(text) text = fix_g2pk2_error(text)
text = korean_to_lazy_ipa(text) text = korean_to_lazy_ipa(text)
return text.replace('ʧ','').replace('ʥ','') return text.replace("ʧ", "").replace("ʥ", "")
def post_replace_ph(ph): def post_replace_ph(ph):
rep_map = { rep_map = {
@ -301,12 +320,13 @@ def post_replace_ph(ph):
ph = "" ph = ""
return ph return ph
def g2p(text): def g2p(text):
text = latin_to_hangul(text) text = latin_to_hangul(text)
text = _g2p(text) text = _g2p(text)
text = divide_hangul(text) text = divide_hangul(text)
text = fix_g2pk2_error(text) text = fix_g2pk2_error(text)
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text) text = re.sub(r"([\u3131-\u3163])$", r"\1.", text)
# text = "".join([post_replace_ph(i) for i in text]) # text = "".join([post_replace_ph(i) for i in text])
text = [post_replace_ph(i) for i in text] text = [post_replace_ph(i) for i in text]
return text return text
@ -314,4 +334,4 @@ def g2p(text):
if __name__ == "__main__": if __name__ == "__main__":
text = "안녕하세요" text = "안녕하세요"
print(g2p(text)) print(g2p(text))

View File

@ -1,4 +1,3 @@
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 # punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ["!", "?", "", ",", "."] # @是SP停顿 punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-") punctuation.append("-")

View File

@ -1,4 +1,3 @@
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 # punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ["!", "?", "", ",", "."] # @是SP停顿 punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-") punctuation.append("-")
@ -395,24 +394,404 @@ arpa = {
"SH", "SH",
} }
ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停' ko_symbols = "ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停"
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' # ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
yue_symbols={'Yeot3', 'Yip1', 'Yyu3', 'Yeng4', 'Yut5', 'Yaan5', 'Ym5', 'Yaan6', 'Yang1', 'Yun4', 'Yon2', 'Yui5', 'Yun2', 'Yat3', 'Ye', 'Yeot1', 'Yoeng5', 'Yoek2', 'Yam2', 'Yeon6', 'Yu6', 'Yiu3', 'Yaang6', 'Yp5', 'Yai4', 'Yoek4', 'Yit6', 'Yam5', 'Yoeng6', 'Yg1', 'Yk3', 'Yoe4', 'Yam3', 'Yc', 'Yyu4', 'Yyut1', 'Yiu4', 'Ying3', 'Yip3', 'Yaap3', 'Yau3', 'Yan4', 'Yau1', 'Yap4', 'Yk6', 'Yok3', 'Yai1', 'Yeot6', 'Yan2', 'Yoek6', 'Yt1', 'Yoi1', 'Yit5', 'Yn4', 'Yaau3', 'Yau4', 'Yuk6', 'Ys', 'Yuk', 'Yin6', 'Yung6', 'Ya', 'You', 'Yaai5', 'Yau5', 'Yoi3', 'Yaak3', 'Yaat3', 'Ying2', 'Yok5', 'Yeng2', 'Yyut3', 'Yam1', 'Yip5', 'You1', 'Yam6', 'Yaa5', 'Yi6', 'Yek4', 'Yyu2', 'Yuk5', 'Yaam1', 'Yang2', 'Yai', 'Yiu6', 'Yin4', 'Yok4', 'Yot3', 'Yui2', 'Yeoi5', 'Yyun6', 'Yyu5', 'Yoi5', 'Yeot2', 'Yim4', 'Yeoi2', 'Yaan1', 'Yang6', 'Yong1', 'Yaang4', 'Yung5', 'Yeon1', 'Yin2', 'Ya3', 'Yaang3', 'Yg', 'Yk2', 'Yaau5', 'Yut1', 'Yt5', 'Yip4', 'Yung4', 'Yj', 'Yong3', 'Ya1', 'Yg6', 'Yaau6', 'Yit3', 'Yun3', 'Ying1', 'Yn2', 'Yg4', 'Yl', 'Yp3', 'Yn3', 'Yak1', 'Yang5', 'Yoe6', 'You2', 'Yap2', 'Yak2', 'Yt3', 'Yot5', 'Yim2', 'Yi1', 'Yn6', 'Yaat5', 'Yaam3', 'Yoek5', 'Ye3', 'Yeon4', 'Yaa2', 'Yu3', 'Yim6', 'Ym', 'Yoe3', 'Yaai2', 'Ym2', 'Ya6', 'Yeng6', 'Yik4', 'Yot4', 'Yaai4', 'Yyun3', 'Yu1', 'Yoeng1', 'Yaap2', 'Yuk3', 'Yoek3', 'Yeng5', 'Yeoi1', 'Yiu2', 'Yok1', 'Yo1', 'Yoek1', 'Yoeng2', 'Yeon5', 'Yiu1', 'Yoeng4', 'Yuk2', 'Yat4', 'Yg5', 'Yut4', 'Yan6', 'Yin3', 'Yaa6', 'Yap1', 'Yg2', 'Yoe5', 'Yt4', 'Ya5', 'Yo4', 'Yyu1', 'Yak3', 'Yeon2', 'Yong4', 'Ym1', 'Ye2', 'Yaang5', 'Yoi2', 'Yeng3', 'Yn', 'Yyut4', 'Yau', 'Yaak2', 'Yaan4', 'Yek2', 'Yin1', 'Yi5', 'Yoe2', 'Yei5', 'Yaat6', 'Yak5', 'Yp6', 'Yok6', 'Yei2', 'Yaap1', 'Yyut5', 'Yi4', 'Yim1', 'Yk5', 'Ye4', 'Yok2', 'Yaam6', 'Yat2', 'Yon6', 'Yei3', 'Yyu6', 'Yeot5', 'Yk4', 'Yai6', 'Yd', 'Yg3', 'Yei6', 'Yau2', 'Yok', 'Yau6', 'Yung3', 'Yim5', 'Yut6', 'Yit1', 'Yon3', 'Yat1', 'Yaam2', 'Yyut2', 'Yui6', 'Yt2', 'Yek6', 'Yt', 'Ye6', 'Yang3', 'Ying6', 'Yaau1', 'Yeon3', 'Yng', 'Yh', 'Yang4', 'Ying5', 'Yaap6', 'Yoeng3', 'Yyun4', 'You3', 'Yan5', 'Yat5', 'Yot1', 'Yun1', 'Yi3', 'Yaa1', 'Yaap4', 'You6', 'Yaang2', 'Yaap5', 'Yaa3', 'Yaak6', 'Yeng1', 'Yaak1', 'Yo5', 'Yoi4', 'Yam4', 'Yik1', 'Ye1', 'Yai5', 'Yung1', 'Yp2', 'Yui4', 'Yaak4', 'Yung2', 'Yak4', 'Yaat4', 'Yeoi4', 'Yut2', 'Yin5', 'Yaau4', 'Yap6', 'Yb', 'Yaam4', 'Yw', 'Yut3', 'Yong2', 'Yt6', 'Yaai6', 'Yap5', 'Yik5', 'Yun6', 'Yaam5', 'Yun5', 'Yik3', 'Ya2', 'Yyut6', 'Yon4', 'Yk1', 'Yit4', 'Yak6', 'Yaan2', 'Yuk1', 'Yai2', 'Yik2', 'Yaat2', 'Yo3', 'Ykw', 'Yn5', 'Yaa', 'Ye5', 'Yu4', 'Yei1', 'Yai3', 'Yyun5', 'Yip2', 'Yaau2', 'Yiu5', 'Ym4', 'Yeoi6', 'Yk', 'Ym6', 'Yoe1', 'Yeoi3', 'Yon', 'Yuk4', 'Yaai3', 'Yaa4', 'Yot6', 'Yaang1', 'Yei4', 'Yek1', 'Yo', 'Yp', 'Yo6', 'Yp4', 'Yan3', 'Yoi', 'Yap3', 'Yek3', 'Yim3', 'Yz', 'Yot2', 'Yoi6', 'Yit2', 'Yu5', 'Yaan3', 'Yan1', 'Yon5', 'Yp1', 'Yong5', 'Ygw', 'Yak', 'Yat6', 'Ying4', 'Yu2', 'Yf', 'Ya4', 'Yon1', 'You4', 'Yik6', 'Yui1', 'Yaat1', 'Yeot4', 'Yi2', 'Yaai1', 'Yek5', 'Ym3', 'Yong6', 'You5', 'Yyun1', 'Yn1', 'Yo2', 'Yip6', 'Yui3', 'Yaak5', 'Yyun2'} yue_symbols = {
"Yeot3",
"Yip1",
"Yyu3",
"Yeng4",
"Yut5",
"Yaan5",
"Ym5",
"Yaan6",
"Yang1",
"Yun4",
"Yon2",
"Yui5",
"Yun2",
"Yat3",
"Ye",
"Yeot1",
"Yoeng5",
"Yoek2",
"Yam2",
"Yeon6",
"Yu6",
"Yiu3",
"Yaang6",
"Yp5",
"Yai4",
"Yoek4",
"Yit6",
"Yam5",
"Yoeng6",
"Yg1",
"Yk3",
"Yoe4",
"Yam3",
"Yc",
"Yyu4",
"Yyut1",
"Yiu4",
"Ying3",
"Yip3",
"Yaap3",
"Yau3",
"Yan4",
"Yau1",
"Yap4",
"Yk6",
"Yok3",
"Yai1",
"Yeot6",
"Yan2",
"Yoek6",
"Yt1",
"Yoi1",
"Yit5",
"Yn4",
"Yaau3",
"Yau4",
"Yuk6",
"Ys",
"Yuk",
"Yin6",
"Yung6",
"Ya",
"You",
"Yaai5",
"Yau5",
"Yoi3",
"Yaak3",
"Yaat3",
"Ying2",
"Yok5",
"Yeng2",
"Yyut3",
"Yam1",
"Yip5",
"You1",
"Yam6",
"Yaa5",
"Yi6",
"Yek4",
"Yyu2",
"Yuk5",
"Yaam1",
"Yang2",
"Yai",
"Yiu6",
"Yin4",
"Yok4",
"Yot3",
"Yui2",
"Yeoi5",
"Yyun6",
"Yyu5",
"Yoi5",
"Yeot2",
"Yim4",
"Yeoi2",
"Yaan1",
"Yang6",
"Yong1",
"Yaang4",
"Yung5",
"Yeon1",
"Yin2",
"Ya3",
"Yaang3",
"Yg",
"Yk2",
"Yaau5",
"Yut1",
"Yt5",
"Yip4",
"Yung4",
"Yj",
"Yong3",
"Ya1",
"Yg6",
"Yaau6",
"Yit3",
"Yun3",
"Ying1",
"Yn2",
"Yg4",
"Yl",
"Yp3",
"Yn3",
"Yak1",
"Yang5",
"Yoe6",
"You2",
"Yap2",
"Yak2",
"Yt3",
"Yot5",
"Yim2",
"Yi1",
"Yn6",
"Yaat5",
"Yaam3",
"Yoek5",
"Ye3",
"Yeon4",
"Yaa2",
"Yu3",
"Yim6",
"Ym",
"Yoe3",
"Yaai2",
"Ym2",
"Ya6",
"Yeng6",
"Yik4",
"Yot4",
"Yaai4",
"Yyun3",
"Yu1",
"Yoeng1",
"Yaap2",
"Yuk3",
"Yoek3",
"Yeng5",
"Yeoi1",
"Yiu2",
"Yok1",
"Yo1",
"Yoek1",
"Yoeng2",
"Yeon5",
"Yiu1",
"Yoeng4",
"Yuk2",
"Yat4",
"Yg5",
"Yut4",
"Yan6",
"Yin3",
"Yaa6",
"Yap1",
"Yg2",
"Yoe5",
"Yt4",
"Ya5",
"Yo4",
"Yyu1",
"Yak3",
"Yeon2",
"Yong4",
"Ym1",
"Ye2",
"Yaang5",
"Yoi2",
"Yeng3",
"Yn",
"Yyut4",
"Yau",
"Yaak2",
"Yaan4",
"Yek2",
"Yin1",
"Yi5",
"Yoe2",
"Yei5",
"Yaat6",
"Yak5",
"Yp6",
"Yok6",
"Yei2",
"Yaap1",
"Yyut5",
"Yi4",
"Yim1",
"Yk5",
"Ye4",
"Yok2",
"Yaam6",
"Yat2",
"Yon6",
"Yei3",
"Yyu6",
"Yeot5",
"Yk4",
"Yai6",
"Yd",
"Yg3",
"Yei6",
"Yau2",
"Yok",
"Yau6",
"Yung3",
"Yim5",
"Yut6",
"Yit1",
"Yon3",
"Yat1",
"Yaam2",
"Yyut2",
"Yui6",
"Yt2",
"Yek6",
"Yt",
"Ye6",
"Yang3",
"Ying6",
"Yaau1",
"Yeon3",
"Yng",
"Yh",
"Yang4",
"Ying5",
"Yaap6",
"Yoeng3",
"Yyun4",
"You3",
"Yan5",
"Yat5",
"Yot1",
"Yun1",
"Yi3",
"Yaa1",
"Yaap4",
"You6",
"Yaang2",
"Yaap5",
"Yaa3",
"Yaak6",
"Yeng1",
"Yaak1",
"Yo5",
"Yoi4",
"Yam4",
"Yik1",
"Ye1",
"Yai5",
"Yung1",
"Yp2",
"Yui4",
"Yaak4",
"Yung2",
"Yak4",
"Yaat4",
"Yeoi4",
"Yut2",
"Yin5",
"Yaau4",
"Yap6",
"Yb",
"Yaam4",
"Yw",
"Yut3",
"Yong2",
"Yt6",
"Yaai6",
"Yap5",
"Yik5",
"Yun6",
"Yaam5",
"Yun5",
"Yik3",
"Ya2",
"Yyut6",
"Yon4",
"Yk1",
"Yit4",
"Yak6",
"Yaan2",
"Yuk1",
"Yai2",
"Yik2",
"Yaat2",
"Yo3",
"Ykw",
"Yn5",
"Yaa",
"Ye5",
"Yu4",
"Yei1",
"Yai3",
"Yyun5",
"Yip2",
"Yaau2",
"Yiu5",
"Ym4",
"Yeoi6",
"Yk",
"Ym6",
"Yoe1",
"Yeoi3",
"Yon",
"Yuk4",
"Yaai3",
"Yaa4",
"Yot6",
"Yaang1",
"Yei4",
"Yek1",
"Yo",
"Yp",
"Yo6",
"Yp4",
"Yan3",
"Yoi",
"Yap3",
"Yek3",
"Yim3",
"Yz",
"Yot2",
"Yoi6",
"Yit2",
"Yu5",
"Yaan3",
"Yan1",
"Yon5",
"Yp1",
"Yong5",
"Ygw",
"Yak",
"Yat6",
"Ying4",
"Yu2",
"Yf",
"Ya4",
"Yon1",
"You4",
"Yik6",
"Yui1",
"Yaat1",
"Yeot4",
"Yi2",
"Yaai1",
"Yek5",
"Ym3",
"Yong6",
"You5",
"Yyun1",
"Yn1",
"Yo2",
"Yip6",
"Yui3",
"Yaak5",
"Yyun2",
}
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了 # symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
symbols = sorted(set(symbols)) symbols = sorted(set(symbols))
# print(len(symbols)) # print(len(symbols))
symbols+=["[","]"]##日文新增上升下降调型 symbols += ["[", "]"] ##日文新增上升下降调型
symbols+=sorted(list(ko_symbols)) symbols += sorted(list(ko_symbols))
symbols+=sorted(list(yue_symbols))##新加的yue统一摆在后头#已查过开头加Y后没有重复韩文显然不会重复 symbols += sorted(list(yue_symbols)) ##新加的yue统一摆在后头#已查过开头加Y后没有重复韩文显然不会重复
# print(len(symbols)) # print(len(symbols))
if __name__ == "__main__": if __name__ == "__main__":
print(len(symbols)) print(len(symbols))
''' """
粤语 粤语
732-353=379 732-353=379
韩文+粤语 韩文+粤语
732-322=410 732-322=410
''' """

View File

@ -510,12 +510,7 @@ class ToneSandhi:
# e.g. 走了, 看着, 去过 # e.g. 走了, 看着, 去过
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}: elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
elif ( elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"} and word not in self.must_not_neural_tone_words:
len(word) > 1
and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里 # e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
@ -525,25 +520,18 @@ class ToneSandhi:
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
# 个做量词 # 个做量词
elif ( elif (
ge_idx >= 1 ge_idx >= 1 and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
) or word == "": ) or word == "":
finals[ge_idx] = finals[ge_idx][:-1] + "5" finals[ge_idx] = finals[ge_idx][:-1] + "5"
else: else:
if ( if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word) word_list = self._split_word(word)
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list): for i, word in enumerate(word_list):
# conventional neural in Chinese # conventional neural in Chinese
if ( if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals_list[i][-1] = finals_list[i][-1][:-1] + "5" finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, []) finals = sum(finals_list, [])
return finals return finals
@ -561,9 +549,7 @@ class ToneSandhi:
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零 # "一" in number sequences, e.g. 一零零, 二一零
if word.find("") != -1 and all( if word.find("") != -1 and all([item.isnumeric() for item in word if item != ""]):
[item.isnumeric() for item in word if item != ""]
):
return finals return finals
# "一" between reduplication words shold be yi5, e.g. 看一看 # "一" between reduplication words shold be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]: elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
@ -697,13 +683,10 @@ class ToneSandhi:
return new_seg return new_seg
# the first and the second words are all_tone_three # the first and the second words are all_tone_three
def _merge_continuous_three_tones( def _merge_continuous_three_tones(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
sub_finals_list = [ sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
for (word, pos) in seg
] ]
assert len(sub_finals_list) == len(seg) assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg) merge_last = [False] * len(seg)
@ -715,10 +698,7 @@ class ToneSandhi:
and not merge_last[i - 1] and not merge_last[i - 1]
): ):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if ( if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True merge_last[i] = True
else: else:
@ -732,13 +712,10 @@ class ToneSandhi:
return len(word) == 2 and word[0] == word[1] return len(word) == 2 and word[0] == word[1]
# the last char of first word and the first char of second word is tone_three # the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2( def _merge_continuous_three_tones_2(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
sub_finals_list = [ sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
for (word, pos) in seg
] ]
assert len(sub_finals_list) == len(seg) assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg) merge_last = [False] * len(seg)
@ -750,10 +727,7 @@ class ToneSandhi:
and not merge_last[i - 1] and not merge_last[i - 1]
): ):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if ( if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True merge_last[i] = True
else: else:

File diff suppressed because one or more lines are too long

View File

@ -21,25 +21,29 @@ from .num import verbalize_digit
def _time_num2str(num_string: str) -> str: def _time_num2str(num_string: str) -> str:
"""A special case for verbalizing number in time.""" """A special case for verbalizing number in time."""
result = num2str(num_string.lstrip('0')) result = num2str(num_string.lstrip("0"))
if num_string.startswith('0'): if num_string.startswith("0"):
result = DIGITS['0'] + result result = DIGITS["0"] + result
return result return result
# 时刻表达式 # 时刻表达式
RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])' RE_TIME = re.compile(
r':([0-5][0-9])' r"([0-1]?[0-9]|2[0-3])"
r'(:([0-5][0-9]))?') r":([0-5][0-9])"
r"(:([0-5][0-9]))?"
)
# 时间范围如8:30-12:30 # 时间范围如8:30-12:30
RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])' RE_TIME_RANGE = re.compile(
r':([0-5][0-9])' r"([0-1]?[0-9]|2[0-3])"
r'(:([0-5][0-9]))?' r":([0-5][0-9])"
r'(~|-)' r"(:([0-5][0-9]))?"
r'([0-1]?[0-9]|2[0-3])' r"(~|-)"
r':([0-5][0-9])' r"([0-1]?[0-9]|2[0-3])"
r'(:([0-5][0-9]))?') r":([0-5][0-9])"
r"(:([0-5][0-9]))?"
)
def replace_time(match) -> str: def replace_time(match) -> str:
@ -62,31 +66,33 @@ def replace_time(match) -> str:
second_2 = match.group(9) second_2 = match.group(9)
result = f"{num2str(hour)}" result = f"{num2str(hour)}"
if minute.lstrip('0'): if minute.lstrip("0"):
if int(minute) == 30: if int(minute) == 30:
result += "" result += ""
else: else:
result += f"{_time_num2str(minute)}" result += f"{_time_num2str(minute)}"
if second and second.lstrip('0'): if second and second.lstrip("0"):
result += f"{_time_num2str(second)}" result += f"{_time_num2str(second)}"
if is_range: if is_range:
result += "" result += ""
result += f"{num2str(hour_2)}" result += f"{num2str(hour_2)}"
if minute_2.lstrip('0'): if minute_2.lstrip("0"):
if int(minute) == 30: if int(minute) == 30:
result += "" result += ""
else: else:
result += f"{_time_num2str(minute_2)}" result += f"{_time_num2str(minute_2)}"
if second_2 and second_2.lstrip('0'): if second_2 and second_2.lstrip("0"):
result += f"{_time_num2str(second_2)}" result += f"{_time_num2str(second_2)}"
return result return result
RE_DATE = re.compile(r'(\d{4}|\d{2})年' RE_DATE = re.compile(
r'((0?[1-9]|1[0-2])月)?' r"(\d{4}|\d{2})年"
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?') r"((0?[1-9]|1[0-2])月)?"
r"(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?"
)
def replace_date(match) -> str: def replace_date(match) -> str:
@ -110,8 +116,7 @@ def replace_date(match) -> str:
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期 # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
RE_DATE2 = re.compile( RE_DATE2 = re.compile(r"(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])")
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
def replace_date2(match) -> str: def replace_date2(match) -> str:

View File

@ -18,10 +18,7 @@ from pypinyin.constants import SUPPORT_UCS4
# 全角半角转换 # 全角半角转换
# 英文字符全角 -> 半角映射表 (num: 52) # 英文字符全角 -> 半角映射表 (num: 52)
F2H_ASCII_LETTERS = { F2H_ASCII_LETTERS = {ord(char) + 65248: ord(char) for char in string.ascii_letters}
ord(char) + 65248: ord(char)
for char in string.ascii_letters
}
# 英文字符半角 -> 全角映射表 # 英文字符半角 -> 全角映射表
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()} H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
@ -37,26 +34,29 @@ F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()} H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
# 空格 (num: 1) # 空格 (num: 1)
F2H_SPACE = {'\u3000': ' '} F2H_SPACE = {"\u3000": " "}
H2F_SPACE = {' ': '\u3000'} H2F_SPACE = {" ": "\u3000"}
# 非"有拼音的汉字"的字符串可用于NSW提取 # 非"有拼音的汉字"的字符串可用于NSW提取
if SUPPORT_UCS4: if SUPPORT_UCS4:
RE_NSW = re.compile(r'(?:[^' RE_NSW = re.compile(
r'\u3007' # r"(?:[^"
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] r"\u3007" #
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF] r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F] r"\U00020000-\U0002A6DF" # CJK扩展B:[20000-2A6DF]
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D] r"\U0002A703-\U0002B73F" # CJK扩展C:[2A700-2B73F]
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F] r"\U0002B740-\U0002B81D" # CJK扩展D:[2B740-2B81D]
r'])+') r"\U0002F80A-\U0002FA1F" # CJK兼容扩展:[2F800-2FA1F]
r"])+"
)
else: else:
RE_NSW = re.compile( # pragma: no cover RE_NSW = re.compile( # pragma: no cover
r'(?:[^' r"(?:[^"
r'\u3007' # r"\u3007" #
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
r'])+') r"])+"
)

Some files were not shown because too many files have changed in this diff Show More