mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
ruff format --line-length 120 --target-version py39
This commit is contained in:
parent
a893a4e283
commit
dec3df3282
@ -1,5 +1,8 @@
|
||||
# Download moda ASR related models
|
||||
from modelscope import snapshot_download
|
||||
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
|
||||
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
|
||||
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
|
||||
|
||||
model_dir = snapshot_download(
|
||||
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4"
|
||||
)
|
||||
model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4")
|
||||
model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4")
|
||||
|
@ -4,14 +4,11 @@ import itertools
|
||||
import math
|
||||
import random
|
||||
from random import shuffle
|
||||
from typing import Iterator
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
from typing import Iterator, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
__all__ = [
|
||||
"DistributedBucketSampler",
|
||||
@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError(
|
||||
"Invalid rank {}, rank should be in the interval"
|
||||
" [0, {}]".format(rank, num_replicas - 1)
|
||||
)
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if (
|
||||
self.drop_last and len(self.dataset) % self.num_replicas != 0
|
||||
): # type: ignore[arg-type]
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas)
|
||||
/ self.num_replicas # type: ignore[arg-type]
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(
|
||||
len(self.dataset) / self.num_replicas
|
||||
len(self.dataset) / self.num_replicas,
|
||||
) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
grouped_batch_size = self.batch_size * self.num_replicas
|
||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||
batches = [
|
||||
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
|
||||
for b in range(n_batch)
|
||||
]
|
||||
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
|
||||
shuffle(batches)
|
||||
indices = list(itertools.chain(*batches))
|
||||
else:
|
||||
@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[
|
||||
:padding_size
|
||||
]
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
|
@ -1,9 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from AR.data.dataset import Text2SemanticDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Text2SemanticDataModule(LightningDataModule):
|
||||
@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
|
||||
# pad_val=self.config['data']['pad_val'])
|
||||
|
||||
def train_dataloader(self):
|
||||
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
||||
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
|
||||
batch_size = (
|
||||
self.config["train"]["batch_size"] // 2
|
||||
if self.config["train"].get("if_dpo", False) is True
|
||||
else self.config["train"]["batch_size"]
|
||||
)
|
||||
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
||||
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
||||
return DataLoader(
|
||||
self._train_dataset,
|
||||
|
@ -2,18 +2,16 @@
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
|
||||
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
||||
import traceback
|
||||
import os
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
import traceback
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
version = os.environ.get('version',None)
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
@ -32,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
|
||||
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
padding = (
|
||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
)
|
||||
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
||||
padded_sequences.append(padded_seq)
|
||||
batch = np.stack(padded_sequences)
|
||||
@ -59,12 +55,16 @@ class Text2SemanticDataset(Dataset):
|
||||
super().__init__()
|
||||
|
||||
self.semantic_data = pd.read_csv(
|
||||
semantic_path, delimiter="\t", encoding="utf-8"
|
||||
semantic_path,
|
||||
delimiter="\t",
|
||||
encoding="utf-8",
|
||||
)
|
||||
# get dict
|
||||
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
||||
self.path3 = "%s/3-bert" % (
|
||||
os.path.dirname(phoneme_path)
|
||||
os.path.dirname(
|
||||
phoneme_path,
|
||||
)
|
||||
) # "%s/3-bert"%exp_dir#bert_dir
|
||||
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||
assert os.path.exists(self.path2)
|
||||
@ -125,7 +125,7 @@ class Text2SemanticDataset(Dataset):
|
||||
for i in range(semantic_data_len):
|
||||
# 先依次遍历
|
||||
# get str
|
||||
item_name = self.semantic_data.iloc[i,0]
|
||||
item_name = self.semantic_data.iloc[i, 0]
|
||||
# print(self.phoneme_data)
|
||||
try:
|
||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||
@ -135,7 +135,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
|
||||
semantic_str = self.semantic_data.iloc[i,1]
|
||||
semantic_str = self.semantic_data.iloc[i, 1]
|
||||
# get token list
|
||||
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
||||
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
||||
@ -156,9 +156,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||
if (
|
||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
||||
): ###########2:改为恒定限制为semantic/2.5就行
|
||||
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
|
||||
num_deleted_ps += 1
|
||||
continue
|
||||
# if len(semantic_ids) > 1000:###########3
|
||||
@ -167,9 +165,7 @@ class Text2SemanticDataset(Dataset):
|
||||
|
||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||
|
||||
if (
|
||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
||||
): ##########4#3~25#每秒多少个phone
|
||||
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
|
||||
num_deleted_ps += 1
|
||||
# print(item_name)
|
||||
continue
|
||||
@ -192,12 +188,12 @@ class Text2SemanticDataset(Dataset):
|
||||
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
||||
if num_deleted_bigger > 0:
|
||||
print(
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
|
||||
)
|
||||
if num_deleted_ps > 0:
|
||||
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
||||
print(
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
|
||||
)
|
||||
"""
|
||||
there are 31 semantic datas not in phoneme datas
|
||||
@ -304,7 +300,10 @@ if __name__ == "__main__":
|
||||
|
||||
batch_size = 12
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate,
|
||||
shuffle=False,
|
||||
)
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i % 1000 == 0:
|
||||
|
@ -9,10 +9,12 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
|
||||
|
||||
class Text2SemanticLightningModule(LightningModule):
|
||||
def __init__(self, config, output_dir, is_train=True):
|
||||
super().__init__()
|
||||
@ -24,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu",
|
||||
)["weight"],
|
||||
)
|
||||
)
|
||||
if is_train:
|
||||
@ -36,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
opt = self.optimizers()
|
||||
scheduler = self.lr_schedulers()
|
||||
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
|
||||
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
|
||||
loss, acc = forward(
|
||||
batch["phoneme_ids"],
|
||||
batch["phoneme_ids_len"],
|
||||
@ -114,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
||||
)
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
|
@ -9,6 +9,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
@ -25,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||
)
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu",
|
||||
)["weight"],
|
||||
),
|
||||
)
|
||||
if is_train:
|
||||
self.automatic_optimization = False
|
||||
@ -80,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
||||
)
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
|
@ -2,25 +2,24 @@
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import make_pad_mask, make_pad_mask_left
|
||||
from AR.models.utils import (
|
||||
topk_sampling,
|
||||
sample,
|
||||
dpo_loss,
|
||||
make_reject_y,
|
||||
get_batch_logps
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding
|
||||
from AR.modules.embedding import TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm
|
||||
from AR.modules.transformer import TransformerEncoder
|
||||
from AR.modules.transformer import TransformerEncoderLayer
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import (
|
||||
dpo_loss,
|
||||
get_batch_logps,
|
||||
make_pad_mask,
|
||||
make_pad_mask_left,
|
||||
make_reject_y,
|
||||
sample,
|
||||
topk_sampling,
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@ -34,10 +33,17 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
|
||||
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
||||
# Efficient implementation equivalent to the following:
|
||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
||||
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
def scaled_dot_product_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
if scale is None:
|
||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||
else:
|
||||
@ -57,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
else:
|
||||
attn_mask[attn_mask!=float("-inf")] =0
|
||||
attn_mask[attn_mask==float("-inf")] =1
|
||||
attn_mask[attn_mask != float("-inf")] = 0
|
||||
attn_mask[attn_mask == float("-inf")] = 1
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
|
||||
return attn_weight @ value
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
@ -112,7 +119,11 @@ class T2SBlock:
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
def to_mask(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
@ -121,9 +132,13 @@ class T2SBlock:
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
|
||||
|
||||
def process_prompt(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
@ -147,9 +162,7 @@ class T2SBlock:
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@ -160,7 +173,14 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
|
||||
def decode_next_token(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
@ -174,7 +194,6 @@ class T2SBlock:
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
|
||||
if torch_sdpa:
|
||||
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
||||
else:
|
||||
@ -185,7 +204,11 @@ class T2SBlock:
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w1,
|
||||
self.norm_b1,
|
||||
self.norm_eps1,
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
@ -200,17 +223,19 @@ class T2SBlock:
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
||||
padding_mask : Optional[torch.Tensor]=None,
|
||||
torch_sdpa:bool=True
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
k_cache : List[torch.Tensor] = []
|
||||
v_cache : List[torch.Tensor] = []
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
||||
k_cache.append(k_cache_)
|
||||
@ -218,14 +243,17 @@ class T2STransformer:
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask : torch.Tensor=None,
|
||||
torch_sdpa:bool=True
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
|
||||
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@ -247,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# assert self.EOS == 1024
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_text_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.phoneme_vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
|
||||
self.h = TransformerEncoder(
|
||||
@ -291,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
layer.linear2.bias,
|
||||
)
|
||||
|
||||
block = T2SBlock(
|
||||
@ -307,7 +345,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
layer.norm2.eps,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
@ -385,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||
|
||||
###### DPO #############
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||
x, x_lens, reject_y, reject_y_lens, bert_feature
|
||||
)
|
||||
|
||||
reject_xy_dec, _ = self.h(
|
||||
(reject_xy_pos, None),
|
||||
@ -506,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
y.device
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
||||
|
||||
xy_dec, _ = self.h(
|
||||
(xy_pos, None),
|
||||
mask=xy_attn_mask,
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = topk_sampling(
|
||||
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
||||
)
|
||||
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
print("use early stop num:", early_stop_num)
|
||||
@ -540,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
return y
|
||||
|
||||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
||||
y_mask_int, (0, 1), value=1
|
||||
)
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel_batch_infer(
|
||||
self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
@ -561,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
|
||||
):
|
||||
if prompts is None:
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
||||
return self.infer_panel_naive_batched(
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
early_stop_num=early_stop_num,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
max_len = kwargs.get("max_len",x_lens.max())
|
||||
max_len = kwargs.get("max_len", x_lens.max())
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
@ -572,10 +615,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
||||
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
|
||||
x_item = (
|
||||
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
|
||||
) ### padding left
|
||||
x_list.append(x_item)
|
||||
x:torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
x: torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
@ -592,12 +636,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
|
||||
|
||||
##### create mask #####
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
@ -619,7 +661,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
value=False,
|
||||
)
|
||||
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||
### 上面是错误的,会导致padding的token被"看见"
|
||||
|
||||
@ -637,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||
|
||||
|
||||
# 正确的attn_mask应该是这样的:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
@ -653,25 +694,22 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
idx_list = [None] * y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask,(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
@ -682,13 +720,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||
tokens = torch.argmax(logits, dim=-1)
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in samples[:, 0]) or \
|
||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0]==self.EOS
|
||||
l2 = tokens==self.EOS
|
||||
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0] == self.EOS
|
||||
l2 = tokens == self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
@ -702,13 +739,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
if k_cache is not None:
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
for i, batch_index in enumerate(batch_idx_map):
|
||||
@ -720,7 +756,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
stop = True
|
||||
|
||||
if stop:
|
||||
if y.shape[1]==0:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
@ -728,34 +764,38 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if (None in idx_list):
|
||||
if None in idx_list:
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, [0] * x.shape[0]
|
||||
# print(idx_list)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive_batched(self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
def infer_panel_naive_batched(
|
||||
self,
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
||||
y, idx = self.infer_panel_naive(
|
||||
x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
@ -764,7 +804,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
@ -772,16 +813,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def infer_panel_naive(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@ -826,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.num_head, -1, -1)
|
||||
.view(bsz, self.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
for idx in tqdm(range(1500)):
|
||||
if xy_attn_mask is not None:
|
||||
@ -838,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = None
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
@ -868,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx
|
||||
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
||||
return self.infer_panel_naive(
|
||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||
)
|
||||
|
@ -1,16 +1,13 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import torch
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding
|
||||
from AR.modules.embedding_onnx import TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm
|
||||
from AR.modules.transformer_onnx import TransformerEncoder
|
||||
from AR.modules.transformer_onnx import TransformerEncoderLayer
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
"hidden_dim": 512,
|
||||
@ -25,12 +22,13 @@ default_config = {
|
||||
|
||||
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
||||
|
||||
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens = None,
|
||||
previous_tokens=None,
|
||||
temperature: float = 1.0,
|
||||
top_k = None,
|
||||
top_p = None,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
@ -38,19 +36,27 @@ def logits_to_probs(
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
torch.nn.functional.softmax(
|
||||
sorted_logits,
|
||||
dim=-1,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=0,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@ -66,7 +72,7 @@ def logits_to_probs(
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(
|
||||
probs_sort
|
||||
probs_sort,
|
||||
): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
@ -78,7 +84,9 @@ def sample(
|
||||
**sampling_kwargs,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
@ -98,8 +106,18 @@ class OnnxEncoder(nn.Module):
|
||||
|
||||
|
||||
class T2SFirstStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@ -113,8 +131,8 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
def forward(self, x, prompt):
|
||||
y = prompt
|
||||
x_example = x[:,:,0] * 0.0
|
||||
#N, 1, 512
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
# N, 1, 512
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": None,
|
||||
@ -131,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
||||
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
|
||||
torch.ones_like(
|
||||
y_example.transpose(0, 1),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
y_attn_mask = y_attn_mask > 0
|
||||
|
||||
@ -144,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["k"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
cache["v"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
@ -159,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
|
||||
class T2SStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@ -183,14 +221,18 @@ class T2SStageDecoder(nn.Module):
|
||||
}
|
||||
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
[
|
||||
cache["y_emb"],
|
||||
self.ar_audio_embedding(y[:, -1:]),
|
||||
],
|
||||
1,
|
||||
)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
||||
xy_pos = y_pos[:, -1:]
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
|
||||
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
||||
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
||||
@ -249,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def init_onnx(self):
|
||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
self.stage_decoder = T2SStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
|
||||
def forward(self, x, prompts, bert_feature):
|
||||
early_stop_num = self.early_stop_num
|
||||
@ -285,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y = prompts
|
||||
prefix_len = y.shape[1]
|
||||
x_len = x.shape[1]
|
||||
x_example = x[:,:,0] * 0.0
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
||||
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
||||
|
||||
@ -302,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if cache["first_infer"] == 1:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
else:
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
)
|
||||
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
if cache["first_infer"] == 1:
|
||||
@ -316,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0), value=False
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
else:
|
||||
|
@ -1,8 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
@ -67,14 +69,18 @@ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
|
||||
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
|
||||
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
|
||||
|
||||
return expaned_lengths<0
|
||||
return expaned_lengths < 0
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||
logits,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
filter_value=-float("Inf"),
|
||||
min_tokens_to_keep=1,
|
||||
):
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
@ -105,9 +111,7 @@ def top_k_top_p_filtering(
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
@ -156,19 +160,21 @@ def logits_to_probs(
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=1,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@ -176,7 +182,7 @@ def logits_to_probs(
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
@ -188,18 +194,19 @@ def sample(
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
**sampling_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||
)
|
||||
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
||||
|
||||
def dpo_loss(
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
reference_chosen_logps: torch.FloatTensor,
|
||||
reference_rejected_logps: torch.FloatTensor,
|
||||
beta: float,
|
||||
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
reference_free: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||
|
||||
@ -214,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
||||
|
||||
return losses.mean(), chosen_rewards, rejected_rewards
|
||||
|
||||
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
|
||||
def get_batch_logps(
|
||||
logits_target: torch.FloatTensor,
|
||||
logits_reject: torch.FloatTensor,
|
||||
labels_target: torch.LongTensor,
|
||||
labels_reject: torch.LongTensor,
|
||||
average_log_prob: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
|
||||
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
|
||||
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
|
||||
per_token_logps_target = torch.gather(
|
||||
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
per_token_logps_reject = torch.gather(
|
||||
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
|
||||
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
||||
|
||||
|
||||
def make_reject_y(y_o, y_lens):
|
||||
def repeat_P(y):
|
||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||
pre = y[:range_idx[0]]
|
||||
shf = y[range_idx[1]:]
|
||||
range_text = y[range_idx[0]:range_idx[1]]
|
||||
pre = y[: range_idx[0]]
|
||||
shf = y[range_idx[1] :]
|
||||
range_text = y[range_idx[0] : range_idx[1]]
|
||||
new_y = torch.cat([pre, range_text, range_text, shf])
|
||||
return new_y
|
||||
|
||||
def lost_P(y):
|
||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||
pre = y[:range_idx[0]]
|
||||
shf = y[range_idx[1]:]
|
||||
range_text = y[range_idx[0]:range_idx[1]]
|
||||
pre = y[: range_idx[0]]
|
||||
shf = y[range_idx[1] :]
|
||||
range_text = y[range_idx[0] : range_idx[1]]
|
||||
new_y = torch.cat([pre, shf])
|
||||
return new_y
|
||||
|
||||
bs = len(y_lens)
|
||||
reject_y = []
|
||||
reject_y_lens = []
|
||||
for b in range(bs):
|
||||
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
|
||||
process_item_idx = torch.randint(0, 1, size=(1,))[0]
|
||||
if process_item_idx == 0:
|
||||
new_y = repeat_P(y_o[b])
|
||||
reject_y.append(new_y)
|
||||
reject_y_lens.append(len(new_y))
|
||||
elif process_item_idx==1:
|
||||
elif process_item_idx == 1:
|
||||
new_y = lost_P(y_o[b])
|
||||
reject_y.append(new_y)
|
||||
reject_y_lens.append(len(new_y))
|
||||
@ -256,7 +276,7 @@ def make_reject_y(y_o, y_lens):
|
||||
pad_length = max_length - reject_y_lens[b]
|
||||
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
||||
|
||||
reject_y = torch.stack(reject_y, dim = 0)
|
||||
reject_y = torch.stack(reject_y, dim=0)
|
||||
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
||||
|
||||
return reject_y, reject_y_lens
|
||||
|
@ -1,17 +1,14 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Module
|
||||
from torch.nn.init import constant_
|
||||
from torch.nn.init import xavier_normal_
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
@ -73,6 +70,7 @@ class MultiheadAttention(Module):
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
@ -104,9 +102,7 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
@ -117,31 +113,32 @@ class MultiheadAttention(Module):
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs),
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
)
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
@ -150,7 +147,10 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
@ -164,7 +164,10 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
@ -261,28 +264,26 @@ class MultiheadAttention(Module):
|
||||
if key_padding_mask is not None:
|
||||
_kpm_dtype = key_padding_mask.dtype
|
||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
key_padding_mask
|
||||
key_padding_mask,
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||
why_not_fast_path = ""
|
||||
if not is_batched:
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif (
|
||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||||
):
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
)
|
||||
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
)
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif not self.batch_first:
|
||||
@ -300,9 +301,7 @@ class MultiheadAttention(Module):
|
||||
elif attn_mask is not None:
|
||||
why_not_fast_path = "attn_mask was not None"
|
||||
elif query.is_nested and key_padding_mask is not None:
|
||||
why_not_fast_path = (
|
||||
"key_padding_mask is not supported with NestedTensor input"
|
||||
)
|
||||
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
|
||||
elif self.num_heads % 2 == 1:
|
||||
why_not_fast_path = "num_heads is odd"
|
||||
elif torch.is_autocast_enabled():
|
||||
@ -322,20 +321,10 @@ class MultiheadAttention(Module):
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not all(
|
||||
[
|
||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args
|
||||
]
|
||||
):
|
||||
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(
|
||||
[x is not None and x.requires_grad for x in tensor_args]
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
|
||||
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
|
||||
if not why_not_fast_path:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
@ -350,11 +339,7 @@ class MultiheadAttention(Module):
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
1
|
||||
if key_padding_mask is not None
|
||||
else 0
|
||||
if attn_mask is not None
|
||||
else None,
|
||||
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
|
@ -1,13 +1,10 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Module
|
||||
from torch.nn.init import constant_
|
||||
from torch.nn.init import xavier_normal_
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
@ -46,9 +43,7 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
@ -59,18 +54,30 @@ class MultiheadAttention(Module):
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, embed_dim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, self.kdim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, self.vdim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(3 * embed_dim, embed_dim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
@ -78,13 +85,11 @@ class MultiheadAttention(Module):
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
torch.empty(3 * embed_dim, **factory_kwargs),
|
||||
)
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
self._reset_parameters()
|
||||
else:
|
||||
@ -92,7 +97,10 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
@ -106,7 +114,10 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
|
@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.embedding_dim)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
|
||||
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
||||
|
||||
def extend_pe(self, x):
|
||||
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
|
||||
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
|
||||
scpe = (position * self.div_term).unsqueeze(0)
|
||||
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
||||
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
||||
|
@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
lr = self.end_lr
|
||||
|
||||
else:
|
||||
decay_ratio = (self._current_step - self.warmup_steps) / (
|
||||
self.total_steps - self.warmup_steps
|
||||
)
|
||||
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
||||
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
||||
raise RuntimeError(
|
||||
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
|
||||
)
|
||||
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
||||
|
||||
@ -70,7 +66,13 @@ if __name__ == "__main__":
|
||||
m = nn.Linear(10, 10)
|
||||
opt = Adam(m.parameters(), lr=1e-4)
|
||||
s = WarmupCosineLRSchedule(
|
||||
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
|
||||
opt,
|
||||
1e-6,
|
||||
2e-4,
|
||||
1e-6,
|
||||
warmup_steps=2000,
|
||||
total_steps=20000,
|
||||
current_step=0,
|
||||
)
|
||||
lrs = []
|
||||
for i in range(25000):
|
||||
|
@ -16,8 +16,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
|
||||
group_params_names: name for each parameter in group,
|
||||
which is List[str].
|
||||
"""
|
||||
batches = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches_names = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
|
||||
assert len(param_group) == len(group_params_names)
|
||||
for p, named_p in zip(param_group, group_params_names):
|
||||
@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
|
||||
batches_names[key].append(named_p)
|
||||
|
||||
batches_names_keys = list(batches_names.keys())
|
||||
sorted_idx = sorted(
|
||||
range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||
batches_names = [
|
||||
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
||||
]
|
||||
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
|
||||
stacked_params_dict = dict()
|
||||
@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([
|
||||
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
|
||||
])
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
tuples.append((p_stacked, state, batch_names))
|
||||
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@ -177,12 +167,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True, ):
|
||||
|
||||
show_dominant_parameters=True,
|
||||
):
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
"which is a List[List[str]]. Each List[str] is for a group"
|
||||
"and each str is for a parameter")
|
||||
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
|
||||
)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
param_max_rms=param_max_rms,
|
||||
scalar_max=scalar_max,
|
||||
size_update_period=size_update_period,
|
||||
clipping_update_period=clipping_update_period, )
|
||||
clipping_update_period=clipping_update_period,
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
assert len(self.param_groups) == len(parameters_names)
|
||||
@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups,
|
||||
self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"],
|
||||
group_params_names) as batches:
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
|
||||
if (len(batches[0][1]) ==
|
||||
0): # if len(first state) == 0: not yet initialized
|
||||
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||
clipping_scale = 1
|
||||
else:
|
||||
clipping_scale = self._get_clipping_scale(group, batches)
|
||||
@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# grad is not going to be None, we handled that when creating the batches.
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (
|
||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(size_update_period,
|
||||
*param_rms.shape, **kwargs)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
def _get_clipping_scale(self,
|
||||
group: dict,
|
||||
tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||
) -> float:
|
||||
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
by this amount before applying the rest of the update.
|
||||
@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state, param_names) in tuples:
|
||||
for p, state, param_names in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients")
|
||||
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||
|
||||
tot_norm = tot_sumsq.sqrt()
|
||||
if "model_norms" not in first_state:
|
||||
first_state["model_norms"] = torch.zeros(
|
||||
clipping_update_period, device=p.device)
|
||||
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
|
||||
for n in range(0, 5):
|
||||
index = min(
|
||||
clipping_update_period - 1,
|
||||
(clipping_update_period // 4) * n, )
|
||||
(clipping_update_period // 4) * n,
|
||||
)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
first_state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (first_state["num_clipped"] * 100.0 /
|
||||
clipping_update_period
|
||||
if "num_clipped" in first_state else 0.0)
|
||||
percent_clipped = (
|
||||
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
|
||||
)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||
logging.info(
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
)
|
||||
|
||||
if step < clipping_update_period:
|
||||
@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except KeyError:
|
||||
logging.info(
|
||||
"Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?"
|
||||
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
|
||||
)
|
||||
return 1.0
|
||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(
|
||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||
)
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
return ans
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||
tot_sumsq: Tensor):
|
||||
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
|
||||
@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
for p, state, batch_param_names in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch_rms_orig = torch.ones(p.shape[0])
|
||||
else:
|
||||
batch_rms_orig = state["param_rms"]
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
|
||||
dim=list(range(1, batch_grad.ndim)))
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
|
||||
|
||||
for name, sumsq_orig, rms, grad in zip(batch_param_names,
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names,
|
||||
batch_sumsq_orig,
|
||||
batch_rms_orig, batch_grad):
|
||||
|
||||
batch_rms_orig,
|
||||
batch_grad,
|
||||
):
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
assert torch.isclose(
|
||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||
torch.tensor(1.0), )
|
||||
torch.tensor(1.0),
|
||||
)
|
||||
sorted_by_proportion = {
|
||||
k: v
|
||||
for k, v in sorted(
|
||||
all_sumsq_orig.items(),
|
||||
key=lambda item: item[1][0],
|
||||
reverse=True, )
|
||||
reverse=True,
|
||||
)
|
||||
}
|
||||
dominant_param_name = next(iter(sorted_by_proportion))
|
||||
(dominant_proportion, dominant_sumsq, dominant_rms,
|
||||
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
(
|
||||
dominant_proportion,
|
||||
dominant_sumsq,
|
||||
dominant_rms,
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||
)
|
||||
|
||||
def _step_one_batch(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
clipping_scale: float):
|
||||
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
|
||||
"""
|
||||
Do the step for one parameter, which is actually going to be a batch of
|
||||
`real` parameters, with dim 0 as the batch dim.
|
||||
@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True)
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_((p**2)
|
||||
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
||||
.sqrt())
|
||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
state["step"] = step + 1
|
||||
|
||||
def _size_update(self,
|
||||
def _size_update(
|
||||
self,
|
||||
group: dict,
|
||||
scale_grads: Tensor,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
state: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||
@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2**size_update_period
|
||||
|
||||
scale_exp_avg_sq = state[
|
||||
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
|
||||
alpha=1 - beta2_corr,
|
||||
) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = (-size_lr * (bias_correction2**0.5) *
|
||||
scale_grads.sum(dim=0) / denom)
|
||||
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||
|
||||
is_too_small = param_rms < param_min_rms
|
||||
is_too_large = param_rms > param_max_rms
|
||||
@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||
|
||||
this_step = state["step"] - (state["zero_step"]
|
||||
if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2**(this_step + 1)
|
||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2**(state["step"] + 1)
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||
|
||||
delta = state["delta"]
|
||||
|
@ -24,18 +24,18 @@ def multi_head_attention_forward_patched(
|
||||
dropout_p: float,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
training = True,
|
||||
key_padding_mask = None,
|
||||
need_weights = True,
|
||||
attn_mask = None,
|
||||
use_separate_proj_weight = False,
|
||||
q_proj_weight = None,
|
||||
k_proj_weight = None,
|
||||
v_proj_weight = None,
|
||||
static_k = None,
|
||||
static_v = None,
|
||||
average_attn_weights = True,
|
||||
is_causal = False,
|
||||
training=True,
|
||||
key_padding_mask=None,
|
||||
need_weights=True,
|
||||
attn_mask=None,
|
||||
use_separate_proj_weight=False,
|
||||
q_proj_weight=None,
|
||||
k_proj_weight=None,
|
||||
v_proj_weight=None,
|
||||
static_k=None,
|
||||
static_v=None,
|
||||
average_attn_weights=True,
|
||||
is_causal=False,
|
||||
cache=None,
|
||||
):
|
||||
r"""
|
||||
@ -155,9 +155,7 @@ def multi_head_attention_forward_patched(
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
is_batched = _mha_shape_check(
|
||||
query, key, value, key_padding_mask, attn_mask, num_heads
|
||||
)
|
||||
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
||||
|
||||
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||
# is batched, run the computation and before returning squeeze the
|
||||
@ -210,45 +208,33 @@ def multi_head_attention_forward_patched(
|
||||
# longer causal.
|
||||
is_causal = False
|
||||
|
||||
assert (
|
||||
embed_dim == embed_dim_to_check
|
||||
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
assert embed_dim == embed_dim_to_check, (
|
||||
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
)
|
||||
if isinstance(embed_dim, torch.Tensor):
|
||||
# embed_dim can be a tensor when JIT tracing
|
||||
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
||||
else:
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
if use_separate_proj_weight:
|
||||
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||
assert (
|
||||
key.shape[:2] == value.shape[:2]
|
||||
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
assert key.shape[:2] == value.shape[:2], (
|
||||
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
key.shape == value.shape
|
||||
), f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
|
||||
#
|
||||
# compute in-projection
|
||||
#
|
||||
if not use_separate_proj_weight:
|
||||
assert (
|
||||
in_proj_weight is not None
|
||||
), "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||
else:
|
||||
assert (
|
||||
q_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert (
|
||||
k_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert (
|
||||
v_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
if in_proj_bias is None:
|
||||
b_q = b_k = b_v = None
|
||||
else:
|
||||
@ -311,9 +297,7 @@ def multi_head_attention_forward_patched(
|
||||
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"attn_mask's dimension {attn_mask.dim()} is not supported"
|
||||
)
|
||||
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||
|
||||
# add bias along batch dimension (currently second)
|
||||
if bias_k is not None and bias_v is not None:
|
||||
@ -337,34 +321,26 @@ def multi_head_attention_forward_patched(
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_k.size(0) == bsz * num_heads
|
||||
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
assert (
|
||||
static_k.size(2) == head_dim
|
||||
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
assert static_k.size(0) == bsz * num_heads, (
|
||||
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
)
|
||||
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_v.size(0) == bsz * num_heads
|
||||
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
assert (
|
||||
static_v.size(2) == head_dim
|
||||
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
assert static_v.size(0) == bsz * num_heads, (
|
||||
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
)
|
||||
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
v = static_v
|
||||
|
||||
# add zero attention along batch dimension (now first)
|
||||
if add_zero_attn:
|
||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||
k = torch.cat(
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
||||
)
|
||||
v = torch.cat(
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
||||
)
|
||||
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
@ -380,9 +356,7 @@ def multi_head_attention_forward_patched(
|
||||
src_len,
|
||||
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||
key_padding_mask = (
|
||||
key_padding_mask.view(bsz, 1, 1, src_len)
|
||||
.expand(-1, num_heads, -1, -1)
|
||||
.reshape(bsz * num_heads, 1, src_len)
|
||||
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||
)
|
||||
if attn_mask is None:
|
||||
attn_mask = key_padding_mask
|
||||
@ -401,14 +375,10 @@ def multi_head_attention_forward_patched(
|
||||
B, Nt, E = q.shape
|
||||
q_scaled = q / math.sqrt(E)
|
||||
|
||||
assert not (
|
||||
is_causal and attn_mask is None
|
||||
), "FIXME: is_causal not implemented for need_weights"
|
||||
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_output_weights = torch.baddbmm(
|
||||
attn_mask, q_scaled, k.transpose(-2, -1)
|
||||
)
|
||||
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||
else:
|
||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
@ -417,9 +387,7 @@ def multi_head_attention_forward_patched(
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
)
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
@ -448,13 +416,9 @@ def multi_head_attention_forward_patched(
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
@ -3,6 +3,7 @@ from torch.nn.functional import (
|
||||
_canonical_mask,
|
||||
)
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
@ -31,7 +32,6 @@ def multi_head_attention_forward_patched(
|
||||
is_causal: bool = False,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
|
||||
# set up shape vars
|
||||
_, _, embed_dim = query.shape
|
||||
attn_mask = _canonical_mask(
|
||||
@ -77,12 +77,8 @@ def multi_head_attention_forward_patched(
|
||||
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
)
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||
|
||||
|
@ -58,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
# floors), should be expectation-preserving.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
||||
deriv
|
||||
)
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
@ -150,13 +148,9 @@ def _compute_scale_factor(
|
||||
else:
|
||||
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
||||
# x_abs)_mean , min_abs.
|
||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
|
||||
|
||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
||||
|
||||
return below_threshold - above_threshold
|
||||
|
||||
@ -178,18 +172,16 @@ def _compute_sign_factor(
|
||||
else:
|
||||
# 0 if proportion_positive >= min_positive, else can be
|
||||
# as large as max_factor.
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
||||
).clamp_(min=0, max=max_factor)
|
||||
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
|
||||
|
||||
if max_positive == 1.0:
|
||||
factor2 = 0.0
|
||||
else:
|
||||
# 0 if self.proportion_positive <= max_positive, else can be
|
||||
# as large as -max_factor.
|
||||
factor2 = (
|
||||
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
||||
).clamp_(min=0, max=max_factor)
|
||||
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
sign_factor = factor1 - factor2
|
||||
# require min_positive != 0 or max_positive != 1:
|
||||
assert not isinstance(sign_factor, float)
|
||||
@ -317,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
def BalancedDoubleSwish(
|
||||
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
||||
) -> nn.Sequential:
|
||||
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
|
||||
"""
|
||||
ActivationBalancer -> DoubleSwish
|
||||
"""
|
||||
balancer = ActivationBalancer(
|
||||
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
||||
)
|
||||
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
|
||||
return nn.Sequential(
|
||||
balancer,
|
||||
DoubleSwish(),
|
||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
assert embedding is None
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"{normalized_shape}, eps={eps}, "
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
)
|
||||
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
|
||||
|
||||
class IdentityNorm(nn.Module):
|
||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = transformer_encoder(src)
|
||||
"""
|
||||
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = linear1_feedforward_cls(
|
||||
d_model, dim_feedforward, **factory_kwargs
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = linear2_feedforward_cls(
|
||||
dim_feedforward, d_model, **factory_kwargs
|
||||
)
|
||||
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
_skpm_dtype = src_key_padding_mask.dtype
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
src_key_padding_mask
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
|
||||
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||
|
||||
if self.norm_first:
|
||||
x = x + self._sa_block(
|
||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
assert embedding is None
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"{normalized_shape}, eps={eps}, "
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
)
|
||||
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
|
||||
|
||||
class IdentityNorm(nn.Module):
|
||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = transformer_encoder(src)
|
||||
"""
|
||||
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
__constants__ = ["batch_first", "norm_first"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
linear2_cls=linear2_self_attention_cls,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(
|
||||
d_model, dim_feedforward, **factory_kwargs
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = linear2_feedforward_cls(
|
||||
dim_feedforward, d_model, **factory_kwargs
|
||||
)
|
||||
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||
self.norm_first = norm_first
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
@ -30,9 +30,7 @@ class GruutPhonemizer:
|
||||
"«": "«",
|
||||
"»": "»",
|
||||
}
|
||||
self._punctuation_regexp: str = (
|
||||
rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
)
|
||||
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
|
||||
def _normalize_punctuation(self, text: str) -> str:
|
||||
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
||||
@ -53,13 +51,8 @@ class GruutPhonemizer:
|
||||
|
||||
def phonemize(self, text: str, espeak: bool = False) -> str:
|
||||
text_to_phonemize: str = self._normalize_punctuation(text)
|
||||
sents: List[Sentence] = [
|
||||
sent
|
||||
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
|
||||
]
|
||||
words: List[str] = [
|
||||
self._convert_punctuation(word) for word in itertools.chain(*sents)
|
||||
]
|
||||
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
||||
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
||||
return " ".join(words)
|
||||
|
||||
def transform(self, phonemes):
|
||||
|
@ -3,7 +3,9 @@
|
||||
PAD = "_"
|
||||
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
||||
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
IPA_LETTERS = (
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
||||
SPACE_ID = SYMBOLS.index(" ")
|
||||
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
||||
|
@ -2,12 +2,12 @@ import re
|
||||
|
||||
|
||||
def str2bool(str):
|
||||
return True if str.lower() == 'true' else False
|
||||
return True if str.lower() == "true" else False
|
||||
|
||||
|
||||
def get_newest_ckpt(string_list):
|
||||
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
||||
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
|
||||
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
|
||||
|
||||
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
||||
extracted_info = []
|
||||
@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
|
||||
step = int(match.group(2))
|
||||
extracted_info.append((epoch, step, string))
|
||||
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
||||
sorted_info = sorted(
|
||||
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||
# 获取最新的 ckpt 文件名
|
||||
newest_ckpt = sorted_info[0][2]
|
||||
return newest_ckpt
|
||||
@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
|
||||
# 文本存在且不为空时 return True
|
||||
def check_txt_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
text = file.readline().strip()
|
||||
assert text.strip() != ''
|
||||
assert text.strip() != ""
|
||||
return text
|
||||
except Exception:
|
||||
return False
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Initialize modules for espnet2 neural networks."""
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
|
||||
|
||||
|
||||
def write_args(args, path):
|
||||
args_dict = dict(
|
||||
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
|
||||
)
|
||||
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
|
||||
with open(path, "a") as args_file:
|
||||
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
||||
args_file.write(
|
||||
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
|
||||
)
|
||||
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
|
||||
args_file.write("==> Cmd:\n")
|
||||
args_file.write(str(sys.argv))
|
||||
args_file.write("\n==> args:\n")
|
||||
|
@ -23,9 +23,7 @@ class Snake(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
|
@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||
activation_results = anti_alias_activation_cuda.forward(
|
||||
inputs, up_ftr, down_ftr, alpha, beta
|
||||
)
|
||||
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
|
||||
|
||||
return activation_results
|
||||
|
||||
@ -61,17 +59,11 @@ class Activation1d(nn.Module):
|
||||
if self.act.__class__.__name__ == "Snake":
|
||||
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||
else:
|
||||
beta = (
|
||||
self.act.beta.data
|
||||
) # Snakebeta uses different params for alpha and beta
|
||||
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
|
||||
alpha = self.act.alpha.data
|
||||
if (
|
||||
not self.act.alpha_logscale
|
||||
): # Exp baked into cuda kernel, cancel it out with a log
|
||||
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
|
||||
alpha = torch.log(alpha)
|
||||
beta = torch.log(beta)
|
||||
|
||||
x = FusedAntiAliasActivation.apply(
|
||||
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
||||
)
|
||||
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
|
||||
return x
|
||||
|
@ -58,17 +58,13 @@ def load():
|
||||
srcpath / "anti_alias_activation.cpp",
|
||||
srcpath / "anti_alias_activation_cuda.cu",
|
||||
]
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
||||
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
||||
)
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
|
||||
|
||||
return anti_alias_activation_cuda
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
|
@ -27,9 +27,7 @@ else:
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff, half_width, kernel_size
|
||||
): # return filter [1,1,kernel_size]
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
|
||||
|
@ -11,18 +11,12 @@ class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
return x
|
||||
@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
|
@ -87,9 +87,7 @@ class AMPBlock1(torch.nn.Module):
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(
|
||||
self.convs2
|
||||
) # Total number of conv layers
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@ -283,9 +265,7 @@ class BigVGAN(
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# Pre-conv
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
|
||||
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
if h.resblock == "1":
|
||||
@ -293,9 +273,7 @@ class BigVGAN(
|
||||
elif h.resblock == "2":
|
||||
resblock_class = AMPBlock2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
||||
)
|
||||
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||
|
||||
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
@ -320,22 +298,14 @@ class BigVGAN(
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(
|
||||
resblock_class(h, ch, k, d, activation=h.activation)
|
||||
)
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# Post-conv
|
||||
activation_post = (
|
||||
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snake"
|
||||
else (
|
||||
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snakebeta"
|
||||
else None
|
||||
)
|
||||
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
|
||||
)
|
||||
if activation_post is None:
|
||||
raise NotImplementedError(
|
||||
@ -346,9 +316,7 @@ class BigVGAN(
|
||||
|
||||
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
||||
)
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||
|
||||
# Weight initialization
|
||||
for i in range(len(self.ups)):
|
||||
|
@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
self.mpd_reshapes = h.mpd_reshapes
|
||||
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
|
||||
for rs in self.mpd_reshapes
|
||||
]
|
||||
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.resolution = resolution
|
||||
assert (
|
||||
len(self.resolution) == 3
|
||||
), f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
self.lrelu_slope = 0.1
|
||||
|
||||
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||
print(
|
||||
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
|
||||
)
|
||||
norm_f = (
|
||||
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
)
|
||||
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
|
||||
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
self.d_mult = cfg.discriminator_channel_mult
|
||||
if hasattr(cfg, "mrd_channel_mult"):
|
||||
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||
@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert (
|
||||
len(self.resolutions) == 3
|
||||
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
||||
assert len(self.resolutions) == 3, (
|
||||
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
||||
),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
# Remove DC offset
|
||||
@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
|
||||
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorCQT(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
|
||||
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
out_chs = min(
|
||||
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
||||
)
|
||||
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding(
|
||||
(self.kernel_size[0], self.kernel_size[0])
|
||||
),
|
||||
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
# Multi-scale params to loop over
|
||||
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
|
||||
"cqtd_bins_per_octaves", [24, 36, 48]
|
||||
)
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
|
@ -35,9 +35,7 @@ def inference(a, h):
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the ground truth audio and resample if necessary
|
||||
wav, sr = librosa.load(
|
||||
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
|
||||
)
|
||||
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
|
||||
wav = torch.FloatTensor(wav).to(device)
|
||||
# Compute mel spectrogram from the ground truth audio
|
||||
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||
@ -48,9 +46,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
@ -61,9 +61,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
@ -122,9 +122,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
B, C, T = wav.shape
|
||||
|
||||
if match_stride:
|
||||
assert (
|
||||
hop_length == window_length // 4
|
||||
), "For match_stride, hop must equal n_fft // 4"
|
||||
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||
pad = (window_length - hop_length) // 2
|
||||
else:
|
||||
@ -154,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
magnitude = torch.abs(stft)
|
||||
|
||||
nf = magnitude.shape[2]
|
||||
mel_basis = self.get_mel_filters(
|
||||
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
||||
)
|
||||
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||
@ -181,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
"""
|
||||
|
||||
loss = 0.0
|
||||
for n_mels, fmin, fmax, s in zip(
|
||||
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
||||
):
|
||||
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||
kwargs = {
|
||||
"n_mels": n_mels,
|
||||
"fmin": fmin,
|
||||
@ -196,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||
x_logmels = torch.log(
|
||||
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(
|
||||
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
|
||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
@ -210,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
|
||||
# Loss functions
|
||||
def feature_loss(
|
||||
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
|
||||
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
@ -225,7 +212,6 @@ def feature_loss(
|
||||
def discriminator_loss(
|
||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
@ -242,7 +228,6 @@ def discriminator_loss(
|
||||
def generator_loss(
|
||||
disc_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
|
@ -86,9 +86,7 @@ def mel_spectrogram(
|
||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||
|
||||
@ -96,9 +94,7 @@ def mel_spectrogram(
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_size) // 2
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (padding, padding), mode="reflect"
|
||||
).squeeze(1)
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
@ -150,17 +146,13 @@ def get_dataset_filelist(a):
|
||||
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first training file: {training_files[0]}")
|
||||
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first validation file: {validation_files[0]}")
|
||||
|
||||
@ -171,9 +163,7 @@ def get_dataset_filelist(a):
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(
|
||||
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
|
||||
)
|
||||
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
|
||||
list_unseen_validation_files.append(unseen_validation_files)
|
||||
|
||||
return training_files, validation_files, list_unseen_validation_files
|
||||
@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
print("[INFO] checking dataset integrity...")
|
||||
for i in tqdm(range(len(self.audio_files))):
|
||||
assert os.path.exists(
|
||||
self.audio_files[i]
|
||||
), f"{self.audio_files[i]} not found"
|
||||
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
try:
|
||||
filename = self.audio_files[index]
|
||||
|
||||
@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
# Obtain randomized audio chunk
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
# Adjust segment size to crop if the source sr is different
|
||||
target_segment_size = math.ceil(
|
||||
self.segment_size
|
||||
* (source_sampling_rate / self.sampling_rate)
|
||||
)
|
||||
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
|
||||
else:
|
||||
target_segment_size = self.segment_size
|
||||
|
||||
# Compute upper bound index for the random chunk
|
||||
random_chunk_upper_bound = max(
|
||||
0, audio.shape[0] - target_segment_size
|
||||
)
|
||||
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
|
||||
|
||||
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||
if audio.shape[0] >= target_segment_size:
|
||||
@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||
assert (
|
||||
source_sampling_rate == self.sampling_rate
|
||||
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
assert source_sampling_rate == self.sampling_rate, (
|
||||
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
)
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start
|
||||
* self.hop_size : (mel_start + frames_per_seg)
|
||||
* self.hop_size,
|
||||
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
|
||||
]
|
||||
|
||||
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||
)
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
|
||||
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||
mel_loss = mel_spectrogram(
|
||||
@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Shape sanity checks
|
||||
assert (
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size
|
||||
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), (
|
||||
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
if self.fine_tuning:
|
||||
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
|
||||
)
|
||||
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
|
||||
return self[random.randrange(len(self))]
|
||||
|
||||
def __len__(self):
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations.Snake cuda vs. torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations, Snake CUDA vs. Torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
@ -57,7 +54,6 @@ def test_anti_alias_activation():
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from alias_free_activation.cuda import load
|
||||
|
||||
|
@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
|
||||
|
||||
|
||||
def get_mel(x, h):
|
||||
return mel_spectrogram(
|
||||
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
|
||||
)
|
||||
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test script to check CUDA kernel correctness."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
||||
parser.add_argument(
|
||||
"--checkpoint_file",
|
||||
type=str,
|
||||
@ -109,9 +105,7 @@ if __name__ == "__main__":
|
||||
diff += test_result.mean(dim=-1).item()
|
||||
|
||||
diff /= num_sample
|
||||
if (
|
||||
diff <= 2e-3
|
||||
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
print(
|
||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
||||
f"\n > mean_difference={diff}"
|
||||
@ -175,8 +169,8 @@ if __name__ == "__main__":
|
||||
audio_second = audio_length_total / h.sampling_rate
|
||||
khz_original = audio_length_total / toc_total_original / 1000
|
||||
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
||||
|
||||
# Print results
|
||||
print(
|
||||
|
@ -77,24 +77,18 @@ def train(rank, a, h):
|
||||
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||
print(
|
||||
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||
mrd = MultiBandDiscriminator(h).to(device)
|
||||
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||
print(
|
||||
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||
else: # Fallback to original MRD in BigVGAN-v1
|
||||
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||
|
||||
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||
if h.get("use_multiscale_melloss", False):
|
||||
print(
|
||||
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
|
||||
)
|
||||
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
|
||||
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||
sampling_rate=h.sampling_rate
|
||||
) # NOTE: accepts waveform as input
|
||||
@ -114,9 +108,7 @@ def train(rank, a, h):
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||
cp_g = scan_checkpoint(
|
||||
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
|
||||
)
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
|
||||
cp_do = scan_checkpoint(
|
||||
a.checkpoint_path,
|
||||
prefix="do_",
|
||||
@ -143,9 +135,7 @@ def train(rank, a, h):
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(
|
||||
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
|
||||
)
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(
|
||||
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||
h.learning_rate,
|
||||
@ -156,12 +146,8 @@ def train(rank, a, h):
|
||||
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
# Define training and validation datasets
|
||||
|
||||
@ -169,9 +155,7 @@ def train(rank, a, h):
|
||||
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||
Example: trained on LibriTTS, validate on VCTK
|
||||
"""
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = (
|
||||
get_dataset_filelist(a)
|
||||
)
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
trainset = MelDataset(
|
||||
training_filelist,
|
||||
@ -324,33 +308,26 @@ def train(rank, a, h):
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
|
||||
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
|
||||
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
||||
|
||||
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||
if (
|
||||
"nonspeech" not in mode
|
||||
): # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
|
||||
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
# Resample to 16000 for pesq
|
||||
y_16k = pesq_resampler(y)
|
||||
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
y_g_hat_int_16k = (
|
||||
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
)
|
||||
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||
|
||||
# MRSTFT calculation
|
||||
min_t = min(y.size(-1), y_g_hat.size(-1))
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
|
||||
|
||||
# Log audio and figures to Tensorboard
|
||||
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||
if steps >= 0:
|
||||
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y[0],
|
||||
os.path.join(
|
||||
@ -373,9 +350,7 @@ def train(rank, a, h):
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y_g_hat[0, 0],
|
||||
os.path.join(
|
||||
@ -487,15 +462,11 @@ def train(rank, a, h):
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
||||
y_df_hat_r, y_df_hat_g
|
||||
)
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MRD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
||||
y_ds_hat_r, y_ds_hat_g
|
||||
)
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
@ -505,17 +476,11 @@ def train(rank, a, h):
|
||||
# Whether to freeze D for initial training steps
|
||||
if steps >= a.freeze_step:
|
||||
loss_disc_all.backward()
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
|
||||
mpd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
|
||||
mrd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
|
||||
optim_d.step()
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
|
||||
grad_norm_mpd = 0.0
|
||||
grad_norm_mrd = 0.0
|
||||
|
||||
@ -523,9 +488,7 @@ def train(rank, a, h):
|
||||
optim_g.zero_grad()
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
lambda_melloss = h.get(
|
||||
"lambda_melloss", 45.0
|
||||
) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||
@ -542,27 +505,19 @@ def train(rank, a, h):
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
|
||||
if steps >= a.freeze_step:
|
||||
loss_gen_all = (
|
||||
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(
|
||||
generator.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to stdout
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
|
||||
print(
|
||||
f"Steps: {steps:d}, "
|
||||
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||
@ -577,11 +532,7 @@ def train(rank, a, h):
|
||||
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{
|
||||
"generator": (
|
||||
generator.module if h.num_gpus > 1 else generator
|
||||
).state_dict()
|
||||
},
|
||||
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
|
||||
)
|
||||
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||
save_checkpoint(
|
||||
@ -598,9 +549,7 @@ def train(rank, a, h):
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to tensorboard
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||
@ -612,12 +561,8 @@ def train(rank, a, h):
|
||||
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||
|
||||
# Validation
|
||||
@ -660,9 +605,7 @@ def train(rank, a, h):
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
|
||||
)
|
||||
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
|
||||
|
||||
|
||||
def main():
|
||||
@ -674,12 +617,8 @@ def main():
|
||||
|
||||
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||
parser.add_argument(
|
||||
"--input_training_file", default="tests/LibriTTS/train-full.txt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
|
||||
)
|
||||
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
|
||||
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
|
||||
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_wavs_dir",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,9 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
@ -19,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'])
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-"])
|
||||
|
||||
def get_first(text:str) -> str:
|
||||
|
||||
def get_first(text: str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
|
||||
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
@ -39,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if (len(text) > 0):
|
||||
if len(text) > 0:
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
@ -47,28 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
|
||||
print(f'############ {i18n("切分文本")} ############')
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
print(f'############ {i18n("提取文本Bert特征")} ############')
|
||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
if phones is None or norm_text=="":
|
||||
if phones is None or norm_text == "":
|
||||
continue
|
||||
res={
|
||||
res = {
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
@ -76,11 +74,11 @@ class TextPreprocessor:
|
||||
result.append(res)
|
||||
return result
|
||||
|
||||
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
|
||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||
text = text.strip("\n")
|
||||
if len(text) == 0:
|
||||
return []
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
if text[0] not in splits and len(get_first(text)) < 4:
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
@ -96,18 +94,18 @@ class TextPreprocessor:
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
if len(text.strip()) == 0:
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
if text[-1] not in splits:
|
||||
text += "。" if lang != "en" else "."
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
if (len(text) > 510):
|
||||
if len(text) > 510:
|
||||
texts.extend(split_big_text(text))
|
||||
else:
|
||||
texts.append(text)
|
||||
@ -116,10 +114,12 @@ class TextPreprocessor:
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1"
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
|
||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
with self.bert_lock:
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
# language = language.replace("all_","")
|
||||
@ -127,17 +127,17 @@ class TextPreprocessor:
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "all_zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
if re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"zh",version)
|
||||
return self.get_phones_and_bert(formattext, "zh", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"yue",version)
|
||||
return self.get_phones_and_bert(formattext, "yue", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
@ -145,8 +145,8 @@ class TextPreprocessor:
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
@ -179,15 +179,14 @@ class TextPreprocessor:
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text,language,version,final=True)
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||
|
||||
return phones, bert, norm_text
|
||||
|
||||
|
||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
@ -202,14 +201,14 @@ class TextPreprocessor:
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text:str, language:str, version:str="v2"):
|
||||
language = language.replace("all_","")
|
||||
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||
language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
|
||||
language=language.replace("all_","")
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
else:
|
||||
@ -220,10 +219,9 @@ class TextPreprocessor:
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
def filter_text(self,texts):
|
||||
_text=[]
|
||||
if all(text in [None, " ", "\n",""] for text in texts):
|
||||
def filter_text(self, texts):
|
||||
_text = []
|
||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||
raise ValueError(i18n("请输入有效文本"))
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
@ -232,9 +230,8 @@ class TextPreprocessor:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(self,text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
def replace_consecutive_punctuation(self, text):
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
@ -1,40 +1,56 @@
|
||||
|
||||
|
||||
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
||||
METHODS = dict()
|
||||
|
||||
def get_method(name:str)->Callable:
|
||||
|
||||
def get_method(name: str) -> Callable:
|
||||
method = METHODS.get(name, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Method {name} not found")
|
||||
return method
|
||||
|
||||
def get_method_names()->list:
|
||||
|
||||
def get_method_names() -> list:
|
||||
return list(METHODS.keys())
|
||||
|
||||
|
||||
def register_method(name):
|
||||
def decorator(func):
|
||||
METHODS[name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
splits = {
|
||||
",",
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
",",
|
||||
".",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
":",
|
||||
":",
|
||||
"—",
|
||||
"…",
|
||||
}
|
||||
|
||||
|
||||
def split_big_text(text, max_len=510):
|
||||
# 定义全角和半角标点符号
|
||||
punctuation = "".join(splits)
|
||||
|
||||
# 切割文本
|
||||
segments = re.split('([' + punctuation + '])', text)
|
||||
segments = re.split("([" + punctuation + "])", text)
|
||||
|
||||
# 初始化结果列表和当前片段
|
||||
result = []
|
||||
current_segment = ''
|
||||
current_segment = ""
|
||||
|
||||
for segment in segments:
|
||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||
@ -51,7 +67,6 @@ def split_big_text(text, max_len=510):
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in splits:
|
||||
@ -90,7 +105,7 @@ def cut1(inp):
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
@ -123,6 +138,7 @@ def cut2(inp):
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# 按中文句号。切
|
||||
@register_method("cut3")
|
||||
def cut3(inp):
|
||||
@ -131,26 +147,28 @@ def cut3(inp):
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
#按英文句号.切
|
||||
|
||||
# 按英文句号.切
|
||||
@register_method("cut4")
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
|
||||
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# 按标点符号切
|
||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||
@register_method("cut5")
|
||||
def cut5(inp):
|
||||
inp = inp.strip("\n")
|
||||
punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
|
||||
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||
mergeitems = []
|
||||
items = []
|
||||
|
||||
for i, char in enumerate(inp):
|
||||
if char in punds:
|
||||
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
items.append(char)
|
||||
else:
|
||||
items.append(char)
|
||||
@ -166,8 +184,6 @@ def cut5(inp):
|
||||
return "\n".join(opt)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
method = get_method("cut5")
|
||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||
|
||||
|
@ -1,6 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.insert(0, now_dir)
|
||||
from text.g2pw import G2PWPinyin
|
||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
||||
|
||||
g2pw = G2PWPinyin(
|
||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
|
@ -32,6 +32,7 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||||
@ -40,6 +41,7 @@ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
t2s_model = t2s_model.eval()
|
||||
return t2s_model
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
@ -56,39 +58,35 @@ def logits_to_probs(
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
# Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sample(
|
||||
logits,
|
||||
@ -99,15 +97,20 @@ def sample(
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
|
||||
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
|
||||
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@ -157,6 +160,7 @@ class DictToAttrRecursive(dict):
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
@ -170,6 +174,7 @@ class T2SMLP:
|
||||
x = F.linear(x, self.w2, self.b2)
|
||||
return x
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
@ -205,7 +210,7 @@ class T2SBlock:
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
@ -214,7 +219,7 @@ class T2SBlock:
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
|
||||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
@ -231,22 +236,20 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device!= padding_mask.device:
|
||||
if self.false.device != padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
||||
x_item = x[i, idx, :].unsqueeze(0)
|
||||
attn_item = attn[i, idx, :].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
@ -255,13 +258,11 @@ class T2SBlock:
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x[i, idx, :] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@ -272,7 +273,7 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
|
||||
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
@ -288,14 +289,12 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@ -306,37 +305,35 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
|
||||
k_cache : list[torch.Tensor] = []
|
||||
v_cache : list[torch.Tensor] = []
|
||||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||
k_cache: list[torch.Tensor] = []
|
||||
v_cache: list[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
||||
k_cache.append(k_cache_)
|
||||
v_cache.append(v_cache_)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
k_cache: list[torch.Tensor],
|
||||
v_cache: list[torch.Tensor]):
|
||||
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path)
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
@ -347,7 +344,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
@ -359,12 +356,13 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
def __init__(self,raw_t2s:Text2SemanticLightningModule):
|
||||
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
||||
super(T2SModel, self).__init__()
|
||||
self.model_dim = raw_t2s.model.model_dim
|
||||
self.embedding_dim = raw_t2s.model.embedding_dim
|
||||
@ -373,7 +371,7 @@ class T2SModel(nn.Module):
|
||||
self.vocab_size = raw_t2s.model.vocab_size
|
||||
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
||||
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
||||
self.EOS:int = int(raw_t2s.model.EOS)
|
||||
self.EOS: int = int(raw_t2s.model.EOS)
|
||||
self.norm_first = raw_t2s.model.norm_first
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
self.hz = 50
|
||||
@ -392,12 +390,7 @@ class T2SModel(nn.Module):
|
||||
|
||||
for i in range(self.num_layers):
|
||||
layer = h.layers[i]
|
||||
t2smlp = T2SMLP(
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
)
|
||||
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
|
||||
|
||||
block = T2SBlock(
|
||||
self.num_head,
|
||||
@ -412,7 +405,7 @@ class T2SModel(nn.Module):
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
layer.norm2.eps,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
@ -426,19 +419,26 @@ class T2SModel(nn.Module):
|
||||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
|
||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor,top_k:LongTensor):
|
||||
def forward(
|
||||
self,
|
||||
prompts: LongTensor,
|
||||
ref_seq: LongTensor,
|
||||
text_seq: LongTensor,
|
||||
ref_bert: torch.Tensor,
|
||||
text_bert: torch.Tensor,
|
||||
top_k: LongTensor,
|
||||
):
|
||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
|
||||
x = self.ar_text_embedding(all_phoneme_ids)
|
||||
x = x + self.bert_proj(bert.transpose(1, 2))
|
||||
x:torch.Tensor = self.ar_text_position(x)
|
||||
x: torch.Tensor = self.ar_text_position(x)
|
||||
|
||||
early_stop_num = self.early_stop_num
|
||||
|
||||
|
||||
#[1,N,512] [1,N]
|
||||
# [1,N,512] [1,N]
|
||||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
y = prompts
|
||||
# x_example = x[:,:,0] * 0.0
|
||||
@ -464,11 +464,13 @@ class T2SModel(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.num_head, -1, -1)
|
||||
.view(bsz, self.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
idx = 0
|
||||
top_k = int(top_k)
|
||||
@ -480,17 +482,19 @@ class T2SModel(nn.Module):
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
stop = False
|
||||
# for idx in range(1, 50):
|
||||
for idx in range(1, 1500):
|
||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
@ -507,20 +511,22 @@ class T2SModel(nn.Module):
|
||||
break
|
||||
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
y[0,-1] = 0
|
||||
y[0, -1] = 0
|
||||
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
bert_path = os.environ.get(
|
||||
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
)
|
||||
|
||||
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
||||
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
||||
phone_level_feature = []
|
||||
for i in range(word2ph.shape[0]):
|
||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||
@ -529,39 +535,45 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
||||
# [sum(word2ph), 1024]
|
||||
return phone_level_feature
|
||||
|
||||
|
||||
class MyBertModel(torch.nn.Module):
|
||||
def __init__(self, bert_model):
|
||||
super(MyBertModel, self).__init__()
|
||||
self.bert = bert_model
|
||||
|
||||
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
|
||||
):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||||
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
||||
return build_phone_level_feature(res, word2ph)
|
||||
|
||||
|
||||
class SSLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ssl = cnhubert.get_model().model
|
||||
|
||||
def forward(self, ref_audio_16k)-> torch.Tensor:
|
||||
def forward(self, ref_audio_16k) -> torch.Tensor:
|
||||
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
return ssl_content
|
||||
|
||||
|
||||
class ExportSSLModel(torch.nn.Module):
|
||||
def __init__(self,ssl:SSLModel):
|
||||
def __init__(self, ssl: SSLModel):
|
||||
super().__init__()
|
||||
self.ssl = ssl
|
||||
|
||||
def forward(self, ref_audio:torch.Tensor):
|
||||
def forward(self, ref_audio: torch.Tensor):
|
||||
return self.ssl(ref_audio)
|
||||
|
||||
@torch.jit.export
|
||||
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
audio = resamplex(ref_audio,src_sr,dst_sr).float()
|
||||
def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
||||
return audio
|
||||
|
||||
|
||||
def export_bert(output_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
|
||||
@ -569,33 +581,34 @@ def export_bert(output_path):
|
||||
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
||||
word2ph = []
|
||||
for c in text:
|
||||
if c in [',','。',':','?',",",".","?"]:
|
||||
if c in [",", "。", ":", "?", ",", ".", "?"]:
|
||||
word2ph.append(1)
|
||||
else:
|
||||
word2ph.append(2)
|
||||
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int()
|
||||
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
|
||||
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
||||
my_bert_model = MyBertModel(bert_model)
|
||||
|
||||
ref_bert_inputs = {
|
||||
'input_ids': ref_bert_inputs['input_ids'],
|
||||
'attention_mask': ref_bert_inputs['attention_mask'],
|
||||
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
||||
'word2ph': ref_bert_inputs['word2ph']
|
||||
"input_ids": ref_bert_inputs["input_ids"],
|
||||
"attention_mask": ref_bert_inputs["attention_mask"],
|
||||
"token_type_ids": ref_bert_inputs["token_type_ids"],
|
||||
"word2ph": ref_bert_inputs["word2ph"],
|
||||
}
|
||||
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
|
||||
|
||||
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
|
||||
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
||||
output_path = os.path.join(output_path, "bert_model.pt")
|
||||
my_bert_model.save(output_path)
|
||||
print('#### exported bert ####')
|
||||
print("#### exported bert ####")
|
||||
|
||||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
|
||||
|
||||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
print(f"目录已创建: {output_path}")
|
||||
@ -605,21 +618,22 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
if export_bert_and_ssl:
|
||||
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
|
||||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||
torch.jit.script(s).save(ssl_path)
|
||||
print('#### exported ssl ####')
|
||||
print("#### exported ssl ####")
|
||||
export_bert(output_path)
|
||||
else:
|
||||
s = ExportSSLModel(ssl)
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
|
||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
|
||||
@ -633,18 +647,18 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print('#### get_raw_t2s_model ####')
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m).to(device)
|
||||
print('#### script t2s_m ####')
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s,vits).to(device)
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
||||
gpt_sovits.eval()
|
||||
|
||||
ref_audio_sr = s.resample(ref_audio,16000,32000).to(device)
|
||||
ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
|
||||
|
||||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||
@ -657,32 +671,28 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert,
|
||||
top_k))
|
||||
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
)
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
print('#### exported gpt_sovits ####')
|
||||
print("#### exported gpt_sovits ####")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()#.to(ref_audio.device)
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device)
|
||||
return ref_audio_16k,ref_audio_sr
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
||||
return ref_audio_16k, ref_audio_sr
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float()
|
||||
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
||||
|
||||
|
||||
class GPT_SoVITS(nn.Module):
|
||||
def __init__(self, t2s:T2SModel,vits:VitsModel):
|
||||
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
||||
super().__init__()
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
@ -709,12 +719,11 @@ class GPT_SoVITS(nn.Module):
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
gpt_path = args.gpt_model
|
||||
@ -725,7 +734,7 @@ def test():
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||
# bert = MyBertModel(bert_model)
|
||||
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
||||
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
|
||||
|
||||
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
||||
# raw_t2s = get_raw_t2s_model(dict_s1)
|
||||
@ -739,78 +748,79 @@ def test():
|
||||
|
||||
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
||||
# ssl.eval()
|
||||
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda')
|
||||
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
|
||||
|
||||
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda')
|
||||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
|
||||
|
||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||
ref_seq = torch.LongTensor([ref_seq_id])
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
||||
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
||||
|
||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2')
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
|
||||
|
||||
test_bert = tokenizer(text, return_tensors="pt")
|
||||
word2ph = []
|
||||
for c in text:
|
||||
if c in [',','。',':','?',"?",",","."]:
|
||||
if c in [",", "。", ":", "?", "?", ",", "."]:
|
||||
word2ph.append(1)
|
||||
else:
|
||||
word2ph.append(2)
|
||||
test_bert['word2ph'] = torch.Tensor(word2ph).int()
|
||||
test_bert["word2ph"] = torch.Tensor(word2ph).int()
|
||||
|
||||
test_bert = my_bert(
|
||||
test_bert['input_ids'].to('cuda'),
|
||||
test_bert['attention_mask'].to('cuda'),
|
||||
test_bert['token_type_ids'].to('cuda'),
|
||||
test_bert['word2ph'].to('cuda')
|
||||
test_bert["input_ids"].to("cuda"),
|
||||
test_bert["attention_mask"].to("cuda"),
|
||||
test_bert["token_type_ids"].to("cuda"),
|
||||
test_bert["word2ph"].to("cuda"),
|
||||
)
|
||||
|
||||
text_seq = torch.LongTensor([text_seq_id])
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
|
||||
print('text_bert:',text_bert.shape,text_bert)
|
||||
print('test_bert:',test_bert.shape,test_bert)
|
||||
print(torch.allclose(text_bert.to('cuda'),test_bert))
|
||||
print("text_bert:", text_bert.shape, text_bert)
|
||||
print("test_bert:", test_bert.shape, test_bert)
|
||||
print(torch.allclose(text_bert.to("cuda"), test_bert))
|
||||
|
||||
print('text_seq:',text_seq.shape)
|
||||
print('text_bert:',text_bert.shape,text_bert.type())
|
||||
print("text_seq:", text_seq.shape)
|
||||
print("text_bert:", text_bert.shape, text_bert.type())
|
||||
|
||||
#[1,N]
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda')
|
||||
print('ref_audio:',ref_audio.shape)
|
||||
# [1,N]
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
|
||||
print("ref_audio:", ref_audio.shape)
|
||||
|
||||
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
||||
print('start ssl')
|
||||
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
||||
print("start ssl")
|
||||
ssl_content = ssl(ref_audio)
|
||||
|
||||
print('start gpt_sovits:')
|
||||
print('ssl_content:',ssl_content.shape)
|
||||
print('ref_audio_sr:',ref_audio_sr.shape)
|
||||
print('ref_seq:',ref_seq.shape)
|
||||
ref_seq=ref_seq.to('cuda')
|
||||
print('text_seq:',text_seq.shape)
|
||||
text_seq=text_seq.to('cuda')
|
||||
print('ref_bert:',ref_bert.shape)
|
||||
ref_bert=ref_bert.to('cuda')
|
||||
print('text_bert:',text_bert.shape)
|
||||
text_bert=text_bert.to('cuda')
|
||||
print("start gpt_sovits:")
|
||||
print("ssl_content:", ssl_content.shape)
|
||||
print("ref_audio_sr:", ref_audio_sr.shape)
|
||||
print("ref_seq:", ref_seq.shape)
|
||||
ref_seq = ref_seq.to("cuda")
|
||||
print("text_seq:", text_seq.shape)
|
||||
text_seq = text_seq.to("cuda")
|
||||
print("ref_bert:", ref_bert.shape)
|
||||
ref_bert = ref_bert.to("cuda")
|
||||
print("text_bert:", text_bert.shape)
|
||||
text_bert = text_bert.to("cuda")
|
||||
|
||||
top_k = torch.LongTensor([5]).to('cuda')
|
||||
top_k = torch.LongTensor([5]).to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||
print('start write wav')
|
||||
print("start write wav")
|
||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
import text
|
||||
import json
|
||||
|
||||
def export_symbel(version='v2'):
|
||||
if version=='v1':
|
||||
|
||||
def export_symbel(version="v2"):
|
||||
if version == "v1":
|
||||
symbols = text._symbol_to_id_v1
|
||||
with open("onnx/symbols_v1.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
@ -819,15 +829,16 @@ def export_symbel(version='v2'):
|
||||
with open("onnx/symbols_v2.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model")
|
||||
parser.add_argument('--device', help="Device to use")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||
parser.add_argument("--device", help="Device to use")
|
||||
|
||||
args = parser.parse_args()
|
||||
export(
|
||||
@ -840,9 +851,11 @@ def main():
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
)
|
||||
|
||||
|
||||
import inference_webui
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference_webui.is_half=False
|
||||
inference_webui.dtype=torch.float32
|
||||
inference_webui.is_half = False
|
||||
inference_webui.dtype = torch.float32
|
||||
main()
|
||||
# test()
|
||||
|
@ -32,7 +32,6 @@ now_dir = os.getcwd()
|
||||
|
||||
|
||||
class MelSpectrgram(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype,
|
||||
@ -48,14 +47,12 @@ class MelSpectrgram(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
||||
self.n_fft:int = n_fft
|
||||
self.hop_size:int = hop_size
|
||||
self.win_size:int = win_size
|
||||
self.center:bool = center
|
||||
self.n_fft: int = n_fft
|
||||
self.hop_size: int = hop_size
|
||||
self.win_size: int = win_size
|
||||
self.center: bool = center
|
||||
|
||||
def forward(self, y):
|
||||
y = torch.nn.functional.pad(
|
||||
@ -172,9 +169,7 @@ class ExportCFM(torch.nn.Module):
|
||||
):
|
||||
T_min = fea_ref.size(2)
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
cfm_res = self.cfm(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps
|
||||
)
|
||||
cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
@ -198,6 +193,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def norm_spec(x):
|
||||
spec_min = -12
|
||||
@ -212,7 +208,6 @@ def denorm_spec(x):
|
||||
|
||||
|
||||
class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
|
||||
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
||||
super().__init__()
|
||||
self.hps = hps
|
||||
@ -231,15 +226,15 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
center=False,
|
||||
)
|
||||
# self.dtype = dtype
|
||||
self.filter_length:int = hps.data.filter_length
|
||||
self.sampling_rate:int = hps.data.sampling_rate
|
||||
self.hop_length:int = hps.data.hop_length
|
||||
self.win_length:int = hps.data.win_length
|
||||
self.filter_length: int = hps.data.filter_length
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k:torch.FloatTensor,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0,
|
||||
phoneme_ids1,
|
||||
bert1,
|
||||
@ -255,18 +250,14 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
center=False,
|
||||
).to(ssl_content.dtype)
|
||||
|
||||
|
||||
codes = self.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0)
|
||||
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
pred_semantic = self.t2s_m(
|
||||
prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
|
||||
ge = self.vq_model.create_ge(refer)
|
||||
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
@ -293,6 +284,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
|
||||
return fea_ref, fea_todo, mel2
|
||||
|
||||
|
||||
class GPTSoVITSV3(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
||||
super().__init__()
|
||||
@ -303,9 +295,9 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k:torch.FloatTensor,
|
||||
phoneme_ids0:torch.LongTensor,
|
||||
phoneme_ids1:torch.LongTensor,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0: torch.LongTensor,
|
||||
phoneme_ids1: torch.LongTensor,
|
||||
bert1,
|
||||
bert2,
|
||||
top_k: torch.LongTensor,
|
||||
@ -313,7 +305,9 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
):
|
||||
# current_time = datetime.now()
|
||||
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
@ -331,7 +325,13 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
||||
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
||||
if complete_len != 0:
|
||||
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype)], 2)
|
||||
fea_todo_chunk = torch.cat(
|
||||
[
|
||||
fea_todo_chunk,
|
||||
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
idx += chunk_len
|
||||
@ -343,13 +343,13 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
global bigvgan_model
|
||||
from BigVGAN import bigvgan
|
||||
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x"
|
||||
% (now_dir,),
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||
use_cuda_kernel=False,
|
||||
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
# remove weight norm in the model and set to eval mode
|
||||
@ -467,10 +467,7 @@ def export_cfm(
|
||||
cfm = e_cfm.cfm
|
||||
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = (
|
||||
torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype)
|
||||
* temperature
|
||||
)
|
||||
x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
||||
print("x:", x.shape, x.dtype)
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
@ -565,11 +562,7 @@ def export():
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = sovits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
@ -626,10 +619,7 @@ def export():
|
||||
"create_ge": refer,
|
||||
}
|
||||
|
||||
|
||||
trace_vq_model = torch.jit.trace_module(
|
||||
sovits.vq_model, inputs, optimize=True
|
||||
)
|
||||
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
|
||||
trace_vq_model.save("onnx/ad/vq_model.pt")
|
||||
|
||||
print(fea_ref.shape, fea_ref.dtype, ge.shape)
|
||||
@ -714,9 +704,7 @@ def export():
|
||||
|
||||
idx += chunk_len
|
||||
|
||||
cfm_res, fea_ref, mel2 = export_cfm_(
|
||||
fea_ref, fea_todo_chunk, mel2, sample_steps
|
||||
)
|
||||
cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
cfm_resss.append(cfm_res)
|
||||
continue
|
||||
|
||||
@ -726,9 +714,7 @@ def export():
|
||||
with torch.inference_mode():
|
||||
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
||||
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
||||
bigvgan_model_ = torch.jit.trace(
|
||||
bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)
|
||||
)
|
||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
@ -748,7 +734,6 @@ def test_export(
|
||||
bigvgan,
|
||||
output,
|
||||
):
|
||||
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
@ -773,13 +758,9 @@ def test_export(
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
|
||||
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
||||
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
@ -799,8 +780,18 @@ def test_export(
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info("start inference %s", current_time)
|
||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
||||
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
print(
|
||||
ssl_content.shape,
|
||||
ref_audio_32k.shape,
|
||||
phoneme_ids0.shape,
|
||||
phoneme_ids1.shape,
|
||||
bert1.shape,
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
print(fea_ref.shape, fea_todo.shape, mel2.shape)
|
||||
|
||||
@ -812,7 +803,6 @@ def test_export(
|
||||
wav_gen_length = fea_todo.shape[2] * 256
|
||||
|
||||
while 1:
|
||||
|
||||
current_time = datetime.now()
|
||||
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||
@ -861,7 +851,6 @@ def test_export1(
|
||||
gpt_sovits_v3,
|
||||
output,
|
||||
):
|
||||
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
@ -886,14 +875,10 @@ def test_export1(
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
|
||||
|
||||
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
||||
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
@ -913,11 +898,19 @@ def test_export1(
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info("start inference %s", current_time)
|
||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
||||
print(
|
||||
ssl_content.shape,
|
||||
ref_audio_32k.shape,
|
||||
phoneme_ids0.shape,
|
||||
phoneme_ids1.shape,
|
||||
bert1.shape,
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
|
||||
wav_gen = torch.cat([wav_gen,zero_wav_torch],0)
|
||||
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
||||
|
||||
audio = wav_gen.cpu().detach().numpy()
|
||||
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
@ -929,7 +922,6 @@ import time
|
||||
|
||||
|
||||
def test_():
|
||||
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
|
||||
# cfm = ExportCFM(sovits.cfm)
|
||||
@ -959,10 +951,7 @@ def test_():
|
||||
# t2s_m.top_k = 15
|
||||
logger.info("t2s_m ok")
|
||||
|
||||
|
||||
vq_model: torch.jit.ScriptModule = torch.jit.load(
|
||||
"onnx/ad/vq_model.pt", map_location=device
|
||||
)
|
||||
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
|
||||
# vq_model = torch.jit.optimize_for_inference(vq_model)
|
||||
# vq_model = vq_model.half().to(device)
|
||||
vq_model.eval()
|
||||
@ -1020,8 +1009,9 @@ def test_():
|
||||
# "out2.wav",
|
||||
# )
|
||||
|
||||
|
||||
def test_export_gpt_sovits_v3():
|
||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt",map_location=device)
|
||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
||||
# test_export1(
|
||||
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||
# gpt_sovits_v3,
|
||||
|
@ -27,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
|
||||
|
||||
from module.commons import sequence_mask
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
@ -129,8 +130,8 @@ class DiT(nn.Module):
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def forward(#x, prompt_x, x_lens, t, style,cond
|
||||
self,#d is channel,n is T
|
||||
def forward( # x, prompt_x, x_lens, t, style,cond
|
||||
self, # d is channel,n is T
|
||||
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond0: float["b n d"], # masked cond audio # noqa: F722
|
||||
x_lens,
|
||||
@ -142,13 +143,11 @@ class DiT(nn.Module):
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
|
||||
):
|
||||
|
||||
x=x0.transpose(2,1)
|
||||
cond=cond0.transpose(2,1)
|
||||
text=text0.transpose(2,1)
|
||||
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
|
||||
x = x0.transpose(2, 1)
|
||||
cond = cond0.transpose(2, 1)
|
||||
text = text0.transpose(2, 1)
|
||||
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
||||
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
@ -157,8 +156,8 @@ class DiT(nn.Module):
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
dt = self.d_embed(dt_base_bootstrap)
|
||||
t+=dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
|
||||
t += dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
@ -391,6 +391,7 @@ class Attention(nn.Module):
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
# from torch.nn.attention import SDPBackend
|
||||
# torch.backends.cuda.enable_flash_sdp(True)
|
||||
class AttnProcessor:
|
||||
@ -545,6 +546,7 @@ class JointAttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
@ -1,6 +1,3 @@
|
||||
from . import cnhubert, whisper_enc
|
||||
|
||||
content_module_map = {
|
||||
'cnhubert': cnhubert,
|
||||
'whisper': whisper_enc
|
||||
}
|
||||
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}
|
||||
|
@ -1,10 +1,11 @@
|
||||
|
||||
import torch
|
||||
import os
|
||||
from transformers import logging as tf_logging
|
||||
|
||||
tf_logging.set_verbosity_error()
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
from transformers import (
|
||||
@ -19,21 +20,19 @@ cnhubert_base_path = None
|
||||
|
||||
|
||||
class CNHubert(nn.Module):
|
||||
def __init__(self, base_path:str=None):
|
||||
def __init__(self, base_path: str = None):
|
||||
super().__init__()
|
||||
if base_path is None:
|
||||
base_path = cnhubert_base_path
|
||||
if os.path.exists(base_path):...
|
||||
else:raise FileNotFoundError(base_path)
|
||||
if os.path.exists(base_path):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(base_path)
|
||||
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_path, local_files_only=True
|
||||
)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
|
||||
|
||||
def forward(self, x):
|
||||
input_values = self.feature_extractor(
|
||||
x, return_tensors="pt", sampling_rate=16000
|
||||
).input_values.to(x.device)
|
||||
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||
feats = self.model(input_values)["last_hidden_state"]
|
||||
return feats
|
||||
|
||||
|
@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
|
||||
feature_len = mel.shape[-1] // 2
|
||||
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
||||
with torch.no_grad():
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
|
||||
:1, :feature_len, :
|
||||
].transpose(1, 2)
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
|
||||
return feature
|
||||
|
@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
|
||||
|
||||
def synthesize(
|
||||
GPT_model_path,
|
||||
SoVITS_model_path,
|
||||
ref_audio_path,
|
||||
ref_text_path,
|
||||
ref_language,
|
||||
target_text_path,
|
||||
target_language,
|
||||
output_path,
|
||||
):
|
||||
# Read reference text
|
||||
with open(ref_text_path, 'r', encoding='utf-8') as file:
|
||||
with open(ref_text_path, "r", encoding="utf-8") as file:
|
||||
ref_text = file.read()
|
||||
|
||||
# Read target text
|
||||
with open(target_text_path, 'r', encoding='utf-8') as file:
|
||||
with open(target_text_path, "r", encoding="utf-8") as file:
|
||||
target_text = file.read()
|
||||
|
||||
# Change model weights
|
||||
@ -21,11 +31,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
|
||||
# Synthesize audio
|
||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language), top_p=1, temperature=1)
|
||||
text_language=i18n(target_language),
|
||||
top_p=1,
|
||||
temperature=1,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
print(f"Audio saved to {output_wav_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
|
||||
parser.add_argument('--target_text', required=True, help="Path to the target text file")
|
||||
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument(
|
||||
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
||||
)
|
||||
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
||||
parser.add_argument(
|
||||
"--target_language",
|
||||
required=True,
|
||||
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
||||
help="Language of the target text",
|
||||
)
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
|
||||
synthesize(
|
||||
args.gpt_model,
|
||||
args.sovits_model,
|
||||
args.ref_audio,
|
||||
args.ref_text,
|
||||
args.ref_language,
|
||||
args.target_text,
|
||||
args.target_language,
|
||||
args.output_path,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.setWindowTitle('GPT-SoVITS GUI')
|
||||
self.setWindowTitle("GPT-SoVITS GUI")
|
||||
self.setGeometry(800, 450, 950, 850)
|
||||
|
||||
self.setStyleSheet("""
|
||||
@ -65,7 +66,8 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
|
||||
license_text = (
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
||||
)
|
||||
license_label = QLabel(license_text)
|
||||
license_label.setWordWrap(True)
|
||||
|
||||
@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
self.output_text = QTextEdit()
|
||||
self.output_text.setReadOnly(True)
|
||||
|
||||
self.add_drag_drop_events([
|
||||
self.add_drag_drop_events(
|
||||
[
|
||||
self.GPT_model_input,
|
||||
self.SoVITS_model_input,
|
||||
self.ref_audio_input,
|
||||
self.ref_text_input,
|
||||
self.target_text_input,
|
||||
self.output_input,
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
self.synthesize_button = QPushButton("合成")
|
||||
self.synthesize_button.clicked.connect(self.synthesize)
|
||||
@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
def upload_ref_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.ref_text_input.setText(content)
|
||||
|
||||
def upload_target_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.target_text_input.setText(content)
|
||||
|
||||
@ -284,11 +288,13 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
self.SoVITS_Path = SoVITS_model_path
|
||||
|
||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=language_combobox,
|
||||
text=target_text,
|
||||
text_language=target_language_combobox)
|
||||
text_language=target_language_combobox,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
@ -303,7 +309,7 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
self.output_text.append("处理结果:\n" + result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
mainWin = GPTSoVITSGUI()
|
||||
mainWin.show()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,17 +1,19 @@
|
||||
'''
|
||||
"""
|
||||
按中英混合识别
|
||||
按日英混合识别
|
||||
多语种启动切分识别语种
|
||||
全部按中文识别
|
||||
全部按英文识别
|
||||
全部按日文识别
|
||||
'''
|
||||
"""
|
||||
|
||||
import random
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
@ -27,8 +29,10 @@ import torch
|
||||
|
||||
try:
|
||||
import gradio.analytics as analytics
|
||||
analytics.version_check = lambda:None
|
||||
except:...
|
||||
|
||||
analytics.version_check = lambda: None
|
||||
except:
|
||||
...
|
||||
|
||||
|
||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||
@ -43,15 +47,15 @@ gpt_path = os.environ.get("gpt_path", None)
|
||||
sovits_path = os.environ.get("sovits_path", None)
|
||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||
bert_path = os.environ.get("bert_path", None)
|
||||
version=os.environ.get("version","v2")
|
||||
version = os.environ.get("version", "v2")
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
@ -68,30 +72,30 @@ else:
|
||||
# device = "cpu"
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language_v2 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("粤语"): "all_yue",#全部按中文识别
|
||||
i18n("韩文"): "all_ko",#全部按韩文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("粤英混合"): "yue",#按粤英混合识别####不变
|
||||
i18n("韩英混合"): "ko",#按韩英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("粤语"): "all_yue", # 全部按中文识别
|
||||
i18n("韩文"): "all_ko", # 全部按韩文识别
|
||||
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
||||
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
|
||||
cut_method = {
|
||||
i18n("不切"):"cut0",
|
||||
i18n("不切"): "cut0",
|
||||
i18n("凑四句一切"): "cut1",
|
||||
i18n("凑50字一切"): "cut2",
|
||||
i18n("按中文句号。切"): "cut3",
|
||||
@ -118,22 +122,33 @@ gpt_path = tts_config.t2s_weights_path
|
||||
sovits_path = tts_config.vits_weights_path
|
||||
version = tts_config.version
|
||||
|
||||
def inference(text, text_lang,
|
||||
|
||||
def inference(
|
||||
text,
|
||||
text_lang,
|
||||
ref_audio_path,
|
||||
aux_ref_audio_paths,
|
||||
prompt_text,
|
||||
prompt_lang, top_k,
|
||||
top_p, temperature,
|
||||
text_split_method, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
):
|
||||
|
||||
prompt_lang,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
text_split_method,
|
||||
batch_size,
|
||||
speed_factor,
|
||||
ref_text_free,
|
||||
split_bucket,
|
||||
fragment_interval,
|
||||
seed,
|
||||
keep_random,
|
||||
parallel_infer,
|
||||
repetition_penalty,
|
||||
sample_steps,
|
||||
super_sampling,
|
||||
):
|
||||
seed = -1 if keep_random else seed
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
||||
inputs={
|
||||
inputs = {
|
||||
"text": text,
|
||||
"text_lang": dict_language[text_lang],
|
||||
"ref_audio_path": ref_audio_path,
|
||||
@ -144,12 +159,12 @@ def inference(text, text_lang,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": cut_method[text_split_method],
|
||||
"batch_size":int(batch_size),
|
||||
"speed_factor":float(speed_factor),
|
||||
"split_bucket":split_bucket,
|
||||
"return_fragment":False,
|
||||
"fragment_interval":fragment_interval,
|
||||
"seed":actual_seed,
|
||||
"batch_size": int(batch_size),
|
||||
"speed_factor": float(speed_factor),
|
||||
"split_bucket": split_bucket,
|
||||
"return_fragment": False,
|
||||
"fragment_interval": fragment_interval,
|
||||
"seed": actual_seed,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"sample_steps": int(sample_steps),
|
||||
@ -159,11 +174,12 @@ def inference(text, text_lang,
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
except NO_PROMPT_ERROR:
|
||||
gr.Warning(i18n('V3不支持无参考文本模式,请填写参考文本!'))
|
||||
gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
|
||||
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
parts = re.split('(\d+)', s)
|
||||
parts = re.split("(\d+)", s)
|
||||
# 将数字部分转换为整数,非数字部分保持不变
|
||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||
return parts
|
||||
@ -171,52 +187,67 @@ def custom_sort_key(s):
|
||||
|
||||
def change_choices():
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
|
||||
"choices": sorted(GPT_names, key=custom_sort_key),
|
||||
"__type__": "update",
|
||||
}
|
||||
|
||||
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
|
||||
|
||||
_ =[[],[]]
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
pretrained_sovits_name = [
|
||||
"GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||
path_sovits_v3,
|
||||
]
|
||||
pretrained_gpt_name = [
|
||||
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
]
|
||||
|
||||
_ = [[], []]
|
||||
for i in range(3):
|
||||
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name,pretrained_sovits_name = _
|
||||
if os.path.exists(pretrained_gpt_name[i]):
|
||||
_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):
|
||||
_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name, pretrained_sovits_name = _
|
||||
|
||||
|
||||
if os.path.exists("./weight.json"):
|
||||
pass
|
||||
else:
|
||||
with open("./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
|
||||
with open("./weight.json", "w", encoding="utf-8") as file:
|
||||
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
||||
|
||||
with open("./weight.json", 'r', encoding="utf-8") as file:
|
||||
with open("./weight.json", "r", encoding="utf-8") as file:
|
||||
weight_data = file.read()
|
||||
weight_data=json.loads(weight_data)
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
|
||||
sovits_path = os.environ.get(
|
||||
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
|
||||
if isinstance(gpt_path,list):
|
||||
weight_data = json.loads(weight_data)
|
||||
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
|
||||
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
|
||||
if isinstance(gpt_path, list):
|
||||
gpt_path = gpt_path[0]
|
||||
if isinstance(sovits_path,list):
|
||||
if isinstance(sovits_path, list):
|
||||
sovits_path = sovits_path[0]
|
||||
|
||||
|
||||
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
|
||||
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
|
||||
for path in SoVITS_weight_root + GPT_weight_root:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
|
||||
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
|
||||
for path in SoVITS_weight_root+GPT_weight_root:
|
||||
os.makedirs(path,exist_ok=True)
|
||||
|
||||
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||
SoVITS_names = [i for i in pretrained_sovits_name]
|
||||
for path in SoVITS_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
|
||||
if name.endswith(".pth"):
|
||||
SoVITS_names.append("%s/%s" % (path, name))
|
||||
GPT_names = [i for i in pretrained_gpt_name]
|
||||
for path in GPT_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
|
||||
if name.endswith(".ckpt"):
|
||||
GPT_names.append("%s/%s" % (path, name))
|
||||
return SoVITS_names, GPT_names
|
||||
|
||||
|
||||
@ -224,72 +255,110 @@ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast
|
||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
global version, dict_language
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
|
||||
if if_lora_v3 and not os.path.exists(path_sovits_v3):
|
||||
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
|
||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
||||
dict_language = dict_language_v1 if tts_pipeline.configs.version == "v1" else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
|
||||
prompt_text_update, prompt_language_update = (
|
||||
{"__type__": "update"},
|
||||
{"__type__": "update", "value": prompt_language},
|
||||
)
|
||||
else:
|
||||
prompt_text_update = {'__type__':'update', 'value':''}
|
||||
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
prompt_text_update = {"__type__": "update", "value": ""}
|
||||
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
if text_language in list(dict_language.keys()):
|
||||
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
|
||||
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
if model_version=="v3":
|
||||
visible_sample_steps=True
|
||||
visible_inp_refs=False
|
||||
text_update = {"__type__": "update", "value": ""}
|
||||
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
if model_version == "v3":
|
||||
visible_sample_steps = True
|
||||
visible_inp_refs = False
|
||||
else:
|
||||
visible_sample_steps=False
|
||||
visible_inp_refs=True
|
||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
|
||||
visible_sample_steps = False
|
||||
visible_inp_refs = True
|
||||
yield (
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "visible": visible_sample_steps},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False},
|
||||
{"__type__": "update", "visible": True if model_version == "v3" else False},
|
||||
)
|
||||
|
||||
with open("./weight.json") as f:
|
||||
data = f.read()
|
||||
data = json.loads(data)
|
||||
data["SoVITS"][version] = sovits_path
|
||||
with open("./weight.json", "w") as f:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
with open("./weight.json")as f:
|
||||
data=f.read()
|
||||
data=json.loads(data)
|
||||
data["SoVITS"][version]=sovits_path
|
||||
with open("./weight.json","w")as f:f.write(json.dumps(data))
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
gr.Markdown(
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
||||
+ "<br>"
|
||||
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
# with gr.Group():
|
||||
gr.Markdown(value=i18n("模型切换"))
|
||||
with gr.Row():
|
||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
||||
GPT_dropdown = gr.Dropdown(
|
||||
label=i18n("GPT模型列表"),
|
||||
choices=sorted(GPT_names, key=custom_sort_key),
|
||||
value=gpt_path,
|
||||
interactive=True,
|
||||
)
|
||||
SoVITS_dropdown = gr.Dropdown(
|
||||
label=i18n("SoVITS模型列表"),
|
||||
choices=sorted(SoVITS_names, key=custom_sort_key),
|
||||
value=sovits_path,
|
||||
interactive=True,
|
||||
)
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||
with gr.Row():
|
||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple")
|
||||
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"), file_count="multiple")
|
||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||
with gr.Row():
|
||||
prompt_language = gr.Dropdown(
|
||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
with gr.Column():
|
||||
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
|
||||
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"<br>"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。"))
|
||||
ref_text_free = gr.Checkbox(
|
||||
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
|
||||
value=False,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
gr.Markdown(
|
||||
i18n("使用无参考文本模式时建议使用微调的GPT")
|
||||
+ "<br>"
|
||||
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
||||
@ -298,42 +367,66 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("推理设置"))
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
|
||||
batch_size = gr.Slider(
|
||||
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||
)
|
||||
sample_steps = gr.Radio(
|
||||
label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True
|
||||
)
|
||||
with gr.Row():
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
|
||||
fragment_interval = gr.Slider(
|
||||
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
||||
)
|
||||
speed_factor = gr.Slider(
|
||||
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||
)
|
||||
with gr.Row():
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
temperature = gr.Slider(
|
||||
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
||||
)
|
||||
repetition_penalty = gr.Slider(
|
||||
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
how_to_cut = gr.Dropdown(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
choices=[
|
||||
i18n("不切"),
|
||||
i18n("凑四句一切"),
|
||||
i18n("凑50字一切"),
|
||||
i18n("按中文句号。切"),
|
||||
i18n("按英文句号.切"),
|
||||
i18n("按标点符号切"),
|
||||
],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True, scale=1
|
||||
interactive=True,
|
||||
scale=1,
|
||||
)
|
||||
super_sampling = gr.Checkbox(
|
||||
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
||||
)
|
||||
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
|
||||
|
||||
with gr.Row():
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(
|
||||
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
seed = gr.Number(label=i18n("随机种子"),value=-1)
|
||||
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
@ -341,40 +434,67 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||
|
||||
|
||||
inference_button.click(
|
||||
inference,
|
||||
[
|
||||
text,text_language, inp_ref, inp_refs,
|
||||
prompt_text, prompt_language,
|
||||
top_k, top_p, temperature,
|
||||
how_to_cut, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
text,
|
||||
text_language,
|
||||
inp_ref,
|
||||
inp_refs,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
how_to_cut,
|
||||
batch_size,
|
||||
speed_factor,
|
||||
ref_text_free,
|
||||
split_bucket,
|
||||
fragment_interval,
|
||||
seed,
|
||||
keep_random,
|
||||
parallel_infer,
|
||||
repetition_penalty,
|
||||
sample_steps,
|
||||
super_sampling,
|
||||
],
|
||||
[output, seed],
|
||||
)
|
||||
stop_infer.click(tts_pipeline.stop, [], [])
|
||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
|
||||
SoVITS_dropdown.change(
|
||||
change_sovits_weights,
|
||||
[SoVITS_dropdown, prompt_language, text_language],
|
||||
[prompt_language, text_language, prompt_text, prompt_language, text, text_language],
|
||||
)
|
||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
||||
gr.Markdown(
|
||||
value=i18n(
|
||||
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||
with gr.Column():
|
||||
_how_to_cut = gr.Radio(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
choices=[
|
||||
i18n("不切"),
|
||||
i18n("凑四句一切"),
|
||||
i18n("凑50字一切"),
|
||||
i18n("按中文句号。切"),
|
||||
i18n("按英文句号.切"),
|
||||
i18n("按标点符号切"),
|
||||
],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
cut_text= gr.Button(i18n("切分"), variant="primary")
|
||||
cut_text = gr.Button(i18n("切分"), variant="primary")
|
||||
|
||||
def to_cut(text_inp, how_to_cut):
|
||||
if len(text_inp.strip()) == 0 or text_inp==[]:
|
||||
if len(text_inp.strip()) == 0 or text_inp == []:
|
||||
return ""
|
||||
method = get_method(cut_method[how_to_cut])
|
||||
return method(text_inp)
|
||||
@ -383,8 +503,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
||||
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.queue().launch(#concurrency_count=511, max_size=1022
|
||||
if __name__ == "__main__":
|
||||
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
share=is_share,
|
||||
|
@ -18,7 +18,7 @@ class Encoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
window_size=4,
|
||||
isflow=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -56,9 +56,7 @@ class Encoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
if isflow:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
@ -74,9 +72,7 @@ class Encoder(nn.Module):
|
||||
x = self.cond_pre(x)
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
||||
)
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
@ -99,7 +95,7 @@ class Decoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -131,9 +127,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||
self.encdec_attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||
)
|
||||
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
@ -153,9 +147,7 @@ class Decoder(nn.Module):
|
||||
x: decoder input
|
||||
h: encoder output
|
||||
"""
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
@ -211,14 +203,8 @@ class MultiHeadAttention(nn.Module):
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
@ -247,46 +233,28 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||
if self.window_size is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Relative attention is only available for self-attention."
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(
|
||||
query / math.sqrt(self.k_channels), key_relative_embeddings
|
||||
)
|
||||
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Local attention is only available for self-attention."
|
||||
block_mask = (
|
||||
torch.ones_like(scores)
|
||||
.triu(-self.block_length)
|
||||
.tril(self.block_length)
|
||||
)
|
||||
assert t_s == t_t, "Local attention is only available for self-attention."
|
||||
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_v, t_s
|
||||
)
|
||||
output = output + self._matmul_with_relative_values(
|
||||
relative_weights, value_relative_embeddings
|
||||
)
|
||||
output = (
|
||||
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
@ -320,9 +288,7 @@ class MultiHeadAttention(nn.Module):
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
@ -336,14 +302,10 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
@ -353,9 +315,7 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
@ -537,9 +497,7 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
|
||||
|
||||
|
||||
def weight_norm_modules(module, name="weight", dim=0):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
||||
module, Depthwise_Separable_TransposeConv1D
|
||||
):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||
module.weight_norm()
|
||||
return module
|
||||
else:
|
||||
@ -547,9 +505,7 @@ def weight_norm_modules(module, name="weight", dim=0):
|
||||
|
||||
|
||||
def remove_weight_norm_modules(module, name="weight"):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
||||
module, Depthwise_Separable_TransposeConv1D
|
||||
):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||
module.remove_weight_norm()
|
||||
else:
|
||||
remove_weight_norm(module, name)
|
||||
@ -567,7 +523,7 @@ class FFT(nn.Module):
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
isflow=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -579,9 +535,7 @@ class FFT(nn.Module):
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
if isflow:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
@ -622,18 +576,14 @@ class FFT(nn.Module):
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
if g is not None:
|
||||
x = self.cond_pre(x)
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
||||
)
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_0[i](x + y)
|
||||
|
@ -7,6 +7,7 @@ from module import commons
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
@ -43,7 +44,7 @@ class Encoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
window_size=4,
|
||||
isflow=True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -65,13 +66,9 @@ class Encoder(nn.Module):
|
||||
if self.gin_channels != 0:
|
||||
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
||||
# vits2 says 3rd block, so idx is 2 by default
|
||||
self.cond_layer_idx = (
|
||||
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||
)
|
||||
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||
logging.debug(self.gin_channels, self.cond_layer_idx)
|
||||
assert (
|
||||
self.cond_layer_idx < self.n_layers
|
||||
), "cond_layer_idx should be less than n_layers"
|
||||
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
@ -121,7 +118,9 @@ class Encoder(nn.Module):
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
|
||||
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
|
||||
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
|
||||
):
|
||||
y = attn_layers(x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = norm_layers_1(x + y)
|
||||
@ -170,14 +169,8 @@ class MultiHeadAttention(nn.Module):
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
@ -187,7 +180,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
|
||||
def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
@ -198,7 +191,7 @@ class MultiHeadAttention(nn.Module):
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
|
||||
def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, _ = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||
@ -224,7 +217,7 @@ class MultiHeadAttention(nn.Module):
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
|
||||
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, -1)
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
@ -248,19 +241,17 @@ class MultiHeadAttention(nn.Module):
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
|
||||
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
|
||||
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
|
||||
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
|
||||
pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
|
||||
pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
|
||||
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
|
||||
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
|
||||
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||
)
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
@ -274,14 +265,10 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
@ -291,9 +278,7 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
@ -395,12 +380,6 @@ class MRTE(nn.Module):
|
||||
|
||||
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
@ -28,9 +28,7 @@ def intersperse(lst, item):
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += (
|
||||
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
)
|
||||
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
@ -67,9 +65,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
||||
num_timescales - 1
|
||||
)
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
|
@ -30,6 +30,7 @@
|
||||
# SOFTWARE.
|
||||
|
||||
"""Core vector quantization implementation."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@ -121,9 +122,7 @@ class EuclideanCodebook(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
||||
uniform_init if not kmeans_init else torch.zeros
|
||||
)
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
@ -151,9 +150,7 @@ class EuclideanCodebook(nn.Module):
|
||||
# broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
@ -174,11 +171,7 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
@ -222,8 +215,7 @@ class EuclideanCodebook(nn.Module):
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
@ -264,12 +256,8 @@ class VectorQuantization(nn.Module):
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
@ -330,13 +318,9 @@ class ResidualVectorQuantization(nn.Module):
|
||||
|
||||
def __init__(self, *, num_quantizers, **kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
||||
|
||||
def forward(
|
||||
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
||||
):
|
||||
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
@ -359,9 +343,7 @@ class ResidualVectorQuantization(nn.Module):
|
||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||
return quantized_out, out_indices, out_losses, out_quantized
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
|
@ -5,11 +5,14 @@ import torch
|
||||
import torch.utils.data
|
||||
from tqdm import tqdm
|
||||
|
||||
from module.mel_processing import spectrogram_torch,spec_to_mel_torch
|
||||
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
||||
from text import cleaned_text_to_sequence
|
||||
import torch.nn.functional as F
|
||||
from tools.my_utils import load_audio
|
||||
version = os.environ.get('version',None)
|
||||
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
|
||||
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
||||
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
@ -34,7 +37,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@ -42,7 +45,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@ -67,7 +70,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@ -102,7 +105,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@ -120,8 +123,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
return spec, audio_norm
|
||||
|
||||
@ -137,12 +141,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
return len(self.audiopaths_sid_text)
|
||||
|
||||
def random_slice(self, ssl, wav, mel):
|
||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
||||
"first", ssl.shape, wav.shape)
|
||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
|
||||
|
||||
len_mel = mel.shape[1]
|
||||
if self.val:
|
||||
reference_mel = mel[:, :len_mel // 3]
|
||||
reference_mel = mel[:, : len_mel // 3]
|
||||
return reference_mel, ssl, wav, mel
|
||||
dir = random.randint(0, 1)
|
||||
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
||||
@ -150,20 +153,29 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
if dir == 0:
|
||||
reference_mel = mel[:, :sep_point]
|
||||
ssl = ssl[:, :, sep_point:]
|
||||
wav2 = wav[:, sep_point * self.hop_length:]
|
||||
wav2 = wav[:, sep_point * self.hop_length :]
|
||||
mel = mel[:, sep_point:]
|
||||
else:
|
||||
reference_mel = mel[:, sep_point:]
|
||||
ssl = ssl[:, :, :sep_point]
|
||||
wav2 = wav[:, :sep_point * self.hop_length]
|
||||
wav2 = wav[:, : sep_point * self.hop_length]
|
||||
mel = mel[:, :sep_point]
|
||||
|
||||
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
||||
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
||||
ssl.shape,
|
||||
wav.shape,
|
||||
wav2.shape,
|
||||
mel.shape,
|
||||
sep_point,
|
||||
self.hop_length,
|
||||
sep_point * self.hop_length,
|
||||
dir,
|
||||
)
|
||||
return reference_mel, ssl, wav2, mel
|
||||
class TextAudioSpeakerCollate():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@ -175,9 +187,7 @@ class TextAudioSpeakerCollate():
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
||||
@ -205,22 +215,24 @@ class TextAudioSpeakerCollate():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wav = row[2]
|
||||
wav_padded[i, :, :wav.size(1)] = wav
|
||||
wav_padded[i, :, : wav.size(1)] = wav
|
||||
wav_lengths[i] = wav.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
@ -244,7 +256,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@ -252,7 +264,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@ -277,7 +289,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@ -304,15 +316,16 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
self.spec_min=-12
|
||||
self.spec_max=2
|
||||
self.spec_min = -12
|
||||
self.spec_max = 2
|
||||
|
||||
self.filter_length_mel = self.win_length_mel = 1024
|
||||
self.hop_length_mel = 256
|
||||
self.n_mel_channels = 100
|
||||
self.sampling_rate_mel = 24000
|
||||
self.mel_fmin = 0
|
||||
self.mel_fmax = None
|
||||
|
||||
self.filter_length_mel=self.win_length_mel=1024
|
||||
self.hop_length_mel=256
|
||||
self.n_mel_channels=100
|
||||
self.sampling_rate_mel=24000
|
||||
self.mel_fmin=0
|
||||
self.mel_fmax=None
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
@ -323,7 +336,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@ -338,25 +351,35 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
return (ssl, spec, mel, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio=torch.FloatTensor(audio_array)#/32768
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24=torch.FloatTensor(audio_array24)#/32768
|
||||
audio_array24 = load_audio(
|
||||
filename, 24000
|
||||
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||
audio_norm24 = audio24
|
||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
|
||||
|
||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
||||
spec1 = spectrogram_torch(
|
||||
audio_norm24,
|
||||
self.filter_length_mel,
|
||||
self.sampling_rate_mel,
|
||||
self.hop_length_mel,
|
||||
self.win_length_mel,
|
||||
center=False,
|
||||
)
|
||||
mel = spec_to_mel_torch(
|
||||
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||
)
|
||||
mel = torch.squeeze(mel, 0)
|
||||
mel=self.norm_spec(mel)
|
||||
mel = self.norm_spec(mel)
|
||||
# print(1111111,spec.shape,mel.shape)
|
||||
return spec, mel
|
||||
|
||||
@ -370,9 +393,10 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
class TextAudioSpeakerCollateV3():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollateV3:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@ -383,12 +407,10 @@ class TextAudioSpeakerCollateV3():
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
#ssl, spec, wav,mel, text
|
||||
# ssl, spec, wav,mel, text
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
#(ssl, spec,mel, text)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
# (ssl, spec,mel, text)
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
|
||||
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
||||
@ -402,7 +424,7 @@ class TextAudioSpeakerCollateV3():
|
||||
# max_wav_len = max([x[2].size(1) for x in batch])
|
||||
|
||||
max_text_len = max([x[3].size(0) for x in batch])
|
||||
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
|
||||
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
||||
|
||||
ssl_lengths = torch.LongTensor(len(batch))
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
@ -426,11 +448,11 @@ class TextAudioSpeakerCollateV3():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
# ssl, spec, wav,mel, text
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
# wav = row[2]
|
||||
@ -438,15 +460,17 @@ class TextAudioSpeakerCollateV3():
|
||||
# wav_lengths[i] = wav.size(1)
|
||||
|
||||
mel = row[2]
|
||||
mel_padded[i, :, :mel.size(1)] = mel
|
||||
mel_padded[i, :, : mel.size(1)] = mel
|
||||
mel_lengths[i] = mel.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
||||
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
@ -470,7 +494,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@ -478,7 +502,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@ -503,7 +527,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@ -530,15 +554,16 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
self.spec_min=-12
|
||||
self.spec_max=2
|
||||
self.spec_min = -12
|
||||
self.spec_max = 2
|
||||
|
||||
self.filter_length_mel = self.win_length_mel = 1024
|
||||
self.hop_length_mel = 256
|
||||
self.n_mel_channels = 100
|
||||
self.sampling_rate_mel = 24000
|
||||
self.mel_fmin = 0
|
||||
self.mel_fmax = None
|
||||
|
||||
self.filter_length_mel=self.win_length_mel=1024
|
||||
self.hop_length_mel=256
|
||||
self.n_mel_channels=100
|
||||
self.sampling_rate_mel=24000
|
||||
self.mel_fmin=0
|
||||
self.mel_fmax=None
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
@ -546,10 +571,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
audiopath, phoneme_ids = audiopath_sid_text
|
||||
text = torch.FloatTensor(phoneme_ids)
|
||||
try:
|
||||
spec, mel,wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@ -564,27 +589,37 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
return (ssl, spec, wav, mel, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio=torch.FloatTensor(audio_array)#/32768
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24=torch.FloatTensor(audio_array24)#/32768
|
||||
audio_array24 = load_audio(
|
||||
filename, 24000
|
||||
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||
audio_norm24 = audio24
|
||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
|
||||
|
||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
||||
spec1 = spectrogram_torch(
|
||||
audio_norm24,
|
||||
self.filter_length_mel,
|
||||
self.sampling_rate_mel,
|
||||
self.hop_length_mel,
|
||||
self.win_length_mel,
|
||||
center=False,
|
||||
)
|
||||
mel = spec_to_mel_torch(
|
||||
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||
)
|
||||
mel = torch.squeeze(mel, 0)
|
||||
mel=self.norm_spec(mel)
|
||||
mel = self.norm_spec(mel)
|
||||
# print(1111111,spec.shape,mel.shape)
|
||||
return spec, mel,audio_norm
|
||||
return spec, mel, audio_norm
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
@ -596,9 +631,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
class TextAudioSpeakerCollateV3b():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollateV3b:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@ -609,12 +645,10 @@ class TextAudioSpeakerCollateV3b():
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
#ssl, spec, wav,mel, text
|
||||
# ssl, spec, wav,mel, text
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
#(ssl, spec,mel, text)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
# (ssl, spec,mel, text)
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
|
||||
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
||||
@ -627,7 +661,7 @@ class TextAudioSpeakerCollateV3b():
|
||||
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
||||
max_wav_len = max([x[2].size(1) for x in batch])
|
||||
max_text_len = max([x[4].size(0) for x in batch])
|
||||
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
|
||||
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
||||
|
||||
ssl_lengths = torch.LongTensor(len(batch))
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
@ -651,28 +685,40 @@ class TextAudioSpeakerCollateV3b():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
# ssl, spec, wav,mel, text
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wav = row[2]
|
||||
wav_padded[i, :, :wav.size(1)] = wav
|
||||
wav_padded[i, :, : wav.size(1)] = wav
|
||||
wav_lengths[i] = wav.size(1)
|
||||
|
||||
mel = row[3]
|
||||
mel_padded[i, :, :mel.size(1)] = mel
|
||||
mel_padded[i, :, : mel.size(1)] = mel
|
||||
mel_lengths[i] = mel.size(1)
|
||||
|
||||
text = row[4]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||
return (
|
||||
ssl_padded,
|
||||
spec_padded,
|
||||
mel_padded,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text_padded,
|
||||
text_lengths,
|
||||
wav_padded,
|
||||
wav_lengths,
|
||||
mel_lengths,
|
||||
)
|
||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
||||
|
||||
|
||||
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
"""
|
||||
Maintain similar input lengths in a batch.
|
||||
@ -736,12 +782,12 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
num_samples_bucket = self.num_samples_per_bucket[i]
|
||||
|
||||
rem = num_samples_bucket - len_bucket
|
||||
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
|
||||
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
|
||||
|
||||
ids_bucket = ids_bucket[self.rank::self.num_replicas]
|
||||
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
||||
|
||||
for j in range(len(ids_bucket) // self.batch_size):
|
||||
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
|
||||
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
|
||||
batches.append(batch)
|
||||
|
||||
if self.shuffle:
|
||||
|
@ -65,8 +65,6 @@ def mle_loss(z, m, logs, logdet, mask):
|
||||
torch.exp(-2 * logs) * ((z - m) ** 2)
|
||||
) # neg normal likelihood w/o the constant term
|
||||
l = l - torch.sum(logdet) # log jacobian determinant
|
||||
l = l / torch.sum(
|
||||
torch.ones_like(z) * mask
|
||||
) # averaging across batch, channel and time axes
|
||||
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
|
||||
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
||||
return l
|
||||
|
@ -47,9 +47,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
@ -79,20 +77,14 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=spec.dtype, device=spec.device
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
@ -103,16 +95,10 @@ def mel_spectrogram_torch(
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import math
|
||||
|
||||
@ -15,6 +16,7 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
@ -46,29 +48,21 @@ class StochasticDurationPredictor(nn.Module):
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
@ -89,10 +83,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = (
|
||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* x_mask
|
||||
)
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
@ -100,13 +91,8 @@ class StochasticDurationPredictor(nn.Module):
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
@ -115,18 +101,12 @@ class StochasticDurationPredictor(nn.Module):
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = (
|
||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* noise_scale
|
||||
)
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
@ -135,9 +115,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
||||
):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
@ -147,13 +125,9 @@ class DurationPredictor(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@ -188,7 +162,7 @@ class TextEncoder(nn.Module):
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
latent_channels=192,
|
||||
version = "v2",
|
||||
version="v2",
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
@ -235,26 +209,22 @@ class TextEncoder(nn.Module):
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
|
||||
text_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(text_lengths, text.size(1)), 1
|
||||
).to(y.dtype)
|
||||
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
|
||||
if test == 1:
|
||||
text[:, :] = 0
|
||||
text = self.text_embedding(text).transpose(1, 2)
|
||||
text = self.encoder_text(text * text_mask, text_mask)
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
if(speed!=1):
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
stats = self.proj(y) * y_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
@ -358,9 +328,7 @@ class PosteriorEncoder(nn.Module):
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
@ -370,14 +338,9 @@ class PosteriorEncoder(nn.Module):
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@ -392,7 +355,7 @@ class Encoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if(g!=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
@ -400,6 +363,7 @@ class Encoder(nn.Module):
|
||||
stats = self.proj(x) * x_mask
|
||||
return stats, x_mask
|
||||
|
||||
|
||||
class WNEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -432,9 +396,7 @@ class WNEncoder(nn.Module):
|
||||
self.norm = modules.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
out = self.proj(x) * x_mask
|
||||
@ -457,9 +419,7 @@ class Generator(torch.nn.Module):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(
|
||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||
)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
@ -479,9 +439,7 @@ class Generator(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
@ -634,9 +592,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [
|
||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||
]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
@ -736,10 +692,7 @@ class Quantizer(torch.nn.Module):
|
||||
super(Quantizer, self).__init__()
|
||||
assert embed_dim % n_code_groups == 0
|
||||
self.quantizer_modules = nn.ModuleList(
|
||||
[
|
||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
||||
for _ in range(n_code_groups)
|
||||
]
|
||||
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
||||
)
|
||||
self.n_code_groups = n_code_groups
|
||||
self.embed_dim = embed_dim
|
||||
@ -757,9 +710,7 @@ class Quantizer(torch.nn.Module):
|
||||
z_q.append(_z_q)
|
||||
min_indicies.append(_min_indicies) # B * T,
|
||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
||||
(z_q - xin.detach()) ** 2
|
||||
)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||
z_q = xin + (z_q - xin).detach()
|
||||
z_q = z_q.transpose(1, 2)
|
||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||
@ -799,13 +750,9 @@ class CodePredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
ssl_dim, style_vector_dim=hidden_channels
|
||||
)
|
||||
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||
|
||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||
self.n_q = n_q
|
||||
@ -818,9 +765,7 @@ class CodePredictor(nn.Module):
|
||||
x = x + g
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
x = self.out_proj(x * x_mask) * x_mask
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
||||
2, 3
|
||||
)
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||
target = codes[1:].transpose(0, 1)
|
||||
if not infer:
|
||||
logits = logits.reshape(-1, self.dims)
|
||||
@ -868,8 +813,8 @@ class SynthesizerTrn(nn.Module):
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version = "v2",
|
||||
**kwargs
|
||||
version="v2",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
@ -900,7 +845,7 @@ class SynthesizerTrn(nn.Module):
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
version = version,
|
||||
version=version,
|
||||
)
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
@ -921,12 +866,10 @@ class SynthesizerTrn(nn.Module):
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
# self.version=os.environ.get("version","v1")
|
||||
if(self.version=="v1"):
|
||||
if self.version == "v1":
|
||||
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||
else:
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
||||
@ -943,13 +886,11 @@ class SynthesizerTrn(nn.Module):
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
|
||||
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
if(self.version=="v1"):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
else:
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
with autocast(enabled=False):
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
with maybe_no_grad:
|
||||
@ -957,24 +898,16 @@ class SynthesizerTrn(nn.Module):
|
||||
self.ssl_proj.eval()
|
||||
self.quantizer.eval()
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
|
||||
z_slice, ids_slice = commons.rand_slice_segments(
|
||||
z, y_lengths, self.segment_size
|
||||
)
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=ge)
|
||||
return (
|
||||
o,
|
||||
@ -987,24 +920,18 @@ class SynthesizerTrn(nn.Module):
|
||||
)
|
||||
|
||||
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
if(self.version=="v1"):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
else:
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge, test=test
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -1013,39 +940,34 @@ class SynthesizerTrn(nn.Module):
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
def get_ge(refer):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||
).to(refer.dtype)
|
||||
if (self.version == "v1"):
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
return ge
|
||||
if(type(refer)==list):
|
||||
ges=[]
|
||||
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for _refer in refer:
|
||||
ge=get_ge(_refer)
|
||||
ge = get_ge(_refer)
|
||||
ges.append(ge)
|
||||
ge=torch.stack(ges,0).mean(0)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
else:
|
||||
ge=get_ge(refer)
|
||||
ge = get_ge(refer)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge,speed
|
||||
)
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -1057,11 +979,10 @@ class SynthesizerTrn(nn.Module):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class CFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,dit
|
||||
):
|
||||
def __init__(self, in_channels, dit):
|
||||
super().__init__()
|
||||
self.sigma_min = 1e-6
|
||||
|
||||
@ -1075,41 +996,54 @@ class CFM(torch.nn.Module):
|
||||
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
|
||||
"""Forward diffusion"""
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
||||
x[..., :prompt_len] = 0
|
||||
mu=mu.transpose(2,1)
|
||||
mu = mu.transpose(2, 1)
|
||||
t = 0
|
||||
d = 1 / n_timesteps
|
||||
for j in range(n_timesteps):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
||||
if inference_cfg_rate>1e-5:
|
||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||
v_pred = self.estimator(
|
||||
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
|
||||
).transpose(2, 1)
|
||||
if inference_cfg_rate > 1e-5:
|
||||
neg = self.estimator(
|
||||
x,
|
||||
prompt_x,
|
||||
x_lens,
|
||||
t_tensor,
|
||||
d_tensor,
|
||||
mu,
|
||||
use_grad_ckpt=False,
|
||||
drop_audio_cond=True,
|
||||
drop_text=True,
|
||||
).transpose(2, 1)
|
||||
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
|
||||
x = x + d * v_pred
|
||||
t = t + d
|
||||
x[:, :, :prompt_len] = 0
|
||||
return x
|
||||
|
||||
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
||||
b, _, t = x1.shape
|
||||
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
|
||||
x0 = torch.randn_like(x1,device=mu.device)
|
||||
x0 = torch.randn_like(x1, device=mu.device)
|
||||
vt = x1 - x0
|
||||
xt = x0 + t[:, None, None] * vt
|
||||
dt = torch.zeros_like(t,device=mu.device)
|
||||
dt = torch.zeros_like(t, device=mu.device)
|
||||
prompt = torch.zeros_like(x1)
|
||||
for i in range(b):
|
||||
prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]]
|
||||
xt[i, :, :prompt_lens[i]] = 0
|
||||
gailv=0.3# if ttime()>1736250488 else 0.1
|
||||
prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]]
|
||||
xt[i, :, : prompt_lens[i]] = 0
|
||||
gailv = 0.3 # if ttime()>1736250488 else 0.1
|
||||
if random.random() < gailv:
|
||||
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
|
||||
d = 1/torch.pow(2, base)
|
||||
d = 1 / torch.pow(2, base)
|
||||
d_input = d.clone()
|
||||
d_input[d_input < 1e-2] = 0
|
||||
# with torch.no_grad():
|
||||
@ -1117,29 +1051,32 @@ class CFM(torch.nn.Module):
|
||||
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
|
||||
x_mid = xt + d[:, None, None] * v_pred_1
|
||||
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
vt = (v_pred_1 + v_pred_2) / 2
|
||||
vt = vt.detach()
|
||||
dt = 2*d
|
||||
dt = 2 * d
|
||||
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1)
|
||||
loss = 0
|
||||
for i in range(b):
|
||||
loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]])
|
||||
loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]])
|
||||
loss /= b
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def set_no_grad(net_g):
|
||||
for name, param in net_g.named_parameters():
|
||||
param.requires_grad=False
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
class SynthesizerTrnV3(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
@ -1161,8 +1098,8 @@ class SynthesizerTrnV3(nn.Module):
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs):
|
||||
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@ -1183,110 +1120,111 @@ class SynthesizerTrnV3(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.model_dim=512
|
||||
self.model_dim = 512
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
# gin_channels=gin_channels)
|
||||
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == '25hz':
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=ssl_dim,
|
||||
n_q=1,
|
||||
bins=1024
|
||||
)
|
||||
self.freeze_quantizer=freeze_quantizer
|
||||
inter_channels2=512
|
||||
self.bridge=nn.Sequential(
|
||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
||||
nn.LeakyReLU()
|
||||
)
|
||||
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
if self.freeze_quantizer==True:
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
inter_channels2 = 512
|
||||
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||
self.cfm = CFM(
|
||||
100,
|
||||
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||
) # text_dim is condition feature dim
|
||||
if self.freeze_quantizer == True:
|
||||
set_no_grad(self.ssl_proj)
|
||||
set_no_grad(self.quantizer)
|
||||
set_no_grad(self.enc_p)
|
||||
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
|
||||
def forward(
|
||||
self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
|
||||
): # ssl_lengths no need now
|
||||
with autocast(enabled=False):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
with maybe_no_grad:
|
||||
if self.freeze_quantizer:
|
||||
self.ssl_proj.eval()#
|
||||
self.ssl_proj.eval() #
|
||||
self.quantizer.eval()
|
||||
self.enc_p.eval()
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
)
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
||||
B=ssl.shape[0]
|
||||
prompt_len_max = mel_lengths*2/3
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
fea, y_mask_ = self.wns1(
|
||||
fea, mel_lengths, ge
|
||||
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
||||
B = ssl.shape[0]
|
||||
prompt_len_max = mel_lengths * 2 / 3
|
||||
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
|
||||
minn=min(mel.shape[-1],fea.shape[-1])
|
||||
mel=mel[:,:,:minn]
|
||||
fea=fea[:,:,:minn]
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
||||
minn = min(mel.shape[-1], fea.shape[-1])
|
||||
mel = mel[:, :, :minn]
|
||||
fea = fea[:, :, :minn]
|
||||
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
||||
return cfm_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_encp(self, codes,text, refer,ge=None,speed=1):
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1):
|
||||
# print(2333333,refer.shape)
|
||||
# ge=None
|
||||
if(ge==None):
|
||||
if ge == None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
|
||||
if speed==1:
|
||||
sizee=int(codes.size(2)*2.5*1.5)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
||||
if speed == 1:
|
||||
sizee = int(codes.size(2) * 2.5 * 1.5)
|
||||
else:
|
||||
sizee=int(codes.size(2)*2.5*1.5/speed)+1
|
||||
sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1
|
||||
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea,ge
|
||||
return fea, ge
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0,1)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class SynthesizerTrnV3b(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
@ -1307,8 +1245,8 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
**kwargs):
|
||||
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@ -1328,47 +1266,52 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.model_dim=512
|
||||
self.model_dim = 512
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
|
||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
gin_channels=gin_channels)
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == '25hz':
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=ssl_dim,
|
||||
n_q=1,
|
||||
bins=1024
|
||||
)
|
||||
self.freeze_quantizer=freeze_quantizer
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
|
||||
inter_channels2=512
|
||||
self.bridge=nn.Sequential(
|
||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
||||
nn.LeakyReLU()
|
||||
)
|
||||
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
inter_channels2 = 512
|
||||
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||
self.cfm = CFM(
|
||||
100,
|
||||
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||
) # text_dim is condition feature dim
|
||||
|
||||
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
|
||||
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
|
||||
with autocast(enabled=False):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
|
||||
# ge=None
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
@ -1377,51 +1320,59 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
self.ssl_proj.eval()
|
||||
self.quantizer.eval()
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
)
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=ge)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
|
||||
learned_mel = self.linear_mel(fea)
|
||||
B=ssl.shape[0]
|
||||
prompt_len_max = mel_lengths*2/3
|
||||
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)#
|
||||
minn=min(mel.shape[-1],fea.shape[-1])
|
||||
mel=mel[:,:,:minn]
|
||||
fea=fea[:,:,:minn]
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need
|
||||
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
|
||||
B = ssl.shape[0]
|
||||
prompt_len_max = mel_lengths * 2 / 3
|
||||
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) #
|
||||
minn = min(mel.shape[-1], fea.shape[-1])
|
||||
mel = mel[:, :, :minn]
|
||||
fea = fea[:, :, :minn]
|
||||
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
|
||||
return (
|
||||
commit_loss,
|
||||
cfm_loss,
|
||||
F.mse_loss(learned_mel, mel),
|
||||
o,
|
||||
ids_slice,
|
||||
y_mask,
|
||||
y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
quantized,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_encp(self, codes,text, refer,ge=None):
|
||||
def decode_encp(self, codes, text, refer, ge=None):
|
||||
# print(2333333,refer.shape)
|
||||
# ge=None
|
||||
if(ge==None):
|
||||
if ge == None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
|
||||
y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
||||
y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea,ge
|
||||
return fea, ge
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0,1)
|
||||
return codes.transpose(0, 1)
|
||||
|
@ -14,6 +14,7 @@ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
@ -42,29 +43,21 @@ class StochasticDurationPredictor(nn.Module):
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
@ -85,10 +78,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = (
|
||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* x_mask
|
||||
)
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
@ -96,13 +86,8 @@ class StochasticDurationPredictor(nn.Module):
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
@ -111,18 +96,12 @@ class StochasticDurationPredictor(nn.Module):
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = (
|
||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* noise_scale
|
||||
)
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
@ -131,9 +110,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
||||
):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
@ -143,13 +120,9 @@ class DurationPredictor(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@ -232,7 +205,7 @@ class TextEncoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, text, ge, speed=1):
|
||||
y_mask = torch.ones_like(y[:1,:1,:])
|
||||
y_mask = torch.ones_like(y[:1, :1, :])
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
@ -244,8 +217,8 @@ class TextEncoder(nn.Module):
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
if(speed!=1):
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
|
||||
stats = self.proj(y) * y_mask
|
||||
@ -331,9 +304,7 @@ class PosteriorEncoder(nn.Module):
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
@ -343,14 +314,9 @@ class PosteriorEncoder(nn.Module):
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@ -365,7 +331,7 @@ class Encoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if(g!=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
@ -373,6 +339,7 @@ class Encoder(nn.Module):
|
||||
stats = self.proj(x) * x_mask
|
||||
return stats, x_mask
|
||||
|
||||
|
||||
class WNEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -405,9 +372,7 @@ class WNEncoder(nn.Module):
|
||||
self.norm = modules.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
out = self.proj(x) * x_mask
|
||||
@ -430,9 +395,7 @@ class Generator(torch.nn.Module):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(
|
||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||
)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
@ -452,9 +415,7 @@ class Generator(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
@ -463,7 +424,7 @@ class Generator(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g:Optional[torch.Tensor]=None):
|
||||
def forward(self, x, g: Optional[torch.Tensor] = None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@ -607,9 +568,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [
|
||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||
]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
@ -709,10 +668,7 @@ class Quantizer(torch.nn.Module):
|
||||
super(Quantizer, self).__init__()
|
||||
assert embed_dim % n_code_groups == 0
|
||||
self.quantizer_modules = nn.ModuleList(
|
||||
[
|
||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
||||
for _ in range(n_code_groups)
|
||||
]
|
||||
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
||||
)
|
||||
self.n_code_groups = n_code_groups
|
||||
self.embed_dim = embed_dim
|
||||
@ -730,9 +686,7 @@ class Quantizer(torch.nn.Module):
|
||||
z_q.append(_z_q)
|
||||
min_indicies.append(_min_indicies) # B * T,
|
||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
||||
(z_q - xin.detach()) ** 2
|
||||
)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||
z_q = xin + (z_q - xin).detach()
|
||||
z_q = z_q.transpose(1, 2)
|
||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||
@ -772,13 +726,9 @@ class CodePredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
ssl_dim, style_vector_dim=hidden_channels
|
||||
)
|
||||
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||
|
||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||
self.n_q = n_q
|
||||
@ -791,9 +741,7 @@ class CodePredictor(nn.Module):
|
||||
x = x + g
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
x = self.out_proj(x * x_mask) * x_mask
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
||||
2, 3
|
||||
)
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||
target = codes[1:].transpose(0, 1)
|
||||
if not infer:
|
||||
logits = logits.reshape(-1, self.dims)
|
||||
@ -842,7 +790,7 @@ class SynthesizerTrn(nn.Module):
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v2",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
@ -894,9 +842,7 @@ class SynthesizerTrn(nn.Module):
|
||||
# 16,
|
||||
# gin_channels=gin_channels,
|
||||
# )
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
# self.version=os.environ.get("version","v1")
|
||||
if self.version == "v1":
|
||||
@ -921,9 +867,9 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
|
||||
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
|
||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||
if (self.version == "v1"):
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
@ -933,9 +879,7 @@ class SynthesizerTrn(nn.Module):
|
||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, text, ge, speed
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
@ -949,11 +893,9 @@ class SynthesizerTrn(nn.Module):
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class CFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,dit
|
||||
):
|
||||
def __init__(self, in_channels, dit):
|
||||
super().__init__()
|
||||
# self.sigma_min = 1e-6
|
||||
|
||||
@ -963,27 +905,34 @@ class CFM(torch.nn.Module):
|
||||
|
||||
# self.criterion = torch.nn.MSELoss()
|
||||
|
||||
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0):
|
||||
def forward(
|
||||
self,
|
||||
mu: torch.Tensor,
|
||||
x_lens: torch.LongTensor,
|
||||
prompt: torch.Tensor,
|
||||
n_timesteps: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
"""Forward diffusion"""
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype)
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
|
||||
|
||||
ntimesteps = int(n_timesteps)
|
||||
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
||||
x[..., :prompt_len] = 0.0
|
||||
mu=mu.transpose(2,1)
|
||||
t = torch.tensor(0.0,dtype=x.dtype,device=x.device)
|
||||
d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device)
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
mu = mu.transpose(2, 1)
|
||||
t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
d = torch.tensor(1.0 / ntimesteps, dtype=x.dtype, device=x.device)
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
||||
|
||||
for j in range(ntimesteps):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu).transpose(2, 1)
|
||||
# if inference_cfg_rate>1e-5:
|
||||
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||
@ -995,24 +944,28 @@ class CFM(torch.nn.Module):
|
||||
|
||||
def set_no_grad(net_g):
|
||||
for name, param in net_g.named_parameters():
|
||||
param.requires_grad=False
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def compile_codes_length(codes):
|
||||
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
|
||||
return y_lengths1 * 2.5 * 1.5
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def compile_ref_length(refer):
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
return refer_lengths
|
||||
|
||||
|
||||
class SynthesizerTrnV3(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
@ -1034,8 +987,8 @@ class SynthesizerTrnV3(nn.Module):
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs):
|
||||
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@ -1056,41 +1009,38 @@ class SynthesizerTrnV3(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.model_dim=512
|
||||
self.model_dim = 512
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
# gin_channels=gin_channels)
|
||||
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == '25hz':
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=ssl_dim,
|
||||
n_q=1,
|
||||
bins=1024
|
||||
)
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
freeze_quantizer
|
||||
inter_channels2=512
|
||||
self.bridge=nn.Sequential(
|
||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
||||
nn.LeakyReLU()
|
||||
)
|
||||
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
if freeze_quantizer==True:
|
||||
inter_channels2 = 512
|
||||
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||
self.cfm = CFM(
|
||||
100,
|
||||
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||
) # text_dim is condition feature dim
|
||||
if freeze_quantizer == True:
|
||||
set_no_grad(self.ssl_proj)
|
||||
set_no_grad(self.quantizer)
|
||||
set_no_grad(self.enc_p)
|
||||
@ -1098,19 +1048,18 @@ class SynthesizerTrnV3(nn.Module):
|
||||
def create_ge(self, refer):
|
||||
refer_lengths = compile_ref_length(refer)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
return ge
|
||||
|
||||
def forward(self, codes, text,ge,speed=1):
|
||||
|
||||
y_lengths1=compile_codes_length(codes)
|
||||
def forward(self, codes, text, ge, speed=1):
|
||||
y_lengths1 = compile_codes_length(codes)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea
|
||||
@ -1118,4 +1067,4 @@ class SynthesizerTrnV3(nn.Module):
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0,1)
|
||||
return codes.transpose(0, 1)
|
||||
|
@ -52,11 +52,7 @@ class ConvReluNorm(nn.Module):
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
@ -156,9 +152,7 @@ class WN(torch.nn.Module):
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
gin_channels, 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
@ -479,9 +473,7 @@ class ConvFlow(nn.Module):
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
||||
self.proj = nn.Conv1d(
|
||||
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
||||
)
|
||||
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
@ -495,9 +487,7 @@ class ConvFlow(nn.Module):
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
||||
self.filter_channels
|
||||
)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||
@ -616,9 +606,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||
|
||||
self.attention = ScaledDotProductAttention(
|
||||
temperature=np.power(d_model, 0.5), dropout=dropout
|
||||
)
|
||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
|
||||
|
||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
@ -649,9 +637,7 @@ class MultiHeadAttention(nn.Module):
|
||||
output, attn = self.attention(q, k, v, mask=slf_mask)
|
||||
|
||||
output = output.view(n_head, sz_b, len_x, d_v)
|
||||
output = (
|
||||
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
|
||||
) # b x lq x (n*dv)
|
||||
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
|
||||
|
||||
output = self.fc(output)
|
||||
|
||||
@ -741,9 +727,7 @@ class MelStyleEncoder(nn.Module):
|
||||
if mask is not None:
|
||||
mask = (mask.int() == 0).squeeze(1)
|
||||
max_len = x.shape[1]
|
||||
slf_attn_mask = (
|
||||
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
||||
)
|
||||
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
||||
|
||||
# spectral
|
||||
x = self.spectral(x)
|
||||
@ -785,9 +769,7 @@ class MelStyleEncoderVAE(nn.Module):
|
||||
mu = self.fc1(enc_out)
|
||||
logvar = self.fc2(enc_out)
|
||||
posterior = D.Normal(mu, torch.exp(logvar))
|
||||
kl_divergence = D.kl_divergence(
|
||||
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
|
||||
)
|
||||
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
|
||||
loss_kl = kl_divergence.mean()
|
||||
|
||||
z = posterior.rsample()
|
||||
@ -825,9 +807,7 @@ class ActNorm(nn.Module):
|
||||
|
||||
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
|
||||
if x_mask is None:
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
if not self.initialized:
|
||||
self.initialize(x, x_mask)
|
||||
@ -856,9 +836,7 @@ class ActNorm(nn.Module):
|
||||
v = m_sq - (m**2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (
|
||||
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
)
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
||||
|
||||
self.bias.data.copy_(bias_init)
|
||||
@ -873,9 +851,7 @@ class InvConvNear(nn.Module):
|
||||
self.n_split = n_split
|
||||
self.no_jacobian = no_jacobian
|
||||
|
||||
w_init = torch.linalg.qr(
|
||||
torch.FloatTensor(self.n_split, self.n_split).normal_()
|
||||
)[0]
|
||||
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
self.weight = nn.Parameter(w_init)
|
||||
@ -890,11 +866,7 @@ class InvConvNear(nn.Module):
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
|
||||
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
|
||||
x = (
|
||||
x.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(b, self.n_split, c // self.n_split, t)
|
||||
)
|
||||
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
|
||||
|
||||
if reverse:
|
||||
if hasattr(self, "weight_inv"):
|
||||
|
@ -31,32 +31,15 @@ class MRTE(nn.Module):
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
if test != None:
|
||||
if test == 0:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
elif test == 1:
|
||||
x = ssl_enc + ge
|
||||
elif test == 2:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
|
||||
else:
|
||||
raise ValueError("test should be 0,1,2")
|
||||
else:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
||||
@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
|
||||
model_embedding_size=256,
|
||||
):
|
||||
super(SpeakerEncoder, self).__init__()
|
||||
self.lstm = nn.LSTM(
|
||||
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
|
||||
)
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
|
@ -87,14 +87,10 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
raise ValueError(
|
||||
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(
|
||||
x, n_q=n_q, layers=layers
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
|
||||
return quantized, codes, torch.mean(commit_loss), quantized_list
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
|
@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
**spline_kwargs,
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
@ -175,8 +175,7 @@ def rational_quadratic_spline(
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2)
|
||||
@ -190,12 +189,9 @@ def rational_quadratic_spline(
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (
|
||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
||||
)
|
||||
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
|
@ -1,22 +1,22 @@
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
from feature_extractor import cnhubert
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from torch import nn
|
||||
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
ssl_model = cnhubert.get_model()
|
||||
from text import cleaned_text_to_sequence
|
||||
import soundfile
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import soundfile
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
hann_window = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@ -101,22 +101,22 @@ class T2SModel(nn.Module):
|
||||
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
||||
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
||||
self.stage_decoder = self.t2s_model.stage_decoder
|
||||
#self.t2s_model = torch.jit.script(self.t2s_model)
|
||||
# self.t2s_model = torch.jit.script(self.t2s_model)
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||
early_stop_num = self.t2s_model.early_stop_num
|
||||
|
||||
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
prefix_len = prompts.shape[1]
|
||||
|
||||
#[1,N,512] [1,N]
|
||||
# [1,N,512] [1,N]
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
stop = False
|
||||
for idx in range(1, 1500):
|
||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
y, k, v, y_emb, logits, samples = enco
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
@ -130,13 +130,11 @@ class T2SModel(nn.Module):
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
||||
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||
if dynamo:
|
||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
||||
self.onnx_encoder,
|
||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||
export_options=export_options
|
||||
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
||||
)
|
||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||
return
|
||||
@ -148,13 +146,13 @@ class T2SModel(nn.Module):
|
||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||
output_names=["x", "prompts"],
|
||||
dynamic_axes={
|
||||
"ref_seq": {1 : "ref_length"},
|
||||
"text_seq": {1 : "text_length"},
|
||||
"ref_bert": {0 : "ref_length"},
|
||||
"text_bert": {0 : "text_length"},
|
||||
"ssl_content": {2 : "ssl_length"},
|
||||
"ref_seq": {1: "ref_length"},
|
||||
"text_seq": {1: "text_length"},
|
||||
"ref_bert": {0: "ref_length"},
|
||||
"text_bert": {0: "text_length"},
|
||||
"ssl_content": {2: "ssl_length"},
|
||||
},
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
@ -165,11 +163,11 @@ class T2SModel(nn.Module):
|
||||
input_names=["x", "prompts"],
|
||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||
dynamic_axes={
|
||||
"x": {1 : "x_length"},
|
||||
"prompts": {1 : "prompts_length"},
|
||||
"x": {1: "x_length"},
|
||||
"prompts": {1: "prompts_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
@ -180,23 +178,23 @@ class T2SModel(nn.Module):
|
||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||
dynamic_axes={
|
||||
"iy": {1 : "iy_length"},
|
||||
"ik": {1 : "ik_length"},
|
||||
"iv": {1 : "iv_length"},
|
||||
"iy_emb": {1 : "iy_emb_length"},
|
||||
"ix_example": {1 : "ix_example_length"},
|
||||
"iy": {1: "iy_length"},
|
||||
"ik": {1: "ik_length"},
|
||||
"iv": {1: "iv_length"},
|
||||
"iy_emb": {1: "iy_emb_length"},
|
||||
"ix_example": {1: "ix_example_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path, map_location="cpu")
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
@ -207,7 +205,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
@ -219,7 +217,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||
|
||||
@ -235,12 +233,16 @@ class GptSoVits(nn.Module):
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||
if debug:
|
||||
import onnxruntime
|
||||
|
||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||
audio1 = sess.run(None, {
|
||||
"text_seq" : text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio" : ref_audio.detach().cpu().numpy()
|
||||
})
|
||||
audio1 = sess.run(
|
||||
None,
|
||||
{
|
||||
"text_seq": text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio": ref_audio.detach().cpu().numpy(),
|
||||
},
|
||||
)
|
||||
return audio, audio1
|
||||
return audio
|
||||
|
||||
@ -254,12 +256,12 @@ class GptSoVits(nn.Module):
|
||||
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"text_seq": {1 : "text_length"},
|
||||
"pred_semantic": {2 : "pred_length"},
|
||||
"ref_audio": {1 : "audio_length"},
|
||||
"text_seq": {1: "text_length"},
|
||||
"pred_semantic": {2: "pred_length"},
|
||||
"ref_audio": {1: "audio_length"},
|
||||
},
|
||||
opset_version=17,
|
||||
verbose=False
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@ -277,14 +279,67 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
gpt = T2SModel(gpt_path, vits)
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
ssl = SSLModel()
|
||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||
ref_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"n",
|
||||
"i2",
|
||||
"h",
|
||||
"ao3",
|
||||
",",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
text_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
|
||||
|
||||
try:
|
||||
os.mkdir(f"onnx/{project_name}")
|
||||
@ -325,8 +380,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
}
|
||||
|
||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
|
||||
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -12,8 +12,9 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
version = os.environ.get('version', None)
|
||||
version = os.environ.get("version", None)
|
||||
import traceback
|
||||
import os.path
|
||||
from text.cleaner import clean_text
|
||||
@ -33,13 +34,13 @@ from time import time as ttime
|
||||
import shutil
|
||||
|
||||
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
||||
tmp_path="%s%s.pth"%(ttime(),i_part)
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
||||
@ -53,8 +54,10 @@ if os.path.exists(txt_path) == False:
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
if os.path.exists(bert_pretrained_dir):...
|
||||
else:raise FileNotFoundError(bert_pretrained_dir)
|
||||
if os.path.exists(bert_pretrained_dir):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(bert_pretrained_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
||||
if is_half == True:
|
||||
@ -83,12 +86,10 @@ if os.path.exists(txt_path) == False:
|
||||
def process(data, res):
|
||||
for name, text, lan in data:
|
||||
try:
|
||||
name=clean_path(name)
|
||||
name = clean_path(name)
|
||||
name = os.path.basename(name)
|
||||
print(name)
|
||||
phones, word2ph, norm_text = clean_text(
|
||||
text.replace("%", "-").replace("¥", ","), lan, version
|
||||
)
|
||||
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
|
||||
path_bert = "%s/%s.pt" % (bert_dir, name)
|
||||
if os.path.exists(path_bert) == False and lan == "zh":
|
||||
bert_feature = get_bert_feature(norm_text, word2ph)
|
||||
@ -128,9 +129,7 @@ if os.path.exists(txt_path) == False:
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
# todo.append([name,text,"zh"])
|
||||
if language in language_v1_to_language_v2.keys():
|
||||
todo.append(
|
||||
[wav_name, text, language_v1_to_language_v2.get(language, language)]
|
||||
)
|
||||
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
|
||||
else:
|
||||
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
||||
except:
|
||||
|
@ -2,26 +2,30 @@
|
||||
|
||||
import sys
|
||||
import os
|
||||
inp_text= os.environ.get("inp_text")
|
||||
inp_wav_dir= os.environ.get("inp_wav_dir")
|
||||
exp_name= os.environ.get("exp_name")
|
||||
i_part= os.environ.get("i_part")
|
||||
all_parts= os.environ.get("all_parts")
|
||||
|
||||
inp_text = os.environ.get("inp_text")
|
||||
inp_wav_dir = os.environ.get("inp_wav_dir")
|
||||
exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
from feature_extractor import cnhubert
|
||||
opt_dir= os.environ.get("opt_dir")
|
||||
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
|
||||
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
|
||||
import traceback
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import librosa
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from tools.my_utils import load_audio,clean_path
|
||||
from tools.my_utils import load_audio, clean_path
|
||||
|
||||
# from config import cnhubert_base_path
|
||||
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||
@ -36,90 +40,95 @@ from tools.my_utils import load_audio,clean_path
|
||||
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
|
||||
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
||||
tmp_path="%s%s.pth"%(ttime(),i_part)
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
hubert_dir="%s/4-cnhubert"%(opt_dir)
|
||||
wav32dir="%s/5-wav32k"%(opt_dir)
|
||||
os.makedirs(opt_dir,exist_ok=True)
|
||||
os.makedirs(hubert_dir,exist_ok=True)
|
||||
os.makedirs(wav32dir,exist_ok=True)
|
||||
|
||||
maxx=0.95
|
||||
alpha=0.5
|
||||
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
||||
wav32dir = "%s/5-wav32k" % (opt_dir)
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
os.makedirs(hubert_dir, exist_ok=True)
|
||||
os.makedirs(wav32dir, exist_ok=True)
|
||||
|
||||
maxx = 0.95
|
||||
alpha = 0.5
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
# elif torch.backends.mps.is_available():
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
model=cnhubert.get_model()
|
||||
model = cnhubert.get_model()
|
||||
# is_half=False
|
||||
if(is_half==True):
|
||||
model=model.half().to(device)
|
||||
if is_half == True:
|
||||
model = model.half().to(device)
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
nan_fails=[]
|
||||
def name2go(wav_name,wav_path):
|
||||
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
|
||||
if(os.path.exists(hubert_path)):return
|
||||
nan_fails = []
|
||||
|
||||
|
||||
def name2go(wav_name, wav_path):
|
||||
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
||||
if os.path.exists(hubert_path):
|
||||
return
|
||||
tmp_audio = load_audio(wav_path, 32000)
|
||||
tmp_max = np.abs(tmp_audio).max()
|
||||
if tmp_max > 2.2:
|
||||
print("%s-filtered,%s" % (wav_name, tmp_max))
|
||||
return
|
||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
|
||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
|
||||
tmp_audio = librosa.resample(
|
||||
tmp_audio32b, orig_sr=32000, target_sr=16000
|
||||
)#不是重采样问题
|
||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
||||
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
|
||||
tensor_wav16 = torch.from_numpy(tmp_audio)
|
||||
if (is_half == True):
|
||||
tensor_wav16=tensor_wav16.half().to(device)
|
||||
if is_half == True:
|
||||
tensor_wav16 = tensor_wav16.half().to(device)
|
||||
else:
|
||||
tensor_wav16 = tensor_wav16.to(device)
|
||||
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
|
||||
if np.isnan(ssl.detach().numpy()).sum()!= 0:
|
||||
nan_fails.append((wav_name,wav_path))
|
||||
print("nan filtered:%s"%wav_name)
|
||||
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
|
||||
if np.isnan(ssl.detach().numpy()).sum() != 0:
|
||||
nan_fails.append((wav_name, wav_path))
|
||||
print("nan filtered:%s" % wav_name)
|
||||
return
|
||||
wavfile.write(
|
||||
"%s/%s"%(wav32dir,wav_name),
|
||||
"%s/%s" % (wav32dir, wav_name),
|
||||
32000,
|
||||
tmp_audio32.astype("int16"),
|
||||
)
|
||||
my_save(ssl,hubert_path)
|
||||
my_save(ssl, hubert_path)
|
||||
|
||||
with open(inp_text,"r",encoding="utf8")as f:
|
||||
lines=f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines[int(i_part)::int(all_parts)]:
|
||||
with open(inp_text, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines[int(i_part) :: int(all_parts)]:
|
||||
try:
|
||||
# wav_name,text=line.split("\t")
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
wav_name=clean_path(wav_name)
|
||||
if (inp_wav_dir != "" and inp_wav_dir != None):
|
||||
wav_name = clean_path(wav_name)
|
||||
if inp_wav_dir != "" and inp_wav_dir != None:
|
||||
wav_name = os.path.basename(wav_name)
|
||||
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
|
||||
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
||||
|
||||
else:
|
||||
wav_path=wav_name
|
||||
wav_path = wav_name
|
||||
wav_name = os.path.basename(wav_name)
|
||||
name2go(wav_name,wav_path)
|
||||
name2go(wav_name, wav_path)
|
||||
except:
|
||||
print(line,traceback.format_exc())
|
||||
print(line, traceback.format_exc())
|
||||
|
||||
if(len(nan_fails)>0 and is_half==True):
|
||||
is_half=False
|
||||
model=model.float()
|
||||
if len(nan_fails) > 0 and is_half == True:
|
||||
is_half = False
|
||||
model = model.float()
|
||||
for wav in nan_fails:
|
||||
try:
|
||||
name2go(wav[0],wav[1])
|
||||
name2go(wav[0], wav[1])
|
||||
except:
|
||||
print(wav_name,traceback.format_exc())
|
||||
print(wav_name, traceback.format_exc())
|
||||
|
@ -10,8 +10,10 @@ opt_dir = os.environ.get("opt_dir")
|
||||
pretrained_s2G = os.environ.get("pretrained_s2G")
|
||||
s2config_path = os.environ.get("s2config_path")
|
||||
|
||||
if os.path.exists(pretrained_s2G):...
|
||||
else:raise FileNotFoundError(pretrained_s2G)
|
||||
if os.path.exists(pretrained_s2G):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(pretrained_s2G)
|
||||
# version=os.environ.get("version","v2")
|
||||
size = os.path.getsize(pretrained_s2G)
|
||||
if size < 82978 * 1024:
|
||||
@ -25,6 +27,7 @@ elif size < 700 * 1024 * 1024:
|
||||
else:
|
||||
version = "v3"
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
import traceback
|
||||
import sys
|
||||
@ -33,11 +36,13 @@ now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import logging
|
||||
import utils
|
||||
if version!="v3":
|
||||
|
||||
if version != "v3":
|
||||
from module.models import SynthesizerTrn
|
||||
else:
|
||||
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
# from config import pretrained_s2G
|
||||
|
||||
@ -66,7 +71,7 @@ if os.path.exists(semantic_path) == False:
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
version=version,
|
||||
**hps.model
|
||||
**hps.model,
|
||||
)
|
||||
if is_half == True:
|
||||
vq_model = vq_model.half().to(device)
|
||||
@ -103,7 +108,7 @@ if os.path.exists(semantic_path) == False:
|
||||
try:
|
||||
# wav_name,text=line.split("\t")
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
wav_name=clean_path(wav_name)
|
||||
wav_name = clean_path(wav_name)
|
||||
wav_name = os.path.basename(wav_name)
|
||||
# name2go(name,lines1)
|
||||
name2go(wav_name, lines1)
|
||||
|
@ -8,31 +8,37 @@ from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
tmp_path="%s.pth"%(ttime())
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
|
||||
'''
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
tmp_path = "%s.pth" % (ttime())
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
"""
|
||||
00:v1
|
||||
01:v2
|
||||
02:v3
|
||||
03:v3lora
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
from io import BytesIO
|
||||
def my_save2(fea,path):
|
||||
|
||||
|
||||
def my_save2(fea, path):
|
||||
bio = BytesIO()
|
||||
torch.save(fea, bio)
|
||||
bio.seek(0)
|
||||
data = bio.getvalue()
|
||||
data = b'03' + data[2:]###temp for v3lora only, todo
|
||||
with open(path, "wb") as f: f.write(data)
|
||||
data = b"03" + data[2:] ###temp for v3lora only, todo
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||
try:
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
@ -43,7 +49,7 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
opt["config"] = hps
|
||||
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
||||
if lora_rank:
|
||||
opt["lora_rank"]=lora_rank
|
||||
opt["lora_rank"] = lora_rank
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
else:
|
||||
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
@ -51,41 +57,48 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
head2version={
|
||||
b'00':["v1","v1",False],
|
||||
b'01':["v2","v2",False],
|
||||
b'02':["v2","v3",False],
|
||||
b'03':["v2","v3",True],
|
||||
|
||||
head2version = {
|
||||
b"00": ["v1", "v1", False],
|
||||
b"01": ["v2", "v2", False],
|
||||
b"02": ["v2", "v3", False],
|
||||
b"03": ["v2", "v3", True],
|
||||
}
|
||||
hash_pretrained_dict={
|
||||
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
|
||||
hash_pretrained_dict = {
|
||||
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
||||
}
|
||||
import hashlib
|
||||
|
||||
|
||||
def get_hash_from_file(sovits_path):
|
||||
with open(sovits_path,"rb")as f:data=f.read(8192)
|
||||
with open(sovits_path, "rb") as f:
|
||||
data = f.read(8192)
|
||||
hash_md5 = hashlib.md5()
|
||||
hash_md5.update(data)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def get_sovits_version_from_path_fast(sovits_path):
|
||||
###1-if it is pretrained sovits models, by hash
|
||||
hash=get_hash_from_file(sovits_path)
|
||||
hash = get_hash_from_file(sovits_path)
|
||||
if hash in hash_pretrained_dict:
|
||||
return hash_pretrained_dict[hash]
|
||||
###2-new weights or old weights, by head
|
||||
with open(sovits_path,"rb")as f:version=f.read(2)
|
||||
if version!=b"PK":
|
||||
with open(sovits_path, "rb") as f:
|
||||
version = f.read(2)
|
||||
if version != b"PK":
|
||||
return head2version[version]
|
||||
###3-old weights, by file size
|
||||
if_lora_v3=False
|
||||
size=os.path.getsize(sovits_path)
|
||||
'''
|
||||
if_lora_v3 = False
|
||||
size = os.path.getsize(sovits_path)
|
||||
"""
|
||||
v1weights:about 82942KB
|
||||
half thr:82978KB
|
||||
v2weights:about 83014KB
|
||||
v3weights:about 750MB
|
||||
'''
|
||||
"""
|
||||
if size < 82978 * 1024:
|
||||
model_version = version = "v1"
|
||||
elif size < 700 * 1024 * 1024:
|
||||
@ -93,15 +106,16 @@ def get_sovits_version_from_path_fast(sovits_path):
|
||||
else:
|
||||
version = "v2"
|
||||
model_version = "v3"
|
||||
return version,model_version,if_lora_v3
|
||||
return version, model_version, if_lora_v3
|
||||
|
||||
|
||||
def load_sovits_new(sovits_path):
|
||||
f=open(sovits_path,"rb")
|
||||
meta=f.read(2)
|
||||
if meta!="PK":
|
||||
data = b'PK' + f.read()
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != "PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
bio.seek(0)
|
||||
return torch.load(bio, map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path,map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
|
@ -5,25 +5,24 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
import argparse
|
||||
import logging
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import platform
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from AR.data.data_module import Text2SemanticDataModule
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from AR.utils.io import load_yaml_config
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision("high")
|
||||
from AR.utils import get_newest_ckpt
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from AR.utils import get_newest_ckpt
|
||||
from process_ckpt import my_save
|
||||
|
||||
|
||||
@ -35,7 +34,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
if_save_every_weights,
|
||||
half_weights_save_dir,
|
||||
exp_name,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.if_save_latest = if_save_latest
|
||||
@ -48,10 +47,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
||||
if self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if (
|
||||
self._every_n_epochs >= 1
|
||||
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
||||
):
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
if (
|
||||
self.if_save_latest == True
|
||||
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
@ -73,7 +69,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||
# torch.save(
|
||||
# print(os.environ)
|
||||
if(os.environ.get("LOCAL_RANK","0")=="0"):
|
||||
if os.environ.get("LOCAL_RANK", "0") == "0":
|
||||
my_save(
|
||||
to_save_od,
|
||||
"%s/%s-e%s.ckpt"
|
||||
@ -110,7 +106,7 @@ def main(args):
|
||||
dirpath=ckpt_dir,
|
||||
)
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
os.environ["MASTER_ADDR"]="localhost"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
@ -121,9 +117,9 @@ def main(args):
|
||||
devices=-1 if torch.cuda.is_available() else 1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy = DDPStrategy(
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||
) if torch.cuda.is_available() else "auto",
|
||||
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||
if torch.cuda.is_available()
|
||||
else "auto",
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
@ -131,9 +127,7 @@ def main(args):
|
||||
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
|
||||
)
|
||||
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
||||
config, output_dir
|
||||
)
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
|
||||
|
||||
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
||||
config,
|
||||
|
@ -1,37 +1,41 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import utils
|
||||
import os
|
||||
|
||||
import utils
|
||||
|
||||
hps = utils.get_hparams(stage=2)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.INFO)
|
||||
from random import randint
|
||||
from module import commons
|
||||
|
||||
from module import commons
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollate,
|
||||
DistributedBucketSampler,
|
||||
TextAudioSpeakerCollate,
|
||||
TextAudioSpeakerLoader,
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrn,
|
||||
MultiPeriodDiscriminator,
|
||||
)
|
||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
||||
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from module.models import (
|
||||
MultiPeriodDiscriminator,
|
||||
SynthesizerTrn,
|
||||
)
|
||||
from process_ckpt import savee
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@ -47,7 +51,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
@ -75,7 +78,7 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
@ -129,19 +132,27 @@ def run(rank, n_gpus, hps):
|
||||
# batch_size=1, pin_memory=True,
|
||||
# drop_last=False, collate_fn=collate_fn)
|
||||
|
||||
net_g = SynthesizerTrn(
|
||||
net_g = (
|
||||
SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
|
||||
).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to(device)
|
||||
)
|
||||
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||
net_d = (
|
||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||
)
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
print(name, "not requires_grad")
|
||||
@ -194,7 +205,7 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
try: # 如果能加载自动resume
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
|
||||
net_d,
|
||||
optim_d,
|
||||
) # D多半加载没事
|
||||
@ -202,11 +213,11 @@ def run(rank, n_gpus, hps):
|
||||
logger.info("loaded D")
|
||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
epoch_str+=1
|
||||
epoch_str += 1
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
# epoch_str = 1
|
||||
# global_step = 0
|
||||
@ -214,37 +225,55 @@ def run(rank, n_gpus, hps):
|
||||
# traceback.print_exc()
|
||||
epoch_str = 1
|
||||
global_step = 0
|
||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||
if (
|
||||
hps.train.pretrained_s2G != ""
|
||||
and hps.train.pretrained_s2G != None
|
||||
and os.path.exists(hps.train.pretrained_s2G)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
) if torch.cuda.is_available() else net_g.load_state_dict(
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
),
|
||||
) ##测试不加载优化器
|
||||
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
||||
if (
|
||||
hps.train.pretrained_s2D != ""
|
||||
and hps.train.pretrained_s2D != None
|
||||
and os.path.exists(hps.train.pretrained_s2D)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||
net_d.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
||||
) if torch.cuda.is_available() else net_d.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_d.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||
),
|
||||
)
|
||||
|
||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
optim_g,
|
||||
gamma=hps.train.lr_decay,
|
||||
last_epoch=-1,
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
optim_d,
|
||||
gamma=hps.train.lr_decay,
|
||||
last_epoch=-1,
|
||||
)
|
||||
for _ in range(epoch_str):
|
||||
scheduler_g.step()
|
||||
@ -286,9 +315,7 @@ def run(rank, n_gpus, hps):
|
||||
print("training done")
|
||||
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
):
|
||||
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
@ -312,17 +339,38 @@ def train_and_evaluate(
|
||||
text_lengths,
|
||||
) in enumerate(tqdm(train_loader)):
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
y, y_lengths = (
|
||||
y.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
y_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||
@ -351,9 +399,7 @@ def train_and_evaluate(
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
hps.data.filter_length,
|
||||
@ -365,15 +411,14 @@ def train_and_evaluate(
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
|
||||
y = commons.slice_segments(
|
||||
y, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||
with autocast(enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
y_d_hat_r,
|
||||
y_d_hat_g,
|
||||
)
|
||||
loss_disc_all = loss_disc
|
||||
optim_d.zero_grad()
|
||||
@ -406,7 +451,8 @@ def train_and_evaluate(
|
||||
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch, 100.0 * batch_idx / len(train_loader)
|
||||
epoch,
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
@ -430,25 +476,37 @@ def train_and_evaluate(
|
||||
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
||||
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
||||
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
||||
image_dict=None
|
||||
try:###Some people installed the wrong version of matplotlib.
|
||||
image_dict = None
|
||||
try: ###Some people installed the wrong version of matplotlib.
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy()
|
||||
y_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy()
|
||||
y_hat_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy()
|
||||
mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
||||
stats_ssl[0].data.cpu().numpy()
|
||||
stats_ssl[0].data.cpu().numpy(),
|
||||
),
|
||||
}
|
||||
except:pass
|
||||
if image_dict:utils.summarize(writer=writer,global_step=global_step,images=image_dict,scalars=scalar_dict,)
|
||||
else:utils.summarize(writer=writer,global_step=global_step,scalars=scalar_dict,)
|
||||
except:
|
||||
pass
|
||||
if image_dict:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
else:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
@ -458,7 +516,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(global_step),
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
@ -467,7 +526,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"D_{}.pth".format(global_step),
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -477,7 +537,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(233333333333),
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
@ -486,7 +547,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"D_{}.pth".format(233333333333),
|
||||
),
|
||||
)
|
||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||
@ -541,10 +603,24 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
ssl = ssl.to(device)
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
for test in [0, 1]:
|
||||
y_hat, mask, *_ = generator.module.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
) if torch.cuda.is_available() else generator.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
y_hat, mask, *_ = (
|
||||
generator.module.infer(
|
||||
ssl,
|
||||
spec,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
test=test,
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else generator.infer(
|
||||
ssl,
|
||||
spec,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
test=test,
|
||||
)
|
||||
)
|
||||
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
||||
|
||||
@ -569,19 +645,19 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
image_dict.update(
|
||||
{
|
||||
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].cpu().numpy()
|
||||
)
|
||||
y_hat_mel[0].cpu().numpy(),
|
||||
),
|
||||
}
|
||||
)
|
||||
audio_dict.update(
|
||||
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
|
||||
{
|
||||
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
|
||||
},
|
||||
)
|
||||
image_dict.update(
|
||||
{
|
||||
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].cpu().numpy()
|
||||
)
|
||||
}
|
||||
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
|
||||
},
|
||||
)
|
||||
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
||||
|
||||
|
@ -1,29 +1,37 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import utils
|
||||
import os
|
||||
|
||||
import utils
|
||||
|
||||
hps = utils.get_hparams(stage=2)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.INFO)
|
||||
from random import randint
|
||||
from module import commons
|
||||
|
||||
from module import commons
|
||||
from module.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrnV3 as SynthesizerTrn,
|
||||
@ -43,7 +51,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
@ -71,7 +78,7 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
@ -125,17 +132,21 @@ def run(rank, n_gpus, hps):
|
||||
# batch_size=1, pin_memory=True,
|
||||
# drop_last=False, collate_fn=collate_fn)
|
||||
|
||||
net_g = SynthesizerTrn(
|
||||
net_g = (
|
||||
SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
|
||||
).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to(device)
|
||||
)
|
||||
|
||||
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||
# for name, param in net_g.named_parameters():
|
||||
@ -143,7 +154,7 @@ def run(rank, n_gpus, hps):
|
||||
# print(name, "not requires_grad")
|
||||
|
||||
optim_g = torch.optim.AdamW(
|
||||
filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
|
||||
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
|
||||
hps.train.learning_rate,
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps,
|
||||
@ -171,11 +182,11 @@ def run(rank, n_gpus, hps):
|
||||
# logger.info("loaded D")
|
||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
epoch_str+=1
|
||||
epoch_str += 1
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
# epoch_str = 1
|
||||
# global_step = 0
|
||||
@ -183,17 +194,24 @@ def run(rank, n_gpus, hps):
|
||||
# traceback.print_exc()
|
||||
epoch_str = 1
|
||||
global_step = 0
|
||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||
if (
|
||||
hps.train.pretrained_s2G != ""
|
||||
and hps.train.pretrained_s2G != None
|
||||
and os.path.exists(hps.train.pretrained_s2G)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
) if torch.cuda.is_available() else net_g.load_state_dict(
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
),
|
||||
) ##测试不加载优化器
|
||||
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
||||
# if rank == 0:
|
||||
@ -209,9 +227,7 @@ def run(rank, n_gpus, hps):
|
||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
# )
|
||||
@ -221,7 +237,7 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
scaler = GradScaler(enabled=hps.train.fp16_run)
|
||||
|
||||
net_d=optim_d=scheduler_d=None
|
||||
net_d = optim_d = scheduler_d = None
|
||||
print("start training from epoch %s" % epoch_str)
|
||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||
if rank == 0:
|
||||
@ -257,7 +273,16 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
rank,
|
||||
epoch,
|
||||
hps,
|
||||
nets,
|
||||
optims,
|
||||
schedulers,
|
||||
scaler,
|
||||
loaders,
|
||||
logger,
|
||||
writers,
|
||||
):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
@ -281,19 +306,33 @@ def train_and_evaluate(
|
||||
# text,
|
||||
# text_lengths,
|
||||
# ) in enumerate(tqdm(train_loader)):
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
tqdm(train_loader)
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||
@ -304,8 +343,18 @@ def train_and_evaluate(
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
||||
loss_gen_all=cfm_loss
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
@ -315,12 +364,15 @@ def train_and_evaluate(
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]['lr']
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
||||
losses = [cfm_loss]
|
||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch,
|
||||
100. * batch_idx / len(train_loader)))
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
@ -334,7 +386,8 @@ def train_and_evaluate(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
# images=image_dict,
|
||||
scalars=scalar_dict)
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
# if global_step % hps.train.eval_interval == 0:
|
||||
# # evaluate(hps, net_g, eval_loader, writer_eval)
|
||||
@ -344,7 +397,6 @@ def train_and_evaluate(
|
||||
# # if keep_ckpts > 0:
|
||||
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
|
||||
|
||||
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
@ -354,7 +406,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(global_step),
|
||||
),
|
||||
)
|
||||
# utils.save_checkpoint(
|
||||
@ -373,7 +426,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(233333333333),
|
||||
),
|
||||
)
|
||||
# utils.save_checkpoint(
|
||||
|
@ -1,35 +1,45 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import utils
|
||||
import os
|
||||
|
||||
import utils
|
||||
|
||||
hps = utils.get_hparams(stage=2)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.INFO)
|
||||
from collections import OrderedDict as od
|
||||
from random import randint
|
||||
|
||||
from module import commons
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from module.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrnV3 as SynthesizerTrn,
|
||||
)
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from process_ckpt import savee
|
||||
from collections import OrderedDict as od
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = False
|
||||
###反正A100fp32更快,那试试tf32吧
|
||||
@ -43,7 +53,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
@ -62,7 +71,7 @@ def main():
|
||||
|
||||
|
||||
def run(rank, n_gpus, hps):
|
||||
global global_step,no_grad_names,save_root,lora_rank
|
||||
global global_step, no_grad_names, save_root, lora_rank
|
||||
if rank == 0:
|
||||
logger = utils.get_logger(hps.data.exp_dir)
|
||||
logger.info(hps)
|
||||
@ -71,7 +80,7 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
@ -119,21 +128,24 @@ def run(rank, n_gpus, hps):
|
||||
persistent_workers=True,
|
||||
prefetch_factor=4,
|
||||
)
|
||||
save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
|
||||
os.makedirs(save_root,exist_ok=True)
|
||||
lora_rank=int(hps.train.lora_rank)
|
||||
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
lora_rank = int(hps.train.lora_rank)
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank,
|
||||
init_lora_weights=True,
|
||||
)
|
||||
def get_model(hps):return SynthesizerTrn(
|
||||
|
||||
def get_model(hps):
|
||||
return SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
)
|
||||
|
||||
def get_optim(net_g):
|
||||
return torch.optim.AdamW(
|
||||
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
|
||||
@ -141,61 +153,66 @@ def run(rank, n_gpus, hps):
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps,
|
||||
)
|
||||
def model2cuda(net_g,rank):
|
||||
|
||||
def model2cuda(net_g, rank):
|
||||
if torch.cuda.is_available():
|
||||
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
|
||||
else:
|
||||
net_g = net_g.to(device)
|
||||
return net_g
|
||||
try:# 如果能加载自动resume
|
||||
|
||||
try: # 如果能加载自动resume
|
||||
net_g = get_model(hps)
|
||||
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
||||
net_g=model2cuda(net_g,rank)
|
||||
optim_g=get_optim(net_g)
|
||||
net_g = model2cuda(net_g, rank)
|
||||
optim_g = get_optim(net_g)
|
||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path(save_root, "G_*.pth"),
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
epoch_str+=1
|
||||
epoch_str += 1
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
except: # 如果首次不能加载,加载pretrain
|
||||
# traceback.print_exc()
|
||||
epoch_str = 1
|
||||
global_step = 0
|
||||
net_g = get_model(hps)
|
||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||
if (
|
||||
hps.train.pretrained_s2G != ""
|
||||
and hps.train.pretrained_s2G != None
|
||||
and os.path.exists(hps.train.pretrained_s2G)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
||||
net_g=model2cuda(net_g,rank)
|
||||
net_g = model2cuda(net_g, rank)
|
||||
optim_g = get_optim(net_g)
|
||||
|
||||
no_grad_names=set()
|
||||
no_grad_names = set()
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
no_grad_names.add(name.replace("module.",""))
|
||||
no_grad_names.add(name.replace("module.", ""))
|
||||
# print(name, "not requires_grad")
|
||||
# print(no_grad_names)
|
||||
# os._exit(233333)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||
for _ in range(epoch_str):
|
||||
scheduler_g.step()
|
||||
|
||||
scaler = GradScaler(enabled=hps.train.fp16_run)
|
||||
|
||||
net_d=optim_d=scheduler_d=None
|
||||
print("start training from epoch %s"%epoch_str)
|
||||
net_d = optim_d = scheduler_d = None
|
||||
print("start training from epoch %s" % epoch_str)
|
||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||
if rank == 0:
|
||||
train_and_evaluate(
|
||||
@ -227,9 +244,8 @@ def run(rank, n_gpus, hps):
|
||||
scheduler_g.step()
|
||||
print("training done")
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
):
|
||||
|
||||
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
@ -241,18 +257,32 @@ def train_and_evaluate(
|
||||
global global_step
|
||||
|
||||
net_g.train()
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
tqdm(train_loader)
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||
@ -262,8 +292,18 @@ def train_and_evaluate(
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
||||
loss_gen_all=cfm_loss
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
@ -273,18 +313,17 @@ def train_and_evaluate(
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]['lr']
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
losses = [cfm_loss]
|
||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
||||
epoch,
|
||||
100. * batch_idx / len(train_loader)))
|
||||
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict)
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||
@ -294,9 +333,7 @@ def train_and_evaluate(
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
save_root, "G_{}.pth".format(global_step)
|
||||
),
|
||||
os.path.join(save_root, "G_{}.pth".format(global_step)),
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
@ -304,21 +341,19 @@ def train_and_evaluate(
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
save_root, "G_{}.pth".format(233333333333)
|
||||
),
|
||||
os.path.join(save_root, "G_{}.pth".format(233333333333)),
|
||||
)
|
||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||
if hasattr(net_g, "module"):
|
||||
ckpt = net_g.module.state_dict()
|
||||
else:
|
||||
ckpt = net_g.state_dict()
|
||||
sim_ckpt=od()
|
||||
sim_ckpt = od()
|
||||
for key in ckpt:
|
||||
# if "cfm"not in key:
|
||||
# print(key)
|
||||
if key not in no_grad_names:
|
||||
sim_ckpt[key]=ckpt[key].half().cpu()
|
||||
sim_ckpt[key] = ckpt[key].half().cpu()
|
||||
logger.info(
|
||||
"saving ckpt %s_e%s:%s"
|
||||
% (
|
||||
@ -326,10 +361,11 @@ def train_and_evaluate(
|
||||
epoch,
|
||||
savee(
|
||||
sim_ckpt,
|
||||
hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
|
||||
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
|
||||
epoch,
|
||||
global_step,
|
||||
hps,lora_rank=lora_rank
|
||||
hps,
|
||||
lora_rank=lora_rank,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -3,19 +3,25 @@ import re
|
||||
|
||||
# jieba静音
|
||||
import jieba
|
||||
|
||||
jieba.setLogLevel(logging.CRITICAL)
|
||||
|
||||
# 更改fast_langdetect大模型位置
|
||||
from pathlib import Path
|
||||
import fast_langdetect
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
||||
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
|
||||
fast_langdetect.infer.LangDetectConfig(
|
||||
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
from split_lang import LangSplitter
|
||||
|
||||
|
||||
def full_en(text):
|
||||
pattern = r'^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
|
||||
pattern = r"^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
|
||||
return bool(re.match(pattern, text))
|
||||
|
||||
|
||||
@ -34,7 +40,7 @@ def full_cjk(text):
|
||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||
]
|
||||
|
||||
pattern = r'[0-9、-〜。!?.!?… ]+$'
|
||||
pattern = r"[0-9、-〜。!?.!?… ]+$"
|
||||
|
||||
cjk_text = ""
|
||||
for char in text:
|
||||
@ -45,7 +51,7 @@ def full_cjk(text):
|
||||
return cjk_text
|
||||
|
||||
|
||||
def split_jako(tag_lang,item):
|
||||
def split_jako(tag_lang, item):
|
||||
if tag_lang == "ja":
|
||||
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
|
||||
else:
|
||||
@ -53,28 +59,28 @@ def split_jako(tag_lang,item):
|
||||
|
||||
lang_list: list[dict] = []
|
||||
tag = 0
|
||||
for match in re.finditer(pattern, item['text']):
|
||||
for match in re.finditer(pattern, item["text"]):
|
||||
if match.start() > tag:
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
|
||||
|
||||
tag = match.end()
|
||||
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
|
||||
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
|
||||
|
||||
if tag < len(item['text']):
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
|
||||
if tag < len(item["text"]):
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
|
||||
|
||||
return lang_list
|
||||
|
||||
|
||||
def merge_lang(lang_list, item):
|
||||
if lang_list and item['lang'] == lang_list[-1]['lang']:
|
||||
lang_list[-1]['text'] += item['text']
|
||||
if lang_list and item["lang"] == lang_list[-1]["lang"]:
|
||||
lang_list[-1]["text"] += item["text"]
|
||||
else:
|
||||
lang_list.append(item)
|
||||
return lang_list
|
||||
|
||||
|
||||
class LangSegmenter():
|
||||
class LangSegmenter:
|
||||
# 默认过滤器, 基于gsv目前四种语言
|
||||
DEFAULT_LANG_MAP = {
|
||||
"zh": "zh",
|
||||
@ -87,7 +93,6 @@ class LangSegmenter():
|
||||
"en": "en",
|
||||
}
|
||||
|
||||
|
||||
def getTexts(text):
|
||||
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
||||
substr = lang_splitter.split_by_lang(text=text)
|
||||
@ -95,18 +100,18 @@ class LangSegmenter():
|
||||
lang_list: list[dict] = []
|
||||
|
||||
for _, item in enumerate(substr):
|
||||
dict_item = {'lang':item.lang,'text':item.text}
|
||||
dict_item = {"lang": item.lang, "text": item.text}
|
||||
|
||||
# 处理短英文被识别为其他语言的问题
|
||||
if full_en(dict_item['text']):
|
||||
dict_item['lang'] = 'en'
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
if full_en(dict_item["text"]):
|
||||
dict_item["lang"] = "en"
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 处理非日语夹日文的问题(不包含CJK)
|
||||
ja_list: list[dict] = []
|
||||
if dict_item['lang'] != 'ja':
|
||||
ja_list = split_jako('ja',dict_item)
|
||||
if dict_item["lang"] != "ja":
|
||||
ja_list = split_jako("ja", dict_item)
|
||||
|
||||
if not ja_list:
|
||||
ja_list.append(dict_item)
|
||||
@ -115,8 +120,8 @@ class LangSegmenter():
|
||||
ko_list: list[dict] = []
|
||||
temp_list: list[dict] = []
|
||||
for _, ko_item in enumerate(ja_list):
|
||||
if ko_item["lang"] != 'ko':
|
||||
ko_list = split_jako('ko',ko_item)
|
||||
if ko_item["lang"] != "ko":
|
||||
ko_list = split_jako("ko", ko_item)
|
||||
|
||||
if ko_list:
|
||||
temp_list.extend(ko_list)
|
||||
@ -126,26 +131,26 @@ class LangSegmenter():
|
||||
# 未存在非日韩文夹日韩文
|
||||
if len(temp_list) == 1:
|
||||
# 未知语言检查是否为CJK
|
||||
if dict_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(dict_item['text'])
|
||||
if dict_item["lang"] == "x":
|
||||
cjk_text = full_cjk(dict_item["text"])
|
||||
if cjk_text:
|
||||
dict_item = {'lang':'zh','text':cjk_text}
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item = {"lang": "zh", "text": cjk_text}
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 存在非日韩文夹日韩文
|
||||
for _, temp_item in enumerate(temp_list):
|
||||
# 未知语言检查是否为CJK
|
||||
if temp_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(dict_item['text'])
|
||||
if temp_item["lang"] == "x":
|
||||
cjk_text = full_cjk(dict_item["text"])
|
||||
if cjk_text:
|
||||
dict_item = {'lang':'zh','text':cjk_text}
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item = {"lang": "zh", "text": cjk_text}
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
return lang_list
|
||||
|
||||
|
||||
@ -155,4 +160,3 @@ if __name__ == "__main__":
|
||||
|
||||
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
|
||||
print(LangSegmenter.getTexts(text))
|
||||
|
||||
|
@ -10,18 +10,19 @@ from text import symbols2 as symbols_v2
|
||||
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
||||
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
||||
|
||||
|
||||
def cleaned_text_to_sequence(cleaned_text, version=None):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
"""
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
if version == "v1":
|
||||
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
|
||||
else:
|
||||
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
|
||||
|
||||
return phones
|
||||
|
||||
|
@ -98,9 +98,7 @@ def replace_punctuation(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
@ -114,7 +112,9 @@ def text_normalize(text):
|
||||
return dest_text
|
||||
|
||||
|
||||
punctuation_set=set(punctuation)
|
||||
punctuation_set = set(punctuation)
|
||||
|
||||
|
||||
def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||
initials_finals = []
|
||||
tones = []
|
||||
@ -159,12 +159,14 @@ def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||
assert len(initials_finals) == len(tones)
|
||||
|
||||
###魔改为辅音+带音调的元音
|
||||
phones=[]
|
||||
for a,b in zip(initials_finals,tones):
|
||||
if(b not in [-1,0]):###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||
todo="%s%s"%(a,b)
|
||||
else:todo=a
|
||||
if(todo not in punctuation_set):todo="Y%s"%todo
|
||||
phones = []
|
||||
for a, b in zip(initials_finals, tones):
|
||||
if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||
todo = "%s%s" % (a, b)
|
||||
else:
|
||||
todo = a
|
||||
if todo not in punctuation_set:
|
||||
todo = "Y%s" % todo
|
||||
phones.append(todo)
|
||||
|
||||
# return initials_finals, tones, word2ph
|
||||
|
@ -18,6 +18,7 @@ pinyin_to_symbol_map = {
|
||||
|
||||
import jieba_fast
|
||||
import logging
|
||||
|
||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||
import jieba_fast.posseg as psg
|
||||
|
||||
@ -37,7 +38,7 @@ rep_map = {
|
||||
"/": ",",
|
||||
"—": "-",
|
||||
"~": "…",
|
||||
"~":"…",
|
||||
"~": "…",
|
||||
}
|
||||
|
||||
tone_modifier = ToneSandhi()
|
||||
@ -49,9 +50,7 @@ def replace_punctuation(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
@ -62,17 +61,15 @@ def replace_punctuation_with_en(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
@ -87,9 +84,7 @@ def _get_initials_finals(word):
|
||||
initials = []
|
||||
finals = []
|
||||
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||
orig_finals = lazy_pinyin(
|
||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
||||
)
|
||||
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for c, v in zip(orig_initials, orig_finals):
|
||||
initials.append(c)
|
||||
finals.append(v)
|
||||
|
@ -19,17 +19,24 @@ pinyin_to_symbol_map = {
|
||||
|
||||
import jieba_fast
|
||||
import logging
|
||||
|
||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||
import jieba_fast.posseg as psg
|
||||
|
||||
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
|
||||
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
|
||||
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
|
||||
is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False
|
||||
if is_g2pw:
|
||||
# print("当前使用g2pw进行拼音推理")
|
||||
from text.g2pw import G2PWPinyin, correct_pronunciation
|
||||
|
||||
parent_directory = os.path.dirname(current_file_path)
|
||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)
|
||||
g2pw = G2PWPinyin(
|
||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||
model_source=os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
|
||||
rep_map = {
|
||||
":": ",",
|
||||
@ -46,7 +53,7 @@ rep_map = {
|
||||
"/": ",",
|
||||
"—": "-",
|
||||
"~": "…",
|
||||
"~":"…",
|
||||
"~": "…",
|
||||
}
|
||||
|
||||
tone_modifier = ToneSandhi()
|
||||
@ -58,9 +65,7 @@ def replace_punctuation(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
@ -77,9 +82,7 @@ def _get_initials_finals(word):
|
||||
finals = []
|
||||
|
||||
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||
orig_finals = lazy_pinyin(
|
||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
||||
)
|
||||
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
|
||||
for c, v in zip(orig_initials, orig_finals):
|
||||
initials.append(c)
|
||||
@ -87,31 +90,66 @@ def _get_initials_finals(word):
|
||||
return initials, finals
|
||||
|
||||
|
||||
must_erhua = {
|
||||
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
|
||||
}
|
||||
must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"}
|
||||
not_erhua = {
|
||||
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
|
||||
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
|
||||
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
|
||||
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
|
||||
"狗儿", "少儿"
|
||||
"虐儿",
|
||||
"为儿",
|
||||
"护儿",
|
||||
"瞒儿",
|
||||
"救儿",
|
||||
"替儿",
|
||||
"有儿",
|
||||
"一儿",
|
||||
"我儿",
|
||||
"俺儿",
|
||||
"妻儿",
|
||||
"拐儿",
|
||||
"聋儿",
|
||||
"乞儿",
|
||||
"患儿",
|
||||
"幼儿",
|
||||
"孤儿",
|
||||
"婴儿",
|
||||
"婴幼儿",
|
||||
"连体儿",
|
||||
"脑瘫儿",
|
||||
"流浪儿",
|
||||
"体弱儿",
|
||||
"混血儿",
|
||||
"蜜雪儿",
|
||||
"舫儿",
|
||||
"祖儿",
|
||||
"美儿",
|
||||
"应采儿",
|
||||
"可儿",
|
||||
"侄儿",
|
||||
"孙儿",
|
||||
"侄孙儿",
|
||||
"女儿",
|
||||
"男儿",
|
||||
"红孩儿",
|
||||
"花儿",
|
||||
"虫儿",
|
||||
"马儿",
|
||||
"鸟儿",
|
||||
"猪儿",
|
||||
"猫儿",
|
||||
"狗儿",
|
||||
"少儿",
|
||||
}
|
||||
def _merge_erhua(initials: list[str],
|
||||
finals: list[str],
|
||||
word: str,
|
||||
pos: str) -> list[list[str]]:
|
||||
|
||||
|
||||
def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> list[list[str]]:
|
||||
"""
|
||||
Do erhub.
|
||||
"""
|
||||
# fix er1
|
||||
for i, phn in enumerate(finals):
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn == 'er1':
|
||||
finals[i] = 'er2'
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn == "er1":
|
||||
finals[i] = "er2"
|
||||
|
||||
# 发音
|
||||
if word not in must_erhua and (word in not_erhua or
|
||||
pos in {"a", "j", "nr"}):
|
||||
if word not in must_erhua and (word in not_erhua or pos in {"a", "j", "nr"}):
|
||||
return initials, finals
|
||||
|
||||
# "……" 等情况直接返回
|
||||
@ -124,9 +162,13 @@ def _merge_erhua(initials: list[str],
|
||||
new_initials = []
|
||||
new_finals = []
|
||||
for i, phn in enumerate(finals):
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn in {
|
||||
"er2", "er5"
|
||||
} and word[-2:] not in not_erhua and new_finals:
|
||||
if (
|
||||
i == len(finals) - 1
|
||||
and word[i] == "儿"
|
||||
and phn in {"er2", "er5"}
|
||||
and word[-2:] not in not_erhua
|
||||
and new_finals
|
||||
):
|
||||
phn = "er" + new_finals[-1][-1]
|
||||
|
||||
new_initials.append(initials[i])
|
||||
@ -160,7 +202,7 @@ def _g2p(segments):
|
||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果",initials,finals)
|
||||
print("pypinyin结果", initials, finals)
|
||||
else:
|
||||
# g2pw采用整句推理
|
||||
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
@ -171,19 +213,19 @@ def _g2p(segments):
|
||||
sub_finals = []
|
||||
now_word_length = pre_word_length + len(word)
|
||||
|
||||
if pos == 'eng':
|
||||
if pos == "eng":
|
||||
pre_word_length = now_word_length
|
||||
continue
|
||||
|
||||
word_pinyins = pinyins[pre_word_length:now_word_length]
|
||||
|
||||
# 多音字消歧
|
||||
word_pinyins = correct_pronunciation(word,word_pinyins)
|
||||
word_pinyins = correct_pronunciation(word, word_pinyins)
|
||||
|
||||
for pinyin in word_pinyins:
|
||||
if pinyin[0].isalpha():
|
||||
sub_initials.append(to_initials(pinyin))
|
||||
sub_finals.append(to_finals_tone3(pinyin,neutral_tone_with_five=True))
|
||||
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
|
||||
else:
|
||||
sub_initials.append(pinyin)
|
||||
sub_finals.append(pinyin)
|
||||
@ -259,18 +301,18 @@ def replace_punctuation_with_en(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
def text_normalize(text):
|
||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||
tx = TextNormalizer()
|
||||
@ -283,6 +325,7 @@ def text_normalize(text):
|
||||
dest_text = replace_consecutive_punctuation(dest_text)
|
||||
return dest_text
|
||||
|
||||
|
||||
# 不排除英文的文本格式化
|
||||
def mix_text_normalize(text):
|
||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||
|
@ -19,55 +19,57 @@ special = [
|
||||
|
||||
|
||||
def clean_text(text, language, version=None):
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
|
||||
|
||||
if(language not in language_module_map):
|
||||
language="en"
|
||||
text=" "
|
||||
if language not in language_module_map:
|
||||
language = "en"
|
||||
text = " "
|
||||
for special_s, special_l, target_symbol in special:
|
||||
if special_s in text and language == special_l:
|
||||
return clean_special(text, language, special_s, target_symbol, version)
|
||||
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
|
||||
if hasattr(language_module,"text_normalize"):
|
||||
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
|
||||
if hasattr(language_module, "text_normalize"):
|
||||
norm_text = language_module.text_normalize(text)
|
||||
else:
|
||||
norm_text=text
|
||||
if language == "zh" or language=="yue":##########
|
||||
norm_text = text
|
||||
if language == "zh" or language == "yue": ##########
|
||||
phones, word2ph = language_module.g2p(norm_text)
|
||||
assert len(phones) == sum(word2ph)
|
||||
assert len(norm_text) == len(word2ph)
|
||||
elif language == "en":
|
||||
phones = language_module.g2p(norm_text)
|
||||
if len(phones) < 4:
|
||||
phones = [','] + phones
|
||||
phones = [","] + phones
|
||||
word2ph = None
|
||||
else:
|
||||
phones = language_module.g2p(norm_text)
|
||||
word2ph = None
|
||||
phones = ['UNK' if ph not in symbols else ph for ph in phones]
|
||||
phones = ["UNK" if ph not in symbols else ph for ph in phones]
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def clean_special(text, language, special_s, target_symbol, version=None):
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
|
||||
|
||||
"""
|
||||
特殊静音段sp符号处理
|
||||
"""
|
||||
text = text.replace(special_s, ",")
|
||||
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
|
||||
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
|
||||
norm_text = language_module.text_normalize(text)
|
||||
phones = language_module.g2p(norm_text)
|
||||
new_ph = []
|
||||
@ -81,8 +83,9 @@ def clean_special(text, language, special_s, target_symbol, version=None):
|
||||
|
||||
|
||||
def text_to_sequence(text, language, version=None):
|
||||
version = os.environ.get('version',version)
|
||||
if version is None:version='v2'
|
||||
version = os.environ.get("version", version)
|
||||
if version is None:
|
||||
version = "v2"
|
||||
phones = clean_text(text)
|
||||
return cleaned_text_to_sequence(phones, version)
|
||||
|
||||
|
@ -9,17 +9,17 @@ import unicodedata
|
||||
# 后缀计量单位替换表
|
||||
measurement_map = {
|
||||
"m": ["meter", "meters"],
|
||||
'km': ["kilometer", "kilometers"],
|
||||
"km": ["kilometer", "kilometers"],
|
||||
"km/h": ["kilometer per hour", "kilometers per hour"],
|
||||
"ft": ["feet", "feet"],
|
||||
"L": ["liter", "liters"],
|
||||
"tbsp": ["tablespoon", "tablespoons"],
|
||||
'tsp': ["teaspoon", "teaspoons"],
|
||||
"tsp": ["teaspoon", "teaspoons"],
|
||||
"h": ["hour", "hours"],
|
||||
"min": ["minute", "minutes"],
|
||||
"s": ["second", "seconds"],
|
||||
"°C": ["degree celsius", "degrees celsius"],
|
||||
"°F": ["degree fahrenheit", "degrees fahrenheit"]
|
||||
"°F": ["degree fahrenheit", "degrees fahrenheit"],
|
||||
}
|
||||
|
||||
|
||||
@ -27,37 +27,38 @@ measurement_map = {
|
||||
_inflect = inflect.engine()
|
||||
|
||||
# 转化数字序数词
|
||||
_ordinal_number_re = re.compile(r'\b([0-9]+)\. ')
|
||||
_ordinal_number_re = re.compile(r"\b([0-9]+)\. ")
|
||||
|
||||
# 我听说好像对于数字正则识别其实用 \d 会好一点
|
||||
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
|
||||
# 时间识别
|
||||
_time_re = re.compile(r'\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b')
|
||||
_time_re = re.compile(r"\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b")
|
||||
|
||||
# 后缀计量单位识别
|
||||
_measurement_re = re.compile(r'\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b')
|
||||
_measurement_re = re.compile(r"\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b")
|
||||
|
||||
# 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ )
|
||||
_pounds_re_start = re.compile(r'£([0-9\.\,]*[0-9]+)')
|
||||
_pounds_re_end = re.compile(r'([0-9\.\,]*[0-9]+)£')
|
||||
_pounds_re_start = re.compile(r"£([0-9\.\,]*[0-9]+)")
|
||||
_pounds_re_end = re.compile(r"([0-9\.\,]*[0-9]+)£")
|
||||
|
||||
# 前后 $ 识别
|
||||
_dollars_re_start = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||
_dollars_re_end = re.compile(r'([(0-9\.\,]*[0-9]+)\$')
|
||||
_dollars_re_start = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_dollars_re_end = re.compile(r"([(0-9\.\,]*[0-9]+)\$")
|
||||
|
||||
# 小数的识别
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.\s*[0-9]+)')
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.\s*[0-9]+)")
|
||||
|
||||
# 分数识别 (形式 "3/4" )
|
||||
_fraction_re = re.compile(r'([0-9]+/[0-9]+)')
|
||||
_fraction_re = re.compile(r"([0-9]+/[0-9]+)")
|
||||
|
||||
# 序数词识别
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
|
||||
# 数字处理
|
||||
_number_re = re.compile(r'[0-9]+')
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _convert_ordinal(m):
|
||||
"""
|
||||
@ -70,8 +71,10 @@ def _convert_ordinal(m):
|
||||
ordinal = _inflect.ordinal(m.group(1))
|
||||
return ordinal + ", "
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_time(m):
|
||||
"""
|
||||
@ -82,12 +85,12 @@ def _expand_time(m):
|
||||
output: "one o'clock p.m. / four o'clock am. / one thirty p.m."
|
||||
"""
|
||||
hours, minutes = map(int, m.group(1, 2))
|
||||
period = 'a.m.' if hours < 12 else 'p.m.'
|
||||
period = "a.m." if hours < 12 else "p.m."
|
||||
if hours > 12:
|
||||
hours -= 12
|
||||
|
||||
hour_word = _inflect.number_to_words(hours)
|
||||
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ''
|
||||
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ""
|
||||
|
||||
if minutes == 0:
|
||||
return f"{hour_word} o'clock {period}"
|
||||
@ -103,7 +106,7 @@ def _expand_measurement(m):
|
||||
sign = m.group(3)
|
||||
ptr = 1
|
||||
# 想不到怎么方便的取数字,又懒得改正则,诶,1.2 反正也是复数读法,干脆直接去掉 "."
|
||||
num = int(m.group(1).replace(sign, '').replace(".",''))
|
||||
num = int(m.group(1).replace(sign, "").replace(".", ""))
|
||||
decimal_part = m.group(2)
|
||||
# 上面判断的漏洞,比如 0.1 的情况,在这里排除了
|
||||
if decimal_part == None and num == 1:
|
||||
@ -116,23 +119,24 @@ def _expand_pounds(m):
|
||||
没找到特别规范的说明,和美元的处理一样,其实可以把两个合并在一起
|
||||
"""
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + ' pounds' # Unexpected format
|
||||
return match + " pounds" # Unexpected format
|
||||
pounds = int(parts[0]) if parts[0] else 0
|
||||
pence = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
|
||||
pence = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
|
||||
if pounds and pence:
|
||||
pound_unit = 'pound' if pounds == 1 else 'pounds'
|
||||
penny_unit = 'penny' if pence == 1 else 'pence'
|
||||
return '%s %s and %s %s' % (pounds, pound_unit, pence, penny_unit)
|
||||
pound_unit = "pound" if pounds == 1 else "pounds"
|
||||
penny_unit = "penny" if pence == 1 else "pence"
|
||||
return "%s %s and %s %s" % (pounds, pound_unit, pence, penny_unit)
|
||||
elif pounds:
|
||||
pound_unit = 'pound' if pounds == 1 else 'pounds'
|
||||
return '%s %s' % (pounds, pound_unit)
|
||||
pound_unit = "pound" if pounds == 1 else "pounds"
|
||||
return "%s %s" % (pounds, pound_unit)
|
||||
elif pence:
|
||||
penny_unit = 'penny' if pence == 1 else 'pence'
|
||||
return '%s %s' % (pence, penny_unit)
|
||||
penny_unit = "penny" if pence == 1 else "pence"
|
||||
return "%s %s" % (pence, penny_unit)
|
||||
else:
|
||||
return 'zero pounds'
|
||||
return "zero pounds"
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
"""
|
||||
@ -142,23 +146,24 @@ def _expand_dollars(m):
|
||||
output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents"
|
||||
"""
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
return match + " dollars" # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
|
||||
cents = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s and %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s and %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
# 小数的处理
|
||||
def _expand_decimal_number(m):
|
||||
@ -168,11 +173,11 @@ def _expand_decimal_number(m):
|
||||
output: "thirteen point two three four"
|
||||
"""
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
parts = match.split(".")
|
||||
words = []
|
||||
# 遍历字符串中的每个字符
|
||||
for char in parts[1]:
|
||||
if char == '.':
|
||||
if char == ".":
|
||||
words.append("point")
|
||||
else:
|
||||
words.append(char)
|
||||
@ -196,39 +201,41 @@ def _expend_fraction(m):
|
||||
| 3/2 | three halves |
|
||||
"""
|
||||
match = m.group(0)
|
||||
numerator, denominator = map(int, match.split('/'))
|
||||
numerator, denominator = map(int, match.split("/"))
|
||||
|
||||
numerator_part = _inflect.number_to_words(numerator)
|
||||
if denominator == 2:
|
||||
if numerator == 1:
|
||||
denominator_part = 'half'
|
||||
denominator_part = "half"
|
||||
else:
|
||||
denominator_part = 'halves'
|
||||
denominator_part = "halves"
|
||||
elif denominator == 1:
|
||||
return f'{numerator_part}'
|
||||
return f"{numerator_part}"
|
||||
else:
|
||||
denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator))
|
||||
if numerator > 1:
|
||||
denominator_part += 's'
|
||||
denominator_part += "s"
|
||||
|
||||
return f"{numerator_part} {denominator_part}"
|
||||
|
||||
return f'{numerator_part} {denominator_part}'
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize(text):
|
||||
@ -238,7 +245,7 @@ def normalize(text):
|
||||
"""
|
||||
|
||||
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
|
||||
text = re.sub(r'(?<!\d)-|-(?!\d)', ' minus ', text)
|
||||
text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_time_re, _expand_time, text)
|
||||
text = re.sub(_measurement_re, _expand_measurement, text)
|
||||
@ -251,19 +258,20 @@ def normalize(text):
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
|
||||
text = ''.join(char for char in unicodedata.normalize('NFD', text)
|
||||
if unicodedata.category(char) != 'Mn') # Strip accents
|
||||
text = "".join(
|
||||
char for char in unicodedata.normalize("NFD", text) if unicodedata.category(char) != "Mn"
|
||||
) # Strip accents
|
||||
|
||||
text = re.sub("%", " percent", text)
|
||||
text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
|
||||
text = re.sub(r"(?i)i\.e\.", "that is", text)
|
||||
text = re.sub(r"(?i)e\.g\.", "for example", text)
|
||||
# 增加纯大写单词拆分
|
||||
text = re.sub(r'(?<!^)(?<![\s])([A-Z])', r' \1', text)
|
||||
text = re.sub(r"(?<!^)(?<![\s])([A-Z])", r" \1", text)
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# 我觉得其实可以把切分结果展示出来(只读,或者修改不影响传给TTS的实际text)
|
||||
# 然后让用户确认后再输入给 TTS,可以让用户检查自己有没有不标准的输入
|
||||
print(normalize("1. test ordinal number 1st"))
|
||||
|
@ -11,6 +11,7 @@ from text.symbols2 import symbols
|
||||
from builtins import str as unicode
|
||||
from text.en_normalization.expend import normalize
|
||||
from nltk.tokenize import TweetTokenizer
|
||||
|
||||
word_tokenize = TweetTokenizer().tokenize
|
||||
from nltk import pos_tag
|
||||
|
||||
@ -121,9 +122,9 @@ def replace_phs(phs):
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}\s])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}\s])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
@ -182,6 +183,7 @@ def read_dict_new():
|
||||
|
||||
return g2p_dict
|
||||
|
||||
|
||||
def hot_reload_hot(g2p_dict):
|
||||
with open(CMU_DICT_HOT_PATH) as f:
|
||||
line = f.readline()
|
||||
@ -258,9 +260,12 @@ class en_G2p(G2p):
|
||||
del self.cmu[word.lower()]
|
||||
|
||||
# 修正多音字
|
||||
self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP')
|
||||
self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ')
|
||||
|
||||
self.homograph2features["read"] = (["R", "IY1", "D"], ["R", "EH1", "D"], "VBP")
|
||||
self.homograph2features["complex"] = (
|
||||
["K", "AH0", "M", "P", "L", "EH1", "K", "S"],
|
||||
["K", "AA1", "M", "P", "L", "EH0", "K", "S"],
|
||||
"JJ",
|
||||
)
|
||||
|
||||
def __call__(self, text):
|
||||
# tokenization
|
||||
@ -279,7 +284,7 @@ class en_G2p(G2p):
|
||||
elif len(word) == 1:
|
||||
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写
|
||||
if o_word == "A":
|
||||
pron = ['EY1']
|
||||
pron = ["EY1"]
|
||||
else:
|
||||
pron = self.cmu[word][0]
|
||||
# g2p_en 原版多音字处理
|
||||
@ -288,7 +293,7 @@ class en_G2p(G2p):
|
||||
if pos.startswith(pos1):
|
||||
pron = pron1
|
||||
# pos1比pos长仅出现在read
|
||||
elif len(pos) < len(pos1) and pos == pos1[:len(pos)]:
|
||||
elif len(pos) < len(pos1) and pos == pos1[: len(pos)]:
|
||||
pron = pron1
|
||||
else:
|
||||
pron = pron2
|
||||
@ -301,7 +306,6 @@ class en_G2p(G2p):
|
||||
|
||||
return prons[:-1]
|
||||
|
||||
|
||||
def qryword(self, o_word):
|
||||
word = o_word.lower()
|
||||
|
||||
@ -319,7 +323,7 @@ class en_G2p(G2p):
|
||||
for w in word:
|
||||
# 单读 A 发音修正, 此处不存在大写的情况
|
||||
if w == "a":
|
||||
phones.extend(['EY1'])
|
||||
phones.extend(["EY1"])
|
||||
elif not w.isalpha():
|
||||
phones.extend([w])
|
||||
else:
|
||||
@ -330,23 +334,23 @@ class en_G2p(G2p):
|
||||
if re.match(r"^([a-z]+)('s)$", word):
|
||||
phones = self.qryword(word[:-2])[:]
|
||||
# P T K F TH HH 无声辅音结尾 's 发 ['S']
|
||||
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
|
||||
phones.extend(['S'])
|
||||
if phones[-1] in ["P", "T", "K", "F", "TH", "HH"]:
|
||||
phones.extend(["S"])
|
||||
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
|
||||
elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']:
|
||||
phones.extend(['AH0', 'Z'])
|
||||
elif phones[-1] in ["S", "Z", "SH", "ZH", "CH", "JH"]:
|
||||
phones.extend(["AH0", "Z"])
|
||||
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
|
||||
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
|
||||
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
|
||||
else:
|
||||
phones.extend(['Z'])
|
||||
phones.extend(["Z"])
|
||||
return phones
|
||||
|
||||
# 尝试进行分词,应对复合词
|
||||
comps = wordsegment.segment(word.lower())
|
||||
|
||||
# 无法分词的送回去预测
|
||||
if len(comps)==1:
|
||||
if len(comps) == 1:
|
||||
return self.predict(word)
|
||||
|
||||
# 可以分词的递归处理
|
||||
|
@ -15,6 +15,7 @@
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
@ -23,21 +24,24 @@ import numpy as np
|
||||
|
||||
from .utils import tokenize_and_map
|
||||
|
||||
ANCHOR_CHAR = '▁'
|
||||
ANCHOR_CHAR = "▁"
|
||||
|
||||
|
||||
def prepare_onnx_input(tokenizer,
|
||||
def prepare_onnx_input(
|
||||
tokenizer,
|
||||
labels: List[str],
|
||||
char2phonemes: Dict[str, List[int]],
|
||||
chars: List[str],
|
||||
texts: List[str],
|
||||
query_ids: List[int],
|
||||
use_mask: bool=False,
|
||||
window_size: int=None,
|
||||
max_len: int=512) -> Dict[str, np.array]:
|
||||
use_mask: bool = False,
|
||||
window_size: int = None,
|
||||
max_len: int = 512,
|
||||
) -> Dict[str, np.array]:
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||
window_size=window_size, texts=texts, query_ids=query_ids)
|
||||
window_size=window_size, texts=texts, query_ids=query_ids
|
||||
)
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
attention_masks = []
|
||||
@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(
|
||||
tokenizer=tokenizer, text=text)
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(
|
||||
max_len=max_len,
|
||||
text=text,
|
||||
query_id=query_id,
|
||||
tokens=tokens,
|
||||
text2token=text2token,
|
||||
token2text=token2text)
|
||||
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
|
||||
)
|
||||
|
||||
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
|
||||
input_id = list(
|
||||
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
|
||||
if use_mask else [1] * len(labels)
|
||||
phoneme_mask = (
|
||||
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
|
||||
)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[
|
||||
query_id] + 1 # [CLS] token locate at first place
|
||||
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
|
||||
position_ids.append(position_id)
|
||||
|
||||
outputs = {
|
||||
'input_ids': np.array(input_ids).astype(np.int64),
|
||||
'token_type_ids': np.array(token_type_ids).astype(np.int64),
|
||||
'attention_masks': np.array(attention_masks).astype(np.int64),
|
||||
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
|
||||
'char_ids': np.array(char_ids).astype(np.int64),
|
||||
'position_ids': np.array(position_ids).astype(np.int64),
|
||||
"input_ids": np.array(input_ids).astype(np.int64),
|
||||
"token_type_ids": np.array(token_type_ids).astype(np.int64),
|
||||
"attention_masks": np.array(attention_masks).astype(np.int64),
|
||||
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
||||
"char_ids": np.array(char_ids).astype(np.int64),
|
||||
"position_ids": np.array(position_ids).astype(np.int64),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
def _truncate_texts(window_size: int, texts: List[str],
|
||||
query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||
truncated_texts = []
|
||||
truncated_query_ids = []
|
||||
for text, query_id in zip(texts, query_ids):
|
||||
@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
|
||||
return truncated_texts, truncated_query_ids
|
||||
|
||||
|
||||
def _truncate(max_len: int,
|
||||
text: str,
|
||||
query_id: int,
|
||||
tokens: List[str],
|
||||
text2token: List[int],
|
||||
token2text: List[Tuple[int]]):
|
||||
def _truncate(
|
||||
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
|
||||
):
|
||||
truncate_len = max_len - 2
|
||||
if len(tokens) <= truncate_len:
|
||||
return (text, query_id, tokens, text2token, token2text)
|
||||
@ -137,14 +131,16 @@ def _truncate(max_len: int,
|
||||
start = token2text[token_start][0]
|
||||
end = token2text[token_end - 1][1]
|
||||
|
||||
return (text[start:end], query_id - start, tokens[token_start:token_end], [
|
||||
i - token_start if i is not None else None
|
||||
for i in text2token[start:end]
|
||||
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
|
||||
return (
|
||||
text[start:end],
|
||||
query_id - start,
|
||||
tokens[token_start:token_end],
|
||||
[i - token_start if i is not None else None for i in text2token[start:end]],
|
||||
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
|
||||
)
|
||||
|
||||
|
||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
return labels, char2phonemes
|
||||
|
||||
|
||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(
|
||||
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
if char not in char2phonemes:
|
||||
char2phonemes[char] = []
|
||||
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
|
||||
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
|
||||
return labels, char2phonemes
|
||||
|
@ -17,17 +17,25 @@ PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
|
||||
|
||||
|
||||
class G2PWPinyin(Pinyin):
|
||||
def __init__(self, model_dir='G2PWModel/', model_source=None,
|
||||
def __init__(
|
||||
self,
|
||||
model_dir="G2PWModel/",
|
||||
model_source=None,
|
||||
enable_non_tradional_chinese=True,
|
||||
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=False,
|
||||
tone_sandhi=False,
|
||||
**kwargs,
|
||||
):
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style='pinyin',
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
self._converter = Converter(
|
||||
self._g2pw, v_to_u=v_to_u,
|
||||
self._g2pw,
|
||||
v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi,
|
||||
)
|
||||
@ -37,31 +45,25 @@ class G2PWPinyin(Pinyin):
|
||||
|
||||
|
||||
class Converter(UltimateConverter):
|
||||
def __init__(self, g2pw_instance, v_to_u=False,
|
||||
neutral_tone_with_five=False,
|
||||
tone_sandhi=False, **kwargs):
|
||||
def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||
super(Converter, self).__init__(
|
||||
v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi, **kwargs)
|
||||
v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs
|
||||
)
|
||||
|
||||
self._g2pw = g2pw_instance
|
||||
|
||||
def convert(self, words, style, heteronym, errors, strict, **kwargs):
|
||||
pys = []
|
||||
if RE_HANS.match(words):
|
||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
|
||||
errors=errors, strict=strict)
|
||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict)
|
||||
post_data = self.post_pinyin(words, heteronym, pys)
|
||||
if post_data is not None:
|
||||
pys = post_data
|
||||
|
||||
pys = self.convert_styles(
|
||||
pys, words, style, heteronym, errors, strict)
|
||||
pys = self.convert_styles(pys, words, style, heteronym, errors, strict)
|
||||
|
||||
else:
|
||||
py = self.handle_nopinyin(words, style=style, errors=errors,
|
||||
heteronym=heteronym, strict=strict)
|
||||
py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict)
|
||||
if py:
|
||||
pys.extend(py)
|
||||
|
||||
@ -73,13 +75,11 @@ class Converter(UltimateConverter):
|
||||
g2pw_pinyin = self._g2pw(han)
|
||||
|
||||
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
return super(Converter, self).convert(
|
||||
han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
|
||||
for i, item in enumerate(g2pw_pinyin[0]):
|
||||
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
py = super(Converter, self).convert(
|
||||
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
pinyins.extend(py)
|
||||
else:
|
||||
pinyins.append([to_tone(item)])
|
||||
@ -104,7 +104,7 @@ def _remove_dup_and_empty(lst_list):
|
||||
if lst:
|
||||
new_lst_list.append(lst)
|
||||
else:
|
||||
new_lst_list.append([''])
|
||||
new_lst_list.append([""])
|
||||
|
||||
return new_lst_list
|
||||
|
||||
@ -127,17 +127,17 @@ def get_dict():
|
||||
|
||||
def read_dict():
|
||||
polyphonic_dict = {}
|
||||
with open(PP_DICT_PATH,encoding="utf-8") as f:
|
||||
with open(PP_DICT_PATH, encoding="utf-8") as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
key, value_str = line.split(":")
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
with open(PP_FIX_DICT_PATH,encoding="utf-8") as f:
|
||||
with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
key, value_str = line.split(":")
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
|
@ -2,6 +2,7 @@
|
||||
# This code is modified from https://github.com/GitYCC/g2pW
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import json
|
||||
import os
|
||||
@ -14,6 +15,7 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
from opencc import OpenCC
|
||||
from transformers import AutoTokenizer
|
||||
@ -26,21 +28,23 @@ from .dataset import prepare_onnx_input
|
||||
from .utils import load_config
|
||||
from ..zh_normalization.char_convert import tranditional_to_simplified
|
||||
|
||||
model_version = '1.1'
|
||||
model_version = "1.1"
|
||||
|
||||
|
||||
def predict(session, onnx_input: Dict[str, Any],
|
||||
labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||
def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||
all_preds = []
|
||||
all_confidences = []
|
||||
probs = session.run([], {
|
||||
"input_ids": onnx_input['input_ids'],
|
||||
"token_type_ids": onnx_input['token_type_ids'],
|
||||
"attention_mask": onnx_input['attention_masks'],
|
||||
"phoneme_mask": onnx_input['phoneme_masks'],
|
||||
"char_ids": onnx_input['char_ids'],
|
||||
"position_ids": onnx_input['position_ids']
|
||||
})[0]
|
||||
probs = session.run(
|
||||
[],
|
||||
{
|
||||
"input_ids": onnx_input["input_ids"],
|
||||
"token_type_ids": onnx_input["token_type_ids"],
|
||||
"attention_mask": onnx_input["attention_masks"],
|
||||
"phoneme_mask": onnx_input["phoneme_masks"],
|
||||
"char_ids": onnx_input["char_ids"],
|
||||
"position_ids": onnx_input["position_ids"],
|
||||
},
|
||||
)[0]
|
||||
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
max_probs = []
|
||||
@ -52,17 +56,17 @@ def predict(session, onnx_input: Dict[str, Any],
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str='G2PWModel/'):
|
||||
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip")
|
||||
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory,"G2PWModel")
|
||||
zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
|
||||
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
|
||||
print("Downloading g2pw model...")
|
||||
modelscope_url = "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
with requests.get(modelscope_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(zip_dir, 'wb') as f:
|
||||
with open(zip_dir, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
@ -75,12 +79,15 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
def __init__(self,
|
||||
model_dir: str='G2PWModel/',
|
||||
style: str='bopomofo',
|
||||
model_source: str=None,
|
||||
enable_non_tradional_chinese: bool=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
uncompress_path = download_and_decompress(model_dir)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
@ -88,41 +95,59 @@ class G2PWOnnxConverter:
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
try:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
except:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
|
||||
self.config = load_config(
|
||||
config_path=os.path.join(uncompress_path, 'config.py'),
|
||||
use_default=True)
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
|
||||
polyphonic_chars_path = os.path.join(uncompress_path,
|
||||
'POLYPHONIC_CHARS.txt')
|
||||
monophonic_chars_path = os.path.join(uncompress_path,
|
||||
'MONOPHONIC_CHARS.txt')
|
||||
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
|
||||
self.polyphonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(polyphonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
self.non_polyphonic = {
|
||||
'一', '不', '和', '咋', '嗲', '剖', '差', '攢', '倒', '難', '奔', '勁', '拗',
|
||||
'肖', '瘙', '誒', '泊', '听', '噢'
|
||||
"一",
|
||||
"不",
|
||||
"和",
|
||||
"咋",
|
||||
"嗲",
|
||||
"剖",
|
||||
"差",
|
||||
"攢",
|
||||
"倒",
|
||||
"難",
|
||||
"奔",
|
||||
"勁",
|
||||
"拗",
|
||||
"肖",
|
||||
"瘙",
|
||||
"誒",
|
||||
"泊",
|
||||
"听",
|
||||
"噢",
|
||||
}
|
||||
self.non_monophonic = {'似', '攢'}
|
||||
self.non_monophonic = {"似", "攢"}
|
||||
self.monophonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(monophonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars
|
||||
) if self.config.use_char_phoneme else get_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars)
|
||||
self.labels, self.char2phonemes = (
|
||||
get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||
if self.config.use_char_phoneme
|
||||
else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||
)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
|
||||
@ -131,41 +156,29 @@ class G2PWOnnxConverter:
|
||||
if char in self.polyphonic_chars_new:
|
||||
self.polyphonic_chars_new.remove(char)
|
||||
|
||||
self.monophonic_chars_dict = {
|
||||
char: phoneme
|
||||
for char, phoneme in self.monophonic_chars
|
||||
}
|
||||
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
|
||||
for char in self.non_monophonic:
|
||||
if char in self.monophonic_chars_dict:
|
||||
self.monophonic_chars_dict.pop(char)
|
||||
|
||||
self.pos_tags = [
|
||||
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
||||
]
|
||||
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path,
|
||||
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
'bopomofo': lambda x: x,
|
||||
'pinyin': self._convert_bopomofo_to_pinyin,
|
||||
"bopomofo": lambda x: x,
|
||||
"pinyin": self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC('s2tw')
|
||||
self.cc = OpenCC("s2tw")
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||
tone = bopomofo[-1]
|
||||
assert tone in '12345'
|
||||
assert tone in "12345"
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
@ -185,8 +198,7 @@ class G2PWOnnxConverter:
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
||||
sentences=sentences)
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
@ -199,14 +211,12 @@ class G2PWOnnxConverter:
|
||||
texts=texts,
|
||||
query_ids=query_ids,
|
||||
use_mask=self.config.use_mask,
|
||||
window_size=None)
|
||||
window_size=None,
|
||||
)
|
||||
|
||||
preds, confidences = predict(
|
||||
session=self.session_g2pW,
|
||||
onnx_input=onnx_input,
|
||||
labels=self.labels)
|
||||
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(' ')[1] for pred in preds]
|
||||
preds = [pred.split(" ")[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
@ -214,15 +224,12 @@ class G2PWOnnxConverter:
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||
sent_s = tranditional_to_simplified(sent)
|
||||
pypinyin_result = pinyin(
|
||||
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
for i, char in enumerate(sent):
|
||||
if char in self.polyphonic_chars_new:
|
||||
@ -230,8 +237,7 @@ class G2PWOnnxConverter:
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
elif char in self.monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(
|
||||
self.monophonic_chars_dict[char])
|
||||
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
|
@ -15,6 +15,7 @@
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
@ -24,14 +25,14 @@ def wordize_and_map(text: str):
|
||||
index_map_from_text_to_word = []
|
||||
index_map_from_word_to_text = []
|
||||
while len(text) > 0:
|
||||
match_space = re.match(r'^ +', text)
|
||||
match_space = re.match(r"^ +", text)
|
||||
if match_space:
|
||||
space_str = match_space.group(0)
|
||||
index_map_from_text_to_word += [None] * len(space_str)
|
||||
text = text[len(space_str):]
|
||||
text = text[len(space_str) :]
|
||||
continue
|
||||
|
||||
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
||||
match_en = re.match(r"^[a-zA-Z0-9]+", text)
|
||||
if match_en:
|
||||
en_word = match_en.group(0)
|
||||
|
||||
@ -42,7 +43,7 @@ def wordize_and_map(text: str):
|
||||
index_map_from_text_to_word += [len(words)] * len(en_word)
|
||||
|
||||
words.append(en_word)
|
||||
text = text[len(en_word):]
|
||||
text = text[len(en_word) :]
|
||||
else:
|
||||
word_start_pos = len(index_map_from_text_to_word)
|
||||
word_end_pos = word_start_pos + 1
|
||||
@ -63,15 +64,14 @@ def tokenize_and_map(tokenizer, text: str):
|
||||
for word, (word_start, word_end) in zip(words, word2text):
|
||||
word_tokens = tokenizer.tokenize(word)
|
||||
|
||||
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
||||
if len(word_tokens) == 0 or word_tokens == ["[UNK]"]:
|
||||
index_map_from_token_to_text.append((word_start, word_end))
|
||||
tokens.append('[UNK]')
|
||||
tokens.append("[UNK]")
|
||||
else:
|
||||
current_word_start = word_start
|
||||
for word_token in word_tokens:
|
||||
word_token_len = len(re.sub(r'^##', '', word_token))
|
||||
index_map_from_token_to_text.append(
|
||||
(current_word_start, current_word_start + word_token_len))
|
||||
word_token_len = len(re.sub(r"^##", "", word_token))
|
||||
index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len))
|
||||
current_word_start = current_word_start + word_token_len
|
||||
tokens.append(word_token)
|
||||
|
||||
@ -85,53 +85,51 @@ def tokenize_and_map(tokenizer, text: str):
|
||||
|
||||
def _load_config(config_path: os.PathLike):
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("__init__", config_path)
|
||||
config = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config)
|
||||
return config
|
||||
|
||||
|
||||
default_config_dict = {
|
||||
'manual_seed': 1313,
|
||||
'model_source': 'bert-base-chinese',
|
||||
'window_size': 32,
|
||||
'num_workers': 2,
|
||||
'use_mask': True,
|
||||
'use_char_phoneme': False,
|
||||
'use_conditional': True,
|
||||
'param_conditional': {
|
||||
'affect_location': 'softmax',
|
||||
'bias': True,
|
||||
'char-linear': True,
|
||||
'pos-linear': False,
|
||||
'char+pos-second': True,
|
||||
'char+pos-second_lowrank': False,
|
||||
'lowrank_size': 0,
|
||||
'char+pos-second_fm': False,
|
||||
'fm_size': 0,
|
||||
'fix_mode': None,
|
||||
'count_json': 'train.count.json'
|
||||
"manual_seed": 1313,
|
||||
"model_source": "bert-base-chinese",
|
||||
"window_size": 32,
|
||||
"num_workers": 2,
|
||||
"use_mask": True,
|
||||
"use_char_phoneme": False,
|
||||
"use_conditional": True,
|
||||
"param_conditional": {
|
||||
"affect_location": "softmax",
|
||||
"bias": True,
|
||||
"char-linear": True,
|
||||
"pos-linear": False,
|
||||
"char+pos-second": True,
|
||||
"char+pos-second_lowrank": False,
|
||||
"lowrank_size": 0,
|
||||
"char+pos-second_fm": False,
|
||||
"fm_size": 0,
|
||||
"fix_mode": None,
|
||||
"count_json": "train.count.json",
|
||||
},
|
||||
'lr': 5e-5,
|
||||
'val_interval': 200,
|
||||
'num_iter': 10000,
|
||||
'use_focal': False,
|
||||
'param_focal': {
|
||||
'alpha': 0.0,
|
||||
'gamma': 0.7
|
||||
"lr": 5e-5,
|
||||
"val_interval": 200,
|
||||
"num_iter": 10000,
|
||||
"use_focal": False,
|
||||
"param_focal": {"alpha": 0.0, "gamma": 0.7},
|
||||
"use_pos": True,
|
||||
"param_pos ": {
|
||||
"weight": 0.1,
|
||||
"pos_joint_training": True,
|
||||
"train_pos_path": "train.pos",
|
||||
"valid_pos_path": "dev.pos",
|
||||
"test_pos_path": "test.pos",
|
||||
},
|
||||
'use_pos': True,
|
||||
'param_pos ': {
|
||||
'weight': 0.1,
|
||||
'pos_joint_training': True,
|
||||
'train_pos_path': 'train.pos',
|
||||
'valid_pos_path': 'dev.pos',
|
||||
'test_pos_path': 'test.pos'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_config(config_path: os.PathLike, use_default: bool=False):
|
||||
def load_config(config_path: os.PathLike, use_default: bool = False):
|
||||
config = _load_config(config_path)
|
||||
if use_default:
|
||||
for attr, val in default_config_dict.items():
|
||||
|
@ -2,43 +2,51 @@
|
||||
import re
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
import pyopenjtalk
|
||||
|
||||
current_file_path = os.path.dirname(__file__)
|
||||
|
||||
# 防止win下无法读取模型
|
||||
if os.name == 'nt':
|
||||
if os.name == "nt":
|
||||
python_dir = os.getcwd()
|
||||
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
|
||||
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
|
||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)):
|
||||
if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper():
|
||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
|
||||
else:
|
||||
import shutil
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||
os.mkdir(os.path.join("TEMP", "ja"))
|
||||
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
|
||||
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
|
||||
shutil.copytree(pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), os.path.join("TEMP", "ja", "open_jtalk_dic"), )
|
||||
shutil.copytree(
|
||||
pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"),
|
||||
os.path.join("TEMP", "ja", "open_jtalk_dic"),
|
||||
)
|
||||
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
|
||||
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
|
||||
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
|
||||
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
|
||||
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)):
|
||||
if current_file_path[: len(python_dir)].upper() == python_dir.upper():
|
||||
current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
|
||||
else:
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||
os.mkdir(os.path.join("TEMP", "ja"))
|
||||
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
|
||||
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
|
||||
shutil.copyfile(os.path.join(current_file_path, "ja_userdic", "userdict.csv"),os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"))
|
||||
shutil.copyfile(
|
||||
os.path.join(current_file_path, "ja_userdic", "userdict.csv"),
|
||||
os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"),
|
||||
)
|
||||
current_file_path = os.path.join("TEMP", "ja")
|
||||
|
||||
|
||||
def get_hash(fp: str) -> str:
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(fp, "rb") as f:
|
||||
@ -51,9 +59,12 @@ try:
|
||||
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
|
||||
# 如果没有用户词典,就生成一个;如果有,就检查md5,如果不一样,就重新生成
|
||||
if os.path.exists(USERDIC_CSV_PATH):
|
||||
if not os.path.exists(USERDIC_BIN_PATH) or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r",encoding='utf-8').read():
|
||||
if (
|
||||
not os.path.exists(USERDIC_BIN_PATH)
|
||||
or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read()
|
||||
):
|
||||
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
|
||||
with open(USERDIC_HASH_PATH, "w", encoding='utf-8') as f:
|
||||
with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(get_hash(USERDIC_CSV_PATH))
|
||||
|
||||
if os.path.exists(USERDIC_BIN_PATH):
|
||||
@ -61,11 +72,13 @@ try:
|
||||
except Exception:
|
||||
# print(e)
|
||||
import pyopenjtalk
|
||||
|
||||
# failed to load user dictionary, ignore.
|
||||
pass
|
||||
|
||||
|
||||
from text.symbols import punctuation
|
||||
|
||||
# Regular expression matching Japanese without punctuation marks:
|
||||
_japanese_characters = re.compile(
|
||||
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
||||
@ -123,9 +136,9 @@ def post_replace_ph(ph):
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
@ -152,7 +165,7 @@ def preprocess_jap(text, with_prosody=False):
|
||||
text += p.split(" ")
|
||||
|
||||
if i < len(marks):
|
||||
if marks[i] == " ":# 防止意外的UNK
|
||||
if marks[i] == " ": # 防止意外的UNK
|
||||
continue
|
||||
text += [marks[i].replace(" ", "")]
|
||||
return text
|
||||
@ -165,6 +178,7 @@ def text_normalize(text):
|
||||
text = replace_consecutive_punctuation(text)
|
||||
return text
|
||||
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
|
||||
@ -241,6 +255,7 @@ def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||
|
||||
return phones
|
||||
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
def _numeric_feature_by_regex(regex, s):
|
||||
match = re.search(regex, s)
|
||||
@ -248,6 +263,7 @@ def _numeric_feature_by_regex(regex, s):
|
||||
return -50
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
def g2p(norm_text, with_prosody=True):
|
||||
phones = preprocess_jap(norm_text, with_prosody)
|
||||
phones = [post_replace_ph(i) for i in phones]
|
||||
|
@ -9,39 +9,43 @@ import importlib
|
||||
import os
|
||||
|
||||
# 防止win下无法读取模型
|
||||
if os.name == 'nt':
|
||||
if os.name == "nt":
|
||||
|
||||
class win_G2p(G2p):
|
||||
def check_mecab(self):
|
||||
super().check_mecab()
|
||||
spam_spec = importlib.util.find_spec("eunjeon")
|
||||
non_found = spam_spec is None
|
||||
if non_found:
|
||||
print('you have to install eunjeon. install it...')
|
||||
print("you have to install eunjeon. install it...")
|
||||
else:
|
||||
installpath = spam_spec.submodule_search_locations[0]
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
||||
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
|
||||
import sys
|
||||
from eunjeon import Mecab as _Mecab
|
||||
|
||||
class Mecab(_Mecab):
|
||||
def get_dicpath(installpath):
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
|
||||
import shutil
|
||||
python_dir = os.getcwd()
|
||||
if (installpath[:len(python_dir)].upper() == python_dir.upper()):
|
||||
dicpath = os.path.join(os.path.relpath(installpath,python_dir),'data','mecabrc')
|
||||
else:
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
if not os.path.exists(os.path.join('TEMP', 'ko')):
|
||||
os.mkdir(os.path.join('TEMP', 'ko'))
|
||||
if os.path.exists(os.path.join('TEMP', 'ko', 'ko_dict')):
|
||||
shutil.rmtree(os.path.join('TEMP', 'ko', 'ko_dict'))
|
||||
|
||||
shutil.copytree(os.path.join(installpath, 'data'), os.path.join('TEMP', 'ko', 'ko_dict'))
|
||||
dicpath = os.path.join('TEMP', 'ko', 'ko_dict', 'mecabrc')
|
||||
python_dir = os.getcwd()
|
||||
if installpath[: len(python_dir)].upper() == python_dir.upper():
|
||||
dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc")
|
||||
else:
|
||||
dicpath=os.path.abspath(os.path.join(installpath, 'data/mecabrc'))
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ko")):
|
||||
os.mkdir(os.path.join("TEMP", "ko"))
|
||||
if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")):
|
||||
shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict"))
|
||||
|
||||
shutil.copytree(
|
||||
os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict")
|
||||
)
|
||||
dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc")
|
||||
else:
|
||||
dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc"))
|
||||
return dicpath
|
||||
|
||||
def __init__(self, dicpath=get_dicpath(installpath)):
|
||||
@ -55,10 +59,14 @@ if os.name == 'nt':
|
||||
from text.symbols2 import symbols
|
||||
|
||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
||||
_korean_classifiers = (
|
||||
"군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통"
|
||||
)
|
||||
|
||||
# List of (hangul, hangul divided) pairs:
|
||||
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
||||
_hangul_divided = [
|
||||
(re.compile("%s" % x[0]), x[1])
|
||||
for x in [
|
||||
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
|
||||
# ('ㄵ', 'ㄴㅈ'),
|
||||
# ('ㄶ', 'ㄴㅎ'),
|
||||
@ -70,79 +78,86 @@ _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
||||
# ('ㄿ', 'ㄹㅍ'),
|
||||
# ('ㅀ', 'ㄹㅎ'),
|
||||
# ('ㅄ', 'ㅂㅅ'),
|
||||
('ㅘ', 'ㅗㅏ'),
|
||||
('ㅙ', 'ㅗㅐ'),
|
||||
('ㅚ', 'ㅗㅣ'),
|
||||
('ㅝ', 'ㅜㅓ'),
|
||||
('ㅞ', 'ㅜㅔ'),
|
||||
('ㅟ', 'ㅜㅣ'),
|
||||
('ㅢ', 'ㅡㅣ'),
|
||||
('ㅑ', 'ㅣㅏ'),
|
||||
('ㅒ', 'ㅣㅐ'),
|
||||
('ㅕ', 'ㅣㅓ'),
|
||||
('ㅖ', 'ㅣㅔ'),
|
||||
('ㅛ', 'ㅣㅗ'),
|
||||
('ㅠ', 'ㅣㅜ')
|
||||
]]
|
||||
("ㅘ", "ㅗㅏ"),
|
||||
("ㅙ", "ㅗㅐ"),
|
||||
("ㅚ", "ㅗㅣ"),
|
||||
("ㅝ", "ㅜㅓ"),
|
||||
("ㅞ", "ㅜㅔ"),
|
||||
("ㅟ", "ㅜㅣ"),
|
||||
("ㅢ", "ㅡㅣ"),
|
||||
("ㅑ", "ㅣㅏ"),
|
||||
("ㅒ", "ㅣㅐ"),
|
||||
("ㅕ", "ㅣㅓ"),
|
||||
("ㅖ", "ㅣㅔ"),
|
||||
("ㅛ", "ㅣㅗ"),
|
||||
("ㅠ", "ㅣㅜ"),
|
||||
]
|
||||
]
|
||||
|
||||
# List of (Latin alphabet, hangul) pairs:
|
||||
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('a', '에이'),
|
||||
('b', '비'),
|
||||
('c', '시'),
|
||||
('d', '디'),
|
||||
('e', '이'),
|
||||
('f', '에프'),
|
||||
('g', '지'),
|
||||
('h', '에이치'),
|
||||
('i', '아이'),
|
||||
('j', '제이'),
|
||||
('k', '케이'),
|
||||
('l', '엘'),
|
||||
('m', '엠'),
|
||||
('n', '엔'),
|
||||
('o', '오'),
|
||||
('p', '피'),
|
||||
('q', '큐'),
|
||||
('r', '아르'),
|
||||
('s', '에스'),
|
||||
('t', '티'),
|
||||
('u', '유'),
|
||||
('v', '브이'),
|
||||
('w', '더블유'),
|
||||
('x', '엑스'),
|
||||
('y', '와이'),
|
||||
('z', '제트')
|
||||
]]
|
||||
_latin_to_hangul = [
|
||||
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("a", "에이"),
|
||||
("b", "비"),
|
||||
("c", "시"),
|
||||
("d", "디"),
|
||||
("e", "이"),
|
||||
("f", "에프"),
|
||||
("g", "지"),
|
||||
("h", "에이치"),
|
||||
("i", "아이"),
|
||||
("j", "제이"),
|
||||
("k", "케이"),
|
||||
("l", "엘"),
|
||||
("m", "엠"),
|
||||
("n", "엔"),
|
||||
("o", "오"),
|
||||
("p", "피"),
|
||||
("q", "큐"),
|
||||
("r", "아르"),
|
||||
("s", "에스"),
|
||||
("t", "티"),
|
||||
("u", "유"),
|
||||
("v", "브이"),
|
||||
("w", "더블유"),
|
||||
("x", "엑스"),
|
||||
("y", "와이"),
|
||||
("z", "제트"),
|
||||
]
|
||||
]
|
||||
|
||||
# List of (ipa, lazy ipa) pairs:
|
||||
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('t͡ɕ','ʧ'),
|
||||
('d͡ʑ','ʥ'),
|
||||
('ɲ','n^'),
|
||||
('ɕ','ʃ'),
|
||||
('ʷ','w'),
|
||||
('ɭ','l`'),
|
||||
('ʎ','ɾ'),
|
||||
('ɣ','ŋ'),
|
||||
('ɰ','ɯ'),
|
||||
('ʝ','j'),
|
||||
('ʌ','ə'),
|
||||
('ɡ','g'),
|
||||
('\u031a','#'),
|
||||
('\u0348','='),
|
||||
('\u031e',''),
|
||||
('\u0320',''),
|
||||
('\u0339','')
|
||||
]]
|
||||
_ipa_to_lazy_ipa = [
|
||||
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("t͡ɕ", "ʧ"),
|
||||
("d͡ʑ", "ʥ"),
|
||||
("ɲ", "n^"),
|
||||
("ɕ", "ʃ"),
|
||||
("ʷ", "w"),
|
||||
("ɭ", "l`"),
|
||||
("ʎ", "ɾ"),
|
||||
("ɣ", "ŋ"),
|
||||
("ɰ", "ɯ"),
|
||||
("ʝ", "j"),
|
||||
("ʌ", "ə"),
|
||||
("ɡ", "g"),
|
||||
("\u031a", "#"),
|
||||
("\u0348", "="),
|
||||
("\u031e", ""),
|
||||
("\u0320", ""),
|
||||
("\u0339", ""),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def fix_g2pk2_error(text):
|
||||
new_text = ""
|
||||
i = 0
|
||||
while i < len(text) - 4:
|
||||
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ':
|
||||
new_text += text[i:i+3] + ' ' + 'ㄴ'
|
||||
if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "ㄹ":
|
||||
new_text += text[i : i + 3] + " " + "ㄴ"
|
||||
i += 5
|
||||
else:
|
||||
new_text += text[i]
|
||||
@ -166,20 +181,20 @@ def divide_hangul(text):
|
||||
|
||||
|
||||
def hangul_number(num, sino=True):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
num = re.sub(',', '', num)
|
||||
"""Reference https://github.com/Kyubyong/g2pK"""
|
||||
num = re.sub(",", "", num)
|
||||
|
||||
if num == '0':
|
||||
return '영'
|
||||
if not sino and num == '20':
|
||||
return '스무'
|
||||
if num == "0":
|
||||
return "영"
|
||||
if not sino and num == "20":
|
||||
return "스무"
|
||||
|
||||
digits = '123456789'
|
||||
names = '일이삼사오육칠팔구'
|
||||
digits = "123456789"
|
||||
names = "일이삼사오육칠팔구"
|
||||
digit2name = {d: n for d, n in zip(digits, names)}
|
||||
|
||||
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
|
||||
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
|
||||
modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉"
|
||||
decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔"
|
||||
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
||||
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
||||
|
||||
@ -188,75 +203,75 @@ def hangul_number(num, sino=True):
|
||||
i = len(num) - i - 1
|
||||
if sino:
|
||||
if i == 0:
|
||||
name = digit2name.get(digit, '')
|
||||
name = digit2name.get(digit, "")
|
||||
elif i == 1:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
name = name.replace("일십", "십")
|
||||
else:
|
||||
if i == 0:
|
||||
name = digit2mod.get(digit, '')
|
||||
name = digit2mod.get(digit, "")
|
||||
elif i == 1:
|
||||
name = digit2dec.get(digit, '')
|
||||
if digit == '0':
|
||||
name = digit2dec.get(digit, "")
|
||||
if digit == "0":
|
||||
if i % 4 == 0:
|
||||
last_three = spelledout[-min(3, len(spelledout)):]
|
||||
if ''.join(last_three) == '':
|
||||
spelledout.append('')
|
||||
last_three = spelledout[-min(3, len(spelledout)) :]
|
||||
if "".join(last_three) == "":
|
||||
spelledout.append("")
|
||||
continue
|
||||
else:
|
||||
spelledout.append('')
|
||||
spelledout.append("")
|
||||
continue
|
||||
if i == 2:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
name = name.replace("일백", "백")
|
||||
elif i == 3:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
name = name.replace("일천", "천")
|
||||
elif i == 4:
|
||||
name = digit2name.get(digit, '') + '만'
|
||||
name = name.replace('일만', '만')
|
||||
name = digit2name.get(digit, "") + "만"
|
||||
name = name.replace("일만", "만")
|
||||
elif i == 5:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
name = name.replace("일십", "십")
|
||||
elif i == 6:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
name = name.replace("일백", "백")
|
||||
elif i == 7:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
name = name.replace("일천", "천")
|
||||
elif i == 8:
|
||||
name = digit2name.get(digit, '') + '억'
|
||||
name = digit2name.get(digit, "") + "억"
|
||||
elif i == 9:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
elif i == 10:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
elif i == 11:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
elif i == 12:
|
||||
name = digit2name.get(digit, '') + '조'
|
||||
name = digit2name.get(digit, "") + "조"
|
||||
elif i == 13:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
elif i == 14:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
elif i == 15:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
spelledout.append(name)
|
||||
return ''.join(elem for elem in spelledout)
|
||||
return "".join(elem for elem in spelledout)
|
||||
|
||||
|
||||
def number_to_hangul(text):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
|
||||
"""Reference https://github.com/Kyubyong/g2pK"""
|
||||
tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text))
|
||||
for token in tokens:
|
||||
num, classifier = token
|
||||
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
||||
spelledout = hangul_number(num, sino=False)
|
||||
else:
|
||||
spelledout = hangul_number(num, sino=True)
|
||||
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
|
||||
text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}")
|
||||
# digit by digit for remaining digits
|
||||
digits = '0123456789'
|
||||
names = '영일이삼사오육칠팔구'
|
||||
digits = "0123456789"
|
||||
names = "영일이삼사오육칠팔구"
|
||||
for d, n in zip(digits, names):
|
||||
text = text.replace(d, n)
|
||||
return text
|
||||
@ -265,19 +280,23 @@ def number_to_hangul(text):
|
||||
def korean_to_lazy_ipa(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
|
||||
text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text)
|
||||
for regex, replacement in _ipa_to_lazy_ipa:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
_g2p=G2p()
|
||||
|
||||
_g2p = G2p()
|
||||
|
||||
|
||||
def korean_to_ipa(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text = _g2p(text)
|
||||
text = fix_g2pk2_error(text)
|
||||
text = korean_to_lazy_ipa(text)
|
||||
return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
|
||||
return text.replace("ʧ", "tʃ").replace("ʥ", "dʑ")
|
||||
|
||||
|
||||
def post_replace_ph(ph):
|
||||
rep_map = {
|
||||
@ -301,12 +320,13 @@ def post_replace_ph(ph):
|
||||
ph = "停"
|
||||
return ph
|
||||
|
||||
|
||||
def g2p(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = _g2p(text)
|
||||
text = divide_hangul(text)
|
||||
text = fix_g2pk2_error(text)
|
||||
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
|
||||
text = re.sub(r"([\u3131-\u3163])$", r"\1.", text)
|
||||
# text = "".join([post_replace_ph(i) for i in text])
|
||||
text = [post_replace_ph(i) for i in text]
|
||||
return text
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
||||
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
||||
punctuation.append("-")
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
||||
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
||||
punctuation.append("-")
|
||||
@ -395,24 +394,404 @@ arpa = {
|
||||
"SH",
|
||||
}
|
||||
|
||||
ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停'
|
||||
ko_symbols = "ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停"
|
||||
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
||||
|
||||
yue_symbols={'Yeot3', 'Yip1', 'Yyu3', 'Yeng4', 'Yut5', 'Yaan5', 'Ym5', 'Yaan6', 'Yang1', 'Yun4', 'Yon2', 'Yui5', 'Yun2', 'Yat3', 'Ye', 'Yeot1', 'Yoeng5', 'Yoek2', 'Yam2', 'Yeon6', 'Yu6', 'Yiu3', 'Yaang6', 'Yp5', 'Yai4', 'Yoek4', 'Yit6', 'Yam5', 'Yoeng6', 'Yg1', 'Yk3', 'Yoe4', 'Yam3', 'Yc', 'Yyu4', 'Yyut1', 'Yiu4', 'Ying3', 'Yip3', 'Yaap3', 'Yau3', 'Yan4', 'Yau1', 'Yap4', 'Yk6', 'Yok3', 'Yai1', 'Yeot6', 'Yan2', 'Yoek6', 'Yt1', 'Yoi1', 'Yit5', 'Yn4', 'Yaau3', 'Yau4', 'Yuk6', 'Ys', 'Yuk', 'Yin6', 'Yung6', 'Ya', 'You', 'Yaai5', 'Yau5', 'Yoi3', 'Yaak3', 'Yaat3', 'Ying2', 'Yok5', 'Yeng2', 'Yyut3', 'Yam1', 'Yip5', 'You1', 'Yam6', 'Yaa5', 'Yi6', 'Yek4', 'Yyu2', 'Yuk5', 'Yaam1', 'Yang2', 'Yai', 'Yiu6', 'Yin4', 'Yok4', 'Yot3', 'Yui2', 'Yeoi5', 'Yyun6', 'Yyu5', 'Yoi5', 'Yeot2', 'Yim4', 'Yeoi2', 'Yaan1', 'Yang6', 'Yong1', 'Yaang4', 'Yung5', 'Yeon1', 'Yin2', 'Ya3', 'Yaang3', 'Yg', 'Yk2', 'Yaau5', 'Yut1', 'Yt5', 'Yip4', 'Yung4', 'Yj', 'Yong3', 'Ya1', 'Yg6', 'Yaau6', 'Yit3', 'Yun3', 'Ying1', 'Yn2', 'Yg4', 'Yl', 'Yp3', 'Yn3', 'Yak1', 'Yang5', 'Yoe6', 'You2', 'Yap2', 'Yak2', 'Yt3', 'Yot5', 'Yim2', 'Yi1', 'Yn6', 'Yaat5', 'Yaam3', 'Yoek5', 'Ye3', 'Yeon4', 'Yaa2', 'Yu3', 'Yim6', 'Ym', 'Yoe3', 'Yaai2', 'Ym2', 'Ya6', 'Yeng6', 'Yik4', 'Yot4', 'Yaai4', 'Yyun3', 'Yu1', 'Yoeng1', 'Yaap2', 'Yuk3', 'Yoek3', 'Yeng5', 'Yeoi1', 'Yiu2', 'Yok1', 'Yo1', 'Yoek1', 'Yoeng2', 'Yeon5', 'Yiu1', 'Yoeng4', 'Yuk2', 'Yat4', 'Yg5', 'Yut4', 'Yan6', 'Yin3', 'Yaa6', 'Yap1', 'Yg2', 'Yoe5', 'Yt4', 'Ya5', 'Yo4', 'Yyu1', 'Yak3', 'Yeon2', 'Yong4', 'Ym1', 'Ye2', 'Yaang5', 'Yoi2', 'Yeng3', 'Yn', 'Yyut4', 'Yau', 'Yaak2', 'Yaan4', 'Yek2', 'Yin1', 'Yi5', 'Yoe2', 'Yei5', 'Yaat6', 'Yak5', 'Yp6', 'Yok6', 'Yei2', 'Yaap1', 'Yyut5', 'Yi4', 'Yim1', 'Yk5', 'Ye4', 'Yok2', 'Yaam6', 'Yat2', 'Yon6', 'Yei3', 'Yyu6', 'Yeot5', 'Yk4', 'Yai6', 'Yd', 'Yg3', 'Yei6', 'Yau2', 'Yok', 'Yau6', 'Yung3', 'Yim5', 'Yut6', 'Yit1', 'Yon3', 'Yat1', 'Yaam2', 'Yyut2', 'Yui6', 'Yt2', 'Yek6', 'Yt', 'Ye6', 'Yang3', 'Ying6', 'Yaau1', 'Yeon3', 'Yng', 'Yh', 'Yang4', 'Ying5', 'Yaap6', 'Yoeng3', 'Yyun4', 'You3', 'Yan5', 'Yat5', 'Yot1', 'Yun1', 'Yi3', 'Yaa1', 'Yaap4', 'You6', 'Yaang2', 'Yaap5', 'Yaa3', 'Yaak6', 'Yeng1', 'Yaak1', 'Yo5', 'Yoi4', 'Yam4', 'Yik1', 'Ye1', 'Yai5', 'Yung1', 'Yp2', 'Yui4', 'Yaak4', 'Yung2', 'Yak4', 'Yaat4', 'Yeoi4', 'Yut2', 'Yin5', 'Yaau4', 'Yap6', 'Yb', 'Yaam4', 'Yw', 'Yut3', 'Yong2', 'Yt6', 'Yaai6', 'Yap5', 'Yik5', 'Yun6', 'Yaam5', 'Yun5', 'Yik3', 'Ya2', 'Yyut6', 'Yon4', 'Yk1', 'Yit4', 'Yak6', 'Yaan2', 'Yuk1', 'Yai2', 'Yik2', 'Yaat2', 'Yo3', 'Ykw', 'Yn5', 'Yaa', 'Ye5', 'Yu4', 'Yei1', 'Yai3', 'Yyun5', 'Yip2', 'Yaau2', 'Yiu5', 'Ym4', 'Yeoi6', 'Yk', 'Ym6', 'Yoe1', 'Yeoi3', 'Yon', 'Yuk4', 'Yaai3', 'Yaa4', 'Yot6', 'Yaang1', 'Yei4', 'Yek1', 'Yo', 'Yp', 'Yo6', 'Yp4', 'Yan3', 'Yoi', 'Yap3', 'Yek3', 'Yim3', 'Yz', 'Yot2', 'Yoi6', 'Yit2', 'Yu5', 'Yaan3', 'Yan1', 'Yon5', 'Yp1', 'Yong5', 'Ygw', 'Yak', 'Yat6', 'Ying4', 'Yu2', 'Yf', 'Ya4', 'Yon1', 'You4', 'Yik6', 'Yui1', 'Yaat1', 'Yeot4', 'Yi2', 'Yaai1', 'Yek5', 'Ym3', 'Yong6', 'You5', 'Yyun1', 'Yn1', 'Yo2', 'Yip6', 'Yui3', 'Yaak5', 'Yyun2'}
|
||||
yue_symbols = {
|
||||
"Yeot3",
|
||||
"Yip1",
|
||||
"Yyu3",
|
||||
"Yeng4",
|
||||
"Yut5",
|
||||
"Yaan5",
|
||||
"Ym5",
|
||||
"Yaan6",
|
||||
"Yang1",
|
||||
"Yun4",
|
||||
"Yon2",
|
||||
"Yui5",
|
||||
"Yun2",
|
||||
"Yat3",
|
||||
"Ye",
|
||||
"Yeot1",
|
||||
"Yoeng5",
|
||||
"Yoek2",
|
||||
"Yam2",
|
||||
"Yeon6",
|
||||
"Yu6",
|
||||
"Yiu3",
|
||||
"Yaang6",
|
||||
"Yp5",
|
||||
"Yai4",
|
||||
"Yoek4",
|
||||
"Yit6",
|
||||
"Yam5",
|
||||
"Yoeng6",
|
||||
"Yg1",
|
||||
"Yk3",
|
||||
"Yoe4",
|
||||
"Yam3",
|
||||
"Yc",
|
||||
"Yyu4",
|
||||
"Yyut1",
|
||||
"Yiu4",
|
||||
"Ying3",
|
||||
"Yip3",
|
||||
"Yaap3",
|
||||
"Yau3",
|
||||
"Yan4",
|
||||
"Yau1",
|
||||
"Yap4",
|
||||
"Yk6",
|
||||
"Yok3",
|
||||
"Yai1",
|
||||
"Yeot6",
|
||||
"Yan2",
|
||||
"Yoek6",
|
||||
"Yt1",
|
||||
"Yoi1",
|
||||
"Yit5",
|
||||
"Yn4",
|
||||
"Yaau3",
|
||||
"Yau4",
|
||||
"Yuk6",
|
||||
"Ys",
|
||||
"Yuk",
|
||||
"Yin6",
|
||||
"Yung6",
|
||||
"Ya",
|
||||
"You",
|
||||
"Yaai5",
|
||||
"Yau5",
|
||||
"Yoi3",
|
||||
"Yaak3",
|
||||
"Yaat3",
|
||||
"Ying2",
|
||||
"Yok5",
|
||||
"Yeng2",
|
||||
"Yyut3",
|
||||
"Yam1",
|
||||
"Yip5",
|
||||
"You1",
|
||||
"Yam6",
|
||||
"Yaa5",
|
||||
"Yi6",
|
||||
"Yek4",
|
||||
"Yyu2",
|
||||
"Yuk5",
|
||||
"Yaam1",
|
||||
"Yang2",
|
||||
"Yai",
|
||||
"Yiu6",
|
||||
"Yin4",
|
||||
"Yok4",
|
||||
"Yot3",
|
||||
"Yui2",
|
||||
"Yeoi5",
|
||||
"Yyun6",
|
||||
"Yyu5",
|
||||
"Yoi5",
|
||||
"Yeot2",
|
||||
"Yim4",
|
||||
"Yeoi2",
|
||||
"Yaan1",
|
||||
"Yang6",
|
||||
"Yong1",
|
||||
"Yaang4",
|
||||
"Yung5",
|
||||
"Yeon1",
|
||||
"Yin2",
|
||||
"Ya3",
|
||||
"Yaang3",
|
||||
"Yg",
|
||||
"Yk2",
|
||||
"Yaau5",
|
||||
"Yut1",
|
||||
"Yt5",
|
||||
"Yip4",
|
||||
"Yung4",
|
||||
"Yj",
|
||||
"Yong3",
|
||||
"Ya1",
|
||||
"Yg6",
|
||||
"Yaau6",
|
||||
"Yit3",
|
||||
"Yun3",
|
||||
"Ying1",
|
||||
"Yn2",
|
||||
"Yg4",
|
||||
"Yl",
|
||||
"Yp3",
|
||||
"Yn3",
|
||||
"Yak1",
|
||||
"Yang5",
|
||||
"Yoe6",
|
||||
"You2",
|
||||
"Yap2",
|
||||
"Yak2",
|
||||
"Yt3",
|
||||
"Yot5",
|
||||
"Yim2",
|
||||
"Yi1",
|
||||
"Yn6",
|
||||
"Yaat5",
|
||||
"Yaam3",
|
||||
"Yoek5",
|
||||
"Ye3",
|
||||
"Yeon4",
|
||||
"Yaa2",
|
||||
"Yu3",
|
||||
"Yim6",
|
||||
"Ym",
|
||||
"Yoe3",
|
||||
"Yaai2",
|
||||
"Ym2",
|
||||
"Ya6",
|
||||
"Yeng6",
|
||||
"Yik4",
|
||||
"Yot4",
|
||||
"Yaai4",
|
||||
"Yyun3",
|
||||
"Yu1",
|
||||
"Yoeng1",
|
||||
"Yaap2",
|
||||
"Yuk3",
|
||||
"Yoek3",
|
||||
"Yeng5",
|
||||
"Yeoi1",
|
||||
"Yiu2",
|
||||
"Yok1",
|
||||
"Yo1",
|
||||
"Yoek1",
|
||||
"Yoeng2",
|
||||
"Yeon5",
|
||||
"Yiu1",
|
||||
"Yoeng4",
|
||||
"Yuk2",
|
||||
"Yat4",
|
||||
"Yg5",
|
||||
"Yut4",
|
||||
"Yan6",
|
||||
"Yin3",
|
||||
"Yaa6",
|
||||
"Yap1",
|
||||
"Yg2",
|
||||
"Yoe5",
|
||||
"Yt4",
|
||||
"Ya5",
|
||||
"Yo4",
|
||||
"Yyu1",
|
||||
"Yak3",
|
||||
"Yeon2",
|
||||
"Yong4",
|
||||
"Ym1",
|
||||
"Ye2",
|
||||
"Yaang5",
|
||||
"Yoi2",
|
||||
"Yeng3",
|
||||
"Yn",
|
||||
"Yyut4",
|
||||
"Yau",
|
||||
"Yaak2",
|
||||
"Yaan4",
|
||||
"Yek2",
|
||||
"Yin1",
|
||||
"Yi5",
|
||||
"Yoe2",
|
||||
"Yei5",
|
||||
"Yaat6",
|
||||
"Yak5",
|
||||
"Yp6",
|
||||
"Yok6",
|
||||
"Yei2",
|
||||
"Yaap1",
|
||||
"Yyut5",
|
||||
"Yi4",
|
||||
"Yim1",
|
||||
"Yk5",
|
||||
"Ye4",
|
||||
"Yok2",
|
||||
"Yaam6",
|
||||
"Yat2",
|
||||
"Yon6",
|
||||
"Yei3",
|
||||
"Yyu6",
|
||||
"Yeot5",
|
||||
"Yk4",
|
||||
"Yai6",
|
||||
"Yd",
|
||||
"Yg3",
|
||||
"Yei6",
|
||||
"Yau2",
|
||||
"Yok",
|
||||
"Yau6",
|
||||
"Yung3",
|
||||
"Yim5",
|
||||
"Yut6",
|
||||
"Yit1",
|
||||
"Yon3",
|
||||
"Yat1",
|
||||
"Yaam2",
|
||||
"Yyut2",
|
||||
"Yui6",
|
||||
"Yt2",
|
||||
"Yek6",
|
||||
"Yt",
|
||||
"Ye6",
|
||||
"Yang3",
|
||||
"Ying6",
|
||||
"Yaau1",
|
||||
"Yeon3",
|
||||
"Yng",
|
||||
"Yh",
|
||||
"Yang4",
|
||||
"Ying5",
|
||||
"Yaap6",
|
||||
"Yoeng3",
|
||||
"Yyun4",
|
||||
"You3",
|
||||
"Yan5",
|
||||
"Yat5",
|
||||
"Yot1",
|
||||
"Yun1",
|
||||
"Yi3",
|
||||
"Yaa1",
|
||||
"Yaap4",
|
||||
"You6",
|
||||
"Yaang2",
|
||||
"Yaap5",
|
||||
"Yaa3",
|
||||
"Yaak6",
|
||||
"Yeng1",
|
||||
"Yaak1",
|
||||
"Yo5",
|
||||
"Yoi4",
|
||||
"Yam4",
|
||||
"Yik1",
|
||||
"Ye1",
|
||||
"Yai5",
|
||||
"Yung1",
|
||||
"Yp2",
|
||||
"Yui4",
|
||||
"Yaak4",
|
||||
"Yung2",
|
||||
"Yak4",
|
||||
"Yaat4",
|
||||
"Yeoi4",
|
||||
"Yut2",
|
||||
"Yin5",
|
||||
"Yaau4",
|
||||
"Yap6",
|
||||
"Yb",
|
||||
"Yaam4",
|
||||
"Yw",
|
||||
"Yut3",
|
||||
"Yong2",
|
||||
"Yt6",
|
||||
"Yaai6",
|
||||
"Yap5",
|
||||
"Yik5",
|
||||
"Yun6",
|
||||
"Yaam5",
|
||||
"Yun5",
|
||||
"Yik3",
|
||||
"Ya2",
|
||||
"Yyut6",
|
||||
"Yon4",
|
||||
"Yk1",
|
||||
"Yit4",
|
||||
"Yak6",
|
||||
"Yaan2",
|
||||
"Yuk1",
|
||||
"Yai2",
|
||||
"Yik2",
|
||||
"Yaat2",
|
||||
"Yo3",
|
||||
"Ykw",
|
||||
"Yn5",
|
||||
"Yaa",
|
||||
"Ye5",
|
||||
"Yu4",
|
||||
"Yei1",
|
||||
"Yai3",
|
||||
"Yyun5",
|
||||
"Yip2",
|
||||
"Yaau2",
|
||||
"Yiu5",
|
||||
"Ym4",
|
||||
"Yeoi6",
|
||||
"Yk",
|
||||
"Ym6",
|
||||
"Yoe1",
|
||||
"Yeoi3",
|
||||
"Yon",
|
||||
"Yuk4",
|
||||
"Yaai3",
|
||||
"Yaa4",
|
||||
"Yot6",
|
||||
"Yaang1",
|
||||
"Yei4",
|
||||
"Yek1",
|
||||
"Yo",
|
||||
"Yp",
|
||||
"Yo6",
|
||||
"Yp4",
|
||||
"Yan3",
|
||||
"Yoi",
|
||||
"Yap3",
|
||||
"Yek3",
|
||||
"Yim3",
|
||||
"Yz",
|
||||
"Yot2",
|
||||
"Yoi6",
|
||||
"Yit2",
|
||||
"Yu5",
|
||||
"Yaan3",
|
||||
"Yan1",
|
||||
"Yon5",
|
||||
"Yp1",
|
||||
"Yong5",
|
||||
"Ygw",
|
||||
"Yak",
|
||||
"Yat6",
|
||||
"Ying4",
|
||||
"Yu2",
|
||||
"Yf",
|
||||
"Ya4",
|
||||
"Yon1",
|
||||
"You4",
|
||||
"Yik6",
|
||||
"Yui1",
|
||||
"Yaat1",
|
||||
"Yeot4",
|
||||
"Yi2",
|
||||
"Yaai1",
|
||||
"Yek5",
|
||||
"Ym3",
|
||||
"Yong6",
|
||||
"You5",
|
||||
"Yyun1",
|
||||
"Yn1",
|
||||
"Yo2",
|
||||
"Yip6",
|
||||
"Yui3",
|
||||
"Yaak5",
|
||||
"Yyun2",
|
||||
}
|
||||
|
||||
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
|
||||
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
|
||||
symbols = sorted(set(symbols))
|
||||
# print(len(symbols))
|
||||
symbols+=["[","]"]##日文新增上升下降调型
|
||||
symbols+=sorted(list(ko_symbols))
|
||||
symbols+=sorted(list(yue_symbols))##新加的yue统一摆在后头#已查过开头加Y后没有重复,韩文显然不会重复
|
||||
symbols += ["[", "]"] ##日文新增上升下降调型
|
||||
symbols += sorted(list(ko_symbols))
|
||||
symbols += sorted(list(yue_symbols)) ##新加的yue统一摆在后头#已查过开头加Y后没有重复,韩文显然不会重复
|
||||
# print(len(symbols))
|
||||
if __name__ == "__main__":
|
||||
print(len(symbols))
|
||||
'''
|
||||
"""
|
||||
粤语:
|
||||
732-353=379
|
||||
韩文+粤语:
|
||||
732-322=410
|
||||
'''
|
||||
"""
|
||||
|
@ -510,12 +510,7 @@ class ToneSandhi:
|
||||
# e.g. 走了, 看着, 去过
|
||||
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
elif (
|
||||
len(word) > 1
|
||||
and word[-1] in "们子"
|
||||
and pos in {"r", "n"}
|
||||
and word not in self.must_not_neural_tone_words
|
||||
):
|
||||
elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"} and word not in self.must_not_neural_tone_words:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
# e.g. 桌上, 地下, 家里
|
||||
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
||||
@ -525,25 +520,18 @@ class ToneSandhi:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
# 个做量词
|
||||
elif (
|
||||
ge_idx >= 1
|
||||
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
|
||||
ge_idx >= 1 and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
|
||||
) or word == "个":
|
||||
finals[ge_idx] = finals[ge_idx][:-1] + "5"
|
||||
else:
|
||||
if (
|
||||
word in self.must_neural_tone_words
|
||||
or word[-2:] in self.must_neural_tone_words
|
||||
):
|
||||
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
|
||||
word_list = self._split_word(word)
|
||||
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
|
||||
for i, word in enumerate(word_list):
|
||||
# conventional neural in Chinese
|
||||
if (
|
||||
word in self.must_neural_tone_words
|
||||
or word[-2:] in self.must_neural_tone_words
|
||||
):
|
||||
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
|
||||
finals = sum(finals_list, [])
|
||||
return finals
|
||||
@ -561,9 +549,7 @@ class ToneSandhi:
|
||||
|
||||
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||
# "一" in number sequences, e.g. 一零零, 二一零
|
||||
if word.find("一") != -1 and all(
|
||||
[item.isnumeric() for item in word if item != "一"]
|
||||
):
|
||||
if word.find("一") != -1 and all([item.isnumeric() for item in word if item != "一"]):
|
||||
return finals
|
||||
# "一" between reduplication words shold be yi5, e.g. 看一看
|
||||
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
||||
@ -697,13 +683,10 @@ class ToneSandhi:
|
||||
return new_seg
|
||||
|
||||
# the first and the second words are all_tone_three
|
||||
def _merge_continuous_three_tones(
|
||||
self, seg: List[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str]]:
|
||||
def _merge_continuous_three_tones(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
new_seg = []
|
||||
sub_finals_list = [
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for (word, pos) in seg
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
|
||||
]
|
||||
assert len(sub_finals_list) == len(seg)
|
||||
merge_last = [False] * len(seg)
|
||||
@ -715,10 +698,7 @@ class ToneSandhi:
|
||||
and not merge_last[i - 1]
|
||||
):
|
||||
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||
if (
|
||||
not self._is_reduplication(seg[i - 1][0])
|
||||
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
||||
):
|
||||
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||
merge_last[i] = True
|
||||
else:
|
||||
@ -732,13 +712,10 @@ class ToneSandhi:
|
||||
return len(word) == 2 and word[0] == word[1]
|
||||
|
||||
# the last char of first word and the first char of second word is tone_three
|
||||
def _merge_continuous_three_tones_2(
|
||||
self, seg: List[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str]]:
|
||||
def _merge_continuous_three_tones_2(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
new_seg = []
|
||||
sub_finals_list = [
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for (word, pos) in seg
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
|
||||
]
|
||||
assert len(sub_finals_list) == len(seg)
|
||||
merge_last = [False] * len(seg)
|
||||
@ -750,10 +727,7 @@ class ToneSandhi:
|
||||
and not merge_last[i - 1]
|
||||
):
|
||||
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||
if (
|
||||
not self._is_reduplication(seg[i - 1][0])
|
||||
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
||||
):
|
||||
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||
merge_last[i] = True
|
||||
else:
|
||||
|
File diff suppressed because one or more lines are too long
@ -21,25 +21,29 @@ from .num import verbalize_digit
|
||||
|
||||
def _time_num2str(num_string: str) -> str:
|
||||
"""A special case for verbalizing number in time."""
|
||||
result = num2str(num_string.lstrip('0'))
|
||||
if num_string.startswith('0'):
|
||||
result = DIGITS['0'] + result
|
||||
result = num2str(num_string.lstrip("0"))
|
||||
if num_string.startswith("0"):
|
||||
result = DIGITS["0"] + result
|
||||
return result
|
||||
|
||||
|
||||
# 时刻表达式
|
||||
RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?')
|
||||
RE_TIME = re.compile(
|
||||
r"([0-1]?[0-9]|2[0-3])"
|
||||
r":([0-5][0-9])"
|
||||
r"(:([0-5][0-9]))?"
|
||||
)
|
||||
|
||||
# 时间范围,如8:30-12:30
|
||||
RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?'
|
||||
r'(~|-)'
|
||||
r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?')
|
||||
RE_TIME_RANGE = re.compile(
|
||||
r"([0-1]?[0-9]|2[0-3])"
|
||||
r":([0-5][0-9])"
|
||||
r"(:([0-5][0-9]))?"
|
||||
r"(~|-)"
|
||||
r"([0-1]?[0-9]|2[0-3])"
|
||||
r":([0-5][0-9])"
|
||||
r"(:([0-5][0-9]))?"
|
||||
)
|
||||
|
||||
|
||||
def replace_time(match) -> str:
|
||||
@ -62,31 +66,33 @@ def replace_time(match) -> str:
|
||||
second_2 = match.group(9)
|
||||
|
||||
result = f"{num2str(hour)}点"
|
||||
if minute.lstrip('0'):
|
||||
if minute.lstrip("0"):
|
||||
if int(minute) == 30:
|
||||
result += "半"
|
||||
else:
|
||||
result += f"{_time_num2str(minute)}分"
|
||||
if second and second.lstrip('0'):
|
||||
if second and second.lstrip("0"):
|
||||
result += f"{_time_num2str(second)}秒"
|
||||
|
||||
if is_range:
|
||||
result += "至"
|
||||
result += f"{num2str(hour_2)}点"
|
||||
if minute_2.lstrip('0'):
|
||||
if minute_2.lstrip("0"):
|
||||
if int(minute) == 30:
|
||||
result += "半"
|
||||
else:
|
||||
result += f"{_time_num2str(minute_2)}分"
|
||||
if second_2 and second_2.lstrip('0'):
|
||||
if second_2 and second_2.lstrip("0"):
|
||||
result += f"{_time_num2str(second_2)}秒"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
RE_DATE = re.compile(r'(\d{4}|\d{2})年'
|
||||
r'((0?[1-9]|1[0-2])月)?'
|
||||
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
|
||||
RE_DATE = re.compile(
|
||||
r"(\d{4}|\d{2})年"
|
||||
r"((0?[1-9]|1[0-2])月)?"
|
||||
r"(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?"
|
||||
)
|
||||
|
||||
|
||||
def replace_date(match) -> str:
|
||||
@ -110,8 +116,7 @@ def replace_date(match) -> str:
|
||||
|
||||
|
||||
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
|
||||
RE_DATE2 = re.compile(
|
||||
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
|
||||
RE_DATE2 = re.compile(r"(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])")
|
||||
|
||||
|
||||
def replace_date2(match) -> str:
|
||||
|
@ -18,10 +18,7 @@ from pypinyin.constants import SUPPORT_UCS4
|
||||
|
||||
# 全角半角转换
|
||||
# 英文字符全角 -> 半角映射表 (num: 52)
|
||||
F2H_ASCII_LETTERS = {
|
||||
ord(char) + 65248: ord(char)
|
||||
for char in string.ascii_letters
|
||||
}
|
||||
F2H_ASCII_LETTERS = {ord(char) + 65248: ord(char) for char in string.ascii_letters}
|
||||
|
||||
# 英文字符半角 -> 全角映射表
|
||||
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
|
||||
@ -37,26 +34,29 @@ F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
|
||||
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
|
||||
|
||||
# 空格 (num: 1)
|
||||
F2H_SPACE = {'\u3000': ' '}
|
||||
H2F_SPACE = {' ': '\u3000'}
|
||||
F2H_SPACE = {"\u3000": " "}
|
||||
H2F_SPACE = {" ": "\u3000"}
|
||||
|
||||
# 非"有拼音的汉字"的字符串,可用于NSW提取
|
||||
if SUPPORT_UCS4:
|
||||
RE_NSW = re.compile(r'(?:[^'
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
|
||||
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
|
||||
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
|
||||
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
|
||||
r'])+')
|
||||
RE_NSW = re.compile(
|
||||
r"(?:[^"
|
||||
r"\u3007" # 〇
|
||||
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
|
||||
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
|
||||
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
|
||||
r"\U00020000-\U0002A6DF" # CJK扩展B:[20000-2A6DF]
|
||||
r"\U0002A703-\U0002B73F" # CJK扩展C:[2A700-2B73F]
|
||||
r"\U0002B740-\U0002B81D" # CJK扩展D:[2B740-2B81D]
|
||||
r"\U0002F80A-\U0002FA1F" # CJK兼容扩展:[2F800-2FA1F]
|
||||
r"])+"
|
||||
)
|
||||
else:
|
||||
RE_NSW = re.compile( # pragma: no cover
|
||||
r'(?:[^'
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'])+')
|
||||
r"(?:[^"
|
||||
r"\u3007" # 〇
|
||||
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
|
||||
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
|
||||
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
|
||||
r"])+"
|
||||
)
|
||||
|
@ -15,23 +15,26 @@
|
||||
Rules to verbalize numbers into Chinese characters.
|
||||
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
||||
UNITS = OrderedDict({
|
||||
1: '十',
|
||||
2: '百',
|
||||
3: '千',
|
||||
4: '万',
|
||||
8: '亿',
|
||||
})
|
||||
DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")}
|
||||
UNITS = OrderedDict(
|
||||
{
|
||||
1: "十",
|
||||
2: "百",
|
||||
3: "千",
|
||||
4: "万",
|
||||
8: "亿",
|
||||
}
|
||||
)
|
||||
|
||||
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
||||
COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
||||
|
||||
# 分数表达式
|
||||
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
||||
RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)")
|
||||
|
||||
|
||||
def replace_frac(match) -> str:
|
||||
@ -52,7 +55,7 @@ def replace_frac(match) -> str:
|
||||
|
||||
|
||||
# 百分数表达式
|
||||
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
||||
RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%")
|
||||
|
||||
|
||||
def replace_percentage(match) -> str:
|
||||
@ -72,7 +75,7 @@ def replace_percentage(match) -> str:
|
||||
|
||||
# 整数表达式
|
||||
# 带负号的整数 -10
|
||||
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
||||
RE_INTEGER = re.compile(r"(-)" r"(\d+)")
|
||||
|
||||
|
||||
def replace_negative_num(match) -> str:
|
||||
@ -92,7 +95,7 @@ def replace_negative_num(match) -> str:
|
||||
|
||||
# 编号-无符号整形
|
||||
# 00078
|
||||
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
|
||||
RE_DEFAULT_NUM = re.compile(r"\d{3}\d*")
|
||||
|
||||
|
||||
def replace_default_num(match):
|
||||
@ -110,15 +113,11 @@ def replace_default_num(match):
|
||||
# RE_ASMD = re.compile(
|
||||
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
||||
RE_ASMD = re.compile(
|
||||
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
|
||||
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
|
||||
)
|
||||
|
||||
asmd_map = {"+": "加", "-": "减", "×": "乘", "÷": "除", "=": "等于"}
|
||||
|
||||
asmd_map = {
|
||||
'+': '加',
|
||||
'-': '减',
|
||||
'×': '乘',
|
||||
'÷': '除',
|
||||
'=': '等于'
|
||||
}
|
||||
|
||||
def replace_asmd(match) -> str:
|
||||
"""
|
||||
@ -132,24 +131,25 @@ def replace_asmd(match) -> str:
|
||||
|
||||
|
||||
# 次方专项
|
||||
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
|
||||
RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+")
|
||||
|
||||
power_map = {
|
||||
'⁰': '0',
|
||||
'¹': '1',
|
||||
'²': '2',
|
||||
'³': '3',
|
||||
'⁴': '4',
|
||||
'⁵': '5',
|
||||
'⁶': '6',
|
||||
'⁷': '7',
|
||||
'⁸': '8',
|
||||
'⁹': '9',
|
||||
'ˣ': 'x',
|
||||
'ʸ': 'y',
|
||||
'ⁿ': 'n'
|
||||
"⁰": "0",
|
||||
"¹": "1",
|
||||
"²": "2",
|
||||
"³": "3",
|
||||
"⁴": "4",
|
||||
"⁵": "5",
|
||||
"⁶": "6",
|
||||
"⁷": "7",
|
||||
"⁸": "8",
|
||||
"⁹": "9",
|
||||
"ˣ": "x",
|
||||
"ʸ": "y",
|
||||
"ⁿ": "n",
|
||||
}
|
||||
|
||||
|
||||
def replace_power(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
@ -166,10 +166,10 @@ def replace_power(match) -> str:
|
||||
|
||||
# 数字表达式
|
||||
# 纯小数
|
||||
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
||||
RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))")
|
||||
# 正整数 + 量词
|
||||
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
||||
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
||||
RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))")
|
||||
|
||||
|
||||
def replace_positive_quantifier(match) -> str:
|
||||
@ -220,7 +220,9 @@ RE_RANGE = re.compile(
|
||||
[-~] # 匹配范围分隔符
|
||||
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
||||
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
||||
""", re.VERBOSE)
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def replace_range(match) -> str:
|
||||
@ -239,7 +241,9 @@ def replace_range(match) -> str:
|
||||
|
||||
# ~至表达式
|
||||
RE_TO_RANGE = re.compile(
|
||||
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
|
||||
r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)"
|
||||
)
|
||||
|
||||
|
||||
def replace_to_range(match) -> str:
|
||||
"""
|
||||
@ -248,71 +252,66 @@ def replace_to_range(match) -> str:
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
result = match.group(0).replace('~', '至')
|
||||
result = match.group(0).replace("~", "至")
|
||||
return result
|
||||
|
||||
|
||||
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
|
||||
stripped = value_string.lstrip('0')
|
||||
def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
|
||||
stripped = value_string.lstrip("0")
|
||||
if len(stripped) == 0:
|
||||
return []
|
||||
elif len(stripped) == 1:
|
||||
if use_zero and len(stripped) < len(value_string):
|
||||
return [DIGITS['0'], DIGITS[stripped]]
|
||||
return [DIGITS["0"], DIGITS[stripped]]
|
||||
else:
|
||||
return [DIGITS[stripped]]
|
||||
else:
|
||||
largest_unit = next(
|
||||
power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||
first_part = value_string[:-largest_unit]
|
||||
second_part = value_string[-largest_unit:]
|
||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
|
||||
second_part)
|
||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
|
||||
|
||||
|
||||
def verbalize_cardinal(value_string: str) -> str:
|
||||
if not value_string:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
# 000 -> '零' , 0 -> '零'
|
||||
value_string = value_string.lstrip('0')
|
||||
value_string = value_string.lstrip("0")
|
||||
if len(value_string) == 0:
|
||||
return DIGITS['0']
|
||||
return DIGITS["0"]
|
||||
|
||||
result_symbols = _get_value(value_string)
|
||||
# verbalized number starting with '一十*' is abbreviated as `十*`
|
||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
|
||||
'1'] and result_symbols[1] == UNITS[1]:
|
||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS["1"] and result_symbols[1] == UNITS[1]:
|
||||
result_symbols = result_symbols[1:]
|
||||
return ''.join(result_symbols)
|
||||
return "".join(result_symbols)
|
||||
|
||||
|
||||
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
||||
result_symbols = [DIGITS[digit] for digit in value_string]
|
||||
result = ''.join(result_symbols)
|
||||
result = "".join(result_symbols)
|
||||
if alt_one:
|
||||
result = result.replace("一", "幺")
|
||||
return result
|
||||
|
||||
|
||||
def num2str(value_string: str) -> str:
|
||||
integer_decimal = value_string.split('.')
|
||||
integer_decimal = value_string.split(".")
|
||||
if len(integer_decimal) == 1:
|
||||
integer = integer_decimal[0]
|
||||
decimal = ''
|
||||
decimal = ""
|
||||
elif len(integer_decimal) == 2:
|
||||
integer, decimal = integer_decimal
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The value string: '${value_string}' has more than one point in it."
|
||||
)
|
||||
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
|
||||
|
||||
result = verbalize_cardinal(integer)
|
||||
|
||||
decimal = decimal.rstrip('0')
|
||||
decimal = decimal.rstrip("0")
|
||||
if decimal:
|
||||
# '.22' is verbalized as '零点二二'
|
||||
# '3.20' is verbalized as '三点二
|
||||
result = result if result else "零"
|
||||
result += '点' + verbalize_digit(decimal)
|
||||
result += "点" + verbalize_digit(decimal)
|
||||
return result
|
||||
|
@ -21,10 +21,8 @@ from .num import verbalize_digit
|
||||
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
||||
# 联通:130、131、132、156、155、186、185、176
|
||||
# 电信:133、153、189、180、181、177
|
||||
RE_MOBILE_PHONE = re.compile(
|
||||
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||
RE_TELEPHONE = re.compile(
|
||||
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
||||
RE_MOBILE_PHONE = re.compile(r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||
RE_TELEPHONE = re.compile(r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
||||
|
||||
# 全国统一的号码400开头
|
||||
RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
||||
@ -32,14 +30,12 @@ RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
||||
|
||||
def phone2str(phone_string: str, mobile=True) -> str:
|
||||
if mobile:
|
||||
sp_parts = phone_string.strip('+').split()
|
||||
result = ','.join(
|
||||
[verbalize_digit(part, alt_one=True) for part in sp_parts])
|
||||
sp_parts = phone_string.strip("+").split()
|
||||
result = ",".join([verbalize_digit(part, alt_one=True) for part in sp_parts])
|
||||
return result
|
||||
else:
|
||||
sil_parts = phone_string.split('-')
|
||||
result = ','.join(
|
||||
[verbalize_digit(part, alt_one=True) for part in sil_parts])
|
||||
sil_parts = phone_string.split("-")
|
||||
result = ",".join([verbalize_digit(part, alt_one=True) for part in sil_parts])
|
||||
return result
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from .num import num2str
|
||||
|
||||
# 温度表达式,温度会影响负号的读法
|
||||
# -3°C 零下三度
|
||||
RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
|
||||
RE_TEMPERATURE = re.compile(r"(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)")
|
||||
measure_dict = {
|
||||
"cm2": "平方厘米",
|
||||
"cm²": "平方厘米",
|
||||
@ -35,7 +35,7 @@ measure_dict = {
|
||||
"ml": "毫升",
|
||||
"m": "米",
|
||||
"mm": "毫米",
|
||||
"s": "秒"
|
||||
"s": "秒",
|
||||
}
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user