mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
Merge 77917ef6e25d37553c89de60718e4c8626fdcb9d into 9da7e17efe05041e31d3c3f42c8730ae890397f2
This commit is contained in:
commit
fcc0d24825
@ -1,5 +1,8 @@
|
||||
# Download moda ASR related models
|
||||
from modelscope import snapshot_download
|
||||
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
|
||||
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
|
||||
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
|
||||
|
||||
model_dir = snapshot_download(
|
||||
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4"
|
||||
)
|
||||
model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4")
|
||||
model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4")
|
||||
|
@ -4,14 +4,11 @@ import itertools
|
||||
import math
|
||||
import random
|
||||
from random import shuffle
|
||||
from typing import Iterator
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
from typing import Iterator, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
__all__ = [
|
||||
"DistributedBucketSampler",
|
||||
@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError(
|
||||
"Invalid rank {}, rank should be in the interval"
|
||||
" [0, {}]".format(rank, num_replicas - 1)
|
||||
)
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if (
|
||||
self.drop_last and len(self.dataset) % self.num_replicas != 0
|
||||
): # type: ignore[arg-type]
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas)
|
||||
/ self.num_replicas # type: ignore[arg-type]
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(
|
||||
len(self.dataset) / self.num_replicas
|
||||
len(self.dataset) / self.num_replicas,
|
||||
) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
grouped_batch_size = self.batch_size * self.num_replicas
|
||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||
batches = [
|
||||
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
|
||||
for b in range(n_batch)
|
||||
]
|
||||
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
|
||||
shuffle(batches)
|
||||
indices = list(itertools.chain(*batches))
|
||||
else:
|
||||
@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[
|
||||
:padding_size
|
||||
]
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
|
@ -1,9 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from AR.data.dataset import Text2SemanticDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Text2SemanticDataModule(LightningDataModule):
|
||||
@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
|
||||
# pad_val=self.config['data']['pad_val'])
|
||||
|
||||
def train_dataloader(self):
|
||||
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
||||
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
|
||||
batch_size = (
|
||||
self.config["train"]["batch_size"] // 2
|
||||
if self.config["train"].get("if_dpo", False) is True
|
||||
else self.config["train"]["batch_size"]
|
||||
)
|
||||
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
||||
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
||||
return DataLoader(
|
||||
self._train_dataset,
|
||||
|
@ -1,21 +1,17 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import pdb
|
||||
import sys
|
||||
|
||||
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
||||
import traceback, os
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch, json
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
version = os.environ.get('version',None)
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
|
||||
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
padding = (
|
||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
)
|
||||
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
||||
padded_sequences.append(padded_seq)
|
||||
batch = np.stack(padded_sequences)
|
||||
@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
|
||||
super().__init__()
|
||||
|
||||
self.semantic_data = pd.read_csv(
|
||||
semantic_path, delimiter="\t", encoding="utf-8"
|
||||
semantic_path,
|
||||
delimiter="\t",
|
||||
encoding="utf-8",
|
||||
)
|
||||
# get dict
|
||||
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
||||
self.path3 = "%s/3-bert" % (
|
||||
os.path.dirname(phoneme_path)
|
||||
os.path.dirname(
|
||||
phoneme_path,
|
||||
)
|
||||
) # "%s/3-bert"%exp_dir#bert_dir
|
||||
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||
assert os.path.exists(self.path2)
|
||||
@ -127,7 +125,7 @@ class Text2SemanticDataset(Dataset):
|
||||
for i in range(semantic_data_len):
|
||||
# 先依次遍历
|
||||
# get str
|
||||
item_name = self.semantic_data.iloc[i,0]
|
||||
item_name = self.semantic_data.iloc[i, 0]
|
||||
# print(self.phoneme_data)
|
||||
try:
|
||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||
@ -137,7 +135,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
|
||||
semantic_str = self.semantic_data.iloc[i,1]
|
||||
semantic_str = self.semantic_data.iloc[i, 1]
|
||||
# get token list
|
||||
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
||||
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
||||
@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||
if (
|
||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
||||
): ###########2:改为恒定限制为semantic/2.5就行
|
||||
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
|
||||
num_deleted_ps += 1
|
||||
continue
|
||||
# if len(semantic_ids) > 1000:###########3
|
||||
@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
|
||||
|
||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||
|
||||
if (
|
||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
||||
): ##########4#3~25#每秒多少个phone
|
||||
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
|
||||
num_deleted_ps += 1
|
||||
# print(item_name)
|
||||
continue
|
||||
@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
|
||||
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
||||
if num_deleted_bigger > 0:
|
||||
print(
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
|
||||
)
|
||||
if num_deleted_ps > 0:
|
||||
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
||||
print(
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
|
||||
)
|
||||
"""
|
||||
there are 31 semantic datas not in phoneme datas
|
||||
@ -306,7 +300,10 @@ if __name__ == "__main__":
|
||||
|
||||
batch_size = 12
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate,
|
||||
shuffle=False,
|
||||
)
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i % 1000 == 0:
|
||||
|
@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -8,10 +9,12 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
|
||||
|
||||
class Text2SemanticLightningModule(LightningModule):
|
||||
def __init__(self, config, output_dir, is_train=True):
|
||||
super().__init__()
|
||||
@ -23,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu",
|
||||
)["weight"],
|
||||
)
|
||||
)
|
||||
if is_train:
|
||||
@ -35,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
opt = self.optimizers()
|
||||
scheduler = self.lr_schedulers()
|
||||
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
|
||||
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
|
||||
loss, acc = forward(
|
||||
batch["phoneme_ids"],
|
||||
batch["phoneme_ids_len"],
|
||||
@ -113,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
||||
)
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
|
@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@ -8,6 +9,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
@ -24,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||
)
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu",
|
||||
)["weight"],
|
||||
),
|
||||
)
|
||||
if is_train:
|
||||
self.automatic_optimization = False
|
||||
@ -79,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
||||
)
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
|
@ -2,27 +2,24 @@
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import make_pad_mask, make_pad_mask_left
|
||||
from AR.models.utils import (
|
||||
topk_sampling,
|
||||
sample,
|
||||
logits_to_probs,
|
||||
multinomial_sample_one_no_sync,
|
||||
dpo_loss,
|
||||
make_reject_y,
|
||||
get_batch_logps
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding
|
||||
from AR.modules.embedding import TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm
|
||||
from AR.modules.transformer import TransformerEncoder
|
||||
from AR.modules.transformer import TransformerEncoderLayer
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import (
|
||||
dpo_loss,
|
||||
get_batch_logps,
|
||||
make_pad_mask,
|
||||
make_pad_mask_left,
|
||||
make_reject_y,
|
||||
sample,
|
||||
topk_sampling,
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@ -36,10 +33,17 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
|
||||
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
||||
# Efficient implementation equivalent to the following:
|
||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
||||
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
def scaled_dot_product_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
if scale is None:
|
||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||
else:
|
||||
@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
else:
|
||||
attn_mask[attn_mask!=float("-inf")] =0
|
||||
attn_mask[attn_mask==float("-inf")] =1
|
||||
attn_mask[attn_mask != float("-inf")] = 0
|
||||
attn_mask[attn_mask == float("-inf")] = 1
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
|
||||
return attn_weight @ value
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
@ -82,20 +87,20 @@ class T2SMLP:
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2,
|
||||
self,
|
||||
num_heads,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
@ -114,7 +119,11 @@ class T2SBlock:
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
def to_mask(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
@ -123,9 +132,13 @@ class T2SBlock:
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
|
||||
|
||||
def process_prompt(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
@ -149,9 +162,7 @@ class T2SBlock:
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@ -162,7 +173,14 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
|
||||
def decode_next_token(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
@ -176,7 +194,6 @@ class T2SBlock:
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
|
||||
if torch_sdpa:
|
||||
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
||||
else:
|
||||
@ -187,7 +204,11 @@ class T2SBlock:
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w1,
|
||||
self.norm_b1,
|
||||
self.norm_eps1,
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
@ -202,17 +223,19 @@ class T2SBlock:
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
||||
padding_mask : Optional[torch.Tensor]=None,
|
||||
torch_sdpa:bool=True
|
||||
):
|
||||
k_cache : List[torch.Tensor] = []
|
||||
v_cache : List[torch.Tensor] = []
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
||||
k_cache.append(k_cache_)
|
||||
@ -220,14 +243,17 @@ class T2STransformer:
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask : torch.Tensor=None,
|
||||
torch_sdpa:bool=True
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
|
||||
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# assert self.EOS == 1024
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_text_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.phoneme_vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
|
||||
self.h = TransformerEncoder(
|
||||
@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
layer.linear2.bias,
|
||||
)
|
||||
|
||||
block = T2SBlock(
|
||||
@ -309,7 +345,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
layer.norm2.eps,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||
|
||||
###### DPO #############
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||
x, x_lens, reject_y, reject_y_lens, bert_feature
|
||||
)
|
||||
|
||||
reject_xy_dec, _ = self.h(
|
||||
(reject_xy_pos, None),
|
||||
@ -473,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
||||
def infer(
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
y.device
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
||||
|
||||
xy_dec, _ = self.h(
|
||||
(xy_pos, None),
|
||||
mask=xy_attn_mask,
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = topk_sampling(
|
||||
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
||||
)
|
||||
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
print("use early stop num:", early_stop_num)
|
||||
@ -542,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
return y
|
||||
|
||||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
||||
y_mask_int, (0, 1), value=1
|
||||
)
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel_batch_infer(
|
||||
self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
@ -563,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
|
||||
):
|
||||
if prompts is None:
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
||||
return self.infer_panel_naive_batched(
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
early_stop_num=early_stop_num,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
max_len = kwargs.get("max_len",x_lens.max())
|
||||
max_len = kwargs.get("max_len", x_lens.max())
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
@ -574,10 +615,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
||||
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
|
||||
x_item = (
|
||||
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
|
||||
) ### padding left
|
||||
x_list.append(x_item)
|
||||
x:torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
x: torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
@ -594,12 +636,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
|
||||
|
||||
##### create mask #####
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
@ -610,10 +650,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||
|
||||
x_mask = F.pad(
|
||||
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
||||
@ -621,7 +661,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
value=False,
|
||||
)
|
||||
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||
### 上面是错误的,会导致padding的token被"看见"
|
||||
|
||||
@ -639,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||
|
||||
|
||||
# 正确的attn_mask应该是这样的:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
@ -655,62 +694,57 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
idx_list = [None] * y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask,(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||
tokens = torch.argmax(logits, dim=-1)
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in samples[:, 0]) or \
|
||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0]==self.EOS
|
||||
l2 = tokens==self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0] == self.EOS
|
||||
l2 = tokens == self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留batch中未生成完毕的序列
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
if k_cache is not None:
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
for i, batch_index in enumerate(batch_idx_map):
|
||||
@ -718,11 +752,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
if not (None in idx_list):
|
||||
if None not in idx_list:
|
||||
stop = True
|
||||
|
||||
if stop:
|
||||
if y.shape[1]==0:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
@ -730,43 +764,48 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if (None in idx_list):
|
||||
if None in idx_list:
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, [0] * x.shape[0]
|
||||
# print(idx_list)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive_batched(self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
def infer_panel_naive_batched(
|
||||
self,
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs)
|
||||
y, idx = self.infer_panel_naive(
|
||||
x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs,
|
||||
)
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
@ -774,16 +813,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def infer_panel_naive(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@ -828,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.num_head, -1, -1)
|
||||
.view(bsz, self.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
for idx in tqdm(range(1500)):
|
||||
if xy_attn_mask is not None:
|
||||
@ -840,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = None
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
@ -870,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx
|
||||
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
||||
return self.infer_panel_naive(
|
||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||
)
|
||||
|
@ -1,17 +1,13 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding
|
||||
from AR.modules.embedding_onnx import TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm
|
||||
from AR.modules.transformer_onnx import TransformerEncoder
|
||||
from AR.modules.transformer_onnx import TransformerEncoderLayer
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
"hidden_dim": 512,
|
||||
@ -26,12 +22,13 @@ default_config = {
|
||||
|
||||
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
||||
|
||||
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens = None,
|
||||
previous_tokens=None,
|
||||
temperature: float = 1.0,
|
||||
top_k = None,
|
||||
top_p = None,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
@ -39,19 +36,27 @@ def logits_to_probs(
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
torch.nn.functional.softmax(
|
||||
sorted_logits,
|
||||
dim=-1,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=0,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@ -67,7 +72,7 @@ def logits_to_probs(
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(
|
||||
probs_sort
|
||||
probs_sort,
|
||||
): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
@ -79,7 +84,9 @@ def sample(
|
||||
**sampling_kwargs,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
|
||||
|
||||
|
||||
class T2SFirstStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@ -114,8 +131,8 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
def forward(self, x, prompt):
|
||||
y = prompt
|
||||
x_example = x[:,:,0] * 0.0
|
||||
#N, 1, 512
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
# N, 1, 512
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": None,
|
||||
@ -132,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
||||
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
|
||||
torch.ones_like(
|
||||
y_example.transpose(0, 1),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
y_attn_mask = y_attn_mask > 0
|
||||
|
||||
@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["k"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
cache["v"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
|
||||
class T2SStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@ -184,14 +221,18 @@ class T2SStageDecoder(nn.Module):
|
||||
}
|
||||
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
[
|
||||
cache["y_emb"],
|
||||
self.ar_audio_embedding(y[:, -1:]),
|
||||
],
|
||||
1,
|
||||
)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
||||
xy_pos = y_pos[:, -1:]
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
|
||||
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
||||
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
||||
@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def init_onnx(self):
|
||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
self.stage_decoder = T2SStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
|
||||
def forward(self, x, prompts, bert_feature):
|
||||
early_stop_num = self.early_stop_num
|
||||
@ -286,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y = prompts
|
||||
prefix_len = y.shape[1]
|
||||
x_len = x.shape[1]
|
||||
x_example = x[:,:,0] * 0.0
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
||||
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
||||
|
||||
@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if cache["first_infer"] == 1:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
else:
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
)
|
||||
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
if cache["first_infer"] == 1:
|
||||
@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0), value=False
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
else:
|
||||
|
@ -1,8 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
@ -67,14 +69,18 @@ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
|
||||
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
|
||||
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
|
||||
|
||||
return expaned_lengths<0
|
||||
return expaned_lengths < 0
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||
logits,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
filter_value=-float("Inf"),
|
||||
min_tokens_to_keep=1,
|
||||
):
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
@ -105,9 +111,7 @@ def top_k_top_p_filtering(
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
@ -130,7 +134,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
||||
return token
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(
|
||||
@ -156,19 +160,21 @@ def logits_to_probs(
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=1,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@ -176,7 +182,7 @@ def logits_to_probs(
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
@ -188,18 +194,19 @@ def sample(
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
**sampling_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||
)
|
||||
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
reference_chosen_logps: torch.FloatTensor,
|
||||
reference_rejected_logps: torch.FloatTensor,
|
||||
beta: float,
|
||||
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
|
||||
def dpo_loss(
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
reference_chosen_logps: torch.FloatTensor,
|
||||
reference_rejected_logps: torch.FloatTensor,
|
||||
beta: float,
|
||||
reference_free: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||
|
||||
@ -214,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
||||
|
||||
return losses.mean(), chosen_rewards, rejected_rewards
|
||||
|
||||
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
|
||||
def get_batch_logps(
|
||||
logits_target: torch.FloatTensor,
|
||||
logits_reject: torch.FloatTensor,
|
||||
labels_target: torch.LongTensor,
|
||||
labels_reject: torch.LongTensor,
|
||||
average_log_prob: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
|
||||
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
|
||||
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
|
||||
per_token_logps_target = torch.gather(
|
||||
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
per_token_logps_reject = torch.gather(
|
||||
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
|
||||
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
||||
|
||||
|
||||
def make_reject_y(y_o, y_lens):
|
||||
def repeat_P(y):
|
||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||
pre = y[:range_idx[0]]
|
||||
shf = y[range_idx[1]:]
|
||||
range_text = y[range_idx[0]:range_idx[1]]
|
||||
pre = y[: range_idx[0]]
|
||||
shf = y[range_idx[1] :]
|
||||
range_text = y[range_idx[0] : range_idx[1]]
|
||||
new_y = torch.cat([pre, range_text, range_text, shf])
|
||||
return new_y
|
||||
|
||||
def lost_P(y):
|
||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||
pre = y[:range_idx[0]]
|
||||
shf = y[range_idx[1]:]
|
||||
range_text = y[range_idx[0]:range_idx[1]]
|
||||
pre = y[: range_idx[0]]
|
||||
shf = y[range_idx[1] :]
|
||||
range_text = y[range_idx[0] : range_idx[1]]
|
||||
new_y = torch.cat([pre, shf])
|
||||
return new_y
|
||||
|
||||
bs = len(y_lens)
|
||||
reject_y = []
|
||||
reject_y_lens = []
|
||||
for b in range(bs):
|
||||
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
|
||||
process_item_idx = torch.randint(0, 1, size=(1,))[0]
|
||||
if process_item_idx == 0:
|
||||
new_y = repeat_P(y_o[b])
|
||||
reject_y.append(new_y)
|
||||
reject_y_lens.append(len(new_y))
|
||||
elif process_item_idx==1:
|
||||
elif process_item_idx == 1:
|
||||
new_y = lost_P(y_o[b])
|
||||
reject_y.append(new_y)
|
||||
reject_y_lens.append(len(new_y))
|
||||
@ -256,7 +276,7 @@ def make_reject_y(y_o, y_lens):
|
||||
pad_length = max_length - reject_y_lens[b]
|
||||
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
||||
|
||||
reject_y = torch.stack(reject_y, dim = 0)
|
||||
reject_y = torch.stack(reject_y, dim=0)
|
||||
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
||||
|
||||
return reject_y, reject_y_lens
|
||||
|
@ -1,17 +1,14 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Module
|
||||
from torch.nn.init import constant_
|
||||
from torch.nn.init import xavier_normal_
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
@ -73,6 +70,7 @@ class MultiheadAttention(Module):
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
@ -104,9 +102,7 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
@ -117,31 +113,32 @@ class MultiheadAttention(Module):
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs),
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
)
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
@ -150,7 +147,10 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
@ -164,7 +164,10 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
@ -261,28 +264,26 @@ class MultiheadAttention(Module):
|
||||
if key_padding_mask is not None:
|
||||
_kpm_dtype = key_padding_mask.dtype
|
||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
key_padding_mask
|
||||
key_padding_mask,
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||
why_not_fast_path = ""
|
||||
if not is_batched:
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif (
|
||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||||
):
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
)
|
||||
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
)
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif not self.batch_first:
|
||||
@ -300,9 +301,7 @@ class MultiheadAttention(Module):
|
||||
elif attn_mask is not None:
|
||||
why_not_fast_path = "attn_mask was not None"
|
||||
elif query.is_nested and key_padding_mask is not None:
|
||||
why_not_fast_path = (
|
||||
"key_padding_mask is not supported with NestedTensor input"
|
||||
)
|
||||
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
|
||||
elif self.num_heads % 2 == 1:
|
||||
why_not_fast_path = "num_heads is odd"
|
||||
elif torch.is_autocast_enabled():
|
||||
@ -322,20 +321,10 @@ class MultiheadAttention(Module):
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not all(
|
||||
[
|
||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args
|
||||
]
|
||||
):
|
||||
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(
|
||||
[x is not None and x.requires_grad for x in tensor_args]
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
|
||||
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
|
||||
if not why_not_fast_path:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
@ -350,11 +339,7 @@ class MultiheadAttention(Module):
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
1
|
||||
if key_padding_mask is not None
|
||||
else 0
|
||||
if attn_mask is not None
|
||||
else None,
|
||||
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
|
@ -1,17 +1,13 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Module
|
||||
from torch.nn.init import constant_
|
||||
from torch.nn.init import xavier_normal_
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||
|
||||
|
||||
@ -47,9 +43,7 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
@ -60,18 +54,30 @@ class MultiheadAttention(Module):
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, embed_dim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, self.kdim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, self.vdim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(3 * embed_dim, embed_dim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
@ -79,13 +85,11 @@ class MultiheadAttention(Module):
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
torch.empty(3 * embed_dim, **factory_kwargs),
|
||||
)
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
self._reset_parameters()
|
||||
else:
|
||||
@ -93,7 +97,10 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
@ -107,7 +114,10 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
|
@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.embedding_dim)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
|
||||
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
||||
|
||||
def extend_pe(self, x):
|
||||
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
|
||||
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
|
||||
scpe = (position * self.div_term).unsqueeze(0)
|
||||
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
||||
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
||||
|
@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
lr = self.end_lr
|
||||
|
||||
else:
|
||||
decay_ratio = (self._current_step - self.warmup_steps) / (
|
||||
self.total_steps - self.warmup_steps
|
||||
)
|
||||
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
||||
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
||||
raise RuntimeError(
|
||||
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
|
||||
)
|
||||
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
||||
|
||||
@ -70,7 +66,13 @@ if __name__ == "__main__":
|
||||
m = nn.Linear(10, 10)
|
||||
opt = Adam(m.parameters(), lr=1e-4)
|
||||
s = WarmupCosineLRSchedule(
|
||||
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
|
||||
opt,
|
||||
1e-6,
|
||||
2e-4,
|
||||
1e-6,
|
||||
warmup_steps=2000,
|
||||
total_steps=20000,
|
||||
current_step=0,
|
||||
)
|
||||
lrs = []
|
||||
for i in range(25000):
|
||||
|
@ -16,8 +16,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
|
||||
group_params_names: name for each parameter in group,
|
||||
which is List[str].
|
||||
"""
|
||||
batches = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches_names = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
|
||||
assert len(param_group) == len(group_params_names)
|
||||
for p, named_p in zip(param_group, group_params_names):
|
||||
@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
|
||||
batches_names[key].append(named_p)
|
||||
|
||||
batches_names_keys = list(batches_names.keys())
|
||||
sorted_idx = sorted(
|
||||
range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||
batches_names = [
|
||||
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
||||
]
|
||||
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
|
||||
stacked_params_dict = dict()
|
||||
@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([
|
||||
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
|
||||
])
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
tuples.append((p_stacked, state, batch_names))
|
||||
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@ -164,25 +154,24 @@ class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
clipping_scale=None,
|
||||
betas=(0.9, 0.98),
|
||||
scalar_lr_scale=0.1,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=3.0,
|
||||
scalar_max=10.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True, ):
|
||||
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
clipping_scale=None,
|
||||
betas=(0.9, 0.98),
|
||||
scalar_lr_scale=0.1,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=3.0,
|
||||
scalar_max=10.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True,
|
||||
):
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
"which is a List[List[str]]. Each List[str] is for a group"
|
||||
"and each str is for a parameter")
|
||||
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
|
||||
)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
param_max_rms=param_max_rms,
|
||||
scalar_max=scalar_max,
|
||||
size_update_period=size_update_period,
|
||||
clipping_update_period=clipping_update_period, )
|
||||
clipping_update_period=clipping_update_period,
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
assert len(self.param_groups) == len(parameters_names)
|
||||
@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups,
|
||||
self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"],
|
||||
group_params_names) as batches:
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
|
||||
if (len(batches[0][1]) ==
|
||||
0): # if len(first state) == 0: not yet initialized
|
||||
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||
clipping_scale = 1
|
||||
else:
|
||||
clipping_scale = self._get_clipping_scale(group, batches)
|
||||
@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# grad is not going to be None, we handled that when creating the batches.
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (
|
||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(size_update_period,
|
||||
*param_rms.shape, **kwargs)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
def _get_clipping_scale(self,
|
||||
group: dict,
|
||||
tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||
) -> float:
|
||||
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
by this amount before applying the rest of the update.
|
||||
@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state, param_names) in tuples:
|
||||
for p, state, param_names in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients")
|
||||
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||
|
||||
tot_norm = tot_sumsq.sqrt()
|
||||
if "model_norms" not in first_state:
|
||||
first_state["model_norms"] = torch.zeros(
|
||||
clipping_update_period, device=p.device)
|
||||
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
|
||||
for n in range(0, 5):
|
||||
index = min(
|
||||
clipping_update_period - 1,
|
||||
(clipping_update_period // 4) * n, )
|
||||
(clipping_update_period // 4) * n,
|
||||
)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
first_state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (first_state["num_clipped"] * 100.0 /
|
||||
clipping_update_period
|
||||
if "num_clipped" in first_state else 0.0)
|
||||
percent_clipped = (
|
||||
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
|
||||
)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||
logging.info(
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
)
|
||||
|
||||
if step < clipping_update_period:
|
||||
@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except KeyError:
|
||||
logging.info(
|
||||
"Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?"
|
||||
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
|
||||
)
|
||||
return 1.0
|
||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(
|
||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||
)
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
return ans
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||
tot_sumsq: Tensor):
|
||||
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
|
||||
@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
for p, state, batch_param_names in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch_rms_orig = torch.ones(p.shape[0])
|
||||
else:
|
||||
batch_rms_orig = state["param_rms"]
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
|
||||
dim=list(range(1, batch_grad.ndim)))
|
||||
|
||||
for name, sumsq_orig, rms, grad in zip(batch_param_names,
|
||||
batch_sumsq_orig,
|
||||
batch_rms_orig, batch_grad):
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
|
||||
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names,
|
||||
batch_sumsq_orig,
|
||||
batch_rms_orig,
|
||||
batch_grad,
|
||||
):
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
assert torch.isclose(
|
||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||
torch.tensor(1.0), )
|
||||
torch.tensor(1.0),
|
||||
)
|
||||
sorted_by_proportion = {
|
||||
k: v
|
||||
for k, v in sorted(
|
||||
all_sumsq_orig.items(),
|
||||
key=lambda item: item[1][0],
|
||||
reverse=True, )
|
||||
reverse=True,
|
||||
)
|
||||
}
|
||||
dominant_param_name = next(iter(sorted_by_proportion))
|
||||
(dominant_proportion, dominant_sumsq, dominant_rms,
|
||||
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
|
||||
(
|
||||
dominant_proportion,
|
||||
dominant_sumsq,
|
||||
dominant_rms,
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||
)
|
||||
|
||||
def _step_one_batch(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
clipping_scale: float):
|
||||
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
|
||||
"""
|
||||
Do the step for one parameter, which is actually going to be a batch of
|
||||
`real` parameters, with dim 0 as the batch dim.
|
||||
@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True)
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_((p**2)
|
||||
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
||||
.sqrt())
|
||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
state["step"] = step + 1
|
||||
|
||||
def _size_update(self,
|
||||
group: dict,
|
||||
scale_grads: Tensor,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
def _size_update(
|
||||
self,
|
||||
group: dict,
|
||||
scale_grads: Tensor,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||
@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2**size_update_period
|
||||
|
||||
scale_exp_avg_sq = state[
|
||||
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
|
||||
alpha=1 - beta2_corr,
|
||||
) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = (-size_lr * (bias_correction2**0.5) *
|
||||
scale_grads.sum(dim=0) / denom)
|
||||
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||
|
||||
is_too_small = param_rms < param_min_rms
|
||||
is_too_large = param_rms > param_max_rms
|
||||
@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||
|
||||
this_step = state["step"] - (state["zero_step"]
|
||||
if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2**(this_step + 1)
|
||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2**(state["step"] + 1)
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||
|
||||
delta = state["delta"]
|
||||
|
@ -5,7 +5,6 @@ from torch.nn.functional import (
|
||||
_none_or_dtype,
|
||||
_in_projection_packed,
|
||||
)
|
||||
from torch.nn import functional as F
|
||||
import torch
|
||||
# Tensor = torch.Tensor
|
||||
# from typing import Callable, List, Optional, Tuple, Union
|
||||
@ -25,18 +24,18 @@ def multi_head_attention_forward_patched(
|
||||
dropout_p: float,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
training = True,
|
||||
key_padding_mask = None,
|
||||
need_weights = True,
|
||||
attn_mask = None,
|
||||
use_separate_proj_weight = False,
|
||||
q_proj_weight = None,
|
||||
k_proj_weight = None,
|
||||
v_proj_weight = None,
|
||||
static_k = None,
|
||||
static_v = None,
|
||||
average_attn_weights = True,
|
||||
is_causal = False,
|
||||
training=True,
|
||||
key_padding_mask=None,
|
||||
need_weights=True,
|
||||
attn_mask=None,
|
||||
use_separate_proj_weight=False,
|
||||
q_proj_weight=None,
|
||||
k_proj_weight=None,
|
||||
v_proj_weight=None,
|
||||
static_k=None,
|
||||
static_v=None,
|
||||
average_attn_weights=True,
|
||||
is_causal=False,
|
||||
cache=None,
|
||||
):
|
||||
r"""
|
||||
@ -156,9 +155,7 @@ def multi_head_attention_forward_patched(
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
is_batched = _mha_shape_check(
|
||||
query, key, value, key_padding_mask, attn_mask, num_heads
|
||||
)
|
||||
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
||||
|
||||
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||
# is batched, run the computation and before returning squeeze the
|
||||
@ -211,45 +208,33 @@ def multi_head_attention_forward_patched(
|
||||
# longer causal.
|
||||
is_causal = False
|
||||
|
||||
assert (
|
||||
embed_dim == embed_dim_to_check
|
||||
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
assert embed_dim == embed_dim_to_check, (
|
||||
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
)
|
||||
if isinstance(embed_dim, torch.Tensor):
|
||||
# embed_dim can be a tensor when JIT tracing
|
||||
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
||||
else:
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
if use_separate_proj_weight:
|
||||
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||
assert (
|
||||
key.shape[:2] == value.shape[:2]
|
||||
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
assert key.shape[:2] == value.shape[:2], (
|
||||
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
key.shape == value.shape
|
||||
), f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
|
||||
#
|
||||
# compute in-projection
|
||||
#
|
||||
if not use_separate_proj_weight:
|
||||
assert (
|
||||
in_proj_weight is not None
|
||||
), "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||
else:
|
||||
assert (
|
||||
q_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert (
|
||||
k_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert (
|
||||
v_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
if in_proj_bias is None:
|
||||
b_q = b_k = b_v = None
|
||||
else:
|
||||
@ -312,9 +297,7 @@ def multi_head_attention_forward_patched(
|
||||
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"attn_mask's dimension {attn_mask.dim()} is not supported"
|
||||
)
|
||||
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||
|
||||
# add bias along batch dimension (currently second)
|
||||
if bias_k is not None and bias_v is not None:
|
||||
@ -338,34 +321,26 @@ def multi_head_attention_forward_patched(
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_k.size(0) == bsz * num_heads
|
||||
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
assert (
|
||||
static_k.size(2) == head_dim
|
||||
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
assert static_k.size(0) == bsz * num_heads, (
|
||||
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
)
|
||||
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_v.size(0) == bsz * num_heads
|
||||
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
assert (
|
||||
static_v.size(2) == head_dim
|
||||
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
assert static_v.size(0) == bsz * num_heads, (
|
||||
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
)
|
||||
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
v = static_v
|
||||
|
||||
# add zero attention along batch dimension (now first)
|
||||
if add_zero_attn:
|
||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||
k = torch.cat(
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
||||
)
|
||||
v = torch.cat(
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
||||
)
|
||||
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
@ -381,9 +356,7 @@ def multi_head_attention_forward_patched(
|
||||
src_len,
|
||||
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||
key_padding_mask = (
|
||||
key_padding_mask.view(bsz, 1, 1, src_len)
|
||||
.expand(-1, num_heads, -1, -1)
|
||||
.reshape(bsz * num_heads, 1, src_len)
|
||||
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||
)
|
||||
if attn_mask is None:
|
||||
attn_mask = key_padding_mask
|
||||
@ -402,14 +375,10 @@ def multi_head_attention_forward_patched(
|
||||
B, Nt, E = q.shape
|
||||
q_scaled = q / math.sqrt(E)
|
||||
|
||||
assert not (
|
||||
is_causal and attn_mask is None
|
||||
), "FIXME: is_causal not implemented for need_weights"
|
||||
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_output_weights = torch.baddbmm(
|
||||
attn_mask, q_scaled, k.transpose(-2, -1)
|
||||
)
|
||||
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||
else:
|
||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
@ -418,9 +387,7 @@ def multi_head_attention_forward_patched(
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
)
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
@ -449,13 +416,9 @@ def multi_head_attention_forward_patched(
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
@ -1,11 +1,9 @@
|
||||
from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_mha_shape_check,
|
||||
_canonical_mask,
|
||||
_none_or_dtype,
|
||||
_in_projection_packed,
|
||||
)
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
@ -34,7 +32,6 @@ def multi_head_attention_forward_patched(
|
||||
is_causal: bool = False,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
|
||||
# set up shape vars
|
||||
_, _, embed_dim = query.shape
|
||||
attn_mask = _canonical_mask(
|
||||
@ -80,12 +77,8 @@ def multi_head_attention_forward_patched(
|
||||
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
)
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||
|
||||
|
@ -13,12 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -61,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
# floors), should be expectation-preserving.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
||||
deriv
|
||||
)
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
@ -153,13 +148,9 @@ def _compute_scale_factor(
|
||||
else:
|
||||
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
||||
# x_abs)_mean , min_abs.
|
||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
|
||||
|
||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
||||
|
||||
return below_threshold - above_threshold
|
||||
|
||||
@ -181,18 +172,16 @@ def _compute_sign_factor(
|
||||
else:
|
||||
# 0 if proportion_positive >= min_positive, else can be
|
||||
# as large as max_factor.
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
||||
).clamp_(min=0, max=max_factor)
|
||||
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
|
||||
|
||||
if max_positive == 1.0:
|
||||
factor2 = 0.0
|
||||
else:
|
||||
# 0 if self.proportion_positive <= max_positive, else can be
|
||||
# as large as -max_factor.
|
||||
factor2 = (
|
||||
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
||||
).clamp_(min=0, max=max_factor)
|
||||
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
sign_factor = factor1 - factor2
|
||||
# require min_positive != 0 or max_positive != 1:
|
||||
assert not isinstance(sign_factor, float)
|
||||
@ -320,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
def BalancedDoubleSwish(
|
||||
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
||||
) -> nn.Sequential:
|
||||
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
|
||||
"""
|
||||
ActivationBalancer -> DoubleSwish
|
||||
"""
|
||||
balancer = ActivationBalancer(
|
||||
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
||||
)
|
||||
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
|
||||
return nn.Sequential(
|
||||
balancer,
|
||||
DoubleSwish(),
|
||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
assert embedding is None
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"{normalized_shape}, eps={eps}, "
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
)
|
||||
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
|
||||
|
||||
class IdentityNorm(nn.Module):
|
||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = transformer_encoder(src)
|
||||
"""
|
||||
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = linear1_feedforward_cls(
|
||||
d_model, dim_feedforward, **factory_kwargs
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = linear2_feedforward_cls(
|
||||
dim_feedforward, d_model, **factory_kwargs
|
||||
)
|
||||
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
_skpm_dtype = src_key_padding_mask.dtype
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
src_key_padding_mask
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
|
||||
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||
|
||||
if self.norm_first:
|
||||
x = x + self._sa_block(
|
||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
assert embedding is None
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"{normalized_shape}, eps={eps}, "
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
)
|
||||
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
|
||||
|
||||
class IdentityNorm(nn.Module):
|
||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = transformer_encoder(src)
|
||||
"""
|
||||
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
__constants__ = ["batch_first", "norm_first"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
linear2_cls=linear2_self_attention_cls,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(
|
||||
d_model, dim_feedforward, **factory_kwargs
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = linear2_feedforward_cls(
|
||||
dim_feedforward, d_model, **factory_kwargs
|
||||
)
|
||||
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||
self.norm_first = norm_first
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
@ -30,9 +30,7 @@ class GruutPhonemizer:
|
||||
"«": "«",
|
||||
"»": "»",
|
||||
}
|
||||
self._punctuation_regexp: str = (
|
||||
rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
)
|
||||
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
|
||||
def _normalize_punctuation(self, text: str) -> str:
|
||||
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
||||
@ -53,13 +51,8 @@ class GruutPhonemizer:
|
||||
|
||||
def phonemize(self, text: str, espeak: bool = False) -> str:
|
||||
text_to_phonemize: str = self._normalize_punctuation(text)
|
||||
sents: List[Sentence] = [
|
||||
sent
|
||||
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
|
||||
]
|
||||
words: List[str] = [
|
||||
self._convert_punctuation(word) for word in itertools.chain(*sents)
|
||||
]
|
||||
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
||||
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
||||
return " ".join(words)
|
||||
|
||||
def transform(self, phonemes):
|
||||
|
@ -3,7 +3,9 @@
|
||||
PAD = "_"
|
||||
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
||||
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
IPA_LETTERS = (
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
||||
SPACE_ID = SYMBOLS.index(" ")
|
||||
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
||||
|
@ -2,12 +2,12 @@ import re
|
||||
|
||||
|
||||
def str2bool(str):
|
||||
return True if str.lower() == 'true' else False
|
||||
return True if str.lower() == "true" else False
|
||||
|
||||
|
||||
def get_newest_ckpt(string_list):
|
||||
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
||||
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
|
||||
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
|
||||
|
||||
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
||||
extracted_info = []
|
||||
@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
|
||||
step = int(match.group(2))
|
||||
extracted_info.append((epoch, step, string))
|
||||
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
||||
sorted_info = sorted(
|
||||
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||
# 获取最新的 ckpt 文件名
|
||||
newest_ckpt = sorted_info[0][2]
|
||||
return newest_ckpt
|
||||
@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
|
||||
# 文本存在且不为空时 return True
|
||||
def check_txt_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
text = file.readline().strip()
|
||||
assert text.strip() != ''
|
||||
assert text.strip() != ""
|
||||
return text
|
||||
except Exception:
|
||||
return False
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Initialize modules for espnet2 neural networks."""
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
|
||||
|
||||
|
||||
def write_args(args, path):
|
||||
args_dict = dict(
|
||||
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
|
||||
)
|
||||
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
|
||||
with open(path, "a") as args_file:
|
||||
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
||||
args_file.write(
|
||||
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
|
||||
)
|
||||
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
|
||||
args_file.write("==> Cmd:\n")
|
||||
args_file.write(str(sys.argv))
|
||||
args_file.write("\n==> args:\n")
|
||||
|
@ -23,9 +23,7 @@ class Snake(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
|
@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||
activation_results = anti_alias_activation_cuda.forward(
|
||||
inputs, up_ftr, down_ftr, alpha, beta
|
||||
)
|
||||
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
|
||||
|
||||
return activation_results
|
||||
|
||||
@ -61,17 +59,11 @@ class Activation1d(nn.Module):
|
||||
if self.act.__class__.__name__ == "Snake":
|
||||
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||
else:
|
||||
beta = (
|
||||
self.act.beta.data
|
||||
) # Snakebeta uses different params for alpha and beta
|
||||
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
|
||||
alpha = self.act.alpha.data
|
||||
if (
|
||||
not self.act.alpha_logscale
|
||||
): # Exp baked into cuda kernel, cancel it out with a log
|
||||
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
|
||||
alpha = torch.log(alpha)
|
||||
beta = torch.log(beta)
|
||||
|
||||
x = FusedAntiAliasActivation.apply(
|
||||
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
||||
)
|
||||
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
|
||||
return x
|
||||
|
@ -58,17 +58,13 @@ def load():
|
||||
srcpath / "anti_alias_activation.cpp",
|
||||
srcpath / "anti_alias_activation_cuda.cu",
|
||||
]
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
||||
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
||||
)
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
|
||||
|
||||
return anti_alias_activation_cuda
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
|
@ -27,9 +27,7 @@ else:
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff, half_width, kernel_size
|
||||
): # return filter [1,1,kernel_size]
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
|
||||
|
@ -11,18 +11,12 @@ class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
return x
|
||||
@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
|
@ -87,9 +87,7 @@ class AMPBlock1(torch.nn.Module):
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(
|
||||
self.convs2
|
||||
) # Total number of conv layers
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@ -283,9 +265,7 @@ class BigVGAN(
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# Pre-conv
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
|
||||
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
if h.resblock == "1":
|
||||
@ -293,9 +273,7 @@ class BigVGAN(
|
||||
elif h.resblock == "2":
|
||||
resblock_class = AMPBlock2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
||||
)
|
||||
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||
|
||||
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
@ -320,22 +298,14 @@ class BigVGAN(
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(
|
||||
resblock_class(h, ch, k, d, activation=h.activation)
|
||||
)
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# Post-conv
|
||||
activation_post = (
|
||||
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snake"
|
||||
else (
|
||||
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snakebeta"
|
||||
else None
|
||||
)
|
||||
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
|
||||
)
|
||||
if activation_post is None:
|
||||
raise NotImplementedError(
|
||||
@ -346,9 +316,7 @@ class BigVGAN(
|
||||
|
||||
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
||||
)
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||
|
||||
# Weight initialization
|
||||
for i in range(len(self.ups)):
|
||||
@ -451,13 +419,13 @@ class BigVGAN(
|
||||
# instantiate BigVGAN using h
|
||||
if use_cuda_kernel:
|
||||
print(
|
||||
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
)
|
||||
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||
|
||||
@ -485,7 +453,7 @@ class BigVGAN(
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
except RuntimeError:
|
||||
print(
|
||||
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
)
|
||||
model.remove_weight_norm()
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
|
@ -15,7 +15,7 @@ from torchaudio.transforms import Spectrogram, Resample
|
||||
from env import AttrDict
|
||||
from utils import get_padding
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
self.mpd_reshapes = h.mpd_reshapes
|
||||
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
|
||||
for rs in self.mpd_reshapes
|
||||
]
|
||||
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.resolution = resolution
|
||||
assert (
|
||||
len(self.resolution) == 3
|
||||
), f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
self.lrelu_slope = 0.1
|
||||
|
||||
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||
print(
|
||||
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
|
||||
)
|
||||
norm_f = (
|
||||
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
)
|
||||
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
|
||||
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
self.d_mult = cfg.discriminator_channel_mult
|
||||
if hasattr(cfg, "mrd_channel_mult"):
|
||||
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||
@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert (
|
||||
len(self.resolutions) == 3
|
||||
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
||||
assert len(self.resolutions) == 3, (
|
||||
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
||||
),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
# Remove DC offset
|
||||
@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
|
||||
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorCQT(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
|
||||
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
out_chs = min(
|
||||
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
||||
)
|
||||
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding(
|
||||
(self.kernel_size[0], self.kernel_size[0])
|
||||
),
|
||||
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -508,7 +482,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
||||
if self.cqtd_normalize_volume:
|
||||
print(
|
||||
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
)
|
||||
|
||||
def get_2d_padding(
|
||||
@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
# Multi-scale params to loop over
|
||||
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
|
||||
"cqtd_bins_per_octaves", [24, 36, 48]
|
||||
)
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
|
@ -35,9 +35,7 @@ def inference(a, h):
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the ground truth audio and resample if necessary
|
||||
wav, sr = librosa.load(
|
||||
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
|
||||
)
|
||||
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
|
||||
wav = torch.FloatTensor(wav).to(device)
|
||||
# Compute mel spectrogram from the ground truth audio
|
||||
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||
@ -48,9 +46,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
@ -61,9 +61,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
@ -6,13 +6,12 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy import signal
|
||||
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import functools
|
||||
@ -123,9 +122,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
B, C, T = wav.shape
|
||||
|
||||
if match_stride:
|
||||
assert (
|
||||
hop_length == window_length // 4
|
||||
), "For match_stride, hop must equal n_fft // 4"
|
||||
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||
pad = (window_length - hop_length) // 2
|
||||
else:
|
||||
@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
magnitude = torch.abs(stft)
|
||||
|
||||
nf = magnitude.shape[2]
|
||||
mel_basis = self.get_mel_filters(
|
||||
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
||||
)
|
||||
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||
@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
"""
|
||||
|
||||
loss = 0.0
|
||||
for n_mels, fmin, fmax, s in zip(
|
||||
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
||||
):
|
||||
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||
kwargs = {
|
||||
"n_mels": n_mels,
|
||||
"fmin": fmin,
|
||||
@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||
x_logmels = torch.log(
|
||||
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(
|
||||
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
|
||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
|
||||
# Loss functions
|
||||
def feature_loss(
|
||||
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
|
||||
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
@ -226,7 +212,6 @@ def feature_loss(
|
||||
def discriminator_loss(
|
||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
@ -243,7 +228,6 @@ def discriminator_loss(
|
||||
def generator_loss(
|
||||
disc_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
|
@ -86,9 +86,7 @@ def mel_spectrogram(
|
||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||
|
||||
@ -96,9 +94,7 @@ def mel_spectrogram(
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_size) // 2
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (padding, padding), mode="reflect"
|
||||
).squeeze(1)
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
@ -150,17 +146,13 @@ def get_dataset_filelist(a):
|
||||
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first training file: {training_files[0]}")
|
||||
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first validation file: {validation_files[0]}")
|
||||
|
||||
@ -171,9 +163,7 @@ def get_dataset_filelist(a):
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(
|
||||
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
|
||||
)
|
||||
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
|
||||
list_unseen_validation_files.append(unseen_validation_files)
|
||||
|
||||
return training_files, validation_files, list_unseen_validation_files
|
||||
@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
print("[INFO] checking dataset integrity...")
|
||||
for i in tqdm(range(len(self.audio_files))):
|
||||
assert os.path.exists(
|
||||
self.audio_files[i]
|
||||
), f"{self.audio_files[i]} not found"
|
||||
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
try:
|
||||
filename = self.audio_files[index]
|
||||
|
||||
@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
# Obtain randomized audio chunk
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
# Adjust segment size to crop if the source sr is different
|
||||
target_segment_size = math.ceil(
|
||||
self.segment_size
|
||||
* (source_sampling_rate / self.sampling_rate)
|
||||
)
|
||||
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
|
||||
else:
|
||||
target_segment_size = self.segment_size
|
||||
|
||||
# Compute upper bound index for the random chunk
|
||||
random_chunk_upper_bound = max(
|
||||
0, audio.shape[0] - target_segment_size
|
||||
)
|
||||
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
|
||||
|
||||
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||
if audio.shape[0] >= target_segment_size:
|
||||
@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||
assert (
|
||||
source_sampling_rate == self.sampling_rate
|
||||
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
assert source_sampling_rate == self.sampling_rate, (
|
||||
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
)
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start
|
||||
* self.hop_size : (mel_start + frames_per_seg)
|
||||
* self.hop_size,
|
||||
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
|
||||
]
|
||||
|
||||
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||
)
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
|
||||
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||
mel_loss = mel_spectrogram(
|
||||
@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Shape sanity checks
|
||||
assert (
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size
|
||||
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), (
|
||||
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
if self.fine_tuning:
|
||||
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
|
||||
)
|
||||
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
|
||||
return self[random.randrange(len(self))]
|
||||
|
||||
def __len__(self):
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations.Snake cuda vs. torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations, Snake CUDA vs. Torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
@ -57,7 +54,6 @@ def test_anti_alias_activation():
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from alias_free_activation.cuda import load
|
||||
|
||||
|
@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
|
||||
|
||||
|
||||
def get_mel(x, h):
|
||||
return mel_spectrogram(
|
||||
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
|
||||
)
|
||||
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test script to check CUDA kernel correctness."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
||||
parser.add_argument(
|
||||
"--checkpoint_file",
|
||||
type=str,
|
||||
@ -109,9 +105,7 @@ if __name__ == "__main__":
|
||||
diff += test_result.mean(dim=-1).item()
|
||||
|
||||
diff /= num_sample
|
||||
if (
|
||||
diff <= 2e-3
|
||||
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
print(
|
||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
||||
f"\n > mean_difference={diff}"
|
||||
@ -175,8 +169,8 @@ if __name__ == "__main__":
|
||||
audio_second = audio_length_total / h.sampling_rate
|
||||
khz_original = audio_length_total / toc_total_original / 1000
|
||||
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
||||
|
||||
# Print results
|
||||
print(
|
||||
|
@ -77,24 +77,18 @@ def train(rank, a, h):
|
||||
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||
print(
|
||||
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||
mrd = MultiBandDiscriminator(h).to(device)
|
||||
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||
print(
|
||||
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||
else: # Fallback to original MRD in BigVGAN-v1
|
||||
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||
|
||||
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||
if h.get("use_multiscale_melloss", False):
|
||||
print(
|
||||
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
|
||||
)
|
||||
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
|
||||
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||
sampling_rate=h.sampling_rate
|
||||
) # NOTE: accepts waveform as input
|
||||
@ -114,9 +108,7 @@ def train(rank, a, h):
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||
cp_g = scan_checkpoint(
|
||||
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
|
||||
)
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
|
||||
cp_do = scan_checkpoint(
|
||||
a.checkpoint_path,
|
||||
prefix="do_",
|
||||
@ -143,9 +135,7 @@ def train(rank, a, h):
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(
|
||||
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
|
||||
)
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(
|
||||
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||
h.learning_rate,
|
||||
@ -156,12 +146,8 @@ def train(rank, a, h):
|
||||
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
# Define training and validation datasets
|
||||
|
||||
@ -169,9 +155,7 @@ def train(rank, a, h):
|
||||
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||
Example: trained on LibriTTS, validate on VCTK
|
||||
"""
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = (
|
||||
get_dataset_filelist(a)
|
||||
)
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
trainset = MelDataset(
|
||||
training_filelist,
|
||||
@ -324,33 +308,26 @@ def train(rank, a, h):
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
|
||||
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
|
||||
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
||||
|
||||
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||
if (
|
||||
not "nonspeech" in mode
|
||||
): # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
|
||||
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
# Resample to 16000 for pesq
|
||||
y_16k = pesq_resampler(y)
|
||||
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
y_g_hat_int_16k = (
|
||||
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
)
|
||||
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||
|
||||
# MRSTFT calculation
|
||||
min_t = min(y.size(-1), y_g_hat.size(-1))
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
|
||||
|
||||
# Log audio and figures to Tensorboard
|
||||
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||
if steps >= 0:
|
||||
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y[0],
|
||||
os.path.join(
|
||||
@ -373,9 +350,7 @@ def train(rank, a, h):
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y_g_hat[0, 0],
|
||||
os.path.join(
|
||||
@ -487,15 +462,11 @@ def train(rank, a, h):
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
||||
y_df_hat_r, y_df_hat_g
|
||||
)
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MRD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
||||
y_ds_hat_r, y_ds_hat_g
|
||||
)
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
@ -505,17 +476,11 @@ def train(rank, a, h):
|
||||
# Whether to freeze D for initial training steps
|
||||
if steps >= a.freeze_step:
|
||||
loss_disc_all.backward()
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
|
||||
mpd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
|
||||
mrd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
|
||||
optim_d.step()
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
|
||||
grad_norm_mpd = 0.0
|
||||
grad_norm_mrd = 0.0
|
||||
|
||||
@ -523,9 +488,7 @@ def train(rank, a, h):
|
||||
optim_g.zero_grad()
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
lambda_melloss = h.get(
|
||||
"lambda_melloss", 45.0
|
||||
) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||
@ -542,27 +505,19 @@ def train(rank, a, h):
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
|
||||
if steps >= a.freeze_step:
|
||||
loss_gen_all = (
|
||||
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(
|
||||
generator.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to stdout
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
|
||||
print(
|
||||
f"Steps: {steps:d}, "
|
||||
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||
@ -577,11 +532,7 @@ def train(rank, a, h):
|
||||
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{
|
||||
"generator": (
|
||||
generator.module if h.num_gpus > 1 else generator
|
||||
).state_dict()
|
||||
},
|
||||
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
|
||||
)
|
||||
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||
save_checkpoint(
|
||||
@ -598,9 +549,7 @@ def train(rank, a, h):
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to tensorboard
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||
@ -612,12 +561,8 @@ def train(rank, a, h):
|
||||
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||
|
||||
# Validation
|
||||
@ -660,9 +605,7 @@ def train(rank, a, h):
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
|
||||
)
|
||||
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
|
||||
|
||||
|
||||
def main():
|
||||
@ -674,12 +617,8 @@ def main():
|
||||
|
||||
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||
parser.add_argument(
|
||||
"--input_training_file", default="tests/LibriTTS/train-full.txt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
|
||||
)
|
||||
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
|
||||
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
|
||||
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_wavs_dir",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,9 @@
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
@ -18,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'])
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-"])
|
||||
|
||||
def get_first(text:str) -> str:
|
||||
|
||||
def get_first(text: str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
|
||||
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
@ -38,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if (len(text) > 0):
|
||||
if len(text) > 0:
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
@ -46,28 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
|
||||
print(f'############ {i18n("切分文本")} ############')
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
print(f'############ {i18n("提取文本Bert特征")} ############')
|
||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
if phones is None or norm_text=="":
|
||||
if phones is None or norm_text == "":
|
||||
continue
|
||||
res={
|
||||
res = {
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
@ -75,11 +74,11 @@ class TextPreprocessor:
|
||||
result.append(res)
|
||||
return result
|
||||
|
||||
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
|
||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||
text = text.strip("\n")
|
||||
if len(text) == 0:
|
||||
return []
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
if text[0] not in splits and len(get_first(text)) < 4:
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
@ -95,18 +94,18 @@ class TextPreprocessor:
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if len(text.strip()) == 0:
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
if text[-1] not in splits:
|
||||
text += "。" if lang != "en" else "."
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
if (len(text) > 510):
|
||||
if len(text) > 510:
|
||||
texts.extend(split_big_text(text))
|
||||
else:
|
||||
texts.append(text)
|
||||
@ -115,78 +114,79 @@ class TextPreprocessor:
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1"
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
|
||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
with self.bert_lock:
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
# language = language.replace("all_","")
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "all_zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"zh",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"yue",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
if language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
# language = language.replace("all_","")
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "all_zh":
|
||||
if re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext, "zh", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext, "yue", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text,language,version,final=True)
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||
|
||||
return phones, bert, norm_text
|
||||
return phones, bert, norm_text
|
||||
|
||||
|
||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
@ -201,14 +201,14 @@ class TextPreprocessor:
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text:str, language:str, version:str="v2"):
|
||||
language = language.replace("all_","")
|
||||
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||
language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
|
||||
language=language.replace("all_","")
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
else:
|
||||
@ -219,21 +219,19 @@ class TextPreprocessor:
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
def filter_text(self,texts):
|
||||
_text=[]
|
||||
if all(text in [None, " ", "\n",""] for text in texts):
|
||||
def filter_text(self, texts):
|
||||
_text = []
|
||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||
raise ValueError(i18n("请输入有效文本"))
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
if text in [None, " ", ""]:
|
||||
pass
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(self,text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
def replace_consecutive_punctuation(self, text):
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
@ -1,40 +1,56 @@
|
||||
|
||||
|
||||
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
||||
METHODS = dict()
|
||||
|
||||
def get_method(name:str)->Callable:
|
||||
|
||||
def get_method(name: str) -> Callable:
|
||||
method = METHODS.get(name, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Method {name} not found")
|
||||
return method
|
||||
|
||||
def get_method_names()->list:
|
||||
|
||||
def get_method_names() -> list:
|
||||
return list(METHODS.keys())
|
||||
|
||||
|
||||
def register_method(name):
|
||||
def decorator(func):
|
||||
METHODS[name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
splits = {
|
||||
",",
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
",",
|
||||
".",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
":",
|
||||
":",
|
||||
"—",
|
||||
"…",
|
||||
}
|
||||
|
||||
|
||||
def split_big_text(text, max_len=510):
|
||||
# 定义全角和半角标点符号
|
||||
punctuation = "".join(splits)
|
||||
|
||||
# 切割文本
|
||||
segments = re.split('([' + punctuation + '])', text)
|
||||
segments = re.split("([" + punctuation + "])", text)
|
||||
|
||||
# 初始化结果列表和当前片段
|
||||
result = []
|
||||
current_segment = ''
|
||||
current_segment = ""
|
||||
|
||||
for segment in segments:
|
||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||
@ -51,7 +67,6 @@ def split_big_text(text, max_len=510):
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in splits:
|
||||
@ -90,7 +105,7 @@ def cut1(inp):
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
@ -123,6 +138,7 @@ def cut2(inp):
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# 按中文句号。切
|
||||
@register_method("cut3")
|
||||
def cut3(inp):
|
||||
@ -131,26 +147,28 @@ def cut3(inp):
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
#按英文句号.切
|
||||
|
||||
# 按英文句号.切
|
||||
@register_method("cut4")
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
|
||||
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# 按标点符号切
|
||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||
@register_method("cut5")
|
||||
def cut5(inp):
|
||||
inp = inp.strip("\n")
|
||||
punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
|
||||
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||
mergeitems = []
|
||||
items = []
|
||||
|
||||
for i, char in enumerate(inp):
|
||||
if char in punds:
|
||||
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
items.append(char)
|
||||
else:
|
||||
items.append(char)
|
||||
@ -166,8 +184,6 @@ def cut5(inp):
|
||||
return "\n".join(opt)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
method = get_method("cut5")
|
||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||
|
||||
|
@ -1,5 +1,13 @@
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.insert(0, now_dir)
|
||||
from text.g2pw import G2PWPinyin
|
||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
||||
|
||||
g2pw = G2PWPinyin(
|
||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
|
@ -3,7 +3,6 @@
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from my_utils import load_audio
|
||||
from text import cleaned_text_to_sequence
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
@ -33,6 +32,7 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||||
@ -41,6 +41,7 @@ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
t2s_model = t2s_model.eval()
|
||||
return t2s_model
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
@ -57,39 +58,35 @@ def logits_to_probs(
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
# Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sample(
|
||||
logits,
|
||||
@ -100,15 +97,20 @@ def sample(
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
|
||||
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
|
||||
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@ -158,6 +160,7 @@ class DictToAttrRecursive(dict):
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
@ -171,23 +174,24 @@ class T2SMLP:
|
||||
x = F.linear(x, self.w2, self.b2)
|
||||
return x
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1: float,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2: float,
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1: float,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2: float,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
@ -206,7 +210,7 @@ class T2SBlock:
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
@ -215,7 +219,7 @@ class T2SBlock:
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
|
||||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
@ -232,22 +236,20 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device!= padding_mask.device:
|
||||
if self.false.device != padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
||||
x_item = x[i, idx, :].unsqueeze(0)
|
||||
attn_item = attn[i, idx, :].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
@ -256,13 +258,11 @@ class T2SBlock:
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x[i, idx, :] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@ -273,7 +273,7 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
|
||||
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
@ -289,14 +289,12 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@ -307,37 +305,35 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
|
||||
k_cache : list[torch.Tensor] = []
|
||||
v_cache : list[torch.Tensor] = []
|
||||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||
k_cache: list[torch.Tensor] = []
|
||||
v_cache: list[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
||||
k_cache.append(k_cache_)
|
||||
v_cache.append(v_cache_)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
k_cache: list[torch.Tensor],
|
||||
v_cache: list[torch.Tensor]):
|
||||
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path)
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
@ -348,7 +344,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
@ -360,12 +356,13 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
def __init__(self,raw_t2s:Text2SemanticLightningModule):
|
||||
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
||||
super(T2SModel, self).__init__()
|
||||
self.model_dim = raw_t2s.model.model_dim
|
||||
self.embedding_dim = raw_t2s.model.embedding_dim
|
||||
@ -374,7 +371,7 @@ class T2SModel(nn.Module):
|
||||
self.vocab_size = raw_t2s.model.vocab_size
|
||||
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
||||
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
||||
self.EOS:int = int(raw_t2s.model.EOS)
|
||||
self.EOS: int = int(raw_t2s.model.EOS)
|
||||
self.norm_first = raw_t2s.model.norm_first
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
self.hz = 50
|
||||
@ -393,12 +390,7 @@ class T2SModel(nn.Module):
|
||||
|
||||
for i in range(self.num_layers):
|
||||
layer = h.layers[i]
|
||||
t2smlp = T2SMLP(
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
)
|
||||
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
|
||||
|
||||
block = T2SBlock(
|
||||
self.num_head,
|
||||
@ -413,7 +405,7 @@ class T2SModel(nn.Module):
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
layer.norm2.eps,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
@ -427,19 +419,26 @@ class T2SModel(nn.Module):
|
||||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
|
||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor,top_k:LongTensor):
|
||||
def forward(
|
||||
self,
|
||||
prompts: LongTensor,
|
||||
ref_seq: LongTensor,
|
||||
text_seq: LongTensor,
|
||||
ref_bert: torch.Tensor,
|
||||
text_bert: torch.Tensor,
|
||||
top_k: LongTensor,
|
||||
):
|
||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
|
||||
x = self.ar_text_embedding(all_phoneme_ids)
|
||||
x = x + self.bert_proj(bert.transpose(1, 2))
|
||||
x:torch.Tensor = self.ar_text_position(x)
|
||||
x: torch.Tensor = self.ar_text_position(x)
|
||||
|
||||
early_stop_num = self.early_stop_num
|
||||
|
||||
|
||||
#[1,N,512] [1,N]
|
||||
# [1,N,512] [1,N]
|
||||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
y = prompts
|
||||
# x_example = x[:,:,0] * 0.0
|
||||
@ -465,11 +464,13 @@ class T2SModel(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.num_head, -1, -1)
|
||||
.view(bsz, self.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
idx = 0
|
||||
top_k = int(top_k)
|
||||
@ -481,17 +482,19 @@ class T2SModel(nn.Module):
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
stop = False
|
||||
# for idx in range(1, 50):
|
||||
for idx in range(1, 1500):
|
||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
@ -508,20 +511,22 @@ class T2SModel(nn.Module):
|
||||
break
|
||||
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
y[0,-1] = 0
|
||||
y[0, -1] = 0
|
||||
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
bert_path = os.environ.get(
|
||||
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
)
|
||||
|
||||
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
||||
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
||||
phone_level_feature = []
|
||||
for i in range(word2ph.shape[0]):
|
||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||
@ -530,39 +535,45 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
||||
# [sum(word2ph), 1024]
|
||||
return phone_level_feature
|
||||
|
||||
|
||||
class MyBertModel(torch.nn.Module):
|
||||
def __init__(self, bert_model):
|
||||
super(MyBertModel, self).__init__()
|
||||
self.bert = bert_model
|
||||
|
||||
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
|
||||
):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||||
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
||||
return build_phone_level_feature(res, word2ph)
|
||||
|
||||
|
||||
class SSLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ssl = cnhubert.get_model().model
|
||||
|
||||
def forward(self, ref_audio_16k)-> torch.Tensor:
|
||||
def forward(self, ref_audio_16k) -> torch.Tensor:
|
||||
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
return ssl_content
|
||||
|
||||
|
||||
class ExportSSLModel(torch.nn.Module):
|
||||
def __init__(self,ssl:SSLModel):
|
||||
def __init__(self, ssl: SSLModel):
|
||||
super().__init__()
|
||||
self.ssl = ssl
|
||||
|
||||
def forward(self, ref_audio:torch.Tensor):
|
||||
def forward(self, ref_audio: torch.Tensor):
|
||||
return self.ssl(ref_audio)
|
||||
|
||||
@torch.jit.export
|
||||
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
audio = resamplex(ref_audio,src_sr,dst_sr).float()
|
||||
def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
||||
return audio
|
||||
|
||||
|
||||
def export_bert(output_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
|
||||
@ -570,33 +581,34 @@ def export_bert(output_path):
|
||||
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
||||
word2ph = []
|
||||
for c in text:
|
||||
if c in [',','。',':','?',",",".","?"]:
|
||||
if c in [",", "。", ":", "?", ",", ".", "?"]:
|
||||
word2ph.append(1)
|
||||
else:
|
||||
word2ph.append(2)
|
||||
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int()
|
||||
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
|
||||
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
||||
my_bert_model = MyBertModel(bert_model)
|
||||
|
||||
ref_bert_inputs = {
|
||||
'input_ids': ref_bert_inputs['input_ids'],
|
||||
'attention_mask': ref_bert_inputs['attention_mask'],
|
||||
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
||||
'word2ph': ref_bert_inputs['word2ph']
|
||||
"input_ids": ref_bert_inputs["input_ids"],
|
||||
"attention_mask": ref_bert_inputs["attention_mask"],
|
||||
"token_type_ids": ref_bert_inputs["token_type_ids"],
|
||||
"word2ph": ref_bert_inputs["word2ph"],
|
||||
}
|
||||
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
|
||||
|
||||
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
|
||||
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
||||
output_path = os.path.join(output_path, "bert_model.pt")
|
||||
my_bert_model.save(output_path)
|
||||
print('#### exported bert ####')
|
||||
print("#### exported bert ####")
|
||||
|
||||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
|
||||
|
||||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
print(f"目录已创建: {output_path}")
|
||||
@ -606,27 +618,28 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
if export_bert_and_ssl:
|
||||
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
|
||||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||
torch.jit.script(s).save(ssl_path)
|
||||
print('#### exported ssl ####')
|
||||
print("#### exported ssl ####")
|
||||
export_bert(output_path)
|
||||
else:
|
||||
s = ExportSSLModel(ssl)
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
|
||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
|
||||
ssl_content = ssl(ref_audio).to(device)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path).to(device)
|
||||
vits.eval()
|
||||
|
||||
@ -634,18 +647,18 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print('#### get_raw_t2s_model ####')
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m).to(device)
|
||||
print('#### script t2s_m ####')
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s,vits).to(device)
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
||||
gpt_sovits.eval()
|
||||
|
||||
ref_audio_sr = s.resample(ref_audio,16000,32000).to(device)
|
||||
ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
|
||||
|
||||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||
@ -658,32 +671,28 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert,
|
||||
top_k))
|
||||
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
)
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
print('#### exported gpt_sovits ####')
|
||||
print("#### exported gpt_sovits ####")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()#.to(ref_audio.device)
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device)
|
||||
return ref_audio_16k,ref_audio_sr
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
||||
return ref_audio_16k, ref_audio_sr
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float()
|
||||
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
||||
|
||||
|
||||
class GPT_SoVITS(nn.Module):
|
||||
def __init__(self, t2s:T2SModel,vits:VitsModel):
|
||||
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
||||
super().__init__()
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
@ -710,12 +719,11 @@ class GPT_SoVITS(nn.Module):
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
gpt_path = args.gpt_model
|
||||
@ -726,7 +734,7 @@ def test():
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||
# bert = MyBertModel(bert_model)
|
||||
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
||||
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
|
||||
|
||||
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
||||
# raw_t2s = get_raw_t2s_model(dict_s1)
|
||||
@ -740,95 +748,97 @@ def test():
|
||||
|
||||
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
||||
# ssl.eval()
|
||||
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda')
|
||||
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
|
||||
|
||||
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda')
|
||||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
|
||||
|
||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||
ref_seq = torch.LongTensor([ref_seq_id])
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
||||
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
||||
|
||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2')
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
|
||||
|
||||
test_bert = tokenizer(text, return_tensors="pt")
|
||||
word2ph = []
|
||||
for c in text:
|
||||
if c in [',','。',':','?',"?",",","."]:
|
||||
if c in [",", "。", ":", "?", "?", ",", "."]:
|
||||
word2ph.append(1)
|
||||
else:
|
||||
word2ph.append(2)
|
||||
test_bert['word2ph'] = torch.Tensor(word2ph).int()
|
||||
test_bert["word2ph"] = torch.Tensor(word2ph).int()
|
||||
|
||||
test_bert = my_bert(
|
||||
test_bert['input_ids'].to('cuda'),
|
||||
test_bert['attention_mask'].to('cuda'),
|
||||
test_bert['token_type_ids'].to('cuda'),
|
||||
test_bert['word2ph'].to('cuda')
|
||||
test_bert["input_ids"].to("cuda"),
|
||||
test_bert["attention_mask"].to("cuda"),
|
||||
test_bert["token_type_ids"].to("cuda"),
|
||||
test_bert["word2ph"].to("cuda"),
|
||||
)
|
||||
|
||||
text_seq = torch.LongTensor([text_seq_id])
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
|
||||
print('text_bert:',text_bert.shape,text_bert)
|
||||
print('test_bert:',test_bert.shape,test_bert)
|
||||
print(torch.allclose(text_bert.to('cuda'),test_bert))
|
||||
print("text_bert:", text_bert.shape, text_bert)
|
||||
print("test_bert:", test_bert.shape, test_bert)
|
||||
print(torch.allclose(text_bert.to("cuda"), test_bert))
|
||||
|
||||
print('text_seq:',text_seq.shape)
|
||||
print('text_bert:',text_bert.shape,text_bert.type())
|
||||
print("text_seq:", text_seq.shape)
|
||||
print("text_bert:", text_bert.shape, text_bert.type())
|
||||
|
||||
#[1,N]
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda')
|
||||
print('ref_audio:',ref_audio.shape)
|
||||
# [1,N]
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
|
||||
print("ref_audio:", ref_audio.shape)
|
||||
|
||||
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
||||
print('start ssl')
|
||||
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
||||
print("start ssl")
|
||||
ssl_content = ssl(ref_audio)
|
||||
|
||||
print('start gpt_sovits:')
|
||||
print('ssl_content:',ssl_content.shape)
|
||||
print('ref_audio_sr:',ref_audio_sr.shape)
|
||||
print('ref_seq:',ref_seq.shape)
|
||||
ref_seq=ref_seq.to('cuda')
|
||||
print('text_seq:',text_seq.shape)
|
||||
text_seq=text_seq.to('cuda')
|
||||
print('ref_bert:',ref_bert.shape)
|
||||
ref_bert=ref_bert.to('cuda')
|
||||
print('text_bert:',text_bert.shape)
|
||||
text_bert=text_bert.to('cuda')
|
||||
print("start gpt_sovits:")
|
||||
print("ssl_content:", ssl_content.shape)
|
||||
print("ref_audio_sr:", ref_audio_sr.shape)
|
||||
print("ref_seq:", ref_seq.shape)
|
||||
ref_seq = ref_seq.to("cuda")
|
||||
print("text_seq:", text_seq.shape)
|
||||
text_seq = text_seq.to("cuda")
|
||||
print("ref_bert:", ref_bert.shape)
|
||||
ref_bert = ref_bert.to("cuda")
|
||||
print("text_bert:", text_bert.shape)
|
||||
text_bert = text_bert.to("cuda")
|
||||
|
||||
top_k = torch.LongTensor([5]).to('cuda')
|
||||
top_k = torch.LongTensor([5]).to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||
print('start write wav')
|
||||
print("start write wav")
|
||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
import text
|
||||
import json
|
||||
|
||||
def export_symbel(version='v2'):
|
||||
if version=='v1':
|
||||
|
||||
def export_symbel(version="v2"):
|
||||
if version == "v1":
|
||||
symbols = text._symbol_to_id_v1
|
||||
with open(f"onnx/symbols_v1.json", "w") as file:
|
||||
with open("onnx/symbols_v1.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
else:
|
||||
symbols = text._symbol_to_id_v2
|
||||
with open(f"onnx/symbols_v2.json", "w") as file:
|
||||
with open("onnx/symbols_v2.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model")
|
||||
parser.add_argument('--device', help="Device to use")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||
parser.add_argument("--device", help="Device to use")
|
||||
|
||||
args = parser.parse_args()
|
||||
export(
|
||||
@ -841,9 +851,11 @@ def main():
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
)
|
||||
|
||||
|
||||
import inference_webui
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference_webui.is_half=False
|
||||
inference_webui.dtype=torch.float32
|
||||
inference_webui.is_half = False
|
||||
inference_webui.dtype = torch.float32
|
||||
main()
|
||||
# test()
|
||||
|
@ -6,16 +6,16 @@ from export_torch_script import (
|
||||
spectrogram_torch,
|
||||
)
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from feature_extractor import cnhubert
|
||||
from inference_webui import get_phones_and_bert
|
||||
import librosa
|
||||
from module import commons
|
||||
from module.mel_processing import mel_spectrogram_torch, spectral_normalize_torch
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
from module.models_onnx import CFM, SynthesizerTrnV3
|
||||
import numpy as np
|
||||
import torch._dynamo.config
|
||||
import torchaudio
|
||||
import logging, uvicorn
|
||||
import logging
|
||||
import uvicorn
|
||||
import torch
|
||||
import soundfile
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
@ -32,7 +32,6 @@ now_dir = os.getcwd()
|
||||
|
||||
|
||||
class MelSpectrgram(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype,
|
||||
@ -48,14 +47,12 @@ class MelSpectrgram(torch.nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
||||
self.n_fft:int = n_fft
|
||||
self.hop_size:int = hop_size
|
||||
self.win_size:int = win_size
|
||||
self.center:bool = center
|
||||
self.n_fft: int = n_fft
|
||||
self.hop_size: int = hop_size
|
||||
self.win_size: int = win_size
|
||||
self.center: bool = center
|
||||
|
||||
def forward(self, y):
|
||||
y = torch.nn.functional.pad(
|
||||
@ -172,9 +169,7 @@ class ExportCFM(torch.nn.Module):
|
||||
):
|
||||
T_min = fea_ref.size(2)
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
cfm_res = self.cfm(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps
|
||||
)
|
||||
cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
@ -198,6 +193,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def norm_spec(x):
|
||||
spec_min = -12
|
||||
@ -212,7 +208,6 @@ def denorm_spec(x):
|
||||
|
||||
|
||||
class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
|
||||
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
||||
super().__init__()
|
||||
self.hps = hps
|
||||
@ -231,15 +226,15 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
center=False,
|
||||
)
|
||||
# self.dtype = dtype
|
||||
self.filter_length:int = hps.data.filter_length
|
||||
self.sampling_rate:int = hps.data.sampling_rate
|
||||
self.hop_length:int = hps.data.hop_length
|
||||
self.win_length:int = hps.data.win_length
|
||||
self.filter_length: int = hps.data.filter_length
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k:torch.FloatTensor,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0,
|
||||
phoneme_ids1,
|
||||
bert1,
|
||||
@ -255,18 +250,14 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
center=False,
|
||||
).to(ssl_content.dtype)
|
||||
|
||||
|
||||
codes = self.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0)
|
||||
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
pred_semantic = self.t2s_m(
|
||||
prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
|
||||
ge = self.vq_model.create_ge(refer)
|
||||
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
|
||||
@ -293,6 +284,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
|
||||
return fea_ref, fea_todo, mel2
|
||||
|
||||
|
||||
class GPTSoVITSV3(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
||||
super().__init__()
|
||||
@ -303,9 +295,9 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
ssl_content,
|
||||
ref_audio_32k:torch.FloatTensor,
|
||||
phoneme_ids0:torch.LongTensor,
|
||||
phoneme_ids1:torch.LongTensor,
|
||||
ref_audio_32k: torch.FloatTensor,
|
||||
phoneme_ids0: torch.LongTensor,
|
||||
phoneme_ids1: torch.LongTensor,
|
||||
bert1,
|
||||
bert2,
|
||||
top_k: torch.LongTensor,
|
||||
@ -313,7 +305,9 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
):
|
||||
# current_time = datetime.now()
|
||||
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
@ -331,7 +325,13 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
||||
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
||||
if complete_len != 0:
|
||||
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype)], 2)
|
||||
fea_todo_chunk = torch.cat(
|
||||
[
|
||||
fea_todo_chunk,
|
||||
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
idx += chunk_len
|
||||
@ -343,13 +343,13 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
global bigvgan_model
|
||||
from BigVGAN import bigvgan
|
||||
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x"
|
||||
% (now_dir,),
|
||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||
use_cuda_kernel=False,
|
||||
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||
# remove weight norm in the model and set to eval mode
|
||||
@ -467,10 +467,7 @@ def export_cfm(
|
||||
cfm = e_cfm.cfm
|
||||
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = (
|
||||
torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype)
|
||||
* temperature
|
||||
)
|
||||
x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
||||
print("x:", x.shape, x.dtype)
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
@ -565,11 +562,7 @@ def export():
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = sovits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
@ -626,10 +619,7 @@ def export():
|
||||
"create_ge": refer,
|
||||
}
|
||||
|
||||
|
||||
trace_vq_model = torch.jit.trace_module(
|
||||
sovits.vq_model, inputs, optimize=True
|
||||
)
|
||||
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
|
||||
trace_vq_model.save("onnx/ad/vq_model.pt")
|
||||
|
||||
print(fea_ref.shape, fea_ref.dtype, ge.shape)
|
||||
@ -714,9 +704,7 @@ def export():
|
||||
|
||||
idx += chunk_len
|
||||
|
||||
cfm_res, fea_ref, mel2 = export_cfm_(
|
||||
fea_ref, fea_todo_chunk, mel2, sample_steps
|
||||
)
|
||||
cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||
cfm_resss.append(cfm_res)
|
||||
continue
|
||||
|
||||
@ -726,9 +714,7 @@ def export():
|
||||
with torch.inference_mode():
|
||||
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
||||
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
||||
bigvgan_model_ = torch.jit.trace(
|
||||
bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)
|
||||
)
|
||||
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
@ -748,7 +734,6 @@ def test_export(
|
||||
bigvgan,
|
||||
output,
|
||||
):
|
||||
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
@ -773,13 +758,9 @@ def test_export(
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
|
||||
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
||||
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
@ -799,8 +780,18 @@ def test_export(
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info("start inference %s", current_time)
|
||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
||||
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||
print(
|
||||
ssl_content.shape,
|
||||
ref_audio_32k.shape,
|
||||
phoneme_ids0.shape,
|
||||
phoneme_ids1.shape,
|
||||
bert1.shape,
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
|
||||
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||
)
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
print(fea_ref.shape, fea_todo.shape, mel2.shape)
|
||||
|
||||
@ -812,7 +803,6 @@ def test_export(
|
||||
wav_gen_length = fea_todo.shape[2] * 256
|
||||
|
||||
while 1:
|
||||
|
||||
current_time = datetime.now()
|
||||
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||
@ -861,7 +851,6 @@ def test_export1(
|
||||
gpt_sovits_v3,
|
||||
output,
|
||||
):
|
||||
|
||||
# hps = sovits.hps
|
||||
ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
@ -886,14 +875,10 @@ def test_export1(
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
|
||||
|
||||
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
||||
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
@ -913,11 +898,19 @@ def test_export1(
|
||||
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info("start inference %s", current_time)
|
||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
||||
print(
|
||||
ssl_content.shape,
|
||||
ref_audio_32k.shape,
|
||||
phoneme_ids0.shape,
|
||||
phoneme_ids1.shape,
|
||||
bert1.shape,
|
||||
bert2.shape,
|
||||
top_k.shape,
|
||||
)
|
||||
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
|
||||
wav_gen = torch.cat([wav_gen,zero_wav_torch],0)
|
||||
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
||||
|
||||
audio = wav_gen.cpu().detach().numpy()
|
||||
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||
@ -929,7 +922,6 @@ import time
|
||||
|
||||
|
||||
def test_():
|
||||
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
|
||||
# cfm = ExportCFM(sovits.cfm)
|
||||
@ -942,7 +934,7 @@ def test_():
|
||||
|
||||
cfm.eval()
|
||||
|
||||
logger.info(f"cfm ok")
|
||||
logger.info("cfm ok")
|
||||
|
||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||
# v2 的 gpt 也可以用
|
||||
@ -957,17 +949,14 @@ def test_():
|
||||
t2s_m = torch.jit.script(t2s_m)
|
||||
t2s_m.eval()
|
||||
# t2s_m.top_k = 15
|
||||
logger.info(f"t2s_m ok")
|
||||
logger.info("t2s_m ok")
|
||||
|
||||
|
||||
vq_model: torch.jit.ScriptModule = torch.jit.load(
|
||||
"onnx/ad/vq_model.pt", map_location=device
|
||||
)
|
||||
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
|
||||
# vq_model = torch.jit.optimize_for_inference(vq_model)
|
||||
# vq_model = vq_model.half().to(device)
|
||||
vq_model.eval()
|
||||
# vq_model = sovits.vq_model
|
||||
logger.info(f"vq_model ok")
|
||||
logger.info("vq_model ok")
|
||||
|
||||
# gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
|
||||
# gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
|
||||
@ -975,7 +964,7 @@ def test_():
|
||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
||||
# gpt_sovits_v3_half.eval()
|
||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||
logger.info(f"gpt_sovits_v3_half ok")
|
||||
logger.info("gpt_sovits_v3_half ok")
|
||||
|
||||
# init_bigvgan()
|
||||
# global bigvgan_model
|
||||
@ -985,7 +974,7 @@ def test_():
|
||||
bigvgan_model = bigvgan_model.cuda()
|
||||
bigvgan_model.eval()
|
||||
|
||||
logger.info(f"bigvgan ok")
|
||||
logger.info("bigvgan ok")
|
||||
|
||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||
@ -1020,8 +1009,9 @@ def test_():
|
||||
# "out2.wav",
|
||||
# )
|
||||
|
||||
|
||||
def test_export_gpt_sovits_v3():
|
||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt",map_location=device)
|
||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
||||
# test_export1(
|
||||
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||
# gpt_sovits_v3,
|
||||
|
@ -11,7 +11,6 @@ from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
@ -28,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
|
||||
|
||||
from module.commons import sequence_mask
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
@ -130,26 +130,24 @@ class DiT(nn.Module):
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def forward(#x, prompt_x, x_lens, t, style,cond
|
||||
self,#d is channel,n is T
|
||||
def forward( # x, prompt_x, x_lens, t, style,cond
|
||||
self, # d is channel,n is T
|
||||
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond0: float["b n d"], # masked cond audio # noqa: F722
|
||||
x_lens,
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
dt_base_bootstrap,
|
||||
dt_base_bootstrap,
|
||||
text0, # : int["b nt"] # noqa: F722#####condition feature
|
||||
use_grad_ckpt=False, # bool
|
||||
###no-use
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
|
||||
):
|
||||
|
||||
x=x0.transpose(2,1)
|
||||
cond=cond0.transpose(2,1)
|
||||
text=text0.transpose(2,1)
|
||||
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
|
||||
x = x0.transpose(2, 1)
|
||||
cond = cond0.transpose(2, 1)
|
||||
text = text0.transpose(2, 1)
|
||||
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
||||
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
@ -158,8 +156,8 @@ class DiT(nn.Module):
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
dt = self.d_embed(dt_base_bootstrap)
|
||||
t+=dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
|
||||
t += dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
@ -391,6 +391,7 @@ class Attention(nn.Module):
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
# from torch.nn.attention import SDPBackend
|
||||
# torch.backends.cuda.enable_flash_sdp(True)
|
||||
class AttnProcessor:
|
||||
@ -545,6 +546,7 @@ class JointAttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
@ -1,6 +1,3 @@
|
||||
from . import cnhubert, whisper_enc
|
||||
|
||||
content_module_map = {
|
||||
'cnhubert': cnhubert,
|
||||
'whisper': whisper_enc
|
||||
}
|
||||
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}
|
||||
|
@ -1,14 +1,11 @@
|
||||
import time
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import soundfile as sf
|
||||
import os
|
||||
from transformers import logging as tf_logging
|
||||
|
||||
tf_logging.set_verbosity_error()
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
from transformers import (
|
||||
@ -23,21 +20,19 @@ cnhubert_base_path = None
|
||||
|
||||
|
||||
class CNHubert(nn.Module):
|
||||
def __init__(self, base_path:str=None):
|
||||
def __init__(self, base_path: str = None):
|
||||
super().__init__()
|
||||
if base_path is None:
|
||||
base_path = cnhubert_base_path
|
||||
if os.path.exists(base_path):...
|
||||
else:raise FileNotFoundError(base_path)
|
||||
if os.path.exists(base_path):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(base_path)
|
||||
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_path, local_files_only=True
|
||||
)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
|
||||
|
||||
def forward(self, x):
|
||||
input_values = self.feature_extractor(
|
||||
x, return_tensors="pt", sampling_rate=16000
|
||||
).input_values.to(x.device)
|
||||
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||
feats = self.model(input_values)["last_hidden_state"]
|
||||
return feats
|
||||
|
||||
|
@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
|
||||
feature_len = mel.shape[-1] // 2
|
||||
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
||||
with torch.no_grad():
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
|
||||
:1, :feature_len, :
|
||||
].transpose(1, 2)
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
|
||||
return feature
|
||||
|
@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
|
||||
|
||||
def synthesize(
|
||||
GPT_model_path,
|
||||
SoVITS_model_path,
|
||||
ref_audio_path,
|
||||
ref_text_path,
|
||||
ref_language,
|
||||
target_text_path,
|
||||
target_language,
|
||||
output_path,
|
||||
):
|
||||
# Read reference text
|
||||
with open(ref_text_path, 'r', encoding='utf-8') as file:
|
||||
with open(ref_text_path, "r", encoding="utf-8") as file:
|
||||
ref_text = file.read()
|
||||
|
||||
# Read target text
|
||||
with open(target_text_path, 'r', encoding='utf-8') as file:
|
||||
with open(target_text_path, "r", encoding="utf-8") as file:
|
||||
target_text = file.read()
|
||||
|
||||
# Change model weights
|
||||
@ -21,11 +31,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
|
||||
# Synthesize audio
|
||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language), top_p=1, temperature=1)
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language),
|
||||
top_p=1,
|
||||
temperature=1,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
print(f"Audio saved to {output_wav_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
|
||||
parser.add_argument('--target_text', required=True, help="Path to the target text file")
|
||||
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument(
|
||||
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
||||
)
|
||||
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
||||
parser.add_argument(
|
||||
"--target_language",
|
||||
required=True,
|
||||
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
||||
help="Language of the target text",
|
||||
)
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
|
||||
synthesize(
|
||||
args.gpt_model,
|
||||
args.sovits_model,
|
||||
args.ref_audio,
|
||||
args.ref_text,
|
||||
args.ref_language,
|
||||
args.target_text,
|
||||
args.target_language,
|
||||
args.output_path,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.setWindowTitle('GPT-SoVITS GUI')
|
||||
self.setWindowTitle("GPT-SoVITS GUI")
|
||||
self.setGeometry(800, 450, 950, 850)
|
||||
|
||||
self.setStyleSheet("""
|
||||
@ -64,8 +65,9 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
""")
|
||||
|
||||
license_text = (
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
||||
)
|
||||
license_label = QLabel(license_text)
|
||||
license_label.setWordWrap(True)
|
||||
|
||||
@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
self.output_text = QTextEdit()
|
||||
self.output_text.setReadOnly(True)
|
||||
|
||||
self.add_drag_drop_events([
|
||||
self.GPT_model_input,
|
||||
self.SoVITS_model_input,
|
||||
self.ref_audio_input,
|
||||
self.ref_text_input,
|
||||
self.target_text_input,
|
||||
self.output_input,
|
||||
])
|
||||
self.add_drag_drop_events(
|
||||
[
|
||||
self.GPT_model_input,
|
||||
self.SoVITS_model_input,
|
||||
self.ref_audio_input,
|
||||
self.ref_text_input,
|
||||
self.target_text_input,
|
||||
self.output_input,
|
||||
]
|
||||
)
|
||||
|
||||
self.synthesize_button = QPushButton("合成")
|
||||
self.synthesize_button.clicked.connect(self.synthesize)
|
||||
@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
def upload_ref_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.ref_text_input.setText(content)
|
||||
|
||||
def upload_target_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.target_text_input.setText(content)
|
||||
|
||||
@ -284,11 +288,13 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
self.SoVITS_Path = SoVITS_model_path
|
||||
|
||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=language_combobox,
|
||||
text=target_text,
|
||||
text_language=target_language_combobox)
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=language_combobox,
|
||||
text=target_text,
|
||||
text_language=target_language_combobox,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
@ -303,7 +309,7 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
self.output_text.append("处理结果:\n" + result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
mainWin = GPTSoVITSGUI()
|
||||
mainWin.show()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,14 +1,19 @@
|
||||
'''
|
||||
"""
|
||||
按中英混合识别
|
||||
按日英混合识别
|
||||
多语种启动切分识别语种
|
||||
全部按中文识别
|
||||
全部按英文识别
|
||||
全部按日文识别
|
||||
'''
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import os, re, logging, json
|
||||
import re
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
@ -20,13 +25,14 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
import torch
|
||||
|
||||
try:
|
||||
import gradio.analytics as analytics
|
||||
analytics.version_check = lambda:None
|
||||
except:...
|
||||
|
||||
analytics.version_check = lambda: None
|
||||
except:
|
||||
...
|
||||
|
||||
|
||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||
@ -41,16 +47,16 @@ gpt_path = os.environ.get("gpt_path", None)
|
||||
sovits_path = os.environ.get("sovits_path", None)
|
||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||
bert_path = os.environ.get("bert_path", None)
|
||||
version=model_version=os.environ.get("version","v2")
|
||||
version = model_version = os.environ.get("version", "v2")
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from inference_webui import DictToAttrRecursive
|
||||
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
@ -67,30 +73,30 @@ else:
|
||||
# device = "cpu"
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language_v2 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("粤语"): "all_yue",#全部按中文识别
|
||||
i18n("韩文"): "all_ko",#全部按韩文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("粤英混合"): "yue",#按粤英混合识别####不变
|
||||
i18n("韩英混合"): "ko",#按韩英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("粤语"): "all_yue", # 全部按中文识别
|
||||
i18n("韩文"): "all_ko", # 全部按韩文识别
|
||||
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
||||
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
|
||||
cut_method = {
|
||||
i18n("不切"):"cut0",
|
||||
i18n("不切"): "cut0",
|
||||
i18n("凑四句一切"): "cut1",
|
||||
i18n("凑50字一切"): "cut2",
|
||||
i18n("按中文句号。切"): "cut3",
|
||||
@ -117,22 +123,33 @@ gpt_path = tts_config.t2s_weights_path
|
||||
sovits_path = tts_config.vits_weights_path
|
||||
version = tts_config.version
|
||||
|
||||
def inference(text, text_lang,
|
||||
ref_audio_path,
|
||||
aux_ref_audio_paths,
|
||||
prompt_text,
|
||||
prompt_lang, top_k,
|
||||
top_p, temperature,
|
||||
text_split_method, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
):
|
||||
|
||||
def inference(
|
||||
text,
|
||||
text_lang,
|
||||
ref_audio_path,
|
||||
aux_ref_audio_paths,
|
||||
prompt_text,
|
||||
prompt_lang,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
text_split_method,
|
||||
batch_size,
|
||||
speed_factor,
|
||||
ref_text_free,
|
||||
split_bucket,
|
||||
fragment_interval,
|
||||
seed,
|
||||
keep_random,
|
||||
parallel_infer,
|
||||
repetition_penalty,
|
||||
sample_steps,
|
||||
super_sampling,
|
||||
):
|
||||
seed = -1 if keep_random else seed
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
||||
inputs={
|
||||
inputs = {
|
||||
"text": text,
|
||||
"text_lang": dict_language[text_lang],
|
||||
"ref_audio_path": ref_audio_path,
|
||||
@ -143,12 +160,12 @@ def inference(text, text_lang,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": cut_method[text_split_method],
|
||||
"batch_size":int(batch_size),
|
||||
"speed_factor":float(speed_factor),
|
||||
"split_bucket":split_bucket,
|
||||
"return_fragment":False,
|
||||
"fragment_interval":fragment_interval,
|
||||
"seed":actual_seed,
|
||||
"batch_size": int(batch_size),
|
||||
"speed_factor": float(speed_factor),
|
||||
"split_bucket": split_bucket,
|
||||
"return_fragment": False,
|
||||
"fragment_interval": fragment_interval,
|
||||
"seed": actual_seed,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"sample_steps": int(sample_steps),
|
||||
@ -158,11 +175,12 @@ def inference(text, text_lang,
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
except NO_PROMPT_ERROR:
|
||||
gr.Warning(i18n('V3不支持无参考文本模式,请填写参考文本!'))
|
||||
gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
|
||||
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
parts = re.split('(\d+)', s)
|
||||
parts = re.split("(\d+)", s)
|
||||
# 将数字部分转换为整数,非数字部分保持不变
|
||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||
return parts
|
||||
@ -170,125 +188,193 @@ def custom_sort_key(s):
|
||||
|
||||
def change_choices():
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
|
||||
"choices": sorted(GPT_names, key=custom_sort_key),
|
||||
"__type__": "update",
|
||||
}
|
||||
|
||||
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
|
||||
|
||||
_ =[[],[]]
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
pretrained_sovits_name = [
|
||||
"GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||
path_sovits_v3,
|
||||
]
|
||||
pretrained_gpt_name = [
|
||||
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||
]
|
||||
|
||||
_ = [[], []]
|
||||
for i in range(3):
|
||||
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name,pretrained_sovits_name = _
|
||||
if os.path.exists(pretrained_gpt_name[i]):
|
||||
_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):
|
||||
_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name, pretrained_sovits_name = _
|
||||
|
||||
|
||||
if os.path.exists(f"./weight.json"):
|
||||
if os.path.exists("./weight.json"):
|
||||
pass
|
||||
else:
|
||||
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
|
||||
with open("./weight.json", "w", encoding="utf-8") as file:
|
||||
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
||||
|
||||
with open(f"./weight.json", 'r', encoding="utf-8") as file:
|
||||
with open("./weight.json", "r", encoding="utf-8") as file:
|
||||
weight_data = file.read()
|
||||
weight_data=json.loads(weight_data)
|
||||
gpt_path = os.environ.get(
|
||||
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
|
||||
sovits_path = os.environ.get(
|
||||
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
|
||||
if isinstance(gpt_path,list):
|
||||
weight_data = json.loads(weight_data)
|
||||
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
|
||||
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
|
||||
if isinstance(gpt_path, list):
|
||||
gpt_path = gpt_path[0]
|
||||
if isinstance(sovits_path,list):
|
||||
if isinstance(sovits_path, list):
|
||||
sovits_path = sovits_path[0]
|
||||
|
||||
|
||||
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
|
||||
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
|
||||
for path in SoVITS_weight_root + GPT_weight_root:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
|
||||
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
|
||||
for path in SoVITS_weight_root+GPT_weight_root:
|
||||
os.makedirs(path,exist_ok=True)
|
||||
|
||||
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||
SoVITS_names = [i for i in pretrained_sovits_name]
|
||||
for path in SoVITS_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
|
||||
if name.endswith(".pth"):
|
||||
SoVITS_names.append("%s/%s" % (path, name))
|
||||
GPT_names = [i for i in pretrained_gpt_name]
|
||||
for path in GPT_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
|
||||
if name.endswith(".ckpt"):
|
||||
GPT_names.append("%s/%s" % (path, name))
|
||||
return SoVITS_names, GPT_names
|
||||
|
||||
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
|
||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
global version, model_version, dict_language,if_lora_v3
|
||||
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
|
||||
from process_ckpt import get_sovits_version_from_path_fast
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
global version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
# print(sovits_path,version, model_version, if_lora_v3)
|
||||
if if_lora_v3 and not os.path.exists(path_sovits_v3):
|
||||
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
|
||||
prompt_text_update, prompt_language_update = (
|
||||
{"__type__": "update"},
|
||||
{"__type__": "update", "value": prompt_language},
|
||||
)
|
||||
else:
|
||||
prompt_text_update = {'__type__':'update', 'value':''}
|
||||
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
prompt_text_update = {"__type__": "update", "value": ""}
|
||||
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
if text_language in list(dict_language.keys()):
|
||||
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
|
||||
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
if model_version=="v3":
|
||||
visible_sample_steps=True
|
||||
visible_inp_refs=False
|
||||
text_update = {"__type__": "update", "value": ""}
|
||||
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
if model_version == "v3":
|
||||
visible_sample_steps = True
|
||||
visible_inp_refs = False
|
||||
else:
|
||||
visible_sample_steps=False
|
||||
visible_inp_refs=True
|
||||
#prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free,
|
||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "interactive": visible_sample_steps,"value":32},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "interactive": True if model_version!="v3"else False},{"__type__": "update", "value":i18n("模型加载中,请等待"),"interactive":False}
|
||||
visible_sample_steps = False
|
||||
visible_inp_refs = True
|
||||
# prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free,
|
||||
yield (
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "interactive": True if model_version != "v3" else False},
|
||||
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
||||
)
|
||||
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "interactive": visible_sample_steps,"value":32},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "interactive": True if model_version!="v3"else False},{"__type__": "update", "value":i18n("合成语音"),"interactive":True}
|
||||
with open("./weight.json")as f:
|
||||
data=f.read()
|
||||
data=json.loads(data)
|
||||
data["SoVITS"][version]=sovits_path
|
||||
with open("./weight.json","w")as f:f.write(json.dumps(data))
|
||||
yield (
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "interactive": True if model_version != "v3" else False},
|
||||
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
||||
)
|
||||
with open("./weight.json") as f:
|
||||
data = f.read()
|
||||
data = json.loads(data)
|
||||
data["SoVITS"][version] = sovits_path
|
||||
with open("./weight.json", "w") as f:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
gr.Markdown(
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
||||
+ "<br>"
|
||||
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
# with gr.Group():
|
||||
gr.Markdown(value=i18n("模型切换"))
|
||||
with gr.Row():
|
||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
||||
GPT_dropdown = gr.Dropdown(
|
||||
label=i18n("GPT模型列表"),
|
||||
choices=sorted(GPT_names, key=custom_sort_key),
|
||||
value=gpt_path,
|
||||
interactive=True,
|
||||
)
|
||||
SoVITS_dropdown = gr.Dropdown(
|
||||
label=i18n("SoVITS模型列表"),
|
||||
choices=sorted(SoVITS_names, key=custom_sort_key),
|
||||
value=sovits_path,
|
||||
interactive=True,
|
||||
)
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||
with gr.Row():
|
||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple", visible=True if model_version!="v3"else False)
|
||||
inp_refs = gr.File(
|
||||
label=i18n("辅参考音频(可选多个,或不选)"),
|
||||
file_count="multiple",
|
||||
visible=True if model_version != "v3" else False,
|
||||
)
|
||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||
with gr.Row():
|
||||
prompt_language = gr.Dropdown(
|
||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
with gr.Column():
|
||||
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True if model_version!="v3"else False, show_label=True)
|
||||
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"<br>"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。"))
|
||||
ref_text_free = gr.Checkbox(
|
||||
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
|
||||
value=False,
|
||||
interactive=True if model_version != "v3" else False,
|
||||
show_label=True,
|
||||
)
|
||||
gr.Markdown(
|
||||
i18n("使用无参考文本模式时建议使用微调的GPT")
|
||||
+ "<br>"
|
||||
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
||||
@ -297,42 +383,66 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("推理设置"))
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
|
||||
batch_size = gr.Slider(
|
||||
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||
)
|
||||
sample_steps = gr.Radio(
|
||||
label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True
|
||||
)
|
||||
with gr.Row():
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
|
||||
fragment_interval = gr.Slider(
|
||||
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
||||
)
|
||||
speed_factor = gr.Slider(
|
||||
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||
)
|
||||
with gr.Row():
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
temperature = gr.Slider(
|
||||
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
||||
)
|
||||
repetition_penalty = gr.Slider(
|
||||
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
how_to_cut = gr.Dropdown(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True, scale=1
|
||||
)
|
||||
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
|
||||
label=i18n("怎么切"),
|
||||
choices=[
|
||||
i18n("不切"),
|
||||
i18n("凑四句一切"),
|
||||
i18n("凑50字一切"),
|
||||
i18n("按中文句号。切"),
|
||||
i18n("按英文句号.切"),
|
||||
i18n("按标点符号切"),
|
||||
],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
scale=1,
|
||||
)
|
||||
super_sampling = gr.Checkbox(
|
||||
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(
|
||||
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
seed = gr.Number(label=i18n("随机种子"),value=-1)
|
||||
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
@ -340,40 +450,78 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||
|
||||
|
||||
inference_button.click(
|
||||
inference,
|
||||
[
|
||||
text,text_language, inp_ref, inp_refs,
|
||||
prompt_text, prompt_language,
|
||||
top_k, top_p, temperature,
|
||||
how_to_cut, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty, sample_steps, super_sampling,
|
||||
],
|
||||
text,
|
||||
text_language,
|
||||
inp_ref,
|
||||
inp_refs,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
how_to_cut,
|
||||
batch_size,
|
||||
speed_factor,
|
||||
ref_text_free,
|
||||
split_bucket,
|
||||
fragment_interval,
|
||||
seed,
|
||||
keep_random,
|
||||
parallel_infer,
|
||||
repetition_penalty,
|
||||
sample_steps,
|
||||
super_sampling,
|
||||
],
|
||||
[output, seed],
|
||||
)
|
||||
stop_infer.click(tts_pipeline.stop, [], [])
|
||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs,ref_text_free,inference_button])#
|
||||
SoVITS_dropdown.change(
|
||||
change_sovits_weights,
|
||||
[SoVITS_dropdown, prompt_language, text_language],
|
||||
[
|
||||
prompt_language,
|
||||
text_language,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
sample_steps,
|
||||
inp_refs,
|
||||
ref_text_free,
|
||||
inference_button,
|
||||
],
|
||||
) #
|
||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
||||
gr.Markdown(
|
||||
value=i18n(
|
||||
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||
with gr.Column():
|
||||
_how_to_cut = gr.Radio(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
cut_text= gr.Button(i18n("切分"), variant="primary")
|
||||
label=i18n("怎么切"),
|
||||
choices=[
|
||||
i18n("不切"),
|
||||
i18n("凑四句一切"),
|
||||
i18n("凑50字一切"),
|
||||
i18n("按中文句号。切"),
|
||||
i18n("按英文句号.切"),
|
||||
i18n("按标点符号切"),
|
||||
],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
cut_text = gr.Button(i18n("切分"), variant="primary")
|
||||
|
||||
def to_cut(text_inp, how_to_cut):
|
||||
if len(text_inp.strip()) == 0 or text_inp==[]:
|
||||
if len(text_inp.strip()) == 0 or text_inp == []:
|
||||
return ""
|
||||
method = get_method(cut_method[how_to_cut])
|
||||
return method(text_inp)
|
||||
@ -382,8 +530,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
||||
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.queue().launch(#concurrency_count=511, max_size=1022
|
||||
if __name__ == "__main__":
|
||||
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
share=is_share,
|
||||
|
@ -18,7 +18,7 @@ class Encoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
window_size=4,
|
||||
isflow=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -56,9 +56,7 @@ class Encoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
if isflow:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
@ -74,9 +72,7 @@ class Encoder(nn.Module):
|
||||
x = self.cond_pre(x)
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
||||
)
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
@ -99,7 +95,7 @@ class Decoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -131,9 +127,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||
self.encdec_attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||
)
|
||||
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
@ -153,9 +147,7 @@ class Decoder(nn.Module):
|
||||
x: decoder input
|
||||
h: encoder output
|
||||
"""
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
@ -211,14 +203,8 @@ class MultiHeadAttention(nn.Module):
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
@ -247,46 +233,28 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||
if self.window_size is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Relative attention is only available for self-attention."
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(
|
||||
query / math.sqrt(self.k_channels), key_relative_embeddings
|
||||
)
|
||||
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Local attention is only available for self-attention."
|
||||
block_mask = (
|
||||
torch.ones_like(scores)
|
||||
.triu(-self.block_length)
|
||||
.tril(self.block_length)
|
||||
)
|
||||
assert t_s == t_t, "Local attention is only available for self-attention."
|
||||
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_v, t_s
|
||||
)
|
||||
output = output + self._matmul_with_relative_values(
|
||||
relative_weights, value_relative_embeddings
|
||||
)
|
||||
output = (
|
||||
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
@ -320,9 +288,7 @@ class MultiHeadAttention(nn.Module):
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
@ -336,14 +302,10 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
@ -353,9 +315,7 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
@ -537,9 +497,7 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
|
||||
|
||||
|
||||
def weight_norm_modules(module, name="weight", dim=0):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
||||
module, Depthwise_Separable_TransposeConv1D
|
||||
):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||
module.weight_norm()
|
||||
return module
|
||||
else:
|
||||
@ -547,9 +505,7 @@ def weight_norm_modules(module, name="weight", dim=0):
|
||||
|
||||
|
||||
def remove_weight_norm_modules(module, name="weight"):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
||||
module, Depthwise_Separable_TransposeConv1D
|
||||
):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||
module.remove_weight_norm()
|
||||
else:
|
||||
remove_weight_norm(module, name)
|
||||
@ -567,7 +523,7 @@ class FFT(nn.Module):
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
isflow=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -579,9 +535,7 @@ class FFT(nn.Module):
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
if isflow:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
@ -622,18 +576,14 @@ class FFT(nn.Module):
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
if g is not None:
|
||||
x = self.cond_pre(x)
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
||||
)
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_0[i](x + y)
|
||||
|
@ -7,6 +7,7 @@ from module import commons
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
@ -43,7 +44,7 @@ class Encoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
window_size=4,
|
||||
isflow=True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@ -65,13 +66,9 @@ class Encoder(nn.Module):
|
||||
if self.gin_channels != 0:
|
||||
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
||||
# vits2 says 3rd block, so idx is 2 by default
|
||||
self.cond_layer_idx = (
|
||||
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||
)
|
||||
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||
logging.debug(self.gin_channels, self.cond_layer_idx)
|
||||
assert (
|
||||
self.cond_layer_idx < self.n_layers
|
||||
), "cond_layer_idx should be less than n_layers"
|
||||
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
@ -121,7 +118,9 @@ class Encoder(nn.Module):
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
|
||||
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
|
||||
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
|
||||
):
|
||||
y = attn_layers(x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = norm_layers_1(x + y)
|
||||
@ -170,14 +169,8 @@ class MultiHeadAttention(nn.Module):
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
@ -187,7 +180,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
|
||||
def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
@ -198,7 +191,7 @@ class MultiHeadAttention(nn.Module):
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
|
||||
def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, _ = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||
@ -224,7 +217,7 @@ class MultiHeadAttention(nn.Module):
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
|
||||
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, -1)
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
@ -248,19 +241,17 @@ class MultiHeadAttention(nn.Module):
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
|
||||
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
|
||||
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
|
||||
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
|
||||
pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
|
||||
pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
|
||||
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
|
||||
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
|
||||
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||
)
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
@ -274,14 +265,10 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
@ -291,9 +278,7 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
@ -395,12 +380,6 @@ class MRTE(nn.Module):
|
||||
|
||||
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
@ -28,9 +28,7 @@ def intersperse(lst, item):
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += (
|
||||
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
)
|
||||
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
@ -67,9 +65,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
||||
num_timescales - 1
|
||||
)
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
|
@ -30,6 +30,7 @@
|
||||
# SOFTWARE.
|
||||
|
||||
"""Core vector quantization implementation."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@ -121,9 +122,7 @@ class EuclideanCodebook(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
||||
uniform_init if not kmeans_init else torch.zeros
|
||||
)
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
@ -151,9 +150,7 @@ class EuclideanCodebook(nn.Module):
|
||||
# broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
@ -174,11 +171,7 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
@ -222,8 +215,7 @@ class EuclideanCodebook(nn.Module):
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
@ -264,12 +256,8 @@ class VectorQuantization(nn.Module):
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
@ -330,13 +318,9 @@ class ResidualVectorQuantization(nn.Module):
|
||||
|
||||
def __init__(self, *, num_quantizers, **kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
||||
|
||||
def forward(
|
||||
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
||||
):
|
||||
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
@ -359,9 +343,7 @@ class ResidualVectorQuantization(nn.Module):
|
||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||
return quantized_out, out_indices, out_losses, out_quantized
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
|
@ -1,24 +1,18 @@
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from tqdm import tqdm
|
||||
|
||||
from module import commons
|
||||
from module.mel_processing import spectrogram_torch,spec_to_mel_torch
|
||||
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
||||
from text import cleaned_text_to_sequence
|
||||
from utils import load_wav_to_torch, load_filepaths_and_text
|
||||
import torch.nn.functional as F
|
||||
from functools import lru_cache
|
||||
import requests
|
||||
from scipy.io import wavfile
|
||||
from io import BytesIO
|
||||
from tools.my_utils import load_audio
|
||||
version = os.environ.get('version',None)
|
||||
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
|
||||
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
||||
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
@ -43,7 +37,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@ -51,7 +45,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@ -76,7 +70,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@ -111,7 +105,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@ -129,8 +123,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
return spec, audio_norm
|
||||
|
||||
@ -146,12 +141,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
return len(self.audiopaths_sid_text)
|
||||
|
||||
def random_slice(self, ssl, wav, mel):
|
||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
||||
"first", ssl.shape, wav.shape)
|
||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
|
||||
|
||||
len_mel = mel.shape[1]
|
||||
if self.val:
|
||||
reference_mel = mel[:, :len_mel // 3]
|
||||
reference_mel = mel[:, : len_mel // 3]
|
||||
return reference_mel, ssl, wav, mel
|
||||
dir = random.randint(0, 1)
|
||||
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
||||
@ -159,20 +153,29 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
if dir == 0:
|
||||
reference_mel = mel[:, :sep_point]
|
||||
ssl = ssl[:, :, sep_point:]
|
||||
wav2 = wav[:, sep_point * self.hop_length:]
|
||||
wav2 = wav[:, sep_point * self.hop_length :]
|
||||
mel = mel[:, sep_point:]
|
||||
else:
|
||||
reference_mel = mel[:, sep_point:]
|
||||
ssl = ssl[:, :, :sep_point]
|
||||
wav2 = wav[:, :sep_point * self.hop_length]
|
||||
wav2 = wav[:, : sep_point * self.hop_length]
|
||||
mel = mel[:, :sep_point]
|
||||
|
||||
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
||||
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
||||
ssl.shape,
|
||||
wav.shape,
|
||||
wav2.shape,
|
||||
mel.shape,
|
||||
sep_point,
|
||||
self.hop_length,
|
||||
sep_point * self.hop_length,
|
||||
dir,
|
||||
)
|
||||
return reference_mel, ssl, wav2, mel
|
||||
class TextAudioSpeakerCollate():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@ -184,9 +187,7 @@ class TextAudioSpeakerCollate():
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
||||
@ -214,22 +215,24 @@ class TextAudioSpeakerCollate():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wav = row[2]
|
||||
wav_padded[i, :, :wav.size(1)] = wav
|
||||
wav_padded[i, :, : wav.size(1)] = wav
|
||||
wav_lengths[i] = wav.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
@ -253,7 +256,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@ -261,7 +264,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@ -286,7 +289,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@ -313,15 +316,16 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
self.spec_min=-12
|
||||
self.spec_max=2
|
||||
self.spec_min = -12
|
||||
self.spec_max = 2
|
||||
|
||||
self.filter_length_mel = self.win_length_mel = 1024
|
||||
self.hop_length_mel = 256
|
||||
self.n_mel_channels = 100
|
||||
self.sampling_rate_mel = 24000
|
||||
self.mel_fmin = 0
|
||||
self.mel_fmax = None
|
||||
|
||||
self.filter_length_mel=self.win_length_mel=1024
|
||||
self.hop_length_mel=256
|
||||
self.n_mel_channels=100
|
||||
self.sampling_rate_mel=24000
|
||||
self.mel_fmin=0
|
||||
self.mel_fmax=None
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
@ -332,7 +336,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@ -347,25 +351,35 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
return (ssl, spec, mel, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio=torch.FloatTensor(audio_array)#/32768
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24=torch.FloatTensor(audio_array24)#/32768
|
||||
audio_array24 = load_audio(
|
||||
filename, 24000
|
||||
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||
audio_norm24 = audio24
|
||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
|
||||
|
||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
||||
spec1 = spectrogram_torch(
|
||||
audio_norm24,
|
||||
self.filter_length_mel,
|
||||
self.sampling_rate_mel,
|
||||
self.hop_length_mel,
|
||||
self.win_length_mel,
|
||||
center=False,
|
||||
)
|
||||
mel = spec_to_mel_torch(
|
||||
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||
)
|
||||
mel = torch.squeeze(mel, 0)
|
||||
mel=self.norm_spec(mel)
|
||||
mel = self.norm_spec(mel)
|
||||
# print(1111111,spec.shape,mel.shape)
|
||||
return spec, mel
|
||||
|
||||
@ -379,9 +393,10 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
class TextAudioSpeakerCollateV3():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollateV3:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@ -392,12 +407,10 @@ class TextAudioSpeakerCollateV3():
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
#ssl, spec, wav,mel, text
|
||||
# ssl, spec, wav,mel, text
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
#(ssl, spec,mel, text)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
# (ssl, spec,mel, text)
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
|
||||
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
||||
@ -411,7 +424,7 @@ class TextAudioSpeakerCollateV3():
|
||||
# max_wav_len = max([x[2].size(1) for x in batch])
|
||||
|
||||
max_text_len = max([x[3].size(0) for x in batch])
|
||||
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
|
||||
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
||||
|
||||
ssl_lengths = torch.LongTensor(len(batch))
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
@ -422,7 +435,7 @@ class TextAudioSpeakerCollateV3():
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
||||
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len)
|
||||
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
||||
|
||||
spec_padded.zero_()
|
||||
@ -435,11 +448,11 @@ class TextAudioSpeakerCollateV3():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
# ssl, spec, wav,mel, text
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
# wav = row[2]
|
||||
@ -447,15 +460,17 @@ class TextAudioSpeakerCollateV3():
|
||||
# wav_lengths[i] = wav.size(1)
|
||||
|
||||
mel = row[2]
|
||||
mel_padded[i, :, :mel.size(1)] = mel
|
||||
mel_padded[i, :, : mel.size(1)] = mel
|
||||
mel_lengths[i] = mel.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
||||
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
@ -479,7 +494,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@ -487,7 +502,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@ -512,7 +527,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@ -539,15 +554,16 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
self.spec_min=-12
|
||||
self.spec_max=2
|
||||
self.spec_min = -12
|
||||
self.spec_max = 2
|
||||
|
||||
self.filter_length_mel = self.win_length_mel = 1024
|
||||
self.hop_length_mel = 256
|
||||
self.n_mel_channels = 100
|
||||
self.sampling_rate_mel = 24000
|
||||
self.mel_fmin = 0
|
||||
self.mel_fmax = None
|
||||
|
||||
self.filter_length_mel=self.win_length_mel=1024
|
||||
self.hop_length_mel=256
|
||||
self.n_mel_channels=100
|
||||
self.sampling_rate_mel=24000
|
||||
self.mel_fmin=0
|
||||
self.mel_fmax=None
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
@ -555,10 +571,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
audiopath, phoneme_ids = audiopath_sid_text
|
||||
text = torch.FloatTensor(phoneme_ids)
|
||||
try:
|
||||
spec, mel,wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@ -573,27 +589,37 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
return (ssl, spec, wav, mel, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio=torch.FloatTensor(audio_array)#/32768
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24=torch.FloatTensor(audio_array24)#/32768
|
||||
audio_array24 = load_audio(
|
||||
filename, 24000
|
||||
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||
audio_norm24 = audio24
|
||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
|
||||
|
||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
||||
spec1 = spectrogram_torch(
|
||||
audio_norm24,
|
||||
self.filter_length_mel,
|
||||
self.sampling_rate_mel,
|
||||
self.hop_length_mel,
|
||||
self.win_length_mel,
|
||||
center=False,
|
||||
)
|
||||
mel = spec_to_mel_torch(
|
||||
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||
)
|
||||
mel = torch.squeeze(mel, 0)
|
||||
mel=self.norm_spec(mel)
|
||||
mel = self.norm_spec(mel)
|
||||
# print(1111111,spec.shape,mel.shape)
|
||||
return spec, mel,audio_norm
|
||||
return spec, mel, audio_norm
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
@ -605,9 +631,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
class TextAudioSpeakerCollateV3b():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollateV3b:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@ -618,12 +645,10 @@ class TextAudioSpeakerCollateV3b():
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
#ssl, spec, wav,mel, text
|
||||
# ssl, spec, wav,mel, text
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
#(ssl, spec,mel, text)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
# (ssl, spec,mel, text)
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
|
||||
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
||||
@ -636,7 +661,7 @@ class TextAudioSpeakerCollateV3b():
|
||||
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
||||
max_wav_len = max([x[2].size(1) for x in batch])
|
||||
max_text_len = max([x[4].size(0) for x in batch])
|
||||
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
|
||||
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
||||
|
||||
ssl_lengths = torch.LongTensor(len(batch))
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
@ -647,7 +672,7 @@ class TextAudioSpeakerCollateV3b():
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
||||
mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len)
|
||||
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
||||
|
||||
spec_padded.zero_()
|
||||
@ -660,28 +685,40 @@ class TextAudioSpeakerCollateV3b():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
# ssl, spec, wav,mel, text
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wav = row[2]
|
||||
wav_padded[i, :, :wav.size(1)] = wav
|
||||
wav_padded[i, :, : wav.size(1)] = wav
|
||||
wav_lengths[i] = wav.size(1)
|
||||
|
||||
mel = row[3]
|
||||
mel_padded[i, :, :mel.size(1)] = mel
|
||||
mel_padded[i, :, : mel.size(1)] = mel
|
||||
mel_lengths[i] = mel.size(1)
|
||||
|
||||
text = row[4]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||
return (
|
||||
ssl_padded,
|
||||
spec_padded,
|
||||
mel_padded,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text_padded,
|
||||
text_lengths,
|
||||
wav_padded,
|
||||
wav_lengths,
|
||||
mel_lengths,
|
||||
)
|
||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
||||
|
||||
|
||||
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
"""
|
||||
Maintain similar input lengths in a batch.
|
||||
@ -745,12 +782,12 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
num_samples_bucket = self.num_samples_per_bucket[i]
|
||||
|
||||
rem = num_samples_bucket - len_bucket
|
||||
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
|
||||
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
|
||||
|
||||
ids_bucket = ids_bucket[self.rank::self.num_replicas]
|
||||
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
||||
|
||||
for j in range(len(ids_bucket) // self.batch_size):
|
||||
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
|
||||
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
|
||||
batches.append(batch)
|
||||
|
||||
if self.shuffle:
|
||||
|
@ -1,7 +1,6 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
@ -66,8 +65,6 @@ def mle_loss(z, m, logs, logdet, mask):
|
||||
torch.exp(-2 * logs) * ((z - m) ** 2)
|
||||
) # neg normal likelihood w/o the constant term
|
||||
l = l - torch.sum(logdet) # log jacobian determinant
|
||||
l = l / torch.sum(
|
||||
torch.ones_like(z) * mask
|
||||
) # averaging across batch, channel and time axes
|
||||
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
|
||||
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
||||
return l
|
||||
|
@ -1,16 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
import librosa
|
||||
import librosa.util as librosa_util
|
||||
from librosa.util import normalize, pad_center, tiny
|
||||
from scipy.signal import get_window
|
||||
from scipy.io.wavfile import read
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
@ -58,9 +47,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
@ -90,20 +77,14 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=spec.dtype, device=spec.device
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
@ -114,16 +95,10 @@ def mel_spectrogram_torch(
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
|
@ -1,9 +1,7 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -13,16 +11,18 @@ from module import commons
|
||||
from module import modules
|
||||
from module import attentions
|
||||
from f5_tts.model import DiT
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.mrte_model import MRTE
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
from torch.cuda.amp import autocast
|
||||
import contextlib,random
|
||||
import contextlib
|
||||
import random
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
@ -48,29 +48,21 @@ class StochasticDurationPredictor(nn.Module):
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
@ -91,10 +83,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = (
|
||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* x_mask
|
||||
)
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
@ -102,13 +91,8 @@ class StochasticDurationPredictor(nn.Module):
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
@ -117,18 +101,12 @@ class StochasticDurationPredictor(nn.Module):
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = (
|
||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* noise_scale
|
||||
)
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
@ -137,9 +115,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
||||
):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
@ -149,13 +125,9 @@ class DurationPredictor(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@ -190,7 +162,7 @@ class TextEncoder(nn.Module):
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
latent_channels=192,
|
||||
version = "v2",
|
||||
version="v2",
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
@ -237,26 +209,22 @@ class TextEncoder(nn.Module):
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
|
||||
text_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(text_lengths, text.size(1)), 1
|
||||
).to(y.dtype)
|
||||
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
|
||||
if test == 1:
|
||||
text[:, :] = 0
|
||||
text = self.text_embedding(text).transpose(1, 2)
|
||||
text = self.encoder_text(text * text_mask, text_mask)
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
if(speed!=1):
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
stats = self.proj(y) * y_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
@ -360,9 +328,7 @@ class PosteriorEncoder(nn.Module):
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
@ -372,14 +338,9 @@ class PosteriorEncoder(nn.Module):
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@ -394,7 +355,7 @@ class Encoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if(g!=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
@ -402,6 +363,7 @@ class Encoder(nn.Module):
|
||||
stats = self.proj(x) * x_mask
|
||||
return stats, x_mask
|
||||
|
||||
|
||||
class WNEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -434,9 +396,7 @@ class WNEncoder(nn.Module):
|
||||
self.norm = modules.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
out = self.proj(x) * x_mask
|
||||
@ -459,9 +419,7 @@ class Generator(torch.nn.Module):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(
|
||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||
)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
@ -481,9 +439,7 @@ class Generator(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
@ -636,9 +592,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [
|
||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||
]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
@ -738,10 +692,7 @@ class Quantizer(torch.nn.Module):
|
||||
super(Quantizer, self).__init__()
|
||||
assert embed_dim % n_code_groups == 0
|
||||
self.quantizer_modules = nn.ModuleList(
|
||||
[
|
||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
||||
for _ in range(n_code_groups)
|
||||
]
|
||||
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
||||
)
|
||||
self.n_code_groups = n_code_groups
|
||||
self.embed_dim = embed_dim
|
||||
@ -759,9 +710,7 @@ class Quantizer(torch.nn.Module):
|
||||
z_q.append(_z_q)
|
||||
min_indicies.append(_min_indicies) # B * T,
|
||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
||||
(z_q - xin.detach()) ** 2
|
||||
)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||
z_q = xin + (z_q - xin).detach()
|
||||
z_q = z_q.transpose(1, 2)
|
||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||
@ -801,13 +750,9 @@ class CodePredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
ssl_dim, style_vector_dim=hidden_channels
|
||||
)
|
||||
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||
|
||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||
self.n_q = n_q
|
||||
@ -820,9 +765,7 @@ class CodePredictor(nn.Module):
|
||||
x = x + g
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
x = self.out_proj(x * x_mask) * x_mask
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
||||
2, 3
|
||||
)
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||
target = codes[1:].transpose(0, 1)
|
||||
if not infer:
|
||||
logits = logits.reshape(-1, self.dims)
|
||||
@ -870,8 +813,8 @@ class SynthesizerTrn(nn.Module):
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version = "v2",
|
||||
**kwargs
|
||||
version="v2",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
@ -902,7 +845,7 @@ class SynthesizerTrn(nn.Module):
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
version = version,
|
||||
version=version,
|
||||
)
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
@ -923,12 +866,10 @@ class SynthesizerTrn(nn.Module):
|
||||
16,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
# self.version=os.environ.get("version","v1")
|
||||
if(self.version=="v1"):
|
||||
if self.version == "v1":
|
||||
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||
else:
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
||||
@ -945,13 +886,11 @@ class SynthesizerTrn(nn.Module):
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
|
||||
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
if(self.version=="v1"):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
else:
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
with autocast(enabled=False):
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
with maybe_no_grad:
|
||||
@ -959,24 +898,16 @@ class SynthesizerTrn(nn.Module):
|
||||
self.ssl_proj.eval()
|
||||
self.quantizer.eval()
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
|
||||
z_slice, ids_slice = commons.rand_slice_segments(
|
||||
z, y_lengths, self.segment_size
|
||||
)
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=ge)
|
||||
return (
|
||||
o,
|
||||
@ -989,24 +920,18 @@ class SynthesizerTrn(nn.Module):
|
||||
)
|
||||
|
||||
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
||||
y.dtype
|
||||
)
|
||||
if(self.version=="v1"):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
else:
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge, test=test
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -1015,39 +940,34 @@ class SynthesizerTrn(nn.Module):
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
def get_ge(refer):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(
|
||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
||||
).to(refer.dtype)
|
||||
if (self.version == "v1"):
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
return ge
|
||||
if(type(refer)==list):
|
||||
ges=[]
|
||||
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for _refer in refer:
|
||||
ge=get_ge(_refer)
|
||||
ge = get_ge(_refer)
|
||||
ges.append(ge)
|
||||
ge=torch.stack(ges,0).mean(0)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
else:
|
||||
ge=get_ge(refer)
|
||||
ge = get_ge(refer)
|
||||
|
||||
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(
|
||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, y_lengths, text, text_lengths, ge,speed
|
||||
)
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -1059,11 +979,10 @@ class SynthesizerTrn(nn.Module):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class CFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,dit
|
||||
):
|
||||
def __init__(self, in_channels, dit):
|
||||
super().__init__()
|
||||
self.sigma_min = 1e-6
|
||||
|
||||
@ -1077,41 +996,54 @@ class CFM(torch.nn.Module):
|
||||
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
|
||||
"""Forward diffusion"""
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
||||
x[..., :prompt_len] = 0
|
||||
mu=mu.transpose(2,1)
|
||||
mu = mu.transpose(2, 1)
|
||||
t = 0
|
||||
d = 1 / n_timesteps
|
||||
for j in range(n_timesteps):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
||||
if inference_cfg_rate>1e-5:
|
||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||
v_pred = self.estimator(
|
||||
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
|
||||
).transpose(2, 1)
|
||||
if inference_cfg_rate > 1e-5:
|
||||
neg = self.estimator(
|
||||
x,
|
||||
prompt_x,
|
||||
x_lens,
|
||||
t_tensor,
|
||||
d_tensor,
|
||||
mu,
|
||||
use_grad_ckpt=False,
|
||||
drop_audio_cond=True,
|
||||
drop_text=True,
|
||||
).transpose(2, 1)
|
||||
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
|
||||
x = x + d * v_pred
|
||||
t = t + d
|
||||
x[:, :, :prompt_len] = 0
|
||||
return x
|
||||
|
||||
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
||||
b, _, t = x1.shape
|
||||
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
|
||||
x0 = torch.randn_like(x1,device=mu.device)
|
||||
x0 = torch.randn_like(x1, device=mu.device)
|
||||
vt = x1 - x0
|
||||
xt = x0 + t[:, None, None] * vt
|
||||
dt = torch.zeros_like(t,device=mu.device)
|
||||
dt = torch.zeros_like(t, device=mu.device)
|
||||
prompt = torch.zeros_like(x1)
|
||||
for i in range(b):
|
||||
prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]]
|
||||
xt[i, :, :prompt_lens[i]] = 0
|
||||
gailv=0.3# if ttime()>1736250488 else 0.1
|
||||
prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]]
|
||||
xt[i, :, : prompt_lens[i]] = 0
|
||||
gailv = 0.3 # if ttime()>1736250488 else 0.1
|
||||
if random.random() < gailv:
|
||||
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
|
||||
d = 1/torch.pow(2, base)
|
||||
d = 1 / torch.pow(2, base)
|
||||
d_input = d.clone()
|
||||
d_input[d_input < 1e-2] = 0
|
||||
# with torch.no_grad():
|
||||
@ -1119,52 +1051,55 @@ class CFM(torch.nn.Module):
|
||||
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
|
||||
x_mid = xt + d[:, None, None] * v_pred_1
|
||||
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||
vt = (v_pred_1 + v_pred_2) / 2
|
||||
vt = vt.detach()
|
||||
dt = 2*d
|
||||
dt = 2 * d
|
||||
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
|
||||
vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1)
|
||||
loss = 0
|
||||
for i in range(b):
|
||||
loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]])
|
||||
loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]])
|
||||
loss /= b
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def set_no_grad(net_g):
|
||||
for name, param in net_g.named_parameters():
|
||||
param.requires_grad=False
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
class SynthesizerTrnV3(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@ -1185,132 +1120,133 @@ class SynthesizerTrnV3(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.model_dim=512
|
||||
self.model_dim = 512
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
# gin_channels=gin_channels)
|
||||
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == '25hz':
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=ssl_dim,
|
||||
n_q=1,
|
||||
bins=1024
|
||||
)
|
||||
self.freeze_quantizer=freeze_quantizer
|
||||
inter_channels2=512
|
||||
self.bridge=nn.Sequential(
|
||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
||||
nn.LeakyReLU()
|
||||
)
|
||||
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
if self.freeze_quantizer==True:
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
inter_channels2 = 512
|
||||
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||
self.cfm = CFM(
|
||||
100,
|
||||
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||
) # text_dim is condition feature dim
|
||||
if self.freeze_quantizer == True:
|
||||
set_no_grad(self.ssl_proj)
|
||||
set_no_grad(self.quantizer)
|
||||
set_no_grad(self.enc_p)
|
||||
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
|
||||
def forward(
|
||||
self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
|
||||
): # ssl_lengths no need now
|
||||
with autocast(enabled=False):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
with maybe_no_grad:
|
||||
if self.freeze_quantizer:
|
||||
self.ssl_proj.eval()#
|
||||
self.ssl_proj.eval() #
|
||||
self.quantizer.eval()
|
||||
self.enc_p.eval()
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
)
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
||||
B=ssl.shape[0]
|
||||
prompt_len_max = mel_lengths*2/3
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
fea, y_mask_ = self.wns1(
|
||||
fea, mel_lengths, ge
|
||||
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
||||
B = ssl.shape[0]
|
||||
prompt_len_max = mel_lengths * 2 / 3
|
||||
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
|
||||
minn=min(mel.shape[-1],fea.shape[-1])
|
||||
mel=mel[:,:,:minn]
|
||||
fea=fea[:,:,:minn]
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
||||
minn = min(mel.shape[-1], fea.shape[-1])
|
||||
mel = mel[:, :, :minn]
|
||||
fea = fea[:, :, :minn]
|
||||
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
||||
return cfm_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_encp(self, codes,text, refer,ge=None,speed=1):
|
||||
def decode_encp(self, codes, text, refer, ge=None, speed=1):
|
||||
# print(2333333,refer.shape)
|
||||
# ge=None
|
||||
if(ge==None):
|
||||
if ge == None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
|
||||
if speed==1:
|
||||
sizee=int(codes.size(2)*2.5*1.5)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
||||
if speed == 1:
|
||||
sizee = int(codes.size(2) * 2.5 * 1.5)
|
||||
else:
|
||||
sizee=int(codes.size(2)*2.5*1.5/speed)+1
|
||||
sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1
|
||||
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea,ge
|
||||
return fea, ge
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0,1)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class SynthesizerTrnV3b(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
**kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@ -1330,47 +1266,52 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.model_dim=512
|
||||
self.model_dim = 512
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
|
||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
gin_channels=gin_channels)
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||
self.dec = Generator(
|
||||
inter_channels,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=gin_channels,
|
||||
)
|
||||
self.enc_q = PosteriorEncoder(
|
||||
spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == '25hz':
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=ssl_dim,
|
||||
n_q=1,
|
||||
bins=1024
|
||||
)
|
||||
self.freeze_quantizer=freeze_quantizer
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
|
||||
inter_channels2=512
|
||||
self.bridge=nn.Sequential(
|
||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
||||
nn.LeakyReLU()
|
||||
)
|
||||
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
inter_channels2 = 512
|
||||
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||
self.cfm = CFM(
|
||||
100,
|
||||
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||
) # text_dim is condition feature dim
|
||||
|
||||
|
||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
|
||||
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
|
||||
with autocast(enabled=False):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
|
||||
# ge=None
|
||||
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
||||
@ -1379,51 +1320,59 @@ class SynthesizerTrnV3b(nn.Module):
|
||||
self.ssl_proj.eval()
|
||||
self.quantizer.eval()
|
||||
ssl = self.ssl_proj(ssl)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||
ssl, layers=[0]
|
||||
)
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||
z_p = self.flow(z, y_mask, g=ge)
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=ge)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
|
||||
learned_mel = self.linear_mel(fea)
|
||||
B=ssl.shape[0]
|
||||
prompt_len_max = mel_lengths*2/3
|
||||
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)#
|
||||
minn=min(mel.shape[-1],fea.shape[-1])
|
||||
mel=mel[:,:,:minn]
|
||||
fea=fea[:,:,:minn]
|
||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need
|
||||
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
|
||||
B = ssl.shape[0]
|
||||
prompt_len_max = mel_lengths * 2 / 3
|
||||
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) #
|
||||
minn = min(mel.shape[-1], fea.shape[-1])
|
||||
mel = mel[:, :, :minn]
|
||||
fea = fea[:, :, :minn]
|
||||
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
|
||||
return (
|
||||
commit_loss,
|
||||
cfm_loss,
|
||||
F.mse_loss(learned_mel, mel),
|
||||
o,
|
||||
ids_slice,
|
||||
y_mask,
|
||||
y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
quantized,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_encp(self, codes,text, refer,ge=None):
|
||||
def decode_encp(self, codes, text, refer, ge=None):
|
||||
# print(2333333,refer.shape)
|
||||
# ge=None
|
||||
if(ge==None):
|
||||
if ge == None:
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
|
||||
y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
||||
y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device)
|
||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea,ge
|
||||
return fea, ge
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0,1)
|
||||
return codes.transpose(0, 1)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
import torch
|
||||
@ -11,14 +10,14 @@ from module import attentions_onnx as attentions
|
||||
|
||||
from f5_tts.model import DiT
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
@ -44,29 +43,21 @@ class StochasticDurationPredictor(nn.Module):
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
@ -87,10 +78,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = (
|
||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* x_mask
|
||||
)
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
@ -98,13 +86,8 @@ class StochasticDurationPredictor(nn.Module):
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
@ -113,18 +96,12 @@ class StochasticDurationPredictor(nn.Module):
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = (
|
||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* noise_scale
|
||||
)
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
@ -133,9 +110,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
||||
):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
@ -145,13 +120,9 @@ class DurationPredictor(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@ -234,7 +205,7 @@ class TextEncoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, text, ge, speed=1):
|
||||
y_mask = torch.ones_like(y[:1,:1,:])
|
||||
y_mask = torch.ones_like(y[:1, :1, :])
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
@ -246,8 +217,8 @@ class TextEncoder(nn.Module):
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
if(speed!=1):
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
|
||||
stats = self.proj(y) * y_mask
|
||||
@ -333,9 +304,7 @@ class PosteriorEncoder(nn.Module):
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
@ -345,14 +314,9 @@ class PosteriorEncoder(nn.Module):
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
@ -367,7 +331,7 @@ class Encoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if(g!=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
@ -375,6 +339,7 @@ class Encoder(nn.Module):
|
||||
stats = self.proj(x) * x_mask
|
||||
return stats, x_mask
|
||||
|
||||
|
||||
class WNEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -407,9 +372,7 @@ class WNEncoder(nn.Module):
|
||||
self.norm = modules.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
out = self.proj(x) * x_mask
|
||||
@ -432,9 +395,7 @@ class Generator(torch.nn.Module):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(
|
||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||
)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
@ -454,9 +415,7 @@ class Generator(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
@ -465,7 +424,7 @@ class Generator(torch.nn.Module):
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g:Optional[torch.Tensor]=None):
|
||||
def forward(self, x, g: Optional[torch.Tensor] = None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@ -609,9 +568,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [
|
||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||
]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
@ -711,10 +668,7 @@ class Quantizer(torch.nn.Module):
|
||||
super(Quantizer, self).__init__()
|
||||
assert embed_dim % n_code_groups == 0
|
||||
self.quantizer_modules = nn.ModuleList(
|
||||
[
|
||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
||||
for _ in range(n_code_groups)
|
||||
]
|
||||
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
||||
)
|
||||
self.n_code_groups = n_code_groups
|
||||
self.embed_dim = embed_dim
|
||||
@ -732,9 +686,7 @@ class Quantizer(torch.nn.Module):
|
||||
z_q.append(_z_q)
|
||||
min_indicies.append(_min_indicies) # B * T,
|
||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
||||
(z_q - xin.detach()) ** 2
|
||||
)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||
z_q = xin + (z_q - xin).detach()
|
||||
z_q = z_q.transpose(1, 2)
|
||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||
@ -774,13 +726,9 @@ class CodePredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
ssl_dim, style_vector_dim=hidden_channels
|
||||
)
|
||||
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||
|
||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||
self.n_q = n_q
|
||||
@ -793,9 +741,7 @@ class CodePredictor(nn.Module):
|
||||
x = x + g
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
x = self.out_proj(x * x_mask) * x_mask
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
||||
2, 3
|
||||
)
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||
target = codes[1:].transpose(0, 1)
|
||||
if not infer:
|
||||
logits = logits.reshape(-1, self.dims)
|
||||
@ -844,7 +790,7 @@ class SynthesizerTrn(nn.Module):
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v2",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
@ -896,9 +842,7 @@ class SynthesizerTrn(nn.Module):
|
||||
# 16,
|
||||
# gin_channels=gin_channels,
|
||||
# )
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
# self.version=os.environ.get("version","v1")
|
||||
if self.version == "v1":
|
||||
@ -923,9 +867,9 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
|
||||
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
|
||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||
if (self.version == "v1"):
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
@ -935,9 +879,7 @@ class SynthesizerTrn(nn.Module):
|
||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, text, ge, speed
|
||||
)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
@ -951,11 +893,9 @@ class SynthesizerTrn(nn.Module):
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class CFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,dit
|
||||
):
|
||||
def __init__(self, in_channels, dit):
|
||||
super().__init__()
|
||||
# self.sigma_min = 1e-6
|
||||
|
||||
@ -965,27 +905,34 @@ class CFM(torch.nn.Module):
|
||||
|
||||
# self.criterion = torch.nn.MSELoss()
|
||||
|
||||
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0):
|
||||
def forward(
|
||||
self,
|
||||
mu: torch.Tensor,
|
||||
x_lens: torch.LongTensor,
|
||||
prompt: torch.Tensor,
|
||||
n_timesteps: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
"""Forward diffusion"""
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype)
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
|
||||
|
||||
ntimesteps = int(n_timesteps)
|
||||
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
||||
x[..., :prompt_len] = 0.0
|
||||
mu=mu.transpose(2,1)
|
||||
t = torch.tensor(0.0,dtype=x.dtype,device=x.device)
|
||||
d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device)
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
mu = mu.transpose(2, 1)
|
||||
t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
d = torch.tensor(1.0 / ntimesteps, dtype=x.dtype, device=x.device)
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
||||
|
||||
for j in range(ntimesteps):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu).transpose(2, 1)
|
||||
# if inference_cfg_rate>1e-5:
|
||||
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||
@ -997,47 +944,51 @@ class CFM(torch.nn.Module):
|
||||
|
||||
def set_no_grad(net_g):
|
||||
for name, param in net_g.named_parameters():
|
||||
param.requires_grad=False
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def compile_codes_length(codes):
|
||||
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
|
||||
return y_lengths1 * 2.5 * 1.5
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def compile_ref_length(refer):
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
return refer_lengths
|
||||
|
||||
|
||||
class SynthesizerTrnV3(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
@ -1058,41 +1009,38 @@ class SynthesizerTrnV3(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.model_dim=512
|
||||
self.model_dim = 512
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
# gin_channels=gin_channels)
|
||||
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == '25hz':
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(
|
||||
dimension=ssl_dim,
|
||||
n_q=1,
|
||||
bins=1024
|
||||
)
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
freeze_quantizer
|
||||
inter_channels2=512
|
||||
self.bridge=nn.Sequential(
|
||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
||||
nn.LeakyReLU()
|
||||
)
|
||||
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
|
||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||
if freeze_quantizer==True:
|
||||
inter_channels2 = 512
|
||||
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||
self.cfm = CFM(
|
||||
100,
|
||||
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||
) # text_dim is condition feature dim
|
||||
if freeze_quantizer == True:
|
||||
set_no_grad(self.ssl_proj)
|
||||
set_no_grad(self.quantizer)
|
||||
set_no_grad(self.enc_p)
|
||||
@ -1100,24 +1048,23 @@ class SynthesizerTrnV3(nn.Module):
|
||||
def create_ge(self, refer):
|
||||
refer_lengths = compile_ref_length(refer)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
return ge
|
||||
|
||||
def forward(self, codes, text,ge,speed=1):
|
||||
|
||||
y_lengths1=compile_codes_length(codes)
|
||||
def forward(self, codes, text, ge, speed=1):
|
||||
y_lengths1 = compile_codes_length(codes)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == '25hz':
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed)
|
||||
fea=self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0,1)
|
||||
return codes.transpose(0, 1)
|
||||
|
@ -52,11 +52,7 @@ class ConvReluNorm(nn.Module):
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
@ -156,9 +152,7 @@ class WN(torch.nn.Module):
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
gin_channels, 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
@ -479,9 +473,7 @@ class ConvFlow(nn.Module):
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
||||
self.proj = nn.Conv1d(
|
||||
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
||||
)
|
||||
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
@ -495,9 +487,7 @@ class ConvFlow(nn.Module):
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
||||
self.filter_channels
|
||||
)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||
@ -616,9 +606,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||
|
||||
self.attention = ScaledDotProductAttention(
|
||||
temperature=np.power(d_model, 0.5), dropout=dropout
|
||||
)
|
||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
|
||||
|
||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
@ -649,9 +637,7 @@ class MultiHeadAttention(nn.Module):
|
||||
output, attn = self.attention(q, k, v, mask=slf_mask)
|
||||
|
||||
output = output.view(n_head, sz_b, len_x, d_v)
|
||||
output = (
|
||||
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
|
||||
) # b x lq x (n*dv)
|
||||
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
|
||||
|
||||
output = self.fc(output)
|
||||
|
||||
@ -741,9 +727,7 @@ class MelStyleEncoder(nn.Module):
|
||||
if mask is not None:
|
||||
mask = (mask.int() == 0).squeeze(1)
|
||||
max_len = x.shape[1]
|
||||
slf_attn_mask = (
|
||||
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
||||
)
|
||||
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
||||
|
||||
# spectral
|
||||
x = self.spectral(x)
|
||||
@ -785,9 +769,7 @@ class MelStyleEncoderVAE(nn.Module):
|
||||
mu = self.fc1(enc_out)
|
||||
logvar = self.fc2(enc_out)
|
||||
posterior = D.Normal(mu, torch.exp(logvar))
|
||||
kl_divergence = D.kl_divergence(
|
||||
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
|
||||
)
|
||||
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
|
||||
loss_kl = kl_divergence.mean()
|
||||
|
||||
z = posterior.rsample()
|
||||
@ -825,9 +807,7 @@ class ActNorm(nn.Module):
|
||||
|
||||
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
|
||||
if x_mask is None:
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
if not self.initialized:
|
||||
self.initialize(x, x_mask)
|
||||
@ -856,9 +836,7 @@ class ActNorm(nn.Module):
|
||||
v = m_sq - (m**2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (
|
||||
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
)
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
||||
|
||||
self.bias.data.copy_(bias_init)
|
||||
@ -873,9 +851,7 @@ class InvConvNear(nn.Module):
|
||||
self.n_split = n_split
|
||||
self.no_jacobian = no_jacobian
|
||||
|
||||
w_init = torch.linalg.qr(
|
||||
torch.FloatTensor(self.n_split, self.n_split).normal_()
|
||||
)[0]
|
||||
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
self.weight = nn.Parameter(w_init)
|
||||
@ -890,11 +866,7 @@ class InvConvNear(nn.Module):
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
|
||||
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
|
||||
x = (
|
||||
x.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(b, self.n_split, c // self.n_split, t)
|
||||
)
|
||||
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
|
||||
|
||||
if reverse:
|
||||
if hasattr(self, "weight_inv"):
|
||||
|
@ -31,32 +31,15 @@ class MRTE(nn.Module):
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
if test != None:
|
||||
if test == 0:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
elif test == 1:
|
||||
x = ssl_enc + ge
|
||||
elif test == 2:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
|
||||
else:
|
||||
raise ValueError("test should be 0,1,2")
|
||||
else:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
||||
@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
|
||||
model_embedding_size=256,
|
||||
):
|
||||
super(SpeakerEncoder, self).__init__()
|
||||
self.lstm = nn.LSTM(
|
||||
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
|
||||
)
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
|
@ -7,7 +7,6 @@
|
||||
"""Residual vector quantizer implementation."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
@ -88,14 +87,10 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
raise ValueError(
|
||||
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(
|
||||
x, n_q=n_q, layers=layers
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
|
||||
return quantized, codes, torch.mean(commit_loss), quantized_list
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
|
@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
**spline_kwargs,
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
@ -175,8 +175,7 @@ def rational_quadratic_spline(
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2)
|
||||
@ -190,12 +189,9 @@ def rational_quadratic_spline(
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (
|
||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
||||
)
|
||||
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
|
@ -1,23 +1,22 @@
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
from feature_extractor import cnhubert
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from torch import nn
|
||||
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
ssl_model = cnhubert.get_model()
|
||||
from text import cleaned_text_to_sequence
|
||||
import soundfile
|
||||
from tools.my_utils import load_audio
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import soundfile
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
hann_window = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@ -102,22 +101,22 @@ class T2SModel(nn.Module):
|
||||
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
||||
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
||||
self.stage_decoder = self.t2s_model.stage_decoder
|
||||
#self.t2s_model = torch.jit.script(self.t2s_model)
|
||||
# self.t2s_model = torch.jit.script(self.t2s_model)
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||
early_stop_num = self.t2s_model.early_stop_num
|
||||
|
||||
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
prefix_len = prompts.shape[1]
|
||||
|
||||
#[1,N,512] [1,N]
|
||||
# [1,N,512] [1,N]
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
stop = False
|
||||
for idx in range(1, 1500):
|
||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
y, k, v, y_emb, logits, samples = enco
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
@ -131,13 +130,11 @@ class T2SModel(nn.Module):
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
||||
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||
if dynamo:
|
||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
||||
self.onnx_encoder,
|
||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||
export_options=export_options
|
||||
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
||||
)
|
||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||
return
|
||||
@ -149,13 +146,13 @@ class T2SModel(nn.Module):
|
||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||
output_names=["x", "prompts"],
|
||||
dynamic_axes={
|
||||
"ref_seq": {1 : "ref_length"},
|
||||
"text_seq": {1 : "text_length"},
|
||||
"ref_bert": {0 : "ref_length"},
|
||||
"text_bert": {0 : "text_length"},
|
||||
"ssl_content": {2 : "ssl_length"},
|
||||
"ref_seq": {1: "ref_length"},
|
||||
"text_seq": {1: "text_length"},
|
||||
"ref_bert": {0: "ref_length"},
|
||||
"text_bert": {0: "text_length"},
|
||||
"ssl_content": {2: "ssl_length"},
|
||||
},
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
@ -166,11 +163,11 @@ class T2SModel(nn.Module):
|
||||
input_names=["x", "prompts"],
|
||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||
dynamic_axes={
|
||||
"x": {1 : "x_length"},
|
||||
"prompts": {1 : "prompts_length"},
|
||||
"x": {1: "x_length"},
|
||||
"prompts": {1: "prompts_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
@ -181,23 +178,23 @@ class T2SModel(nn.Module):
|
||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||
dynamic_axes={
|
||||
"iy": {1 : "iy_length"},
|
||||
"ik": {1 : "ik_length"},
|
||||
"iv": {1 : "iv_length"},
|
||||
"iy_emb": {1 : "iy_emb_length"},
|
||||
"ix_example": {1 : "ix_example_length"},
|
||||
"iy": {1: "iy_length"},
|
||||
"ik": {1: "ik_length"},
|
||||
"iv": {1: "iv_length"},
|
||||
"iy_emb": {1: "iy_emb_length"},
|
||||
"ix_example": {1: "ix_example_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path, map_location="cpu")
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
@ -208,7 +205,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
@ -220,7 +217,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||
|
||||
@ -236,12 +233,16 @@ class GptSoVits(nn.Module):
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||
if debug:
|
||||
import onnxruntime
|
||||
|
||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||
audio1 = sess.run(None, {
|
||||
"text_seq" : text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio" : ref_audio.detach().cpu().numpy()
|
||||
})
|
||||
audio1 = sess.run(
|
||||
None,
|
||||
{
|
||||
"text_seq": text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio": ref_audio.detach().cpu().numpy(),
|
||||
},
|
||||
)
|
||||
return audio, audio1
|
||||
return audio
|
||||
|
||||
@ -255,12 +256,12 @@ class GptSoVits(nn.Module):
|
||||
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"text_seq": {1 : "text_length"},
|
||||
"pred_semantic": {2 : "pred_length"},
|
||||
"ref_audio": {1 : "audio_length"},
|
||||
"text_seq": {1: "text_length"},
|
||||
"pred_semantic": {2: "pred_length"},
|
||||
"ref_audio": {1: "audio_length"},
|
||||
},
|
||||
opset_version=17,
|
||||
verbose=False
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@ -278,14 +279,67 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
gpt = T2SModel(gpt_path, vits)
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
ssl = SSLModel()
|
||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||
ref_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"n",
|
||||
"i2",
|
||||
"h",
|
||||
"ao3",
|
||||
",",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
text_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
|
||||
|
||||
try:
|
||||
os.mkdir(f"onnx/{project_name}")
|
||||
@ -326,8 +380,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
}
|
||||
|
||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
|
||||
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -8,19 +8,17 @@ exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
version = os.environ.get('version', None)
|
||||
import sys, numpy as np, traceback, pdb
|
||||
version = os.environ.get("version", None)
|
||||
import traceback
|
||||
import os.path
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
from text.cleaner import clean_text
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
# inp_text=sys.argv[1]
|
||||
@ -36,13 +34,13 @@ from time import time as ttime
|
||||
import shutil
|
||||
|
||||
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
||||
tmp_path="%s%s.pth"%(ttime(),i_part)
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
||||
@ -56,8 +54,10 @@ if os.path.exists(txt_path) == False:
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
if os.path.exists(bert_pretrained_dir):...
|
||||
else:raise FileNotFoundError(bert_pretrained_dir)
|
||||
if os.path.exists(bert_pretrained_dir):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(bert_pretrained_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
||||
if is_half == True:
|
||||
@ -86,12 +86,10 @@ if os.path.exists(txt_path) == False:
|
||||
def process(data, res):
|
||||
for name, text, lan in data:
|
||||
try:
|
||||
name=clean_path(name)
|
||||
name = clean_path(name)
|
||||
name = os.path.basename(name)
|
||||
print(name)
|
||||
phones, word2ph, norm_text = clean_text(
|
||||
text.replace("%", "-").replace("¥", ","), lan, version
|
||||
)
|
||||
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
|
||||
path_bert = "%s/%s.pt" % (bert_dir, name)
|
||||
if os.path.exists(path_bert) == False and lan == "zh":
|
||||
bert_feature = get_bert_feature(norm_text, word2ph)
|
||||
@ -131,9 +129,7 @@ if os.path.exists(txt_path) == False:
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
# todo.append([name,text,"zh"])
|
||||
if language in language_v1_to_language_v2.keys():
|
||||
todo.append(
|
||||
[wav_name, text, language_v1_to_language_v2.get(language, language)]
|
||||
)
|
||||
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
|
||||
else:
|
||||
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
||||
except:
|
||||
|
@ -1,25 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys,os
|
||||
inp_text= os.environ.get("inp_text")
|
||||
inp_wav_dir= os.environ.get("inp_wav_dir")
|
||||
exp_name= os.environ.get("exp_name")
|
||||
i_part= os.environ.get("i_part")
|
||||
all_parts= os.environ.get("all_parts")
|
||||
import sys
|
||||
import os
|
||||
|
||||
inp_text = os.environ.get("inp_text")
|
||||
inp_wav_dir = os.environ.get("inp_wav_dir")
|
||||
exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
from feature_extractor import cnhubert
|
||||
opt_dir= os.environ.get("opt_dir")
|
||||
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
|
||||
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
|
||||
import pdb,traceback,numpy as np,logging
|
||||
import traceback
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import librosa
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from tools.my_utils import load_audio,clean_path
|
||||
from tools.my_utils import load_audio, clean_path
|
||||
|
||||
# from config import cnhubert_base_path
|
||||
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||
@ -34,90 +40,95 @@ from tools.my_utils import load_audio,clean_path
|
||||
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
|
||||
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
||||
tmp_path="%s%s.pth"%(ttime(),i_part)
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
hubert_dir="%s/4-cnhubert"%(opt_dir)
|
||||
wav32dir="%s/5-wav32k"%(opt_dir)
|
||||
os.makedirs(opt_dir,exist_ok=True)
|
||||
os.makedirs(hubert_dir,exist_ok=True)
|
||||
os.makedirs(wav32dir,exist_ok=True)
|
||||
|
||||
maxx=0.95
|
||||
alpha=0.5
|
||||
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
||||
wav32dir = "%s/5-wav32k" % (opt_dir)
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
os.makedirs(hubert_dir, exist_ok=True)
|
||||
os.makedirs(wav32dir, exist_ok=True)
|
||||
|
||||
maxx = 0.95
|
||||
alpha = 0.5
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
# elif torch.backends.mps.is_available():
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
model=cnhubert.get_model()
|
||||
model = cnhubert.get_model()
|
||||
# is_half=False
|
||||
if(is_half==True):
|
||||
model=model.half().to(device)
|
||||
if is_half == True:
|
||||
model = model.half().to(device)
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
nan_fails=[]
|
||||
def name2go(wav_name,wav_path):
|
||||
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
|
||||
if(os.path.exists(hubert_path)):return
|
||||
nan_fails = []
|
||||
|
||||
|
||||
def name2go(wav_name, wav_path):
|
||||
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
||||
if os.path.exists(hubert_path):
|
||||
return
|
||||
tmp_audio = load_audio(wav_path, 32000)
|
||||
tmp_max = np.abs(tmp_audio).max()
|
||||
if tmp_max > 2.2:
|
||||
print("%s-filtered,%s" % (wav_name, tmp_max))
|
||||
return
|
||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
|
||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
|
||||
tmp_audio = librosa.resample(
|
||||
tmp_audio32b, orig_sr=32000, target_sr=16000
|
||||
)#不是重采样问题
|
||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
||||
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
|
||||
tensor_wav16 = torch.from_numpy(tmp_audio)
|
||||
if (is_half == True):
|
||||
tensor_wav16=tensor_wav16.half().to(device)
|
||||
if is_half == True:
|
||||
tensor_wav16 = tensor_wav16.half().to(device)
|
||||
else:
|
||||
tensor_wav16 = tensor_wav16.to(device)
|
||||
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
|
||||
if np.isnan(ssl.detach().numpy()).sum()!= 0:
|
||||
nan_fails.append((wav_name,wav_path))
|
||||
print("nan filtered:%s"%wav_name)
|
||||
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
|
||||
if np.isnan(ssl.detach().numpy()).sum() != 0:
|
||||
nan_fails.append((wav_name, wav_path))
|
||||
print("nan filtered:%s" % wav_name)
|
||||
return
|
||||
wavfile.write(
|
||||
"%s/%s"%(wav32dir,wav_name),
|
||||
"%s/%s" % (wav32dir, wav_name),
|
||||
32000,
|
||||
tmp_audio32.astype("int16"),
|
||||
)
|
||||
my_save(ssl,hubert_path)
|
||||
my_save(ssl, hubert_path)
|
||||
|
||||
with open(inp_text,"r",encoding="utf8")as f:
|
||||
lines=f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines[int(i_part)::int(all_parts)]:
|
||||
with open(inp_text, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines[int(i_part) :: int(all_parts)]:
|
||||
try:
|
||||
# wav_name,text=line.split("\t")
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
wav_name=clean_path(wav_name)
|
||||
if (inp_wav_dir != "" and inp_wav_dir != None):
|
||||
wav_name = clean_path(wav_name)
|
||||
if inp_wav_dir != "" and inp_wav_dir != None:
|
||||
wav_name = os.path.basename(wav_name)
|
||||
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
|
||||
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
||||
|
||||
else:
|
||||
wav_path=wav_name
|
||||
wav_path = wav_name
|
||||
wav_name = os.path.basename(wav_name)
|
||||
name2go(wav_name,wav_path)
|
||||
name2go(wav_name, wav_path)
|
||||
except:
|
||||
print(line,traceback.format_exc())
|
||||
print(line, traceback.format_exc())
|
||||
|
||||
if(len(nan_fails)>0 and is_half==True):
|
||||
is_half=False
|
||||
model=model.float()
|
||||
if len(nan_fails) > 0 and is_half == True:
|
||||
is_half = False
|
||||
model = model.float()
|
||||
for wav in nan_fails:
|
||||
try:
|
||||
name2go(wav[0],wav[1])
|
||||
name2go(wav[0], wav[1])
|
||||
except:
|
||||
print(wav_name,traceback.format_exc())
|
||||
print(wav_name, traceback.format_exc())
|
||||
|
@ -5,13 +5,15 @@ exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
pretrained_s2G = os.environ.get("pretrained_s2G")
|
||||
s2config_path = os.environ.get("s2config_path")
|
||||
|
||||
if os.path.exists(pretrained_s2G):...
|
||||
else:raise FileNotFoundError(pretrained_s2G)
|
||||
if os.path.exists(pretrained_s2G):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(pretrained_s2G)
|
||||
# version=os.environ.get("version","v2")
|
||||
size = os.path.getsize(pretrained_s2G)
|
||||
if size < 82978 * 1024:
|
||||
@ -25,23 +27,22 @@ elif size < 700 * 1024 * 1024:
|
||||
else:
|
||||
version = "v3"
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
import math, traceback
|
||||
import multiprocessing
|
||||
import sys, pdb
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from random import shuffle
|
||||
import torch.multiprocessing as mp
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
import logging, librosa, utils
|
||||
if version!="v3":
|
||||
import logging
|
||||
import utils
|
||||
|
||||
if version != "v3":
|
||||
from module.models import SynthesizerTrn
|
||||
else:
|
||||
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
# from config import pretrained_s2G
|
||||
|
||||
@ -70,7 +71,7 @@ if os.path.exists(semantic_path) == False:
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
version=version,
|
||||
**hps.model
|
||||
**hps.model,
|
||||
)
|
||||
if is_half == True:
|
||||
vq_model = vq_model.half().to(device)
|
||||
@ -107,7 +108,7 @@ if os.path.exists(semantic_path) == False:
|
||||
try:
|
||||
# wav_name,text=line.split("\t")
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
wav_name=clean_path(wav_name)
|
||||
wav_name = clean_path(wav_name)
|
||||
wav_name = os.path.basename(wav_name)
|
||||
# name2go(name,lines1)
|
||||
name2go(wav_name, lines1)
|
||||
|
@ -1,37 +1,44 @@
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from time import time as ttime
|
||||
import shutil,os
|
||||
import shutil
|
||||
import os
|
||||
import torch
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
tmp_path="%s.pth"%(ttime())
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
|
||||
'''
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
tmp_path = "%s.pth" % (ttime())
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
"""
|
||||
00:v1
|
||||
01:v2
|
||||
02:v3
|
||||
03:v3lora
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
from io import BytesIO
|
||||
def my_save2(fea,path):
|
||||
|
||||
|
||||
def my_save2(fea, path):
|
||||
bio = BytesIO()
|
||||
torch.save(fea, bio)
|
||||
bio.seek(0)
|
||||
data = bio.getvalue()
|
||||
data = b'03' + data[2:]###temp for v3lora only, todo
|
||||
with open(path, "wb") as f: f.write(data)
|
||||
data = b"03" + data[2:] ###temp for v3lora only, todo
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||
try:
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
@ -42,7 +49,7 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
opt["config"] = hps
|
||||
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
||||
if lora_rank:
|
||||
opt["lora_rank"]=lora_rank
|
||||
opt["lora_rank"] = lora_rank
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
else:
|
||||
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
@ -50,41 +57,48 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
head2version={
|
||||
b'00':["v1","v1",False],
|
||||
b'01':["v2","v2",False],
|
||||
b'02':["v2","v3",False],
|
||||
b'03':["v2","v3",True],
|
||||
|
||||
head2version = {
|
||||
b"00": ["v1", "v1", False],
|
||||
b"01": ["v2", "v2", False],
|
||||
b"02": ["v2", "v3", False],
|
||||
b"03": ["v2", "v3", True],
|
||||
}
|
||||
hash_pretrained_dict={
|
||||
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
|
||||
hash_pretrained_dict = {
|
||||
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
||||
}
|
||||
import hashlib
|
||||
|
||||
|
||||
def get_hash_from_file(sovits_path):
|
||||
with open(sovits_path,"rb")as f:data=f.read(8192)
|
||||
with open(sovits_path, "rb") as f:
|
||||
data = f.read(8192)
|
||||
hash_md5 = hashlib.md5()
|
||||
hash_md5.update(data)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def get_sovits_version_from_path_fast(sovits_path):
|
||||
###1-if it is pretrained sovits models, by hash
|
||||
hash=get_hash_from_file(sovits_path)
|
||||
hash = get_hash_from_file(sovits_path)
|
||||
if hash in hash_pretrained_dict:
|
||||
return hash_pretrained_dict[hash]
|
||||
###2-new weights or old weights, by head
|
||||
with open(sovits_path,"rb")as f:version=f.read(2)
|
||||
if version!=b"PK":
|
||||
with open(sovits_path, "rb") as f:
|
||||
version = f.read(2)
|
||||
if version != b"PK":
|
||||
return head2version[version]
|
||||
###3-old weights, by file size
|
||||
if_lora_v3=False
|
||||
size=os.path.getsize(sovits_path)
|
||||
'''
|
||||
if_lora_v3 = False
|
||||
size = os.path.getsize(sovits_path)
|
||||
"""
|
||||
v1weights:about 82942KB
|
||||
half thr:82978KB
|
||||
v2weights:about 83014KB
|
||||
v3weights:about 750MB
|
||||
'''
|
||||
"""
|
||||
if size < 82978 * 1024:
|
||||
model_version = version = "v1"
|
||||
elif size < 700 * 1024 * 1024:
|
||||
@ -92,15 +106,16 @@ def get_sovits_version_from_path_fast(sovits_path):
|
||||
else:
|
||||
version = "v2"
|
||||
model_version = "v3"
|
||||
return version,model_version,if_lora_v3
|
||||
return version, model_version, if_lora_v3
|
||||
|
||||
|
||||
def load_sovits_new(sovits_path):
|
||||
f=open(sovits_path,"rb")
|
||||
meta=f.read(2)
|
||||
if meta!="PK":
|
||||
data = b'PK' + f.read()
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != "PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
bio.seek(0)
|
||||
return torch.load(bio, map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path,map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
|
@ -1,31 +1,28 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
||||
import os
|
||||
import pdb
|
||||
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
import argparse
|
||||
import logging
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import torch, platform
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
import torch
|
||||
from AR.data.data_module import Text2SemanticDataModule
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from AR.utils.io import load_yaml_config
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision("high")
|
||||
from AR.utils import get_newest_ckpt
|
||||
|
||||
from collections import OrderedDict
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
|
||||
from AR.utils import get_newest_ckpt
|
||||
from process_ckpt import my_save
|
||||
|
||||
|
||||
@ -37,7 +34,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
if_save_every_weights,
|
||||
half_weights_save_dir,
|
||||
exp_name,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.if_save_latest = if_save_latest
|
||||
@ -50,10 +47,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
||||
if self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if (
|
||||
self._every_n_epochs >= 1
|
||||
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
||||
):
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
if (
|
||||
self.if_save_latest == True
|
||||
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
@ -75,7 +69,7 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||
# torch.save(
|
||||
# print(os.environ)
|
||||
if(os.environ.get("LOCAL_RANK","0")=="0"):
|
||||
if os.environ.get("LOCAL_RANK", "0") == "0":
|
||||
my_save(
|
||||
to_save_od,
|
||||
"%s/%s-e%s.ckpt"
|
||||
@ -112,7 +106,7 @@ def main(args):
|
||||
dirpath=ckpt_dir,
|
||||
)
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
os.environ["MASTER_ADDR"]="localhost"
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
@ -123,9 +117,9 @@ def main(args):
|
||||
devices=-1 if torch.cuda.is_available() else 1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy = DDPStrategy(
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||
) if torch.cuda.is_available() else "auto",
|
||||
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||
if torch.cuda.is_available()
|
||||
else "auto",
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
@ -133,9 +127,7 @@ def main(args):
|
||||
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
|
||||
)
|
||||
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
||||
config, output_dir
|
||||
)
|
||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
|
||||
|
||||
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
||||
config,
|
||||
|
@ -1,36 +1,41 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import utils, os
|
||||
import os
|
||||
|
||||
import utils
|
||||
|
||||
hps = utils.get_hparams(stage=2)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist, traceback
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import logging, traceback
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.INFO)
|
||||
from random import randint
|
||||
from module import commons
|
||||
|
||||
from module import commons
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollate,
|
||||
DistributedBucketSampler,
|
||||
TextAudioSpeakerCollate,
|
||||
TextAudioSpeakerLoader,
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrn,
|
||||
MultiPeriodDiscriminator,
|
||||
)
|
||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
||||
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from module.models import (
|
||||
MultiPeriodDiscriminator,
|
||||
SynthesizerTrn,
|
||||
)
|
||||
from process_ckpt import savee
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@ -46,7 +51,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
@ -74,7 +78,7 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
@ -128,19 +132,27 @@ def run(rank, n_gpus, hps):
|
||||
# batch_size=1, pin_memory=True,
|
||||
# drop_last=False, collate_fn=collate_fn)
|
||||
|
||||
net_g = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to(device)
|
||||
net_g = (
|
||||
SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to(device)
|
||||
)
|
||||
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||
net_d = (
|
||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||
)
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
print(name, "not requires_grad")
|
||||
@ -193,7 +205,7 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
try: # 如果能加载自动resume
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
|
||||
net_d,
|
||||
optim_d,
|
||||
) # D多半加载没事
|
||||
@ -201,11 +213,11 @@ def run(rank, n_gpus, hps):
|
||||
logger.info("loaded D")
|
||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
epoch_str+=1
|
||||
epoch_str += 1
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
# epoch_str = 1
|
||||
# global_step = 0
|
||||
@ -213,37 +225,55 @@ def run(rank, n_gpus, hps):
|
||||
# traceback.print_exc()
|
||||
epoch_str = 1
|
||||
global_step = 0
|
||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||
if (
|
||||
hps.train.pretrained_s2G != ""
|
||||
and hps.train.pretrained_s2G != None
|
||||
and os.path.exists(hps.train.pretrained_s2G)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
) if torch.cuda.is_available() else net_g.load_state_dict(
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
),
|
||||
) ##测试不加载优化器
|
||||
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
||||
if (
|
||||
hps.train.pretrained_s2D != ""
|
||||
and hps.train.pretrained_s2D != None
|
||||
and os.path.exists(hps.train.pretrained_s2D)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||
net_d.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
||||
) if torch.cuda.is_available() else net_d.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_d.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||
),
|
||||
)
|
||||
|
||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
optim_g,
|
||||
gamma=hps.train.lr_decay,
|
||||
last_epoch=-1,
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
optim_d,
|
||||
gamma=hps.train.lr_decay,
|
||||
last_epoch=-1,
|
||||
)
|
||||
for _ in range(epoch_str):
|
||||
scheduler_g.step()
|
||||
@ -285,9 +315,7 @@ def run(rank, n_gpus, hps):
|
||||
print("training done")
|
||||
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
):
|
||||
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
@ -311,17 +339,38 @@ def train_and_evaluate(
|
||||
text_lengths,
|
||||
) in enumerate(tqdm(train_loader)):
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
y, y_lengths = (
|
||||
y.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
y_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||
@ -350,9 +399,7 @@ def train_and_evaluate(
|
||||
hps.data.mel_fmin,
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
y_mel = commons.slice_segments(
|
||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
||||
)
|
||||
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
hps.data.filter_length,
|
||||
@ -364,15 +411,14 @@ def train_and_evaluate(
|
||||
hps.data.mel_fmax,
|
||||
)
|
||||
|
||||
y = commons.slice_segments(
|
||||
y, ids_slice * hps.data.hop_length, hps.train.segment_size
|
||||
) # slice
|
||||
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
||||
|
||||
# Discriminator
|
||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||
with autocast(enabled=False):
|
||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||
y_d_hat_r, y_d_hat_g
|
||||
y_d_hat_r,
|
||||
y_d_hat_g,
|
||||
)
|
||||
loss_disc_all = loss_disc
|
||||
optim_d.zero_grad()
|
||||
@ -405,7 +451,8 @@ def train_and_evaluate(
|
||||
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch, 100.0 * batch_idx / len(train_loader)
|
||||
epoch,
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
@ -429,25 +476,37 @@ def train_and_evaluate(
|
||||
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
||||
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
||||
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
||||
image_dict=None
|
||||
try:###Some people installed the wrong version of matplotlib.
|
||||
image_dict = None
|
||||
try: ###Some people installed the wrong version of matplotlib.
|
||||
image_dict = {
|
||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||
y_mel[0].data.cpu().numpy()
|
||||
y_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].data.cpu().numpy()
|
||||
y_hat_mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].data.cpu().numpy()
|
||||
mel[0].data.cpu().numpy(),
|
||||
),
|
||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
||||
stats_ssl[0].data.cpu().numpy()
|
||||
stats_ssl[0].data.cpu().numpy(),
|
||||
),
|
||||
}
|
||||
except:pass
|
||||
if image_dict:utils.summarize(writer=writer,global_step=global_step,images=image_dict,scalars=scalar_dict,)
|
||||
else:utils.summarize(writer=writer,global_step=global_step,scalars=scalar_dict,)
|
||||
except:
|
||||
pass
|
||||
if image_dict:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
images=image_dict,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
else:
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
@ -457,7 +516,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(global_step),
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
@ -466,7 +526,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"D_{}.pth".format(global_step),
|
||||
),
|
||||
)
|
||||
else:
|
||||
@ -476,7 +537,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(233333333333),
|
||||
),
|
||||
)
|
||||
utils.save_checkpoint(
|
||||
@ -485,7 +547,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"D_{}.pth".format(233333333333),
|
||||
),
|
||||
)
|
||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||
@ -540,10 +603,24 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
ssl = ssl.to(device)
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
for test in [0, 1]:
|
||||
y_hat, mask, *_ = generator.module.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
) if torch.cuda.is_available() else generator.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
y_hat, mask, *_ = (
|
||||
generator.module.infer(
|
||||
ssl,
|
||||
spec,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
test=test,
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else generator.infer(
|
||||
ssl,
|
||||
spec,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
test=test,
|
||||
)
|
||||
)
|
||||
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
||||
|
||||
@ -568,19 +645,19 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
image_dict.update(
|
||||
{
|
||||
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
||||
y_hat_mel[0].cpu().numpy()
|
||||
)
|
||||
y_hat_mel[0].cpu().numpy(),
|
||||
),
|
||||
}
|
||||
)
|
||||
audio_dict.update(
|
||||
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
|
||||
{
|
||||
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
|
||||
},
|
||||
)
|
||||
image_dict.update(
|
||||
{
|
||||
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
|
||||
mel[0].cpu().numpy()
|
||||
)
|
||||
}
|
||||
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
|
||||
},
|
||||
)
|
||||
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
||||
|
||||
|
@ -1,36 +1,41 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import utils, os
|
||||
import os
|
||||
|
||||
import utils
|
||||
|
||||
hps = utils.get_hparams(stage=2)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist, traceback
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import logging, traceback
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.INFO)
|
||||
from random import randint
|
||||
from module import commons
|
||||
|
||||
from module import commons
|
||||
from module.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrnV3 as SynthesizerTrn,
|
||||
MultiPeriodDiscriminator,
|
||||
)
|
||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from process_ckpt import savee
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@ -46,7 +51,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
@ -74,7 +78,7 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
@ -128,17 +132,21 @@ def run(rank, n_gpus, hps):
|
||||
# batch_size=1, pin_memory=True,
|
||||
# drop_last=False, collate_fn=collate_fn)
|
||||
|
||||
net_g = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to(device)
|
||||
net_g = (
|
||||
SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to(device)
|
||||
)
|
||||
|
||||
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||
# for name, param in net_g.named_parameters():
|
||||
@ -146,7 +154,7 @@ def run(rank, n_gpus, hps):
|
||||
# print(name, "not requires_grad")
|
||||
|
||||
optim_g = torch.optim.AdamW(
|
||||
filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
|
||||
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
|
||||
hps.train.learning_rate,
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps,
|
||||
@ -174,11 +182,11 @@ def run(rank, n_gpus, hps):
|
||||
# logger.info("loaded D")
|
||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
|
||||
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
epoch_str+=1
|
||||
epoch_str += 1
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
# epoch_str = 1
|
||||
# global_step = 0
|
||||
@ -186,17 +194,24 @@ def run(rank, n_gpus, hps):
|
||||
# traceback.print_exc()
|
||||
epoch_str = 1
|
||||
global_step = 0
|
||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||
if (
|
||||
hps.train.pretrained_s2G != ""
|
||||
and hps.train.pretrained_s2G != None
|
||||
and os.path.exists(hps.train.pretrained_s2G)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
) if torch.cuda.is_available() else net_g.load_state_dict(
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
),
|
||||
) ##测试不加载优化器
|
||||
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
||||
# if rank == 0:
|
||||
@ -212,9 +227,7 @@ def run(rank, n_gpus, hps):
|
||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
# )
|
||||
@ -224,7 +237,7 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
scaler = GradScaler(enabled=hps.train.fp16_run)
|
||||
|
||||
net_d=optim_d=scheduler_d=None
|
||||
net_d = optim_d = scheduler_d = None
|
||||
print("start training from epoch %s" % epoch_str)
|
||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||
if rank == 0:
|
||||
@ -260,7 +273,16 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
rank,
|
||||
epoch,
|
||||
hps,
|
||||
nets,
|
||||
optims,
|
||||
schedulers,
|
||||
scaler,
|
||||
loaders,
|
||||
logger,
|
||||
writers,
|
||||
):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
@ -284,19 +306,33 @@ def train_and_evaluate(
|
||||
# text,
|
||||
# text_lengths,
|
||||
# ) in enumerate(tqdm(train_loader)):
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
tqdm(train_loader)
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||
@ -307,8 +343,18 @@ def train_and_evaluate(
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
||||
loss_gen_all=cfm_loss
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
@ -318,12 +364,15 @@ def train_and_evaluate(
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]['lr']
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
||||
losses = [cfm_loss]
|
||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
||||
epoch,
|
||||
100. * batch_idx / len(train_loader)))
|
||||
logger.info(
|
||||
"Train Epoch: {} [{:.0f}%]".format(
|
||||
epoch,
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
)
|
||||
)
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
@ -337,7 +386,8 @@ def train_and_evaluate(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
# images=image_dict,
|
||||
scalars=scalar_dict)
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
# if global_step % hps.train.eval_interval == 0:
|
||||
# # evaluate(hps, net_g, eval_loader, writer_eval)
|
||||
@ -347,7 +397,6 @@ def train_and_evaluate(
|
||||
# # if keep_ckpts > 0:
|
||||
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
|
||||
|
||||
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||
if hps.train.if_save_latest == 0:
|
||||
@ -357,7 +406,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(global_step),
|
||||
),
|
||||
)
|
||||
# utils.save_checkpoint(
|
||||
@ -376,7 +426,8 @@ def train_and_evaluate(
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
|
||||
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||
"G_{}.pth".format(233333333333),
|
||||
),
|
||||
)
|
||||
# utils.save_checkpoint(
|
||||
|
@ -1,38 +1,45 @@
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import utils, os
|
||||
import os
|
||||
|
||||
import utils
|
||||
|
||||
hps = utils.get_hparams(stage=2)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist, traceback
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from tqdm import tqdm
|
||||
import logging, traceback
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||
logging.getLogger("numba").setLevel(logging.INFO)
|
||||
from collections import OrderedDict as od
|
||||
from random import randint
|
||||
|
||||
from module import commons
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from module.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrnV3 as SynthesizerTrn,
|
||||
MultiPeriodDiscriminator,
|
||||
)
|
||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from process_ckpt import savee
|
||||
from collections import OrderedDict as od
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = False
|
||||
###反正A100fp32更快,那试试tf32吧
|
||||
@ -46,7 +53,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
@ -65,7 +71,7 @@ def main():
|
||||
|
||||
|
||||
def run(rank, n_gpus, hps):
|
||||
global global_step,no_grad_names,save_root,lora_rank
|
||||
global global_step, no_grad_names, save_root, lora_rank
|
||||
if rank == 0:
|
||||
logger = utils.get_logger(hps.data.exp_dir)
|
||||
logger.info(hps)
|
||||
@ -74,7 +80,7 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
@ -122,21 +128,24 @@ def run(rank, n_gpus, hps):
|
||||
persistent_workers=True,
|
||||
prefetch_factor=4,
|
||||
)
|
||||
save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
|
||||
os.makedirs(save_root,exist_ok=True)
|
||||
lora_rank=int(hps.train.lora_rank)
|
||||
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
lora_rank = int(hps.train.lora_rank)
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank,
|
||||
init_lora_weights=True,
|
||||
)
|
||||
def get_model(hps):return SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
)
|
||||
|
||||
def get_model(hps):
|
||||
return SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
)
|
||||
|
||||
def get_optim(net_g):
|
||||
return torch.optim.AdamW(
|
||||
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
|
||||
@ -144,61 +153,66 @@ def run(rank, n_gpus, hps):
|
||||
betas=hps.train.betas,
|
||||
eps=hps.train.eps,
|
||||
)
|
||||
def model2cuda(net_g,rank):
|
||||
|
||||
def model2cuda(net_g, rank):
|
||||
if torch.cuda.is_available():
|
||||
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
|
||||
else:
|
||||
net_g = net_g.to(device)
|
||||
return net_g
|
||||
try:# 如果能加载自动resume
|
||||
|
||||
try: # 如果能加载自动resume
|
||||
net_g = get_model(hps)
|
||||
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
||||
net_g=model2cuda(net_g,rank)
|
||||
optim_g=get_optim(net_g)
|
||||
net_g = model2cuda(net_g, rank)
|
||||
optim_g = get_optim(net_g)
|
||||
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
utils.latest_checkpoint_path(save_root, "G_*.pth"),
|
||||
net_g,
|
||||
optim_g,
|
||||
)
|
||||
epoch_str+=1
|
||||
epoch_str += 1
|
||||
global_step = (epoch_str - 1) * len(train_loader)
|
||||
except: # 如果首次不能加载,加载pretrain
|
||||
# traceback.print_exc()
|
||||
epoch_str = 1
|
||||
global_step = 0
|
||||
net_g = get_model(hps)
|
||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||
if (
|
||||
hps.train.pretrained_s2G != ""
|
||||
and hps.train.pretrained_s2G != None
|
||||
and os.path.exists(hps.train.pretrained_s2G)
|
||||
):
|
||||
if rank == 0:
|
||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||
net_g.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
||||
net_g=model2cuda(net_g,rank)
|
||||
net_g = model2cuda(net_g, rank)
|
||||
optim_g = get_optim(net_g)
|
||||
|
||||
no_grad_names=set()
|
||||
no_grad_names = set()
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
no_grad_names.add(name.replace("module.",""))
|
||||
no_grad_names.add(name.replace("module.", ""))
|
||||
# print(name, "not requires_grad")
|
||||
# print(no_grad_names)
|
||||
# os._exit(233333)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||
for _ in range(epoch_str):
|
||||
scheduler_g.step()
|
||||
|
||||
scaler = GradScaler(enabled=hps.train.fp16_run)
|
||||
|
||||
net_d=optim_d=scheduler_d=None
|
||||
print("start training from epoch %s"%epoch_str)
|
||||
net_d = optim_d = scheduler_d = None
|
||||
print("start training from epoch %s" % epoch_str)
|
||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||
if rank == 0:
|
||||
train_and_evaluate(
|
||||
@ -230,9 +244,8 @@ def run(rank, n_gpus, hps):
|
||||
scheduler_g.step()
|
||||
print("training done")
|
||||
|
||||
def train_and_evaluate(
|
||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||
):
|
||||
|
||||
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||
net_g, net_d = nets
|
||||
optim_g, optim_d = optims
|
||||
# scheduler_g, scheduler_d = schedulers
|
||||
@ -244,18 +257,32 @@ def train_and_evaluate(
|
||||
global global_step
|
||||
|
||||
net_g.train()
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
|
||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||
tqdm(train_loader)
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
||||
rank, non_blocking=True
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||
@ -265,8 +292,18 @@ def train_and_evaluate(
|
||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
||||
loss_gen_all=cfm_loss
|
||||
cfm_loss = net_g(
|
||||
ssl,
|
||||
spec,
|
||||
mel,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
use_grad_ckpt=hps.train.grad_ckpt,
|
||||
)
|
||||
loss_gen_all = cfm_loss
|
||||
optim_g.zero_grad()
|
||||
scaler.scale(loss_gen_all).backward()
|
||||
scaler.unscale_(optim_g)
|
||||
@ -276,18 +313,17 @@ def train_and_evaluate(
|
||||
|
||||
if rank == 0:
|
||||
if global_step % hps.train.log_interval == 0:
|
||||
lr = optim_g.param_groups[0]['lr']
|
||||
lr = optim_g.param_groups[0]["lr"]
|
||||
losses = [cfm_loss]
|
||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
||||
epoch,
|
||||
100. * batch_idx / len(train_loader)))
|
||||
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
|
||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||
|
||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||
utils.summarize(
|
||||
writer=writer,
|
||||
global_step=global_step,
|
||||
scalars=scalar_dict)
|
||||
scalars=scalar_dict,
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||
@ -297,9 +333,7 @@ def train_and_evaluate(
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
save_root, "G_{}.pth".format(global_step)
|
||||
),
|
||||
os.path.join(save_root, "G_{}.pth".format(global_step)),
|
||||
)
|
||||
else:
|
||||
utils.save_checkpoint(
|
||||
@ -307,21 +341,19 @@ def train_and_evaluate(
|
||||
optim_g,
|
||||
hps.train.learning_rate,
|
||||
epoch,
|
||||
os.path.join(
|
||||
save_root, "G_{}.pth".format(233333333333)
|
||||
),
|
||||
os.path.join(save_root, "G_{}.pth".format(233333333333)),
|
||||
)
|
||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||
if hasattr(net_g, "module"):
|
||||
ckpt = net_g.module.state_dict()
|
||||
else:
|
||||
ckpt = net_g.state_dict()
|
||||
sim_ckpt=od()
|
||||
sim_ckpt = od()
|
||||
for key in ckpt:
|
||||
# if "cfm"not in key:
|
||||
# print(key)
|
||||
if key not in no_grad_names:
|
||||
sim_ckpt[key]=ckpt[key].half().cpu()
|
||||
sim_ckpt[key] = ckpt[key].half().cpu()
|
||||
logger.info(
|
||||
"saving ckpt %s_e%s:%s"
|
||||
% (
|
||||
@ -329,10 +361,11 @@ def train_and_evaluate(
|
||||
epoch,
|
||||
savee(
|
||||
sim_ckpt,
|
||||
hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
|
||||
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
|
||||
epoch,
|
||||
global_step,
|
||||
hps,lora_rank=lora_rank
|
||||
hps,
|
||||
lora_rank=lora_rank,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -3,38 +3,44 @@ import re
|
||||
|
||||
# jieba静音
|
||||
import jieba
|
||||
|
||||
jieba.setLogLevel(logging.CRITICAL)
|
||||
|
||||
# 更改fast_langdetect大模型位置
|
||||
from pathlib import Path
|
||||
import fast_langdetect
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
||||
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
|
||||
fast_langdetect.infer.LangDetectConfig(
|
||||
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
from split_lang import LangSplitter
|
||||
|
||||
|
||||
def full_en(text):
|
||||
pattern = r'^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
|
||||
pattern = r"^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
|
||||
return bool(re.match(pattern, text))
|
||||
|
||||
|
||||
def full_cjk(text):
|
||||
# 来自wiki
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DB5), # CJK Extension A
|
||||
(0x20000, 0x2A6DD), # CJK Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Extension F
|
||||
(0x30000, 0x3134A), # CJK Extension G
|
||||
(0x31350, 0x323AF), # CJK Extension H
|
||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DB5), # CJK Extension A
|
||||
(0x20000, 0x2A6DD), # CJK Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Extension F
|
||||
(0x30000, 0x3134A), # CJK Extension G
|
||||
(0x31350, 0x323AF), # CJK Extension H
|
||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||
]
|
||||
|
||||
pattern = r'[0-9、-〜。!?.!?… ]+$'
|
||||
pattern = r"[0-9、-〜。!?.!?… ]+$"
|
||||
|
||||
cjk_text = ""
|
||||
for char in text:
|
||||
@ -45,7 +51,7 @@ def full_cjk(text):
|
||||
return cjk_text
|
||||
|
||||
|
||||
def split_jako(tag_lang,item):
|
||||
def split_jako(tag_lang, item):
|
||||
if tag_lang == "ja":
|
||||
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
|
||||
else:
|
||||
@ -53,41 +59,40 @@ def split_jako(tag_lang,item):
|
||||
|
||||
lang_list: list[dict] = []
|
||||
tag = 0
|
||||
for match in re.finditer(pattern, item['text']):
|
||||
for match in re.finditer(pattern, item["text"]):
|
||||
if match.start() > tag:
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
|
||||
|
||||
tag = match.end()
|
||||
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
|
||||
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
|
||||
|
||||
if tag < len(item['text']):
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
|
||||
if tag < len(item["text"]):
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
|
||||
|
||||
return lang_list
|
||||
|
||||
|
||||
def merge_lang(lang_list, item):
|
||||
if lang_list and item['lang'] == lang_list[-1]['lang']:
|
||||
lang_list[-1]['text'] += item['text']
|
||||
if lang_list and item["lang"] == lang_list[-1]["lang"]:
|
||||
lang_list[-1]["text"] += item["text"]
|
||||
else:
|
||||
lang_list.append(item)
|
||||
return lang_list
|
||||
|
||||
|
||||
class LangSegmenter():
|
||||
class LangSegmenter:
|
||||
# 默认过滤器, 基于gsv目前四种语言
|
||||
DEFAULT_LANG_MAP = {
|
||||
"zh": "zh",
|
||||
"yue": "zh", # 粤语
|
||||
"wuu": "zh", # 吴语
|
||||
"zh-cn": "zh",
|
||||
"zh-tw": "x", # 繁体设置为x
|
||||
"zh-tw": "x", # 繁体设置为x
|
||||
"ko": "ko",
|
||||
"ja": "ja",
|
||||
"en": "en",
|
||||
}
|
||||
|
||||
|
||||
def getTexts(text):
|
||||
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
||||
substr = lang_splitter.split_by_lang(text=text)
|
||||
@ -95,18 +100,18 @@ class LangSegmenter():
|
||||
lang_list: list[dict] = []
|
||||
|
||||
for _, item in enumerate(substr):
|
||||
dict_item = {'lang':item.lang,'text':item.text}
|
||||
dict_item = {"lang": item.lang, "text": item.text}
|
||||
|
||||
# 处理短英文被识别为其他语言的问题
|
||||
if full_en(dict_item['text']):
|
||||
dict_item['lang'] = 'en'
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
if full_en(dict_item["text"]):
|
||||
dict_item["lang"] = "en"
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 处理非日语夹日文的问题(不包含CJK)
|
||||
ja_list: list[dict] = []
|
||||
if dict_item['lang'] != 'ja':
|
||||
ja_list = split_jako('ja',dict_item)
|
||||
if dict_item["lang"] != "ja":
|
||||
ja_list = split_jako("ja", dict_item)
|
||||
|
||||
if not ja_list:
|
||||
ja_list.append(dict_item)
|
||||
@ -115,8 +120,8 @@ class LangSegmenter():
|
||||
ko_list: list[dict] = []
|
||||
temp_list: list[dict] = []
|
||||
for _, ko_item in enumerate(ja_list):
|
||||
if ko_item["lang"] != 'ko':
|
||||
ko_list = split_jako('ko',ko_item)
|
||||
if ko_item["lang"] != "ko":
|
||||
ko_list = split_jako("ko", ko_item)
|
||||
|
||||
if ko_list:
|
||||
temp_list.extend(ko_list)
|
||||
@ -126,26 +131,26 @@ class LangSegmenter():
|
||||
# 未存在非日韩文夹日韩文
|
||||
if len(temp_list) == 1:
|
||||
# 未知语言检查是否为CJK
|
||||
if dict_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(dict_item['text'])
|
||||
if dict_item["lang"] == "x":
|
||||
cjk_text = full_cjk(dict_item["text"])
|
||||
if cjk_text:
|
||||
dict_item = {'lang':'zh','text':cjk_text}
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item = {"lang": "zh", "text": cjk_text}
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 存在非日韩文夹日韩文
|
||||
for _, temp_item in enumerate(temp_list):
|
||||
# 未知语言检查是否为CJK
|
||||
if temp_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(dict_item['text'])
|
||||
if temp_item["lang"] == "x":
|
||||
cjk_text = full_cjk(dict_item["text"])
|
||||
if cjk_text:
|
||||
dict_item = {'lang':'zh','text':cjk_text}
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item = {"lang": "zh", "text": cjk_text}
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
return lang_list
|
||||
|
||||
|
||||
@ -155,4 +160,3 @@ if __name__ == "__main__":
|
||||
|
||||
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
|
||||
print(LangSegmenter.getTexts(text))
|
||||
|
||||
|
@ -10,18 +10,19 @@ from text import symbols2 as symbols_v2
|
||||
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
||||
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
||||
|
||||
|
||||
def cleaned_text_to_sequence(cleaned_text, version=None):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version == "v1":
|
||||
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
|
||||
else:
|
||||
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
|
||||
|
||||
return phones
|
||||
"""
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
if version == "v1":
|
||||
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
|
||||
else:
|
||||
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
|
||||
|
||||
return phones
|
||||
|
@ -1,6 +1,5 @@
|
||||
# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py
|
||||
|
||||
import sys
|
||||
import re
|
||||
import cn2an
|
||||
import ToJyutping
|
||||
@ -99,9 +98,7 @@ def replace_punctuation(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
@ -115,7 +112,9 @@ def text_normalize(text):
|
||||
return dest_text
|
||||
|
||||
|
||||
punctuation_set=set(punctuation)
|
||||
punctuation_set = set(punctuation)
|
||||
|
||||
|
||||
def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||
initials_finals = []
|
||||
tones = []
|
||||
@ -160,12 +159,14 @@ def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||
assert len(initials_finals) == len(tones)
|
||||
|
||||
###魔改为辅音+带音调的元音
|
||||
phones=[]
|
||||
for a,b in zip(initials_finals,tones):
|
||||
if(b not in [-1,0]):###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||
todo="%s%s"%(a,b)
|
||||
else:todo=a
|
||||
if(todo not in punctuation_set):todo="Y%s"%todo
|
||||
phones = []
|
||||
for a, b in zip(initials_finals, tones):
|
||||
if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||
todo = "%s%s" % (a, b)
|
||||
else:
|
||||
todo = a
|
||||
if todo not in punctuation_set:
|
||||
todo = "Y%s" % todo
|
||||
phones.append(todo)
|
||||
|
||||
# return initials_finals, tones, word2ph
|
||||
|
@ -1,5 +1,4 @@
|
||||
import os
|
||||
import pdb
|
||||
import re
|
||||
|
||||
import cn2an
|
||||
@ -17,7 +16,9 @@ pinyin_to_symbol_map = {
|
||||
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
||||
}
|
||||
|
||||
import jieba_fast, logging
|
||||
import jieba_fast
|
||||
import logging
|
||||
|
||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||
import jieba_fast.posseg as psg
|
||||
|
||||
@ -37,7 +38,7 @@ rep_map = {
|
||||
"/": ",",
|
||||
"—": "-",
|
||||
"~": "…",
|
||||
"~":"…",
|
||||
"~": "…",
|
||||
}
|
||||
|
||||
tone_modifier = ToneSandhi()
|
||||
@ -49,9 +50,7 @@ def replace_punctuation(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
@ -62,17 +61,15 @@ def replace_punctuation_with_en(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
@ -87,9 +84,7 @@ def _get_initials_finals(word):
|
||||
initials = []
|
||||
finals = []
|
||||
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||
orig_finals = lazy_pinyin(
|
||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
||||
)
|
||||
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for c, v in zip(orig_initials, orig_finals):
|
||||
initials.append(c)
|
||||
finals.append(v)
|
||||
|
@ -1,10 +1,9 @@
|
||||
import os
|
||||
import pdb
|
||||
import re
|
||||
|
||||
import cn2an
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
from pypinyin.contrib.tone_convert import to_normal, to_finals_tone3, to_initials, to_finals
|
||||
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
||||
|
||||
from text.symbols import punctuation
|
||||
from text.tone_sandhi import ToneSandhi
|
||||
@ -18,18 +17,26 @@ pinyin_to_symbol_map = {
|
||||
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
||||
}
|
||||
|
||||
import jieba_fast, logging
|
||||
import jieba_fast
|
||||
import logging
|
||||
|
||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||
import jieba_fast.posseg as psg
|
||||
|
||||
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
|
||||
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
|
||||
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
|
||||
is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False
|
||||
if is_g2pw:
|
||||
# print("当前使用g2pw进行拼音推理")
|
||||
from text.g2pw import G2PWPinyin, correct_pronunciation
|
||||
|
||||
parent_directory = os.path.dirname(current_file_path)
|
||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)
|
||||
g2pw = G2PWPinyin(
|
||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||
model_source=os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
|
||||
rep_map = {
|
||||
":": ",",
|
||||
@ -46,7 +53,7 @@ rep_map = {
|
||||
"/": ",",
|
||||
"—": "-",
|
||||
"~": "…",
|
||||
"~":"…",
|
||||
"~": "…",
|
||||
}
|
||||
|
||||
tone_modifier = ToneSandhi()
|
||||
@ -58,9 +65,7 @@ def replace_punctuation(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
@ -77,9 +82,7 @@ def _get_initials_finals(word):
|
||||
finals = []
|
||||
|
||||
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||
orig_finals = lazy_pinyin(
|
||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
||||
)
|
||||
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
|
||||
for c, v in zip(orig_initials, orig_finals):
|
||||
initials.append(c)
|
||||
@ -87,31 +90,66 @@ def _get_initials_finals(word):
|
||||
return initials, finals
|
||||
|
||||
|
||||
must_erhua = {
|
||||
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
|
||||
}
|
||||
must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"}
|
||||
not_erhua = {
|
||||
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
|
||||
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
|
||||
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
|
||||
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
|
||||
"狗儿", "少儿"
|
||||
"虐儿",
|
||||
"为儿",
|
||||
"护儿",
|
||||
"瞒儿",
|
||||
"救儿",
|
||||
"替儿",
|
||||
"有儿",
|
||||
"一儿",
|
||||
"我儿",
|
||||
"俺儿",
|
||||
"妻儿",
|
||||
"拐儿",
|
||||
"聋儿",
|
||||
"乞儿",
|
||||
"患儿",
|
||||
"幼儿",
|
||||
"孤儿",
|
||||
"婴儿",
|
||||
"婴幼儿",
|
||||
"连体儿",
|
||||
"脑瘫儿",
|
||||
"流浪儿",
|
||||
"体弱儿",
|
||||
"混血儿",
|
||||
"蜜雪儿",
|
||||
"舫儿",
|
||||
"祖儿",
|
||||
"美儿",
|
||||
"应采儿",
|
||||
"可儿",
|
||||
"侄儿",
|
||||
"孙儿",
|
||||
"侄孙儿",
|
||||
"女儿",
|
||||
"男儿",
|
||||
"红孩儿",
|
||||
"花儿",
|
||||
"虫儿",
|
||||
"马儿",
|
||||
"鸟儿",
|
||||
"猪儿",
|
||||
"猫儿",
|
||||
"狗儿",
|
||||
"少儿",
|
||||
}
|
||||
def _merge_erhua(initials: list[str],
|
||||
finals: list[str],
|
||||
word: str,
|
||||
pos: str) -> list[list[str]]:
|
||||
|
||||
|
||||
def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> list[list[str]]:
|
||||
"""
|
||||
Do erhub.
|
||||
"""
|
||||
# fix er1
|
||||
for i, phn in enumerate(finals):
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn == 'er1':
|
||||
finals[i] = 'er2'
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn == "er1":
|
||||
finals[i] = "er2"
|
||||
|
||||
# 发音
|
||||
if word not in must_erhua and (word in not_erhua or
|
||||
pos in {"a", "j", "nr"}):
|
||||
if word not in must_erhua and (word in not_erhua or pos in {"a", "j", "nr"}):
|
||||
return initials, finals
|
||||
|
||||
# "……" 等情况直接返回
|
||||
@ -124,9 +162,13 @@ def _merge_erhua(initials: list[str],
|
||||
new_initials = []
|
||||
new_finals = []
|
||||
for i, phn in enumerate(finals):
|
||||
if i == len(finals) - 1 and word[i] == "儿" and phn in {
|
||||
"er2", "er5"
|
||||
} and word[-2:] not in not_erhua and new_finals:
|
||||
if (
|
||||
i == len(finals) - 1
|
||||
and word[i] == "儿"
|
||||
and phn in {"er2", "er5"}
|
||||
and word[-2:] not in not_erhua
|
||||
and new_finals
|
||||
):
|
||||
phn = "er" + new_finals[-1][-1]
|
||||
|
||||
new_initials.append(initials[i])
|
||||
@ -160,7 +202,7 @@ def _g2p(segments):
|
||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果",initials,finals)
|
||||
print("pypinyin结果", initials, finals)
|
||||
else:
|
||||
# g2pw采用整句推理
|
||||
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
@ -171,19 +213,19 @@ def _g2p(segments):
|
||||
sub_finals = []
|
||||
now_word_length = pre_word_length + len(word)
|
||||
|
||||
if pos == 'eng':
|
||||
if pos == "eng":
|
||||
pre_word_length = now_word_length
|
||||
continue
|
||||
|
||||
word_pinyins = pinyins[pre_word_length:now_word_length]
|
||||
|
||||
# 多音字消歧
|
||||
word_pinyins = correct_pronunciation(word,word_pinyins)
|
||||
word_pinyins = correct_pronunciation(word, word_pinyins)
|
||||
|
||||
for pinyin in word_pinyins:
|
||||
if pinyin[0].isalpha():
|
||||
sub_initials.append(to_initials(pinyin))
|
||||
sub_finals.append(to_finals_tone3(pinyin,neutral_tone_with_five=True))
|
||||
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
|
||||
else:
|
||||
sub_initials.append(pinyin)
|
||||
sub_finals.append(pinyin)
|
||||
@ -259,18 +301,18 @@ def replace_punctuation_with_en(text):
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(
|
||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
||||
)
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
def text_normalize(text):
|
||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||
tx = TextNormalizer()
|
||||
@ -283,6 +325,7 @@ def text_normalize(text):
|
||||
dest_text = replace_consecutive_punctuation(dest_text)
|
||||
return dest_text
|
||||
|
||||
|
||||
# 不排除英文的文本格式化
|
||||
def mix_text_normalize(text):
|
||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||
|
@ -19,55 +19,57 @@ special = [
|
||||
|
||||
|
||||
def clean_text(text, language, version=None):
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
|
||||
|
||||
if(language not in language_module_map):
|
||||
language="en"
|
||||
text=" "
|
||||
if language not in language_module_map:
|
||||
language = "en"
|
||||
text = " "
|
||||
for special_s, special_l, target_symbol in special:
|
||||
if special_s in text and language == special_l:
|
||||
return clean_special(text, language, special_s, target_symbol, version)
|
||||
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
|
||||
if hasattr(language_module,"text_normalize"):
|
||||
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
|
||||
if hasattr(language_module, "text_normalize"):
|
||||
norm_text = language_module.text_normalize(text)
|
||||
else:
|
||||
norm_text=text
|
||||
if language == "zh" or language=="yue":##########
|
||||
norm_text = text
|
||||
if language == "zh" or language == "yue": ##########
|
||||
phones, word2ph = language_module.g2p(norm_text)
|
||||
assert len(phones) == sum(word2ph)
|
||||
assert len(norm_text) == len(word2ph)
|
||||
elif language == "en":
|
||||
phones = language_module.g2p(norm_text)
|
||||
if len(phones) < 4:
|
||||
phones = [','] + phones
|
||||
phones = [","] + phones
|
||||
word2ph = None
|
||||
else:
|
||||
phones = language_module.g2p(norm_text)
|
||||
word2ph = None
|
||||
phones = ['UNK' if ph not in symbols else ph for ph in phones]
|
||||
phones = ["UNK" if ph not in symbols else ph for ph in phones]
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def clean_special(text, language, special_s, target_symbol, version=None):
|
||||
if version is None:version=os.environ.get('version', 'v2')
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
|
||||
|
||||
"""
|
||||
特殊静音段sp符号处理
|
||||
"""
|
||||
text = text.replace(special_s, ",")
|
||||
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
|
||||
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
|
||||
norm_text = language_module.text_normalize(text)
|
||||
phones = language_module.g2p(norm_text)
|
||||
new_ph = []
|
||||
@ -81,8 +83,9 @@ def clean_special(text, language, special_s, target_symbol, version=None):
|
||||
|
||||
|
||||
def text_to_sequence(text, language, version=None):
|
||||
version = os.environ.get('version',version)
|
||||
if version is None:version='v2'
|
||||
version = os.environ.get("version", version)
|
||||
if version is None:
|
||||
version = "v2"
|
||||
phones = clean_text(text)
|
||||
return cleaned_text_to_sequence(phones, version)
|
||||
|
||||
|
@ -9,17 +9,17 @@ import unicodedata
|
||||
# 后缀计量单位替换表
|
||||
measurement_map = {
|
||||
"m": ["meter", "meters"],
|
||||
'km': ["kilometer", "kilometers"],
|
||||
"km": ["kilometer", "kilometers"],
|
||||
"km/h": ["kilometer per hour", "kilometers per hour"],
|
||||
"ft": ["feet", "feet"],
|
||||
"L": ["liter", "liters"],
|
||||
"tbsp": ["tablespoon", "tablespoons"],
|
||||
'tsp': ["teaspoon", "teaspoons"],
|
||||
"tsp": ["teaspoon", "teaspoons"],
|
||||
"h": ["hour", "hours"],
|
||||
"min": ["minute", "minutes"],
|
||||
"s": ["second", "seconds"],
|
||||
"°C": ["degree celsius", "degrees celsius"],
|
||||
"°F": ["degree fahrenheit", "degrees fahrenheit"]
|
||||
"°F": ["degree fahrenheit", "degrees fahrenheit"],
|
||||
}
|
||||
|
||||
|
||||
@ -27,37 +27,38 @@ measurement_map = {
|
||||
_inflect = inflect.engine()
|
||||
|
||||
# 转化数字序数词
|
||||
_ordinal_number_re = re.compile(r'\b([0-9]+)\. ')
|
||||
_ordinal_number_re = re.compile(r"\b([0-9]+)\. ")
|
||||
|
||||
# 我听说好像对于数字正则识别其实用 \d 会好一点
|
||||
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
|
||||
# 时间识别
|
||||
_time_re = re.compile(r'\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b')
|
||||
_time_re = re.compile(r"\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b")
|
||||
|
||||
# 后缀计量单位识别
|
||||
_measurement_re = re.compile(r'\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b')
|
||||
_measurement_re = re.compile(r"\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b")
|
||||
|
||||
# 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ )
|
||||
_pounds_re_start = re.compile(r'£([0-9\.\,]*[0-9]+)')
|
||||
_pounds_re_end = re.compile(r'([0-9\.\,]*[0-9]+)£')
|
||||
_pounds_re_start = re.compile(r"£([0-9\.\,]*[0-9]+)")
|
||||
_pounds_re_end = re.compile(r"([0-9\.\,]*[0-9]+)£")
|
||||
|
||||
# 前后 $ 识别
|
||||
_dollars_re_start = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||
_dollars_re_end = re.compile(r'([(0-9\.\,]*[0-9]+)\$')
|
||||
_dollars_re_start = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_dollars_re_end = re.compile(r"([(0-9\.\,]*[0-9]+)\$")
|
||||
|
||||
# 小数的识别
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.\s*[0-9]+)')
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.\s*[0-9]+)")
|
||||
|
||||
# 分数识别 (形式 "3/4" )
|
||||
_fraction_re = re.compile(r'([0-9]+/[0-9]+)')
|
||||
_fraction_re = re.compile(r"([0-9]+/[0-9]+)")
|
||||
|
||||
# 序数词识别
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
|
||||
# 数字处理
|
||||
_number_re = re.compile(r'[0-9]+')
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _convert_ordinal(m):
|
||||
"""
|
||||
@ -70,8 +71,10 @@ def _convert_ordinal(m):
|
||||
ordinal = _inflect.ordinal(m.group(1))
|
||||
return ordinal + ", "
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_time(m):
|
||||
"""
|
||||
@ -82,12 +85,12 @@ def _expand_time(m):
|
||||
output: "one o'clock p.m. / four o'clock am. / one thirty p.m."
|
||||
"""
|
||||
hours, minutes = map(int, m.group(1, 2))
|
||||
period = 'a.m.' if hours < 12 else 'p.m.'
|
||||
period = "a.m." if hours < 12 else "p.m."
|
||||
if hours > 12:
|
||||
hours -= 12
|
||||
|
||||
hour_word = _inflect.number_to_words(hours)
|
||||
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ''
|
||||
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ""
|
||||
|
||||
if minutes == 0:
|
||||
return f"{hour_word} o'clock {period}"
|
||||
@ -103,7 +106,7 @@ def _expand_measurement(m):
|
||||
sign = m.group(3)
|
||||
ptr = 1
|
||||
# 想不到怎么方便的取数字,又懒得改正则,诶,1.2 反正也是复数读法,干脆直接去掉 "."
|
||||
num = int(m.group(1).replace(sign, '').replace(".",''))
|
||||
num = int(m.group(1).replace(sign, "").replace(".", ""))
|
||||
decimal_part = m.group(2)
|
||||
# 上面判断的漏洞,比如 0.1 的情况,在这里排除了
|
||||
if decimal_part == None and num == 1:
|
||||
@ -116,23 +119,24 @@ def _expand_pounds(m):
|
||||
没找到特别规范的说明,和美元的处理一样,其实可以把两个合并在一起
|
||||
"""
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + ' pounds' # Unexpected format
|
||||
return match + " pounds" # Unexpected format
|
||||
pounds = int(parts[0]) if parts[0] else 0
|
||||
pence = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
|
||||
pence = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
|
||||
if pounds and pence:
|
||||
pound_unit = 'pound' if pounds == 1 else 'pounds'
|
||||
penny_unit = 'penny' if pence == 1 else 'pence'
|
||||
return '%s %s and %s %s' % (pounds, pound_unit, pence, penny_unit)
|
||||
pound_unit = "pound" if pounds == 1 else "pounds"
|
||||
penny_unit = "penny" if pence == 1 else "pence"
|
||||
return "%s %s and %s %s" % (pounds, pound_unit, pence, penny_unit)
|
||||
elif pounds:
|
||||
pound_unit = 'pound' if pounds == 1 else 'pounds'
|
||||
return '%s %s' % (pounds, pound_unit)
|
||||
pound_unit = "pound" if pounds == 1 else "pounds"
|
||||
return "%s %s" % (pounds, pound_unit)
|
||||
elif pence:
|
||||
penny_unit = 'penny' if pence == 1 else 'pence'
|
||||
return '%s %s' % (pence, penny_unit)
|
||||
penny_unit = "penny" if pence == 1 else "pence"
|
||||
return "%s %s" % (pence, penny_unit)
|
||||
else:
|
||||
return 'zero pounds'
|
||||
return "zero pounds"
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
"""
|
||||
@ -142,23 +146,24 @@ def _expand_dollars(m):
|
||||
output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents"
|
||||
"""
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars' # Unexpected format
|
||||
return match + " dollars" # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
|
||||
cents = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s and %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s and %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
# 小数的处理
|
||||
def _expand_decimal_number(m):
|
||||
@ -168,11 +173,11 @@ def _expand_decimal_number(m):
|
||||
output: "thirteen point two three four"
|
||||
"""
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
parts = match.split(".")
|
||||
words = []
|
||||
# 遍历字符串中的每个字符
|
||||
for char in parts[1]:
|
||||
if char == '.':
|
||||
if char == ".":
|
||||
words.append("point")
|
||||
else:
|
||||
words.append(char)
|
||||
@ -196,39 +201,41 @@ def _expend_fraction(m):
|
||||
| 3/2 | three halves |
|
||||
"""
|
||||
match = m.group(0)
|
||||
numerator, denominator = map(int, match.split('/'))
|
||||
numerator, denominator = map(int, match.split("/"))
|
||||
|
||||
numerator_part = _inflect.number_to_words(numerator)
|
||||
if denominator == 2:
|
||||
if numerator == 1:
|
||||
denominator_part = 'half'
|
||||
denominator_part = "half"
|
||||
else:
|
||||
denominator_part = 'halves'
|
||||
denominator_part = "halves"
|
||||
elif denominator == 1:
|
||||
return f'{numerator_part}'
|
||||
return f"{numerator_part}"
|
||||
else:
|
||||
denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator))
|
||||
if numerator > 1:
|
||||
denominator_part += 's'
|
||||
denominator_part += "s"
|
||||
|
||||
return f"{numerator_part} {denominator_part}"
|
||||
|
||||
return f'{numerator_part} {denominator_part}'
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize(text):
|
||||
@ -238,7 +245,7 @@ def normalize(text):
|
||||
"""
|
||||
|
||||
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
|
||||
text = re.sub(r'(?<!\d)-|-(?!\d)', ' minus ', text)
|
||||
text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_time_re, _expand_time, text)
|
||||
text = re.sub(_measurement_re, _expand_measurement, text)
|
||||
@ -251,19 +258,20 @@ def normalize(text):
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
|
||||
text = ''.join(char for char in unicodedata.normalize('NFD', text)
|
||||
if unicodedata.category(char) != 'Mn') # Strip accents
|
||||
text = "".join(
|
||||
char for char in unicodedata.normalize("NFD", text) if unicodedata.category(char) != "Mn"
|
||||
) # Strip accents
|
||||
|
||||
text = re.sub("%", " percent", text)
|
||||
text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
|
||||
text = re.sub(r"(?i)i\.e\.", "that is", text)
|
||||
text = re.sub(r"(?i)e\.g\.", "for example", text)
|
||||
# 增加纯大写单词拆分
|
||||
text = re.sub(r'(?<!^)(?<![\s])([A-Z])', r' \1', text)
|
||||
text = re.sub(r"(?<!^)(?<![\s])([A-Z])", r" \1", text)
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# 我觉得其实可以把切分结果展示出来(只读,或者修改不影响传给TTS的实际text)
|
||||
# 然后让用户确认后再输入给 TTS,可以让用户检查自己有没有不标准的输入
|
||||
print(normalize("1. test ordinal number 1st"))
|
||||
|
@ -8,10 +8,10 @@ from text.symbols import punctuation
|
||||
|
||||
from text.symbols2 import symbols
|
||||
|
||||
import unicodedata
|
||||
from builtins import str as unicode
|
||||
from text.en_normalization.expend import normalize
|
||||
from nltk.tokenize import TweetTokenizer
|
||||
|
||||
word_tokenize = TweetTokenizer().tokenize
|
||||
from nltk import pos_tag
|
||||
|
||||
@ -122,9 +122,9 @@ def replace_phs(phs):
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}\s])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}\s])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
@ -183,6 +183,7 @@ def read_dict_new():
|
||||
|
||||
return g2p_dict
|
||||
|
||||
|
||||
def hot_reload_hot(g2p_dict):
|
||||
with open(CMU_DICT_HOT_PATH) as f:
|
||||
line = f.readline()
|
||||
@ -259,9 +260,12 @@ class en_G2p(G2p):
|
||||
del self.cmu[word.lower()]
|
||||
|
||||
# 修正多音字
|
||||
self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP')
|
||||
self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ')
|
||||
|
||||
self.homograph2features["read"] = (["R", "IY1", "D"], ["R", "EH1", "D"], "VBP")
|
||||
self.homograph2features["complex"] = (
|
||||
["K", "AH0", "M", "P", "L", "EH1", "K", "S"],
|
||||
["K", "AA1", "M", "P", "L", "EH0", "K", "S"],
|
||||
"JJ",
|
||||
)
|
||||
|
||||
def __call__(self, text):
|
||||
# tokenization
|
||||
@ -280,7 +284,7 @@ class en_G2p(G2p):
|
||||
elif len(word) == 1:
|
||||
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写
|
||||
if o_word == "A":
|
||||
pron = ['EY1']
|
||||
pron = ["EY1"]
|
||||
else:
|
||||
pron = self.cmu[word][0]
|
||||
# g2p_en 原版多音字处理
|
||||
@ -289,7 +293,7 @@ class en_G2p(G2p):
|
||||
if pos.startswith(pos1):
|
||||
pron = pron1
|
||||
# pos1比pos长仅出现在read
|
||||
elif len(pos) < len(pos1) and pos == pos1[:len(pos)]:
|
||||
elif len(pos) < len(pos1) and pos == pos1[: len(pos)]:
|
||||
pron = pron1
|
||||
else:
|
||||
pron = pron2
|
||||
@ -302,7 +306,6 @@ class en_G2p(G2p):
|
||||
|
||||
return prons[:-1]
|
||||
|
||||
|
||||
def qryword(self, o_word):
|
||||
word = o_word.lower()
|
||||
|
||||
@ -320,7 +323,7 @@ class en_G2p(G2p):
|
||||
for w in word:
|
||||
# 单读 A 发音修正, 此处不存在大写的情况
|
||||
if w == "a":
|
||||
phones.extend(['EY1'])
|
||||
phones.extend(["EY1"])
|
||||
elif not w.isalpha():
|
||||
phones.extend([w])
|
||||
else:
|
||||
@ -331,23 +334,23 @@ class en_G2p(G2p):
|
||||
if re.match(r"^([a-z]+)('s)$", word):
|
||||
phones = self.qryword(word[:-2])[:]
|
||||
# P T K F TH HH 无声辅音结尾 's 发 ['S']
|
||||
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
|
||||
phones.extend(['S'])
|
||||
if phones[-1] in ["P", "T", "K", "F", "TH", "HH"]:
|
||||
phones.extend(["S"])
|
||||
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
|
||||
elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']:
|
||||
phones.extend(['AH0', 'Z'])
|
||||
elif phones[-1] in ["S", "Z", "SH", "ZH", "CH", "JH"]:
|
||||
phones.extend(["AH0", "Z"])
|
||||
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
|
||||
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
|
||||
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
|
||||
else:
|
||||
phones.extend(['Z'])
|
||||
phones.extend(["Z"])
|
||||
return phones
|
||||
|
||||
# 尝试进行分词,应对复合词
|
||||
comps = wordsegment.segment(word.lower())
|
||||
|
||||
# 无法分词的送回去预测
|
||||
if len(comps)==1:
|
||||
if len(comps) == 1:
|
||||
return self.predict(word)
|
||||
|
||||
# 可以分词的递归处理
|
||||
|
@ -15,6 +15,7 @@
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
@ -23,21 +24,24 @@ import numpy as np
|
||||
|
||||
from .utils import tokenize_and_map
|
||||
|
||||
ANCHOR_CHAR = '▁'
|
||||
ANCHOR_CHAR = "▁"
|
||||
|
||||
|
||||
def prepare_onnx_input(tokenizer,
|
||||
labels: List[str],
|
||||
char2phonemes: Dict[str, List[int]],
|
||||
chars: List[str],
|
||||
texts: List[str],
|
||||
query_ids: List[int],
|
||||
use_mask: bool=False,
|
||||
window_size: int=None,
|
||||
max_len: int=512) -> Dict[str, np.array]:
|
||||
def prepare_onnx_input(
|
||||
tokenizer,
|
||||
labels: List[str],
|
||||
char2phonemes: Dict[str, List[int]],
|
||||
chars: List[str],
|
||||
texts: List[str],
|
||||
query_ids: List[int],
|
||||
use_mask: bool = False,
|
||||
window_size: int = None,
|
||||
max_len: int = 512,
|
||||
) -> Dict[str, np.array]:
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||
window_size=window_size, texts=texts, query_ids=query_ids)
|
||||
window_size=window_size, texts=texts, query_ids=query_ids
|
||||
)
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
attention_masks = []
|
||||
@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(
|
||||
tokenizer=tokenizer, text=text)
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(
|
||||
max_len=max_len,
|
||||
text=text,
|
||||
query_id=query_id,
|
||||
tokens=tokens,
|
||||
text2token=text2token,
|
||||
token2text=token2text)
|
||||
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
|
||||
)
|
||||
|
||||
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
|
||||
input_id = list(
|
||||
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
|
||||
if use_mask else [1] * len(labels)
|
||||
phoneme_mask = (
|
||||
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
|
||||
)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[
|
||||
query_id] + 1 # [CLS] token locate at first place
|
||||
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
|
||||
position_ids.append(position_id)
|
||||
|
||||
outputs = {
|
||||
'input_ids': np.array(input_ids).astype(np.int64),
|
||||
'token_type_ids': np.array(token_type_ids).astype(np.int64),
|
||||
'attention_masks': np.array(attention_masks).astype(np.int64),
|
||||
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
|
||||
'char_ids': np.array(char_ids).astype(np.int64),
|
||||
'position_ids': np.array(position_ids).astype(np.int64),
|
||||
"input_ids": np.array(input_ids).astype(np.int64),
|
||||
"token_type_ids": np.array(token_type_ids).astype(np.int64),
|
||||
"attention_masks": np.array(attention_masks).astype(np.int64),
|
||||
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
||||
"char_ids": np.array(char_ids).astype(np.int64),
|
||||
"position_ids": np.array(position_ids).astype(np.int64),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
def _truncate_texts(window_size: int, texts: List[str],
|
||||
query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||
truncated_texts = []
|
||||
truncated_query_ids = []
|
||||
for text, query_id in zip(texts, query_ids):
|
||||
@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
|
||||
return truncated_texts, truncated_query_ids
|
||||
|
||||
|
||||
def _truncate(max_len: int,
|
||||
text: str,
|
||||
query_id: int,
|
||||
tokens: List[str],
|
||||
text2token: List[int],
|
||||
token2text: List[Tuple[int]]):
|
||||
def _truncate(
|
||||
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
|
||||
):
|
||||
truncate_len = max_len - 2
|
||||
if len(tokens) <= truncate_len:
|
||||
return (text, query_id, tokens, text2token, token2text)
|
||||
@ -137,14 +131,16 @@ def _truncate(max_len: int,
|
||||
start = token2text[token_start][0]
|
||||
end = token2text[token_end - 1][1]
|
||||
|
||||
return (text[start:end], query_id - start, tokens[token_start:token_end], [
|
||||
i - token_start if i is not None else None
|
||||
for i in text2token[start:end]
|
||||
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
|
||||
return (
|
||||
text[start:end],
|
||||
query_id - start,
|
||||
tokens[token_start:token_end],
|
||||
[i - token_start if i is not None else None for i in text2token[start:end]],
|
||||
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
|
||||
)
|
||||
|
||||
|
||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
return labels, char2phonemes
|
||||
|
||||
|
||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
|
||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(
|
||||
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
if char not in char2phonemes:
|
||||
char2phonemes[char] = []
|
||||
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
|
||||
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
|
||||
return labels, char2phonemes
|
||||
|
@ -17,17 +17,25 @@ PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
|
||||
|
||||
|
||||
class G2PWPinyin(Pinyin):
|
||||
def __init__(self, model_dir='G2PWModel/', model_source=None,
|
||||
enable_non_tradional_chinese=True,
|
||||
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir="G2PWModel/",
|
||||
model_source=None,
|
||||
enable_non_tradional_chinese=True,
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=False,
|
||||
tone_sandhi=False,
|
||||
**kwargs,
|
||||
):
|
||||
self._g2pw = G2PWOnnxConverter(
|
||||
model_dir=model_dir,
|
||||
style='pinyin',
|
||||
style="pinyin",
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
self._converter = Converter(
|
||||
self._g2pw, v_to_u=v_to_u,
|
||||
self._g2pw,
|
||||
v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi,
|
||||
)
|
||||
@ -37,31 +45,25 @@ class G2PWPinyin(Pinyin):
|
||||
|
||||
|
||||
class Converter(UltimateConverter):
|
||||
def __init__(self, g2pw_instance, v_to_u=False,
|
||||
neutral_tone_with_five=False,
|
||||
tone_sandhi=False, **kwargs):
|
||||
def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||
super(Converter, self).__init__(
|
||||
v_to_u=v_to_u,
|
||||
neutral_tone_with_five=neutral_tone_with_five,
|
||||
tone_sandhi=tone_sandhi, **kwargs)
|
||||
v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs
|
||||
)
|
||||
|
||||
self._g2pw = g2pw_instance
|
||||
|
||||
def convert(self, words, style, heteronym, errors, strict, **kwargs):
|
||||
pys = []
|
||||
if RE_HANS.match(words):
|
||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
|
||||
errors=errors, strict=strict)
|
||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict)
|
||||
post_data = self.post_pinyin(words, heteronym, pys)
|
||||
if post_data is not None:
|
||||
pys = post_data
|
||||
|
||||
pys = self.convert_styles(
|
||||
pys, words, style, heteronym, errors, strict)
|
||||
pys = self.convert_styles(pys, words, style, heteronym, errors, strict)
|
||||
|
||||
else:
|
||||
py = self.handle_nopinyin(words, style=style, errors=errors,
|
||||
heteronym=heteronym, strict=strict)
|
||||
py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict)
|
||||
if py:
|
||||
pys.extend(py)
|
||||
|
||||
@ -73,13 +75,11 @@ class Converter(UltimateConverter):
|
||||
g2pw_pinyin = self._g2pw(han)
|
||||
|
||||
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
return super(Converter, self).convert(
|
||||
han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
|
||||
for i, item in enumerate(g2pw_pinyin[0]):
|
||||
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||
py = super(Converter, self).convert(
|
||||
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||
pinyins.extend(py)
|
||||
else:
|
||||
pinyins.append([to_tone(item)])
|
||||
@ -104,7 +104,7 @@ def _remove_dup_and_empty(lst_list):
|
||||
if lst:
|
||||
new_lst_list.append(lst)
|
||||
else:
|
||||
new_lst_list.append([''])
|
||||
new_lst_list.append([""])
|
||||
|
||||
return new_lst_list
|
||||
|
||||
@ -127,17 +127,17 @@ def get_dict():
|
||||
|
||||
def read_dict():
|
||||
polyphonic_dict = {}
|
||||
with open(PP_DICT_PATH,encoding="utf-8") as f:
|
||||
with open(PP_DICT_PATH, encoding="utf-8") as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
key, value_str = line.split(":")
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
with open(PP_FIX_DICT_PATH,encoding="utf-8") as f:
|
||||
with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
|
||||
line = f.readline()
|
||||
while line:
|
||||
key, value_str = line.split(':')
|
||||
key, value_str = line.split(":")
|
||||
value = eval(value_str.strip())
|
||||
polyphonic_dict[key.strip()] = value
|
||||
line = f.readline()
|
||||
|
@ -2,44 +2,43 @@
|
||||
# This code is modified from https://github.com/GitYCC/g2pW
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import json
|
||||
import os
|
||||
import zipfile,requests
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
import zipfile
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import requests
|
||||
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
from opencc import OpenCC
|
||||
from pypinyin import Style, pinyin
|
||||
from transformers import AutoTokenizer
|
||||
from pypinyin import pinyin
|
||||
from pypinyin import Style
|
||||
|
||||
from .dataset import get_char_phoneme_labels
|
||||
from .dataset import get_phoneme_labels
|
||||
from .dataset import prepare_onnx_input
|
||||
from .utils import load_config
|
||||
from ..zh_normalization.char_convert import tranditional_to_simplified
|
||||
from .dataset import get_char_phoneme_labels, get_phoneme_labels, prepare_onnx_input
|
||||
from .utils import load_config
|
||||
|
||||
model_version = '1.1'
|
||||
model_version = "1.1"
|
||||
|
||||
|
||||
def predict(session, onnx_input: Dict[str, Any],
|
||||
labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||
def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||
all_preds = []
|
||||
all_confidences = []
|
||||
probs = session.run([], {
|
||||
"input_ids": onnx_input['input_ids'],
|
||||
"token_type_ids": onnx_input['token_type_ids'],
|
||||
"attention_mask": onnx_input['attention_masks'],
|
||||
"phoneme_mask": onnx_input['phoneme_masks'],
|
||||
"char_ids": onnx_input['char_ids'],
|
||||
"position_ids": onnx_input['position_ids']
|
||||
})[0]
|
||||
probs = session.run(
|
||||
[],
|
||||
{
|
||||
"input_ids": onnx_input["input_ids"],
|
||||
"token_type_ids": onnx_input["token_type_ids"],
|
||||
"attention_mask": onnx_input["attention_masks"],
|
||||
"phoneme_mask": onnx_input["phoneme_masks"],
|
||||
"char_ids": onnx_input["char_ids"],
|
||||
"position_ids": onnx_input["position_ids"],
|
||||
},
|
||||
)[0]
|
||||
|
||||
preds = np.argmax(probs, axis=1).tolist()
|
||||
max_probs = []
|
||||
@ -51,17 +50,17 @@ def predict(session, onnx_input: Dict[str, Any],
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str='G2PWModel/'):
|
||||
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip")
|
||||
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory,"G2PWModel")
|
||||
zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
|
||||
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
|
||||
print("Downloading g2pw model...")
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"#"https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
with requests.get(modelscope_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(zip_dir, 'wb') as f:
|
||||
with open(zip_dir, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
@ -74,12 +73,15 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
def __init__(self,
|
||||
model_dir: str='G2PWModel/',
|
||||
style: str='bopomofo',
|
||||
model_source: str=None,
|
||||
enable_non_tradional_chinese: bool=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
uncompress_path = download_and_decompress(model_dir)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
@ -87,41 +89,59 @@ class G2PWOnnxConverter:
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
try:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
except:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
|
||||
self.config = load_config(
|
||||
config_path=os.path.join(uncompress_path, 'config.py'),
|
||||
use_default=True)
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
|
||||
polyphonic_chars_path = os.path.join(uncompress_path,
|
||||
'POLYPHONIC_CHARS.txt')
|
||||
monophonic_chars_path = os.path.join(uncompress_path,
|
||||
'MONOPHONIC_CHARS.txt')
|
||||
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
|
||||
self.polyphonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(polyphonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
self.non_polyphonic = {
|
||||
'一', '不', '和', '咋', '嗲', '剖', '差', '攢', '倒', '難', '奔', '勁', '拗',
|
||||
'肖', '瘙', '誒', '泊', '听', '噢'
|
||||
"一",
|
||||
"不",
|
||||
"和",
|
||||
"咋",
|
||||
"嗲",
|
||||
"剖",
|
||||
"差",
|
||||
"攢",
|
||||
"倒",
|
||||
"難",
|
||||
"奔",
|
||||
"勁",
|
||||
"拗",
|
||||
"肖",
|
||||
"瘙",
|
||||
"誒",
|
||||
"泊",
|
||||
"听",
|
||||
"噢",
|
||||
}
|
||||
self.non_monophonic = {'似', '攢'}
|
||||
self.non_monophonic = {"似", "攢"}
|
||||
self.monophonic_chars = [
|
||||
line.split('\t')
|
||||
for line in open(monophonic_chars_path, encoding='utf-8').read()
|
||||
.strip().split('\n')
|
||||
line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars
|
||||
) if self.config.use_char_phoneme else get_phoneme_labels(
|
||||
polyphonic_chars=self.polyphonic_chars)
|
||||
self.labels, self.char2phonemes = (
|
||||
get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||
if self.config.use_char_phoneme
|
||||
else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||
)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
|
||||
@ -130,41 +150,29 @@ class G2PWOnnxConverter:
|
||||
if char in self.polyphonic_chars_new:
|
||||
self.polyphonic_chars_new.remove(char)
|
||||
|
||||
self.monophonic_chars_dict = {
|
||||
char: phoneme
|
||||
for char, phoneme in self.monophonic_chars
|
||||
}
|
||||
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
|
||||
for char in self.non_monophonic:
|
||||
if char in self.monophonic_chars_dict:
|
||||
self.monophonic_chars_dict.pop(char)
|
||||
|
||||
self.pos_tags = [
|
||||
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
||||
]
|
||||
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path,
|
||||
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
'bopomofo': lambda x: x,
|
||||
'pinyin': self._convert_bopomofo_to_pinyin,
|
||||
"bopomofo": lambda x: x,
|
||||
"pinyin": self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(
|
||||
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
||||
'r',
|
||||
encoding='utf-8') as fr:
|
||||
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC('s2tw')
|
||||
self.cc = OpenCC("s2tw")
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||
tone = bopomofo[-1]
|
||||
assert tone in '12345'
|
||||
assert tone in "12345"
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
@ -184,8 +192,7 @@ class G2PWOnnxConverter:
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
||||
sentences=sentences)
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
@ -198,14 +205,12 @@ class G2PWOnnxConverter:
|
||||
texts=texts,
|
||||
query_ids=query_ids,
|
||||
use_mask=self.config.use_mask,
|
||||
window_size=None)
|
||||
window_size=None,
|
||||
)
|
||||
|
||||
preds, confidences = predict(
|
||||
session=self.session_g2pW,
|
||||
onnx_input=onnx_input,
|
||||
labels=self.labels)
|
||||
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(' ')[1] for pred in preds]
|
||||
preds = [pred.split(" ")[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
@ -213,15 +218,12 @@ class G2PWOnnxConverter:
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||
sent_s = tranditional_to_simplified(sent)
|
||||
pypinyin_result = pinyin(
|
||||
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
for i, char in enumerate(sent):
|
||||
if char in self.polyphonic_chars_new:
|
||||
@ -229,8 +231,7 @@ class G2PWOnnxConverter:
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
elif char in self.monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(
|
||||
self.monophonic_chars_dict[char])
|
||||
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
|
@ -15,6 +15,7 @@
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
@ -24,14 +25,14 @@ def wordize_and_map(text: str):
|
||||
index_map_from_text_to_word = []
|
||||
index_map_from_word_to_text = []
|
||||
while len(text) > 0:
|
||||
match_space = re.match(r'^ +', text)
|
||||
match_space = re.match(r"^ +", text)
|
||||
if match_space:
|
||||
space_str = match_space.group(0)
|
||||
index_map_from_text_to_word += [None] * len(space_str)
|
||||
text = text[len(space_str):]
|
||||
text = text[len(space_str) :]
|
||||
continue
|
||||
|
||||
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
||||
match_en = re.match(r"^[a-zA-Z0-9]+", text)
|
||||
if match_en:
|
||||
en_word = match_en.group(0)
|
||||
|
||||
@ -42,7 +43,7 @@ def wordize_and_map(text: str):
|
||||
index_map_from_text_to_word += [len(words)] * len(en_word)
|
||||
|
||||
words.append(en_word)
|
||||
text = text[len(en_word):]
|
||||
text = text[len(en_word) :]
|
||||
else:
|
||||
word_start_pos = len(index_map_from_text_to_word)
|
||||
word_end_pos = word_start_pos + 1
|
||||
@ -63,15 +64,14 @@ def tokenize_and_map(tokenizer, text: str):
|
||||
for word, (word_start, word_end) in zip(words, word2text):
|
||||
word_tokens = tokenizer.tokenize(word)
|
||||
|
||||
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
||||
if len(word_tokens) == 0 or word_tokens == ["[UNK]"]:
|
||||
index_map_from_token_to_text.append((word_start, word_end))
|
||||
tokens.append('[UNK]')
|
||||
tokens.append("[UNK]")
|
||||
else:
|
||||
current_word_start = word_start
|
||||
for word_token in word_tokens:
|
||||
word_token_len = len(re.sub(r'^##', '', word_token))
|
||||
index_map_from_token_to_text.append(
|
||||
(current_word_start, current_word_start + word_token_len))
|
||||
word_token_len = len(re.sub(r"^##", "", word_token))
|
||||
index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len))
|
||||
current_word_start = current_word_start + word_token_len
|
||||
tokens.append(word_token)
|
||||
|
||||
@ -85,53 +85,51 @@ def tokenize_and_map(tokenizer, text: str):
|
||||
|
||||
def _load_config(config_path: os.PathLike):
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
||||
|
||||
spec = importlib.util.spec_from_file_location("__init__", config_path)
|
||||
config = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config)
|
||||
return config
|
||||
|
||||
|
||||
default_config_dict = {
|
||||
'manual_seed': 1313,
|
||||
'model_source': 'bert-base-chinese',
|
||||
'window_size': 32,
|
||||
'num_workers': 2,
|
||||
'use_mask': True,
|
||||
'use_char_phoneme': False,
|
||||
'use_conditional': True,
|
||||
'param_conditional': {
|
||||
'affect_location': 'softmax',
|
||||
'bias': True,
|
||||
'char-linear': True,
|
||||
'pos-linear': False,
|
||||
'char+pos-second': True,
|
||||
'char+pos-second_lowrank': False,
|
||||
'lowrank_size': 0,
|
||||
'char+pos-second_fm': False,
|
||||
'fm_size': 0,
|
||||
'fix_mode': None,
|
||||
'count_json': 'train.count.json'
|
||||
"manual_seed": 1313,
|
||||
"model_source": "bert-base-chinese",
|
||||
"window_size": 32,
|
||||
"num_workers": 2,
|
||||
"use_mask": True,
|
||||
"use_char_phoneme": False,
|
||||
"use_conditional": True,
|
||||
"param_conditional": {
|
||||
"affect_location": "softmax",
|
||||
"bias": True,
|
||||
"char-linear": True,
|
||||
"pos-linear": False,
|
||||
"char+pos-second": True,
|
||||
"char+pos-second_lowrank": False,
|
||||
"lowrank_size": 0,
|
||||
"char+pos-second_fm": False,
|
||||
"fm_size": 0,
|
||||
"fix_mode": None,
|
||||
"count_json": "train.count.json",
|
||||
},
|
||||
'lr': 5e-5,
|
||||
'val_interval': 200,
|
||||
'num_iter': 10000,
|
||||
'use_focal': False,
|
||||
'param_focal': {
|
||||
'alpha': 0.0,
|
||||
'gamma': 0.7
|
||||
"lr": 5e-5,
|
||||
"val_interval": 200,
|
||||
"num_iter": 10000,
|
||||
"use_focal": False,
|
||||
"param_focal": {"alpha": 0.0, "gamma": 0.7},
|
||||
"use_pos": True,
|
||||
"param_pos ": {
|
||||
"weight": 0.1,
|
||||
"pos_joint_training": True,
|
||||
"train_pos_path": "train.pos",
|
||||
"valid_pos_path": "dev.pos",
|
||||
"test_pos_path": "test.pos",
|
||||
},
|
||||
'use_pos': True,
|
||||
'param_pos ': {
|
||||
'weight': 0.1,
|
||||
'pos_joint_training': True,
|
||||
'train_pos_path': 'train.pos',
|
||||
'valid_pos_path': 'dev.pos',
|
||||
'test_pos_path': 'test.pos'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_config(config_path: os.PathLike, use_default: bool=False):
|
||||
def load_config(config_path: os.PathLike, use_default: bool = False):
|
||||
config = _load_config(config_path)
|
||||
if use_default:
|
||||
for attr, val in default_config_dict.items():
|
||||
|
@ -2,43 +2,51 @@
|
||||
import re
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
import pyopenjtalk
|
||||
|
||||
current_file_path = os.path.dirname(__file__)
|
||||
|
||||
# 防止win下无法读取模型
|
||||
if os.name == 'nt':
|
||||
if os.name == "nt":
|
||||
python_dir = os.getcwd()
|
||||
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
|
||||
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
|
||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)):
|
||||
if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper():
|
||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
|
||||
else:
|
||||
import shutil
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||
os.mkdir(os.path.join("TEMP", "ja"))
|
||||
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
|
||||
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
|
||||
shutil.copytree(pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), os.path.join("TEMP", "ja", "open_jtalk_dic"), )
|
||||
shutil.copytree(
|
||||
pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"),
|
||||
os.path.join("TEMP", "ja", "open_jtalk_dic"),
|
||||
)
|
||||
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
|
||||
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
|
||||
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
|
||||
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
|
||||
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)):
|
||||
if current_file_path[: len(python_dir)].upper() == python_dir.upper():
|
||||
current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
|
||||
else:
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||
os.mkdir(os.path.join("TEMP", "ja"))
|
||||
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
|
||||
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
|
||||
shutil.copyfile(os.path.join(current_file_path, "ja_userdic", "userdict.csv"),os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"))
|
||||
shutil.copyfile(
|
||||
os.path.join(current_file_path, "ja_userdic", "userdict.csv"),
|
||||
os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"),
|
||||
)
|
||||
current_file_path = os.path.join("TEMP", "ja")
|
||||
|
||||
|
||||
def get_hash(fp: str) -> str:
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(fp, "rb") as f:
|
||||
@ -51,21 +59,26 @@ try:
|
||||
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
|
||||
# 如果没有用户词典,就生成一个;如果有,就检查md5,如果不一样,就重新生成
|
||||
if os.path.exists(USERDIC_CSV_PATH):
|
||||
if not os.path.exists(USERDIC_BIN_PATH) or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r",encoding='utf-8').read():
|
||||
if (
|
||||
not os.path.exists(USERDIC_BIN_PATH)
|
||||
or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read()
|
||||
):
|
||||
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
|
||||
with open(USERDIC_HASH_PATH, "w", encoding='utf-8') as f:
|
||||
with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(get_hash(USERDIC_CSV_PATH))
|
||||
|
||||
if os.path.exists(USERDIC_BIN_PATH):
|
||||
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# print(e)
|
||||
import pyopenjtalk
|
||||
|
||||
# failed to load user dictionary, ignore.
|
||||
pass
|
||||
|
||||
|
||||
from text.symbols import punctuation
|
||||
|
||||
# Regular expression matching Japanese without punctuation marks:
|
||||
_japanese_characters = re.compile(
|
||||
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
||||
@ -123,9 +136,9 @@ def post_replace_ph(ph):
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
@ -152,7 +165,7 @@ def preprocess_jap(text, with_prosody=False):
|
||||
text += p.split(" ")
|
||||
|
||||
if i < len(marks):
|
||||
if marks[i] == " ":# 防止意外的UNK
|
||||
if marks[i] == " ": # 防止意外的UNK
|
||||
continue
|
||||
text += [marks[i].replace(" ", "")]
|
||||
return text
|
||||
@ -165,6 +178,7 @@ def text_normalize(text):
|
||||
text = replace_consecutive_punctuation(text)
|
||||
return text
|
||||
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
|
||||
@ -241,6 +255,7 @@ def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||
|
||||
return phones
|
||||
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
def _numeric_feature_by_regex(regex, s):
|
||||
match = re.search(regex, s)
|
||||
@ -248,6 +263,7 @@ def _numeric_feature_by_regex(regex, s):
|
||||
return -50
|
||||
return int(match.group(1))
|
||||
|
||||
|
||||
def g2p(norm_text, with_prosody=True):
|
||||
phones = preprocess_jap(norm_text, with_prosody)
|
||||
phones = [post_replace_ph(i) for i in phones]
|
||||
|
@ -9,39 +9,43 @@ import importlib
|
||||
import os
|
||||
|
||||
# 防止win下无法读取模型
|
||||
if os.name == 'nt':
|
||||
if os.name == "nt":
|
||||
|
||||
class win_G2p(G2p):
|
||||
def check_mecab(self):
|
||||
super().check_mecab()
|
||||
spam_spec = importlib.util.find_spec("eunjeon")
|
||||
non_found = spam_spec is None
|
||||
if non_found:
|
||||
print(f'you have to install eunjeon. install it...')
|
||||
print("you have to install eunjeon. install it...")
|
||||
else:
|
||||
installpath = spam_spec.submodule_search_locations[0]
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
||||
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
|
||||
import sys
|
||||
from eunjeon import Mecab as _Mecab
|
||||
|
||||
class Mecab(_Mecab):
|
||||
def get_dicpath(installpath):
|
||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
||||
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
|
||||
import shutil
|
||||
python_dir = os.getcwd()
|
||||
if (installpath[:len(python_dir)].upper() == python_dir.upper()):
|
||||
dicpath = os.path.join(os.path.relpath(installpath,python_dir),'data','mecabrc')
|
||||
else:
|
||||
if not os.path.exists('TEMP'):
|
||||
os.mkdir('TEMP')
|
||||
if not os.path.exists(os.path.join('TEMP', 'ko')):
|
||||
os.mkdir(os.path.join('TEMP', 'ko'))
|
||||
if os.path.exists(os.path.join('TEMP', 'ko', 'ko_dict')):
|
||||
shutil.rmtree(os.path.join('TEMP', 'ko', 'ko_dict'))
|
||||
|
||||
shutil.copytree(os.path.join(installpath, 'data'), os.path.join('TEMP', 'ko', 'ko_dict'))
|
||||
dicpath = os.path.join('TEMP', 'ko', 'ko_dict', 'mecabrc')
|
||||
python_dir = os.getcwd()
|
||||
if installpath[: len(python_dir)].upper() == python_dir.upper():
|
||||
dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc")
|
||||
else:
|
||||
if not os.path.exists("TEMP"):
|
||||
os.mkdir("TEMP")
|
||||
if not os.path.exists(os.path.join("TEMP", "ko")):
|
||||
os.mkdir(os.path.join("TEMP", "ko"))
|
||||
if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")):
|
||||
shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict"))
|
||||
|
||||
shutil.copytree(
|
||||
os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict")
|
||||
)
|
||||
dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc")
|
||||
else:
|
||||
dicpath=os.path.abspath(os.path.join(installpath, 'data/mecabrc'))
|
||||
dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc"))
|
||||
return dicpath
|
||||
|
||||
def __init__(self, dicpath=get_dicpath(installpath)):
|
||||
@ -55,94 +59,105 @@ if os.name == 'nt':
|
||||
from text.symbols2 import symbols
|
||||
|
||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
||||
_korean_classifiers = (
|
||||
"군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통"
|
||||
)
|
||||
|
||||
# List of (hangul, hangul divided) pairs:
|
||||
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
||||
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
|
||||
# ('ㄵ', 'ㄴㅈ'),
|
||||
# ('ㄶ', 'ㄴㅎ'),
|
||||
# ('ㄺ', 'ㄹㄱ'),
|
||||
# ('ㄻ', 'ㄹㅁ'),
|
||||
# ('ㄼ', 'ㄹㅂ'),
|
||||
# ('ㄽ', 'ㄹㅅ'),
|
||||
# ('ㄾ', 'ㄹㅌ'),
|
||||
# ('ㄿ', 'ㄹㅍ'),
|
||||
# ('ㅀ', 'ㄹㅎ'),
|
||||
# ('ㅄ', 'ㅂㅅ'),
|
||||
('ㅘ', 'ㅗㅏ'),
|
||||
('ㅙ', 'ㅗㅐ'),
|
||||
('ㅚ', 'ㅗㅣ'),
|
||||
('ㅝ', 'ㅜㅓ'),
|
||||
('ㅞ', 'ㅜㅔ'),
|
||||
('ㅟ', 'ㅜㅣ'),
|
||||
('ㅢ', 'ㅡㅣ'),
|
||||
('ㅑ', 'ㅣㅏ'),
|
||||
('ㅒ', 'ㅣㅐ'),
|
||||
('ㅕ', 'ㅣㅓ'),
|
||||
('ㅖ', 'ㅣㅔ'),
|
||||
('ㅛ', 'ㅣㅗ'),
|
||||
('ㅠ', 'ㅣㅜ')
|
||||
]]
|
||||
_hangul_divided = [
|
||||
(re.compile("%s" % x[0]), x[1])
|
||||
for x in [
|
||||
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
|
||||
# ('ㄵ', 'ㄴㅈ'),
|
||||
# ('ㄶ', 'ㄴㅎ'),
|
||||
# ('ㄺ', 'ㄹㄱ'),
|
||||
# ('ㄻ', 'ㄹㅁ'),
|
||||
# ('ㄼ', 'ㄹㅂ'),
|
||||
# ('ㄽ', 'ㄹㅅ'),
|
||||
# ('ㄾ', 'ㄹㅌ'),
|
||||
# ('ㄿ', 'ㄹㅍ'),
|
||||
# ('ㅀ', 'ㄹㅎ'),
|
||||
# ('ㅄ', 'ㅂㅅ'),
|
||||
("ㅘ", "ㅗㅏ"),
|
||||
("ㅙ", "ㅗㅐ"),
|
||||
("ㅚ", "ㅗㅣ"),
|
||||
("ㅝ", "ㅜㅓ"),
|
||||
("ㅞ", "ㅜㅔ"),
|
||||
("ㅟ", "ㅜㅣ"),
|
||||
("ㅢ", "ㅡㅣ"),
|
||||
("ㅑ", "ㅣㅏ"),
|
||||
("ㅒ", "ㅣㅐ"),
|
||||
("ㅕ", "ㅣㅓ"),
|
||||
("ㅖ", "ㅣㅔ"),
|
||||
("ㅛ", "ㅣㅗ"),
|
||||
("ㅠ", "ㅣㅜ"),
|
||||
]
|
||||
]
|
||||
|
||||
# List of (Latin alphabet, hangul) pairs:
|
||||
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('a', '에이'),
|
||||
('b', '비'),
|
||||
('c', '시'),
|
||||
('d', '디'),
|
||||
('e', '이'),
|
||||
('f', '에프'),
|
||||
('g', '지'),
|
||||
('h', '에이치'),
|
||||
('i', '아이'),
|
||||
('j', '제이'),
|
||||
('k', '케이'),
|
||||
('l', '엘'),
|
||||
('m', '엠'),
|
||||
('n', '엔'),
|
||||
('o', '오'),
|
||||
('p', '피'),
|
||||
('q', '큐'),
|
||||
('r', '아르'),
|
||||
('s', '에스'),
|
||||
('t', '티'),
|
||||
('u', '유'),
|
||||
('v', '브이'),
|
||||
('w', '더블유'),
|
||||
('x', '엑스'),
|
||||
('y', '와이'),
|
||||
('z', '제트')
|
||||
]]
|
||||
_latin_to_hangul = [
|
||||
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("a", "에이"),
|
||||
("b", "비"),
|
||||
("c", "시"),
|
||||
("d", "디"),
|
||||
("e", "이"),
|
||||
("f", "에프"),
|
||||
("g", "지"),
|
||||
("h", "에이치"),
|
||||
("i", "아이"),
|
||||
("j", "제이"),
|
||||
("k", "케이"),
|
||||
("l", "엘"),
|
||||
("m", "엠"),
|
||||
("n", "엔"),
|
||||
("o", "오"),
|
||||
("p", "피"),
|
||||
("q", "큐"),
|
||||
("r", "아르"),
|
||||
("s", "에스"),
|
||||
("t", "티"),
|
||||
("u", "유"),
|
||||
("v", "브이"),
|
||||
("w", "더블유"),
|
||||
("x", "엑스"),
|
||||
("y", "와이"),
|
||||
("z", "제트"),
|
||||
]
|
||||
]
|
||||
|
||||
# List of (ipa, lazy ipa) pairs:
|
||||
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('t͡ɕ','ʧ'),
|
||||
('d͡ʑ','ʥ'),
|
||||
('ɲ','n^'),
|
||||
('ɕ','ʃ'),
|
||||
('ʷ','w'),
|
||||
('ɭ','l`'),
|
||||
('ʎ','ɾ'),
|
||||
('ɣ','ŋ'),
|
||||
('ɰ','ɯ'),
|
||||
('ʝ','j'),
|
||||
('ʌ','ə'),
|
||||
('ɡ','g'),
|
||||
('\u031a','#'),
|
||||
('\u0348','='),
|
||||
('\u031e',''),
|
||||
('\u0320',''),
|
||||
('\u0339','')
|
||||
]]
|
||||
_ipa_to_lazy_ipa = [
|
||||
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("t͡ɕ", "ʧ"),
|
||||
("d͡ʑ", "ʥ"),
|
||||
("ɲ", "n^"),
|
||||
("ɕ", "ʃ"),
|
||||
("ʷ", "w"),
|
||||
("ɭ", "l`"),
|
||||
("ʎ", "ɾ"),
|
||||
("ɣ", "ŋ"),
|
||||
("ɰ", "ɯ"),
|
||||
("ʝ", "j"),
|
||||
("ʌ", "ə"),
|
||||
("ɡ", "g"),
|
||||
("\u031a", "#"),
|
||||
("\u0348", "="),
|
||||
("\u031e", ""),
|
||||
("\u0320", ""),
|
||||
("\u0339", ""),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def fix_g2pk2_error(text):
|
||||
new_text = ""
|
||||
i = 0
|
||||
while i < len(text) - 4:
|
||||
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ':
|
||||
new_text += text[i:i+3] + ' ' + 'ㄴ'
|
||||
if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "ㄹ":
|
||||
new_text += text[i : i + 3] + " " + "ㄴ"
|
||||
i += 5
|
||||
else:
|
||||
new_text += text[i]
|
||||
@ -166,20 +181,20 @@ def divide_hangul(text):
|
||||
|
||||
|
||||
def hangul_number(num, sino=True):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
num = re.sub(',', '', num)
|
||||
"""Reference https://github.com/Kyubyong/g2pK"""
|
||||
num = re.sub(",", "", num)
|
||||
|
||||
if num == '0':
|
||||
return '영'
|
||||
if not sino and num == '20':
|
||||
return '스무'
|
||||
if num == "0":
|
||||
return "영"
|
||||
if not sino and num == "20":
|
||||
return "스무"
|
||||
|
||||
digits = '123456789'
|
||||
names = '일이삼사오육칠팔구'
|
||||
digits = "123456789"
|
||||
names = "일이삼사오육칠팔구"
|
||||
digit2name = {d: n for d, n in zip(digits, names)}
|
||||
|
||||
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
|
||||
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
|
||||
modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉"
|
||||
decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔"
|
||||
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
||||
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
||||
|
||||
@ -188,75 +203,75 @@ def hangul_number(num, sino=True):
|
||||
i = len(num) - i - 1
|
||||
if sino:
|
||||
if i == 0:
|
||||
name = digit2name.get(digit, '')
|
||||
name = digit2name.get(digit, "")
|
||||
elif i == 1:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
name = name.replace("일십", "십")
|
||||
else:
|
||||
if i == 0:
|
||||
name = digit2mod.get(digit, '')
|
||||
name = digit2mod.get(digit, "")
|
||||
elif i == 1:
|
||||
name = digit2dec.get(digit, '')
|
||||
if digit == '0':
|
||||
name = digit2dec.get(digit, "")
|
||||
if digit == "0":
|
||||
if i % 4 == 0:
|
||||
last_three = spelledout[-min(3, len(spelledout)):]
|
||||
if ''.join(last_three) == '':
|
||||
spelledout.append('')
|
||||
last_three = spelledout[-min(3, len(spelledout)) :]
|
||||
if "".join(last_three) == "":
|
||||
spelledout.append("")
|
||||
continue
|
||||
else:
|
||||
spelledout.append('')
|
||||
spelledout.append("")
|
||||
continue
|
||||
if i == 2:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
name = name.replace("일백", "백")
|
||||
elif i == 3:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
name = name.replace("일천", "천")
|
||||
elif i == 4:
|
||||
name = digit2name.get(digit, '') + '만'
|
||||
name = name.replace('일만', '만')
|
||||
name = digit2name.get(digit, "") + "만"
|
||||
name = name.replace("일만", "만")
|
||||
elif i == 5:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
name = name.replace("일십", "십")
|
||||
elif i == 6:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
name = name.replace("일백", "백")
|
||||
elif i == 7:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
name = name.replace("일천", "천")
|
||||
elif i == 8:
|
||||
name = digit2name.get(digit, '') + '억'
|
||||
name = digit2name.get(digit, "") + "억"
|
||||
elif i == 9:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
elif i == 10:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
elif i == 11:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
elif i == 12:
|
||||
name = digit2name.get(digit, '') + '조'
|
||||
name = digit2name.get(digit, "") + "조"
|
||||
elif i == 13:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = digit2name.get(digit, "") + "십"
|
||||
elif i == 14:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = digit2name.get(digit, "") + "백"
|
||||
elif i == 15:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = digit2name.get(digit, "") + "천"
|
||||
spelledout.append(name)
|
||||
return ''.join(elem for elem in spelledout)
|
||||
return "".join(elem for elem in spelledout)
|
||||
|
||||
|
||||
def number_to_hangul(text):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
|
||||
"""Reference https://github.com/Kyubyong/g2pK"""
|
||||
tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text))
|
||||
for token in tokens:
|
||||
num, classifier = token
|
||||
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
||||
spelledout = hangul_number(num, sino=False)
|
||||
else:
|
||||
spelledout = hangul_number(num, sino=True)
|
||||
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
|
||||
text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}")
|
||||
# digit by digit for remaining digits
|
||||
digits = '0123456789'
|
||||
names = '영일이삼사오육칠팔구'
|
||||
digits = "0123456789"
|
||||
names = "영일이삼사오육칠팔구"
|
||||
for d, n in zip(digits, names):
|
||||
text = text.replace(d, n)
|
||||
return text
|
||||
@ -265,19 +280,23 @@ def number_to_hangul(text):
|
||||
def korean_to_lazy_ipa(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
|
||||
text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text)
|
||||
for regex, replacement in _ipa_to_lazy_ipa:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
_g2p=G2p()
|
||||
|
||||
_g2p = G2p()
|
||||
|
||||
|
||||
def korean_to_ipa(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text = _g2p(text)
|
||||
text = fix_g2pk2_error(text)
|
||||
text = korean_to_lazy_ipa(text)
|
||||
return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
|
||||
return text.replace("ʧ", "tʃ").replace("ʥ", "dʑ")
|
||||
|
||||
|
||||
def post_replace_ph(ph):
|
||||
rep_map = {
|
||||
@ -301,12 +320,13 @@ def post_replace_ph(ph):
|
||||
ph = "停"
|
||||
return ph
|
||||
|
||||
|
||||
def g2p(text):
|
||||
text = latin_to_hangul(text)
|
||||
text = _g2p(text)
|
||||
text = divide_hangul(text)
|
||||
text = fix_g2pk2_error(text)
|
||||
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
|
||||
text = re.sub(r"([\u3131-\u3163])$", r"\1.", text)
|
||||
# text = "".join([post_replace_ph(i) for i in text])
|
||||
text = [post_replace_ph(i) for i in text]
|
||||
return text
|
||||
|
@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
||||
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
||||
punctuation.append("-")
|
||||
|
@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
||||
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
||||
punctuation.append("-")
|
||||
@ -396,24 +394,404 @@ arpa = {
|
||||
"SH",
|
||||
}
|
||||
|
||||
ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停'
|
||||
ko_symbols = "ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停"
|
||||
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
||||
|
||||
yue_symbols={'Yeot3', 'Yip1', 'Yyu3', 'Yeng4', 'Yut5', 'Yaan5', 'Ym5', 'Yaan6', 'Yang1', 'Yun4', 'Yon2', 'Yui5', 'Yun2', 'Yat3', 'Ye', 'Yeot1', 'Yoeng5', 'Yoek2', 'Yam2', 'Yeon6', 'Yu6', 'Yiu3', 'Yaang6', 'Yp5', 'Yai4', 'Yoek4', 'Yit6', 'Yam5', 'Yoeng6', 'Yg1', 'Yk3', 'Yoe4', 'Yam3', 'Yc', 'Yyu4', 'Yyut1', 'Yiu4', 'Ying3', 'Yip3', 'Yaap3', 'Yau3', 'Yan4', 'Yau1', 'Yap4', 'Yk6', 'Yok3', 'Yai1', 'Yeot6', 'Yan2', 'Yoek6', 'Yt1', 'Yoi1', 'Yit5', 'Yn4', 'Yaau3', 'Yau4', 'Yuk6', 'Ys', 'Yuk', 'Yin6', 'Yung6', 'Ya', 'You', 'Yaai5', 'Yau5', 'Yoi3', 'Yaak3', 'Yaat3', 'Ying2', 'Yok5', 'Yeng2', 'Yyut3', 'Yam1', 'Yip5', 'You1', 'Yam6', 'Yaa5', 'Yi6', 'Yek4', 'Yyu2', 'Yuk5', 'Yaam1', 'Yang2', 'Yai', 'Yiu6', 'Yin4', 'Yok4', 'Yot3', 'Yui2', 'Yeoi5', 'Yyun6', 'Yyu5', 'Yoi5', 'Yeot2', 'Yim4', 'Yeoi2', 'Yaan1', 'Yang6', 'Yong1', 'Yaang4', 'Yung5', 'Yeon1', 'Yin2', 'Ya3', 'Yaang3', 'Yg', 'Yk2', 'Yaau5', 'Yut1', 'Yt5', 'Yip4', 'Yung4', 'Yj', 'Yong3', 'Ya1', 'Yg6', 'Yaau6', 'Yit3', 'Yun3', 'Ying1', 'Yn2', 'Yg4', 'Yl', 'Yp3', 'Yn3', 'Yak1', 'Yang5', 'Yoe6', 'You2', 'Yap2', 'Yak2', 'Yt3', 'Yot5', 'Yim2', 'Yi1', 'Yn6', 'Yaat5', 'Yaam3', 'Yoek5', 'Ye3', 'Yeon4', 'Yaa2', 'Yu3', 'Yim6', 'Ym', 'Yoe3', 'Yaai2', 'Ym2', 'Ya6', 'Yeng6', 'Yik4', 'Yot4', 'Yaai4', 'Yyun3', 'Yu1', 'Yoeng1', 'Yaap2', 'Yuk3', 'Yoek3', 'Yeng5', 'Yeoi1', 'Yiu2', 'Yok1', 'Yo1', 'Yoek1', 'Yoeng2', 'Yeon5', 'Yiu1', 'Yoeng4', 'Yuk2', 'Yat4', 'Yg5', 'Yut4', 'Yan6', 'Yin3', 'Yaa6', 'Yap1', 'Yg2', 'Yoe5', 'Yt4', 'Ya5', 'Yo4', 'Yyu1', 'Yak3', 'Yeon2', 'Yong4', 'Ym1', 'Ye2', 'Yaang5', 'Yoi2', 'Yeng3', 'Yn', 'Yyut4', 'Yau', 'Yaak2', 'Yaan4', 'Yek2', 'Yin1', 'Yi5', 'Yoe2', 'Yei5', 'Yaat6', 'Yak5', 'Yp6', 'Yok6', 'Yei2', 'Yaap1', 'Yyut5', 'Yi4', 'Yim1', 'Yk5', 'Ye4', 'Yok2', 'Yaam6', 'Yat2', 'Yon6', 'Yei3', 'Yyu6', 'Yeot5', 'Yk4', 'Yai6', 'Yd', 'Yg3', 'Yei6', 'Yau2', 'Yok', 'Yau6', 'Yung3', 'Yim5', 'Yut6', 'Yit1', 'Yon3', 'Yat1', 'Yaam2', 'Yyut2', 'Yui6', 'Yt2', 'Yek6', 'Yt', 'Ye6', 'Yang3', 'Ying6', 'Yaau1', 'Yeon3', 'Yng', 'Yh', 'Yang4', 'Ying5', 'Yaap6', 'Yoeng3', 'Yyun4', 'You3', 'Yan5', 'Yat5', 'Yot1', 'Yun1', 'Yi3', 'Yaa1', 'Yaap4', 'You6', 'Yaang2', 'Yaap5', 'Yaa3', 'Yaak6', 'Yeng1', 'Yaak1', 'Yo5', 'Yoi4', 'Yam4', 'Yik1', 'Ye1', 'Yai5', 'Yung1', 'Yp2', 'Yui4', 'Yaak4', 'Yung2', 'Yak4', 'Yaat4', 'Yeoi4', 'Yut2', 'Yin5', 'Yaau4', 'Yap6', 'Yb', 'Yaam4', 'Yw', 'Yut3', 'Yong2', 'Yt6', 'Yaai6', 'Yap5', 'Yik5', 'Yun6', 'Yaam5', 'Yun5', 'Yik3', 'Ya2', 'Yyut6', 'Yon4', 'Yk1', 'Yit4', 'Yak6', 'Yaan2', 'Yuk1', 'Yai2', 'Yik2', 'Yaat2', 'Yo3', 'Ykw', 'Yn5', 'Yaa', 'Ye5', 'Yu4', 'Yei1', 'Yai3', 'Yyun5', 'Yip2', 'Yaau2', 'Yiu5', 'Ym4', 'Yeoi6', 'Yk', 'Ym6', 'Yoe1', 'Yeoi3', 'Yon', 'Yuk4', 'Yaai3', 'Yaa4', 'Yot6', 'Yaang1', 'Yei4', 'Yek1', 'Yo', 'Yp', 'Yo6', 'Yp4', 'Yan3', 'Yoi', 'Yap3', 'Yek3', 'Yim3', 'Yz', 'Yot2', 'Yoi6', 'Yit2', 'Yu5', 'Yaan3', 'Yan1', 'Yon5', 'Yp1', 'Yong5', 'Ygw', 'Yak', 'Yat6', 'Ying4', 'Yu2', 'Yf', 'Ya4', 'Yon1', 'You4', 'Yik6', 'Yui1', 'Yaat1', 'Yeot4', 'Yi2', 'Yaai1', 'Yek5', 'Ym3', 'Yong6', 'You5', 'Yyun1', 'Yn1', 'Yo2', 'Yip6', 'Yui3', 'Yaak5', 'Yyun2'}
|
||||
yue_symbols = {
|
||||
"Yeot3",
|
||||
"Yip1",
|
||||
"Yyu3",
|
||||
"Yeng4",
|
||||
"Yut5",
|
||||
"Yaan5",
|
||||
"Ym5",
|
||||
"Yaan6",
|
||||
"Yang1",
|
||||
"Yun4",
|
||||
"Yon2",
|
||||
"Yui5",
|
||||
"Yun2",
|
||||
"Yat3",
|
||||
"Ye",
|
||||
"Yeot1",
|
||||
"Yoeng5",
|
||||
"Yoek2",
|
||||
"Yam2",
|
||||
"Yeon6",
|
||||
"Yu6",
|
||||
"Yiu3",
|
||||
"Yaang6",
|
||||
"Yp5",
|
||||
"Yai4",
|
||||
"Yoek4",
|
||||
"Yit6",
|
||||
"Yam5",
|
||||
"Yoeng6",
|
||||
"Yg1",
|
||||
"Yk3",
|
||||
"Yoe4",
|
||||
"Yam3",
|
||||
"Yc",
|
||||
"Yyu4",
|
||||
"Yyut1",
|
||||
"Yiu4",
|
||||
"Ying3",
|
||||
"Yip3",
|
||||
"Yaap3",
|
||||
"Yau3",
|
||||
"Yan4",
|
||||
"Yau1",
|
||||
"Yap4",
|
||||
"Yk6",
|
||||
"Yok3",
|
||||
"Yai1",
|
||||
"Yeot6",
|
||||
"Yan2",
|
||||
"Yoek6",
|
||||
"Yt1",
|
||||
"Yoi1",
|
||||
"Yit5",
|
||||
"Yn4",
|
||||
"Yaau3",
|
||||
"Yau4",
|
||||
"Yuk6",
|
||||
"Ys",
|
||||
"Yuk",
|
||||
"Yin6",
|
||||
"Yung6",
|
||||
"Ya",
|
||||
"You",
|
||||
"Yaai5",
|
||||
"Yau5",
|
||||
"Yoi3",
|
||||
"Yaak3",
|
||||
"Yaat3",
|
||||
"Ying2",
|
||||
"Yok5",
|
||||
"Yeng2",
|
||||
"Yyut3",
|
||||
"Yam1",
|
||||
"Yip5",
|
||||
"You1",
|
||||
"Yam6",
|
||||
"Yaa5",
|
||||
"Yi6",
|
||||
"Yek4",
|
||||
"Yyu2",
|
||||
"Yuk5",
|
||||
"Yaam1",
|
||||
"Yang2",
|
||||
"Yai",
|
||||
"Yiu6",
|
||||
"Yin4",
|
||||
"Yok4",
|
||||
"Yot3",
|
||||
"Yui2",
|
||||
"Yeoi5",
|
||||
"Yyun6",
|
||||
"Yyu5",
|
||||
"Yoi5",
|
||||
"Yeot2",
|
||||
"Yim4",
|
||||
"Yeoi2",
|
||||
"Yaan1",
|
||||
"Yang6",
|
||||
"Yong1",
|
||||
"Yaang4",
|
||||
"Yung5",
|
||||
"Yeon1",
|
||||
"Yin2",
|
||||
"Ya3",
|
||||
"Yaang3",
|
||||
"Yg",
|
||||
"Yk2",
|
||||
"Yaau5",
|
||||
"Yut1",
|
||||
"Yt5",
|
||||
"Yip4",
|
||||
"Yung4",
|
||||
"Yj",
|
||||
"Yong3",
|
||||
"Ya1",
|
||||
"Yg6",
|
||||
"Yaau6",
|
||||
"Yit3",
|
||||
"Yun3",
|
||||
"Ying1",
|
||||
"Yn2",
|
||||
"Yg4",
|
||||
"Yl",
|
||||
"Yp3",
|
||||
"Yn3",
|
||||
"Yak1",
|
||||
"Yang5",
|
||||
"Yoe6",
|
||||
"You2",
|
||||
"Yap2",
|
||||
"Yak2",
|
||||
"Yt3",
|
||||
"Yot5",
|
||||
"Yim2",
|
||||
"Yi1",
|
||||
"Yn6",
|
||||
"Yaat5",
|
||||
"Yaam3",
|
||||
"Yoek5",
|
||||
"Ye3",
|
||||
"Yeon4",
|
||||
"Yaa2",
|
||||
"Yu3",
|
||||
"Yim6",
|
||||
"Ym",
|
||||
"Yoe3",
|
||||
"Yaai2",
|
||||
"Ym2",
|
||||
"Ya6",
|
||||
"Yeng6",
|
||||
"Yik4",
|
||||
"Yot4",
|
||||
"Yaai4",
|
||||
"Yyun3",
|
||||
"Yu1",
|
||||
"Yoeng1",
|
||||
"Yaap2",
|
||||
"Yuk3",
|
||||
"Yoek3",
|
||||
"Yeng5",
|
||||
"Yeoi1",
|
||||
"Yiu2",
|
||||
"Yok1",
|
||||
"Yo1",
|
||||
"Yoek1",
|
||||
"Yoeng2",
|
||||
"Yeon5",
|
||||
"Yiu1",
|
||||
"Yoeng4",
|
||||
"Yuk2",
|
||||
"Yat4",
|
||||
"Yg5",
|
||||
"Yut4",
|
||||
"Yan6",
|
||||
"Yin3",
|
||||
"Yaa6",
|
||||
"Yap1",
|
||||
"Yg2",
|
||||
"Yoe5",
|
||||
"Yt4",
|
||||
"Ya5",
|
||||
"Yo4",
|
||||
"Yyu1",
|
||||
"Yak3",
|
||||
"Yeon2",
|
||||
"Yong4",
|
||||
"Ym1",
|
||||
"Ye2",
|
||||
"Yaang5",
|
||||
"Yoi2",
|
||||
"Yeng3",
|
||||
"Yn",
|
||||
"Yyut4",
|
||||
"Yau",
|
||||
"Yaak2",
|
||||
"Yaan4",
|
||||
"Yek2",
|
||||
"Yin1",
|
||||
"Yi5",
|
||||
"Yoe2",
|
||||
"Yei5",
|
||||
"Yaat6",
|
||||
"Yak5",
|
||||
"Yp6",
|
||||
"Yok6",
|
||||
"Yei2",
|
||||
"Yaap1",
|
||||
"Yyut5",
|
||||
"Yi4",
|
||||
"Yim1",
|
||||
"Yk5",
|
||||
"Ye4",
|
||||
"Yok2",
|
||||
"Yaam6",
|
||||
"Yat2",
|
||||
"Yon6",
|
||||
"Yei3",
|
||||
"Yyu6",
|
||||
"Yeot5",
|
||||
"Yk4",
|
||||
"Yai6",
|
||||
"Yd",
|
||||
"Yg3",
|
||||
"Yei6",
|
||||
"Yau2",
|
||||
"Yok",
|
||||
"Yau6",
|
||||
"Yung3",
|
||||
"Yim5",
|
||||
"Yut6",
|
||||
"Yit1",
|
||||
"Yon3",
|
||||
"Yat1",
|
||||
"Yaam2",
|
||||
"Yyut2",
|
||||
"Yui6",
|
||||
"Yt2",
|
||||
"Yek6",
|
||||
"Yt",
|
||||
"Ye6",
|
||||
"Yang3",
|
||||
"Ying6",
|
||||
"Yaau1",
|
||||
"Yeon3",
|
||||
"Yng",
|
||||
"Yh",
|
||||
"Yang4",
|
||||
"Ying5",
|
||||
"Yaap6",
|
||||
"Yoeng3",
|
||||
"Yyun4",
|
||||
"You3",
|
||||
"Yan5",
|
||||
"Yat5",
|
||||
"Yot1",
|
||||
"Yun1",
|
||||
"Yi3",
|
||||
"Yaa1",
|
||||
"Yaap4",
|
||||
"You6",
|
||||
"Yaang2",
|
||||
"Yaap5",
|
||||
"Yaa3",
|
||||
"Yaak6",
|
||||
"Yeng1",
|
||||
"Yaak1",
|
||||
"Yo5",
|
||||
"Yoi4",
|
||||
"Yam4",
|
||||
"Yik1",
|
||||
"Ye1",
|
||||
"Yai5",
|
||||
"Yung1",
|
||||
"Yp2",
|
||||
"Yui4",
|
||||
"Yaak4",
|
||||
"Yung2",
|
||||
"Yak4",
|
||||
"Yaat4",
|
||||
"Yeoi4",
|
||||
"Yut2",
|
||||
"Yin5",
|
||||
"Yaau4",
|
||||
"Yap6",
|
||||
"Yb",
|
||||
"Yaam4",
|
||||
"Yw",
|
||||
"Yut3",
|
||||
"Yong2",
|
||||
"Yt6",
|
||||
"Yaai6",
|
||||
"Yap5",
|
||||
"Yik5",
|
||||
"Yun6",
|
||||
"Yaam5",
|
||||
"Yun5",
|
||||
"Yik3",
|
||||
"Ya2",
|
||||
"Yyut6",
|
||||
"Yon4",
|
||||
"Yk1",
|
||||
"Yit4",
|
||||
"Yak6",
|
||||
"Yaan2",
|
||||
"Yuk1",
|
||||
"Yai2",
|
||||
"Yik2",
|
||||
"Yaat2",
|
||||
"Yo3",
|
||||
"Ykw",
|
||||
"Yn5",
|
||||
"Yaa",
|
||||
"Ye5",
|
||||
"Yu4",
|
||||
"Yei1",
|
||||
"Yai3",
|
||||
"Yyun5",
|
||||
"Yip2",
|
||||
"Yaau2",
|
||||
"Yiu5",
|
||||
"Ym4",
|
||||
"Yeoi6",
|
||||
"Yk",
|
||||
"Ym6",
|
||||
"Yoe1",
|
||||
"Yeoi3",
|
||||
"Yon",
|
||||
"Yuk4",
|
||||
"Yaai3",
|
||||
"Yaa4",
|
||||
"Yot6",
|
||||
"Yaang1",
|
||||
"Yei4",
|
||||
"Yek1",
|
||||
"Yo",
|
||||
"Yp",
|
||||
"Yo6",
|
||||
"Yp4",
|
||||
"Yan3",
|
||||
"Yoi",
|
||||
"Yap3",
|
||||
"Yek3",
|
||||
"Yim3",
|
||||
"Yz",
|
||||
"Yot2",
|
||||
"Yoi6",
|
||||
"Yit2",
|
||||
"Yu5",
|
||||
"Yaan3",
|
||||
"Yan1",
|
||||
"Yon5",
|
||||
"Yp1",
|
||||
"Yong5",
|
||||
"Ygw",
|
||||
"Yak",
|
||||
"Yat6",
|
||||
"Ying4",
|
||||
"Yu2",
|
||||
"Yf",
|
||||
"Ya4",
|
||||
"Yon1",
|
||||
"You4",
|
||||
"Yik6",
|
||||
"Yui1",
|
||||
"Yaat1",
|
||||
"Yeot4",
|
||||
"Yi2",
|
||||
"Yaai1",
|
||||
"Yek5",
|
||||
"Ym3",
|
||||
"Yong6",
|
||||
"You5",
|
||||
"Yyun1",
|
||||
"Yn1",
|
||||
"Yo2",
|
||||
"Yip6",
|
||||
"Yui3",
|
||||
"Yaak5",
|
||||
"Yyun2",
|
||||
}
|
||||
|
||||
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
|
||||
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
|
||||
symbols = sorted(set(symbols))
|
||||
# print(len(symbols))
|
||||
symbols+=["[","]"]##日文新增上升下降调型
|
||||
symbols+=sorted(list(ko_symbols))
|
||||
symbols+=sorted(list(yue_symbols))##新加的yue统一摆在后头#已查过开头加Y后没有重复,韩文显然不会重复
|
||||
symbols += ["[", "]"] ##日文新增上升下降调型
|
||||
symbols += sorted(list(ko_symbols))
|
||||
symbols += sorted(list(yue_symbols)) ##新加的yue统一摆在后头#已查过开头加Y后没有重复,韩文显然不会重复
|
||||
# print(len(symbols))
|
||||
if __name__ == "__main__":
|
||||
print(len(symbols))
|
||||
'''
|
||||
"""
|
||||
粤语:
|
||||
732-353=379
|
||||
韩文+粤语:
|
||||
732-322=410
|
||||
'''
|
||||
"""
|
||||
|
@ -510,12 +510,7 @@ class ToneSandhi:
|
||||
# e.g. 走了, 看着, 去过
|
||||
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
elif (
|
||||
len(word) > 1
|
||||
and word[-1] in "们子"
|
||||
and pos in {"r", "n"}
|
||||
and word not in self.must_not_neural_tone_words
|
||||
):
|
||||
elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"} and word not in self.must_not_neural_tone_words:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
# e.g. 桌上, 地下, 家里
|
||||
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
||||
@ -525,25 +520,18 @@ class ToneSandhi:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
# 个做量词
|
||||
elif (
|
||||
ge_idx >= 1
|
||||
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
|
||||
ge_idx >= 1 and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
|
||||
) or word == "个":
|
||||
finals[ge_idx] = finals[ge_idx][:-1] + "5"
|
||||
else:
|
||||
if (
|
||||
word in self.must_neural_tone_words
|
||||
or word[-2:] in self.must_neural_tone_words
|
||||
):
|
||||
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
|
||||
word_list = self._split_word(word)
|
||||
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
|
||||
for i, word in enumerate(word_list):
|
||||
# conventional neural in Chinese
|
||||
if (
|
||||
word in self.must_neural_tone_words
|
||||
or word[-2:] in self.must_neural_tone_words
|
||||
):
|
||||
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
|
||||
finals = sum(finals_list, [])
|
||||
return finals
|
||||
@ -561,9 +549,7 @@ class ToneSandhi:
|
||||
|
||||
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||
# "一" in number sequences, e.g. 一零零, 二一零
|
||||
if word.find("一") != -1 and all(
|
||||
[item.isnumeric() for item in word if item != "一"]
|
||||
):
|
||||
if word.find("一") != -1 and all([item.isnumeric() for item in word if item != "一"]):
|
||||
return finals
|
||||
# "一" between reduplication words shold be yi5, e.g. 看一看
|
||||
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
||||
@ -697,13 +683,10 @@ class ToneSandhi:
|
||||
return new_seg
|
||||
|
||||
# the first and the second words are all_tone_three
|
||||
def _merge_continuous_three_tones(
|
||||
self, seg: List[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str]]:
|
||||
def _merge_continuous_three_tones(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
new_seg = []
|
||||
sub_finals_list = [
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for (word, pos) in seg
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
|
||||
]
|
||||
assert len(sub_finals_list) == len(seg)
|
||||
merge_last = [False] * len(seg)
|
||||
@ -715,10 +698,7 @@ class ToneSandhi:
|
||||
and not merge_last[i - 1]
|
||||
):
|
||||
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||
if (
|
||||
not self._is_reduplication(seg[i - 1][0])
|
||||
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
||||
):
|
||||
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||
merge_last[i] = True
|
||||
else:
|
||||
@ -732,13 +712,10 @@ class ToneSandhi:
|
||||
return len(word) == 2 and word[0] == word[1]
|
||||
|
||||
# the last char of first word and the first char of second word is tone_three
|
||||
def _merge_continuous_three_tones_2(
|
||||
self, seg: List[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str]]:
|
||||
def _merge_continuous_three_tones_2(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
new_seg = []
|
||||
sub_finals_list = [
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for (word, pos) in seg
|
||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
|
||||
]
|
||||
assert len(sub_finals_list) == len(seg)
|
||||
merge_last = [False] * len(seg)
|
||||
@ -750,10 +727,7 @@ class ToneSandhi:
|
||||
and not merge_last[i - 1]
|
||||
):
|
||||
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||
if (
|
||||
not self._is_reduplication(seg[i - 1][0])
|
||||
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
||||
):
|
||||
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||
merge_last[i] = True
|
||||
else:
|
||||
|
File diff suppressed because one or more lines are too long
@ -21,25 +21,29 @@ from .num import verbalize_digit
|
||||
|
||||
def _time_num2str(num_string: str) -> str:
|
||||
"""A special case for verbalizing number in time."""
|
||||
result = num2str(num_string.lstrip('0'))
|
||||
if num_string.startswith('0'):
|
||||
result = DIGITS['0'] + result
|
||||
result = num2str(num_string.lstrip("0"))
|
||||
if num_string.startswith("0"):
|
||||
result = DIGITS["0"] + result
|
||||
return result
|
||||
|
||||
|
||||
# 时刻表达式
|
||||
RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?')
|
||||
RE_TIME = re.compile(
|
||||
r"([0-1]?[0-9]|2[0-3])"
|
||||
r":([0-5][0-9])"
|
||||
r"(:([0-5][0-9]))?"
|
||||
)
|
||||
|
||||
# 时间范围,如8:30-12:30
|
||||
RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?'
|
||||
r'(~|-)'
|
||||
r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?')
|
||||
RE_TIME_RANGE = re.compile(
|
||||
r"([0-1]?[0-9]|2[0-3])"
|
||||
r":([0-5][0-9])"
|
||||
r"(:([0-5][0-9]))?"
|
||||
r"(~|-)"
|
||||
r"([0-1]?[0-9]|2[0-3])"
|
||||
r":([0-5][0-9])"
|
||||
r"(:([0-5][0-9]))?"
|
||||
)
|
||||
|
||||
|
||||
def replace_time(match) -> str:
|
||||
@ -62,31 +66,33 @@ def replace_time(match) -> str:
|
||||
second_2 = match.group(9)
|
||||
|
||||
result = f"{num2str(hour)}点"
|
||||
if minute.lstrip('0'):
|
||||
if minute.lstrip("0"):
|
||||
if int(minute) == 30:
|
||||
result += "半"
|
||||
else:
|
||||
result += f"{_time_num2str(minute)}分"
|
||||
if second and second.lstrip('0'):
|
||||
if second and second.lstrip("0"):
|
||||
result += f"{_time_num2str(second)}秒"
|
||||
|
||||
if is_range:
|
||||
result += "至"
|
||||
result += f"{num2str(hour_2)}点"
|
||||
if minute_2.lstrip('0'):
|
||||
if minute_2.lstrip("0"):
|
||||
if int(minute) == 30:
|
||||
result += "半"
|
||||
else:
|
||||
result += f"{_time_num2str(minute_2)}分"
|
||||
if second_2 and second_2.lstrip('0'):
|
||||
if second_2 and second_2.lstrip("0"):
|
||||
result += f"{_time_num2str(second_2)}秒"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
RE_DATE = re.compile(r'(\d{4}|\d{2})年'
|
||||
r'((0?[1-9]|1[0-2])月)?'
|
||||
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
|
||||
RE_DATE = re.compile(
|
||||
r"(\d{4}|\d{2})年"
|
||||
r"((0?[1-9]|1[0-2])月)?"
|
||||
r"(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?"
|
||||
)
|
||||
|
||||
|
||||
def replace_date(match) -> str:
|
||||
@ -110,8 +116,7 @@ def replace_date(match) -> str:
|
||||
|
||||
|
||||
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
|
||||
RE_DATE2 = re.compile(
|
||||
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
|
||||
RE_DATE2 = re.compile(r"(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])")
|
||||
|
||||
|
||||
def replace_date2(match) -> str:
|
||||
|
@ -18,10 +18,7 @@ from pypinyin.constants import SUPPORT_UCS4
|
||||
|
||||
# 全角半角转换
|
||||
# 英文字符全角 -> 半角映射表 (num: 52)
|
||||
F2H_ASCII_LETTERS = {
|
||||
ord(char) + 65248: ord(char)
|
||||
for char in string.ascii_letters
|
||||
}
|
||||
F2H_ASCII_LETTERS = {ord(char) + 65248: ord(char) for char in string.ascii_letters}
|
||||
|
||||
# 英文字符半角 -> 全角映射表
|
||||
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
|
||||
@ -37,26 +34,29 @@ F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
|
||||
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
|
||||
|
||||
# 空格 (num: 1)
|
||||
F2H_SPACE = {'\u3000': ' '}
|
||||
H2F_SPACE = {' ': '\u3000'}
|
||||
F2H_SPACE = {"\u3000": " "}
|
||||
H2F_SPACE = {" ": "\u3000"}
|
||||
|
||||
# 非"有拼音的汉字"的字符串,可用于NSW提取
|
||||
if SUPPORT_UCS4:
|
||||
RE_NSW = re.compile(r'(?:[^'
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
|
||||
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
|
||||
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
|
||||
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
|
||||
r'])+')
|
||||
RE_NSW = re.compile(
|
||||
r"(?:[^"
|
||||
r"\u3007" # 〇
|
||||
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
|
||||
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
|
||||
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
|
||||
r"\U00020000-\U0002A6DF" # CJK扩展B:[20000-2A6DF]
|
||||
r"\U0002A703-\U0002B73F" # CJK扩展C:[2A700-2B73F]
|
||||
r"\U0002B740-\U0002B81D" # CJK扩展D:[2B740-2B81D]
|
||||
r"\U0002F80A-\U0002FA1F" # CJK兼容扩展:[2F800-2FA1F]
|
||||
r"])+"
|
||||
)
|
||||
else:
|
||||
RE_NSW = re.compile( # pragma: no cover
|
||||
r'(?:[^'
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'])+')
|
||||
r"(?:[^"
|
||||
r"\u3007" # 〇
|
||||
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
|
||||
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
|
||||
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
|
||||
r"])+"
|
||||
)
|
||||
|
@ -15,23 +15,26 @@
|
||||
Rules to verbalize numbers into Chinese characters.
|
||||
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
||||
UNITS = OrderedDict({
|
||||
1: '十',
|
||||
2: '百',
|
||||
3: '千',
|
||||
4: '万',
|
||||
8: '亿',
|
||||
})
|
||||
DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")}
|
||||
UNITS = OrderedDict(
|
||||
{
|
||||
1: "十",
|
||||
2: "百",
|
||||
3: "千",
|
||||
4: "万",
|
||||
8: "亿",
|
||||
}
|
||||
)
|
||||
|
||||
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
||||
COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
||||
|
||||
# 分数表达式
|
||||
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
||||
RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)")
|
||||
|
||||
|
||||
def replace_frac(match) -> str:
|
||||
@ -52,7 +55,7 @@ def replace_frac(match) -> str:
|
||||
|
||||
|
||||
# 百分数表达式
|
||||
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
||||
RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%")
|
||||
|
||||
|
||||
def replace_percentage(match) -> str:
|
||||
@ -72,7 +75,7 @@ def replace_percentage(match) -> str:
|
||||
|
||||
# 整数表达式
|
||||
# 带负号的整数 -10
|
||||
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
||||
RE_INTEGER = re.compile(r"(-)" r"(\d+)")
|
||||
|
||||
|
||||
def replace_negative_num(match) -> str:
|
||||
@ -92,7 +95,7 @@ def replace_negative_num(match) -> str:
|
||||
|
||||
# 编号-无符号整形
|
||||
# 00078
|
||||
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
|
||||
RE_DEFAULT_NUM = re.compile(r"\d{3}\d*")
|
||||
|
||||
|
||||
def replace_default_num(match):
|
||||
@ -110,15 +113,11 @@ def replace_default_num(match):
|
||||
# RE_ASMD = re.compile(
|
||||
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
||||
RE_ASMD = re.compile(
|
||||
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
|
||||
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
|
||||
)
|
||||
|
||||
asmd_map = {"+": "加", "-": "减", "×": "乘", "÷": "除", "=": "等于"}
|
||||
|
||||
asmd_map = {
|
||||
'+': '加',
|
||||
'-': '减',
|
||||
'×': '乘',
|
||||
'÷': '除',
|
||||
'=': '等于'
|
||||
}
|
||||
|
||||
def replace_asmd(match) -> str:
|
||||
"""
|
||||
@ -132,24 +131,25 @@ def replace_asmd(match) -> str:
|
||||
|
||||
|
||||
# 次方专项
|
||||
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
|
||||
RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+")
|
||||
|
||||
power_map = {
|
||||
'⁰': '0',
|
||||
'¹': '1',
|
||||
'²': '2',
|
||||
'³': '3',
|
||||
'⁴': '4',
|
||||
'⁵': '5',
|
||||
'⁶': '6',
|
||||
'⁷': '7',
|
||||
'⁸': '8',
|
||||
'⁹': '9',
|
||||
'ˣ': 'x',
|
||||
'ʸ': 'y',
|
||||
'ⁿ': 'n'
|
||||
"⁰": "0",
|
||||
"¹": "1",
|
||||
"²": "2",
|
||||
"³": "3",
|
||||
"⁴": "4",
|
||||
"⁵": "5",
|
||||
"⁶": "6",
|
||||
"⁷": "7",
|
||||
"⁸": "8",
|
||||
"⁹": "9",
|
||||
"ˣ": "x",
|
||||
"ʸ": "y",
|
||||
"ⁿ": "n",
|
||||
}
|
||||
|
||||
|
||||
def replace_power(match) -> str:
|
||||
"""
|
||||
Args:
|
||||
@ -166,10 +166,10 @@ def replace_power(match) -> str:
|
||||
|
||||
# 数字表达式
|
||||
# 纯小数
|
||||
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
||||
RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))")
|
||||
# 正整数 + 量词
|
||||
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
||||
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
||||
RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))")
|
||||
|
||||
|
||||
def replace_positive_quantifier(match) -> str:
|
||||
@ -220,7 +220,9 @@ RE_RANGE = re.compile(
|
||||
[-~] # 匹配范围分隔符
|
||||
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
||||
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
||||
""", re.VERBOSE)
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def replace_range(match) -> str:
|
||||
@ -239,7 +241,9 @@ def replace_range(match) -> str:
|
||||
|
||||
# ~至表达式
|
||||
RE_TO_RANGE = re.compile(
|
||||
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
|
||||
r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)"
|
||||
)
|
||||
|
||||
|
||||
def replace_to_range(match) -> str:
|
||||
"""
|
||||
@ -248,71 +252,66 @@ def replace_to_range(match) -> str:
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
result = match.group(0).replace('~', '至')
|
||||
result = match.group(0).replace("~", "至")
|
||||
return result
|
||||
|
||||
|
||||
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
|
||||
stripped = value_string.lstrip('0')
|
||||
def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
|
||||
stripped = value_string.lstrip("0")
|
||||
if len(stripped) == 0:
|
||||
return []
|
||||
elif len(stripped) == 1:
|
||||
if use_zero and len(stripped) < len(value_string):
|
||||
return [DIGITS['0'], DIGITS[stripped]]
|
||||
return [DIGITS["0"], DIGITS[stripped]]
|
||||
else:
|
||||
return [DIGITS[stripped]]
|
||||
else:
|
||||
largest_unit = next(
|
||||
power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||
first_part = value_string[:-largest_unit]
|
||||
second_part = value_string[-largest_unit:]
|
||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
|
||||
second_part)
|
||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
|
||||
|
||||
|
||||
def verbalize_cardinal(value_string: str) -> str:
|
||||
if not value_string:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
# 000 -> '零' , 0 -> '零'
|
||||
value_string = value_string.lstrip('0')
|
||||
value_string = value_string.lstrip("0")
|
||||
if len(value_string) == 0:
|
||||
return DIGITS['0']
|
||||
return DIGITS["0"]
|
||||
|
||||
result_symbols = _get_value(value_string)
|
||||
# verbalized number starting with '一十*' is abbreviated as `十*`
|
||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
|
||||
'1'] and result_symbols[1] == UNITS[1]:
|
||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS["1"] and result_symbols[1] == UNITS[1]:
|
||||
result_symbols = result_symbols[1:]
|
||||
return ''.join(result_symbols)
|
||||
return "".join(result_symbols)
|
||||
|
||||
|
||||
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
||||
result_symbols = [DIGITS[digit] for digit in value_string]
|
||||
result = ''.join(result_symbols)
|
||||
result = "".join(result_symbols)
|
||||
if alt_one:
|
||||
result = result.replace("一", "幺")
|
||||
return result
|
||||
|
||||
|
||||
def num2str(value_string: str) -> str:
|
||||
integer_decimal = value_string.split('.')
|
||||
integer_decimal = value_string.split(".")
|
||||
if len(integer_decimal) == 1:
|
||||
integer = integer_decimal[0]
|
||||
decimal = ''
|
||||
decimal = ""
|
||||
elif len(integer_decimal) == 2:
|
||||
integer, decimal = integer_decimal
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The value string: '${value_string}' has more than one point in it."
|
||||
)
|
||||
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
|
||||
|
||||
result = verbalize_cardinal(integer)
|
||||
|
||||
decimal = decimal.rstrip('0')
|
||||
decimal = decimal.rstrip("0")
|
||||
if decimal:
|
||||
# '.22' is verbalized as '零点二二'
|
||||
# '3.20' is verbalized as '三点二
|
||||
result = result if result else "零"
|
||||
result += '点' + verbalize_digit(decimal)
|
||||
result += "点" + verbalize_digit(decimal)
|
||||
return result
|
||||
|
@ -21,10 +21,8 @@ from .num import verbalize_digit
|
||||
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
||||
# 联通:130、131、132、156、155、186、185、176
|
||||
# 电信:133、153、189、180、181、177
|
||||
RE_MOBILE_PHONE = re.compile(
|
||||
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||
RE_TELEPHONE = re.compile(
|
||||
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
||||
RE_MOBILE_PHONE = re.compile(r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||
RE_TELEPHONE = re.compile(r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
||||
|
||||
# 全国统一的号码400开头
|
||||
RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
||||
@ -32,14 +30,12 @@ RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
||||
|
||||
def phone2str(phone_string: str, mobile=True) -> str:
|
||||
if mobile:
|
||||
sp_parts = phone_string.strip('+').split()
|
||||
result = ','.join(
|
||||
[verbalize_digit(part, alt_one=True) for part in sp_parts])
|
||||
sp_parts = phone_string.strip("+").split()
|
||||
result = ",".join([verbalize_digit(part, alt_one=True) for part in sp_parts])
|
||||
return result
|
||||
else:
|
||||
sil_parts = phone_string.split('-')
|
||||
result = ','.join(
|
||||
[verbalize_digit(part, alt_one=True) for part in sil_parts])
|
||||
sil_parts = phone_string.split("-")
|
||||
result = ",".join([verbalize_digit(part, alt_one=True) for part in sil_parts])
|
||||
return result
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from .num import num2str
|
||||
|
||||
# 温度表达式,温度会影响负号的读法
|
||||
# -3°C 零下三度
|
||||
RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
|
||||
RE_TEMPERATURE = re.compile(r"(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)")
|
||||
measure_dict = {
|
||||
"cm2": "平方厘米",
|
||||
"cm²": "平方厘米",
|
||||
@ -35,7 +35,7 @@ measure_dict = {
|
||||
"ml": "毫升",
|
||||
"m": "米",
|
||||
"mm": "毫米",
|
||||
"s": "秒"
|
||||
"s": "秒",
|
||||
}
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user