ruff format --line-length 120 --target-version py39

This commit is contained in:
XXXXRT666 2025-04-01 11:14:56 +01:00
parent a893a4e283
commit dec3df3282
130 changed files with 7986 additions and 6424 deletions

View File

@ -1,5 +1,8 @@
# Download moda ASR related models
from modelscope import snapshot_download
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
model_dir = snapshot_download(
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4"
)
model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4")
model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4")

View File

@ -4,14 +4,11 @@ import itertools
import math
import random
from random import shuffle
from typing import Iterator
from typing import Optional
from typing import TypeVar
from typing import Iterator, Optional, TypeVar
import torch
import torch.distributed as dist
from torch.utils.data import Dataset
from torch.utils.data import Sampler
from torch.utils.data import Dataset, Sampler
__all__ = [
"DistributedBucketSampler",
@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if torch.cuda.is_available():
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas
len(self.dataset) / self.num_replicas,
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
grouped_batch_size = self.batch_size * self.num_replicas
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
shuffle(batches)
indices = list(itertools.chain(*batches))
else:
@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]

View File

@ -1,9 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
# reference: https://github.com/lifeiteng/vall-e
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
batch_size = (
self.config["train"]["batch_size"] // 2
if self.config["train"].get("if_dpo", False) is True
else self.config["train"]["batch_size"]
)
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,

View File

@ -2,18 +2,16 @@
# reference: https://github.com/lifeiteng/vall-e
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback
import os
from typing import Dict
from typing import List
import traceback
from typing import Dict, List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
version = os.environ.get('version',None)
version = os.environ.get("version", None)
from text import cleaned_text_to_sequence
@ -32,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
@ -59,12 +55,16 @@ class Text2SemanticDataset(Dataset):
super().__init__()
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
semantic_path,
delimiter="\t",
encoding="utf-8",
)
# get dict
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.dirname(phoneme_path)
os.path.dirname(
phoneme_path,
)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
@ -125,7 +125,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data.iloc[i,0]
item_name = self.semantic_data.iloc[i, 0]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@ -135,7 +135,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data.iloc[i,1]
semantic_str = self.semantic_data.iloc[i, 1]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
@ -156,9 +156,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@ -167,9 +165,7 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@ -192,12 +188,12 @@ class Text2SemanticDataset(Dataset):
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
print(
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
)
if num_deleted_ps > 0:
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
)
"""
there are 31 semantic datas not in phoneme datas
@ -304,7 +300,10 @@ if __name__ == "__main__":
batch_size = 12
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False,
)
for i, batch in enumerate(dataloader):
if i % 1000 == 0:

View File

@ -9,10 +9,12 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
@ -24,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
)
)
if is_train:
@ -36,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
@ -114,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@ -9,6 +9,7 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
@ -25,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
),
)
if is_train:
self.automatic_optimization = False
@ -80,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@ -2,25 +2,24 @@
# reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask, make_pad_mask_left
from AR.models.utils import (
topk_sampling,
sample,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
make_pad_mask_left,
make_reject_y,
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
@ -34,10 +33,17 @@ default_config = {
"EOS": 1024,
}
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
@ -57,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_mask[attn_mask != float("-inf")] = 0
attn_mask[attn_mask == float("-inf")] = 1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
@ -112,7 +119,11 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
def to_mask(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor],
):
if padding_mask is None:
return x
@ -121,9 +132,13 @@ class T2SBlock:
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
@ -147,9 +162,7 @@ class T2SBlock:
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
@ -160,7 +173,14 @@ class T2SBlock:
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
@ -174,7 +194,6 @@ class T2SBlock:
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
@ -185,7 +204,11 @@ class T2SBlock:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
@ -200,17 +223,19 @@ class T2SBlock:
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_)
@ -218,14 +243,17 @@ class T2STransformer:
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
self,
x: torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask : torch.Tensor=None,
torch_sdpa:bool=True
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
)
return x, k_cache, v_cache
@ -247,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
self.embedding_dim,
self.phoneme_vocab_size,
self.p_dropout,
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
self.embedding_dim,
self.vocab_size,
self.p_dropout,
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.h = TransformerEncoder(
@ -291,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
layer.linear2.bias,
)
block = T2SBlock(
@ -307,7 +345,7 @@ class Text2SemanticDecoder(nn.Module):
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
layer.norm2.eps,
)
blocks.append(block)
@ -385,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
x, x_lens, reject_y, reject_y_lens, bert_feature
)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
@ -506,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
@ -540,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel_batch_infer(
self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
@ -561,10 +595,19 @@ class Text2SemanticDecoder(nn.Module):
):
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
return self.infer_panel_naive_batched(
x,
x_lens,
prompts,
bert_feature,
top_k=top_k,
top_p=top_p,
early_stop_num=early_stop_num,
temperature=temperature,
**kwargs,
)
max_len = kwargs.get("max_len",x_lens.max())
max_len = kwargs.get("max_len", x_lens.max())
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
@ -572,10 +615,11 @@ class Text2SemanticDecoder(nn.Module):
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
x_item = (
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
) ### padding left
x_list.append(x_item)
x:torch.Tensor = torch.stack(x_list, dim=0)
x: torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
@ -592,12 +636,10 @@ class Text2SemanticDecoder(nn.Module):
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
@ -619,7 +661,7 @@ class Text2SemanticDecoder(nn.Module):
value=False,
)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
@ -637,10 +679,9 @@ class Text2SemanticDecoder(nn.Module):
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
@ -653,25 +694,22 @@ class Text2SemanticDecoder(nn.Module):
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode #####
y_list = [None]*y.shape[0]
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
idx_list = [None] * y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1]
else:
attn_mask = F.pad(attn_mask,(0,1),value=False)
attn_mask = F.pad(attn_mask, (0, 1), value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
@ -682,13 +720,12 @@ class Text2SemanticDecoder(nn.Module):
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0]==self.EOS
l2 = tokens==self.EOS
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0] == self.EOS
l2 = tokens == self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
@ -702,13 +739,12 @@ class Text2SemanticDecoder(nn.Module):
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
@ -720,7 +756,7 @@ class Text2SemanticDecoder(nn.Module):
stop = True
if stop:
if y.shape[1]==0:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
@ -728,34 +764,38 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if (None in idx_list):
if None in idx_list:
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, [0] * x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_naive_batched(self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
def infer_panel_naive_batched(
self,
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
y, idx = self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
@ -764,7 +804,8 @@ class Text2SemanticDecoder(nn.Module):
early_stop_num,
temperature,
repetition_penalty,
**kwargs)
**kwargs,
)
y_list.append(y[0])
idx_list.append(idx)
@ -772,16 +813,16 @@ class Text2SemanticDecoder(nn.Module):
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -826,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.num_head, -1, -1)
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
@ -838,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
@ -868,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
def infer_panel(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)

View File

@ -1,16 +1,13 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from AR.modules.embedding_onnx import SinePositionalEmbedding
from AR.modules.embedding_onnx import TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm
from AR.modules.transformer_onnx import TransformerEncoder
from AR.modules.transformer_onnx import TransformerEncoderLayer
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
@ -25,12 +22,13 @@ default_config = {
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
def logits_to_probs(
logits,
previous_tokens = None,
previous_tokens=None,
temperature: float = 1.0,
top_k = None,
top_p = None,
top_k=None,
top_p=None,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
@ -38,19 +36,27 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@ -66,7 +72,7 @@ def logits_to_probs(
def multinomial_sample_one_no_sync(
probs_sort
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@ -78,7 +84,9 @@ def sample(
**sampling_kwargs,
):
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
logits=logits,
previous_tokens=previous_tokens,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
@ -98,8 +106,18 @@ class OnnxEncoder(nn.Module):
class T2SFirstStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@ -113,8 +131,8 @@ class T2SFirstStageDecoder(nn.Module):
def forward(self, x, prompt):
y = prompt
x_example = x[:,:,0] * 0.0
#N, 1, 512
x_example = x[:, :, 0] * 0.0
# N, 1, 512
cache = {
"all_stage": self.num_layers,
"k": None,
@ -131,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:,:,0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
torch.ones_like(
y_example.transpose(0, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
@ -144,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["k"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
@ -159,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
class T2SStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@ -183,14 +221,18 @@ class T2SStageDecoder(nn.Module):
}
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
[
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
y_example = y_pos[:,:,0] * 0.0
y_example = y_pos[:, :, 0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
@ -249,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.first_stage_decoder = T2SFirstStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature):
early_stop_num = self.early_stop_num
@ -285,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_example = x[:,:,0] * 0.0
x_example = x[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
@ -302,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
@ -316,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), value=False
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
else:

View File

@ -1,8 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
# reference: https://github.com/lifeiteng/vall-e
from typing import Tuple
import torch
import torch.nn.functional as F
from typing import Tuple
def sequence_mask(length, max_length=None):
if max_length is None:
@ -67,14 +69,18 @@ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len-lengths).unsqueeze(-1)
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
return expaned_lengths<0
return expaned_lengths < 0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
@ -105,9 +111,7 @@ def top_k_top_p_filtering(
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
@ -156,19 +160,21 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@ -176,7 +182,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
@ -188,18 +194,19 @@ def sample(
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
)
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
def dpo_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
@ -214,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
return losses.mean(), chosen_rewards, rejected_rewards
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
def get_batch_logps(
logits_target: torch.FloatTensor,
logits_reject: torch.FloatTensor,
labels_target: torch.LongTensor,
labels_reject: torch.LongTensor,
average_log_prob: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# dummy token; we'll ignore the losses on these tokens later
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
per_token_logps_target = torch.gather(
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
).squeeze(2)
per_token_logps_reject = torch.gather(
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
).squeeze(2)
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
def make_reject_y(y_o, y_lens):
def repeat_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, range_text, range_text, shf])
return new_y
def lost_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, shf])
return new_y
bs = len(y_lens)
reject_y = []
reject_y_lens = []
for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
process_item_idx = torch.randint(0, 1, size=(1,))[0]
if process_item_idx == 0:
new_y = repeat_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
elif process_item_idx==1:
elif process_item_idx == 1:
new_y = lost_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
@ -256,7 +276,7 @@ def make_reject_y(y_o, y_lens):
pad_length = max_length - reject_y_lens[b]
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
reject_y = torch.stack(reject_y, dim = 0)
reject_y = torch.stack(reject_y, dim=0)
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
return reject_y, reject_y_lens

View File

@ -1,17 +1,14 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
@ -73,6 +70,7 @@ class MultiheadAttention(Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
@ -104,9 +102,7 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@ -117,31 +113,32 @@ class MultiheadAttention(Module):
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
torch.empty((embed_dim, embed_dim), **factory_kwargs),
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
torch.empty((embed_dim, self.kdim), **factory_kwargs),
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
torch.empty((embed_dim, self.vdim), **factory_kwargs),
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
)
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
self._reset_parameters()
@ -150,7 +147,10 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
@ -164,7 +164,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:
@ -261,28 +264,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask
key_padding_mask,
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
raise AssertionError("only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
)
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
)
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
@ -300,9 +301,7 @@ class MultiheadAttention(Module):
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input"
)
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
@ -322,20 +321,10 @@ class MultiheadAttention(Module):
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
@ -350,11 +339,7 @@ class MultiheadAttention(Module):
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested

View File

@ -1,13 +1,10 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn import Linear, Module
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
@ -46,9 +43,7 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@ -59,18 +54,30 @@ class MultiheadAttention(Module):
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
torch.empty(
(embed_dim, embed_dim),
**factory_kwargs,
)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
torch.empty(
(embed_dim, self.kdim),
**factory_kwargs,
)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
torch.empty(
(embed_dim, self.vdim),
**factory_kwargs,
)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
torch.empty(
(3 * embed_dim, embed_dim),
**factory_kwargs,
)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
@ -78,13 +85,11 @@ class MultiheadAttention(Module):
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
torch.empty(3 * embed_dim, **factory_kwargs),
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self._reset_parameters()
else:
@ -92,7 +97,10 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
@ -106,7 +114,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:

View File

@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
return
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

View File

@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
scpe = (position * self.div_term).unsqueeze(0)
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
pe = pe.contiguous().view(1, -1, self.embedding_dim)

View File

@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
lr = self.end_lr
else:
decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
)
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
@ -70,7 +66,13 @@ if __name__ == "__main__":
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0,
)
lrs = []
for i in range(25000):

View File

@ -16,8 +16,7 @@
import contextlib
import logging
from collections import defaultdict
from typing import List
from typing import Tuple
from typing import List, Tuple
import torch
from torch import Tensor
@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
group_params_names: name for each parameter in group,
which is List[str].
"""
batches = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names):
@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(
range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [
batches_names[batches_names_keys[idx]] for idx in sorted_idx
]
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict()
@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
])
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
tuples.append((p_stacked, state, batch_names))
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@ -177,12 +167,11 @@ class ScaledAdam(BatchedOptimizer):
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True, ):
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter")
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
)
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
clipping_update_period=clipping_update_period, )
clipping_update_period=clipping_update_period,
)
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups,
self.parameters_names):
with self.batched_params(group["params"],
group_params_names) as batches:
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
if (len(batches[0][1]) ==
0): # if len(first state) == 0: not yet initialized
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
clipping_scale = 1
else:
clipping_scale = self._get_clipping_scale(group, batches)
@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
batch_size = p.shape[0]
numel = p.numel() // batch_size
@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period,
*param_rms.shape, **kwargs)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def _get_clipping_scale(self,
group: dict,
tuples: List[Tuple[Tensor, dict, List[str]]]
) -> float:
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients")
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"])**2).sum()
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
tot_norm = tot_sumsq.sqrt()
if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros(
clipping_update_period, device=p.device)
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
for n in range(0, 5):
index = min(
clipping_update_period - 1,
(clipping_update_period // 4) * n, )
(clipping_update_period // 4) * n,
)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 /
clipping_update_period
if "num_clipped" in first_state else 0.0)
percent_clipped = (
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period:
@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor):
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
"""
Show information of parameter wihch dominanting tot_sumsq.
@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
dim=list(range(1, batch_grad.ndim)))
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(batch_param_names,
for name, sumsq_orig, rms, grad in zip(
batch_param_names,
batch_sumsq_orig,
batch_rms_orig, batch_grad):
batch_rms_orig,
batch_grad,
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0), )
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(),
key=lambda item: item[1][0],
reverse=True, )
reverse=True,
)
}
dominant_param_name = next(iter(sorted_by_proportion))
(dominant_proportion, dominant_sumsq, dominant_rms,
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
(
dominant_proportion,
dominant_sumsq,
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict,
clipping_scale: float):
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2)
.mean(dim=list(range(1, p.ndim)), keepdim=True)
.sqrt())
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = step + 1
def _size_update(self,
def _size_update(
self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
state: dict,
) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state[
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (-size_lr * (bias_correction2**0.5) *
scale_grads.sum(dim=0) / denom)
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms
@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"]
if "zero_step" in state else 0)
bias_correction2 = 1 - beta2**(this_step + 1)
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2**(state["step"] + 1)
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"]

View File

@ -24,18 +24,18 @@ def multi_head_attention_forward_patched(
dropout_p: float,
out_proj_weight,
out_proj_bias,
training = True,
key_padding_mask = None,
need_weights = True,
attn_mask = None,
use_separate_proj_weight = False,
q_proj_weight = None,
k_proj_weight = None,
v_proj_weight = None,
static_k = None,
static_v = None,
average_attn_weights = True,
is_causal = False,
training=True,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
use_separate_proj_weight=False,
q_proj_weight=None,
k_proj_weight=None,
v_proj_weight=None,
static_k=None,
static_v=None,
average_attn_weights=True,
is_causal=False,
cache=None,
):
r"""
@ -155,9 +155,7 @@ def multi_head_attention_forward_patched(
cache=cache,
)
is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
@ -210,45 +208,33 @@ def multi_head_attention_forward_patched(
# longer causal.
is_causal = False
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert embed_dim == embed_dim_to_check, (
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
)
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
assert key.shape[:2] == value.shape[:2], (
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
)
else:
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert (
q_proj_weight is not None
), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
@ -311,9 +297,7 @@ def multi_head_attention_forward_patched(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
@ -337,34 +321,26 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
assert static_k.size(0) == bsz * num_heads, (
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
)
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
assert static_v.size(0) == bsz * num_heads, (
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
)
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
@ -380,9 +356,7 @@ def multi_head_attention_forward_patched(
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
@ -401,14 +375,10 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
@ -417,9 +387,7 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@ -448,13 +416,9 @@ def multi_head_attention_forward_patched(
v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@ -3,6 +3,7 @@ from torch.nn.functional import (
_canonical_mask,
)
def multi_head_attention_forward_patched(
query,
key,
@ -31,7 +32,6 @@ def multi_head_attention_forward_patched(
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
_, _, embed_dim = query.shape
attn_mask = _canonical_mask(
@ -77,12 +77,8 @@ def multi_head_attention_forward_patched(
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(-1, 1, attn_output.size(1))

View File

@ -58,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@ -150,13 +148,9 @@ def _compute_scale_factor(
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor
)
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
return below_threshold - above_threshold
@ -178,18 +172,16 @@ def _compute_sign_factor(
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = (
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = (
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
).clamp_(min=0, max=max_factor)
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor
)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
@ -317,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x)
def BalancedDoubleSwish(
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
return nn.Sequential(
balancer,
DoubleSwish(),

View File

@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
raise AssertionError("only bool and floating types of key_padding_mask are supported")
if self.norm_first:
x = x + self._sa_block(

View File

@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

View File

@ -30,9 +30,7 @@ class GruutPhonemizer:
"«": "«",
"»": "»",
}
self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
@ -53,13 +51,8 @@ class GruutPhonemizer:
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [
sent
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
return " ".join(words)
def transform(self, phonemes):

View File

@ -3,7 +3,9 @@
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
IPA_LETTERS = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}

View File

@ -2,12 +2,12 @@ import re
def str2bool(str):
return True if str.lower() == 'true' else False
return True if str.lower() == "true" else False
def get_newest_ckpt(string_list):
# 定义一个正则表达式模式,用于匹配字符串中的数字
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
extracted_info = []
@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
step = int(match.group(2))
extracted_info.append((epoch, step, string))
# 按照 epoch 后面的数字和 step 后面的数字进行排序
sorted_info = sorted(
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
# 获取最新的 ckpt 文件名
newest_ckpt = sorted_info[0][2]
return newest_ckpt
@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
# 文本存在且不为空时 return True
def check_txt_file(file_path):
try:
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
text = file.readline().strip()
assert text.strip() != ''
assert text.strip() != ""
return text
except Exception:
return False

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
"""Initialize modules for espnet2 neural networks."""
import torch
from typeguard import check_argument_types

View File

@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
def write_args(args, path):
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write("\n==> args:\n")

View File

@ -23,9 +23,7 @@ class Snake(nn.Module):
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:
@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:

View File

@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward(
inputs, up_ftr, down_ftr, alpha, beta
)
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
return activation_results
@ -61,17 +59,11 @@ class Activation1d(nn.Module):
if self.act.__class__.__name__ == "Snake":
beta = self.act.alpha.data # Snake uses same params for alpha and beta
else:
beta = (
self.act.beta.data
) # Snakebeta uses different params for alpha and beta
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
alpha = self.act.alpha.data
if (
not self.act.alpha_logscale
): # Exp baked into cuda kernel, cancel it out with a log
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha)
beta = torch.log(beta)
x = FusedAntiAliasActivation.apply(
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
)
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
return x

View File

@ -58,17 +58,13 @@ def load():
srcpath / "anti_alias_activation.cpp",
srcpath / "anti_alias_activation_cuda.cu",
]
anti_alias_activation_cuda = _cpp_extention_load_helper(
"anti_alias_activation_cuda", sources, extra_cuda_flags
)
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
return anti_alias_activation_cuda
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")

View File

@ -27,9 +27,7 @@ else:
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2

View File

@ -11,18 +11,12 @@ class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left : -self.pad_right]
return x
@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,

View File

@ -87,9 +87,7 @@ class AMPBlock1(torch.nn.Module):
)
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(
self.convs2
) # Total number of conv layers
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
@ -283,9 +265,7 @@ class BigVGAN(
self.num_upsamples = len(h.upsample_rates)
# Pre-conv
self.conv_pre = weight_norm(
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
if h.resblock == "1":
@ -293,9 +273,7 @@ class BigVGAN(
elif h.resblock == "2":
resblock_class = AMPBlock2
else:
raise ValueError(
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
)
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
# Transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
@ -320,22 +298,14 @@ class BigVGAN(
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(
resblock_class(h, ch, k, d, activation=h.activation)
)
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
# Post-conv
activation_post = (
activations.Snake(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snake"
else (
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snakebeta"
else None
)
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
)
if activation_post is None:
raise NotImplementedError(
@ -346,9 +316,7 @@ class BigVGAN(
# Whether to use bias for the final conv_post. Default to True for backward compatibility
self.use_bias_at_final = h.get("use_bias_at_final", True)
self.conv_post = weight_norm(
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
# Weight initialization
for i in range(len(self.ups)):

View File

@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
),
]
)
self.conv_post = norm_f(
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
)
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
self.mpd_reshapes = h.mpd_reshapes
print(f"mpd_reshapes: {self.mpd_reshapes}")
self.discriminators = nn.ModuleList(
[
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
for rs in self.mpd_reshapes
]
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
super().__init__()
self.resolution = resolution
assert (
len(self.resolution) == 3
), f"MRD layer requires list with len=3, got {self.resolution}"
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
self.lrelu_slope = 0.1
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print(
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
)
norm_f = (
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
)
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
),
]
)
self.conv_post = norm_f(
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
)
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert (
len(self.resolutions) == 3
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
assert len(self.resolutions) == 3, (
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
)
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm(
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
# Remove DC offset
@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
super().__init__()
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
self.discriminators = nn.ModuleList(
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
)
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorCQT(nn.Module):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
super().__init__()
self.cfg = cfg
@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
in_chs = min(self.filters_scale * self.filters, self.max_filters)
for i, dilation in enumerate(self.dilations):
out_chs = min(
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
)
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
self.convs.append(
weight_norm(
nn.Conv2d(
@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
in_chs,
out_chs,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding(
(self.kernel_size[0], self.kernel_size[0])
),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
)
@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
# Multi-scale params to loop over
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
"cqtd_bins_per_octaves", [24, 36, 48]
)
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
self.discriminators = nn.ModuleList(
[
@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
super().__init__()
self.discrimiantor = nn.ModuleList(list_discriminator)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []

View File

@ -35,9 +35,7 @@ def inference(a, h):
with torch.no_grad():
for i, filname in enumerate(filelist):
# Load the ground truth audio and resample if necessary
wav, sr = librosa.load(
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
)
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
wav = torch.FloatTensor(wav).to(device)
# Compute mel spectrogram from the ground truth audio
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
@ -48,9 +46,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
)
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)

View File

@ -61,9 +61,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
)
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)

View File

@ -122,9 +122,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
B, C, T = wav.shape
if match_stride:
assert (
hop_length == window_length // 4
), "For match_stride, hop must equal n_fft // 4"
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
@ -154,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
)
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
@ -181,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
@ -196,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
@ -210,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
# Loss functions
def feature_loss(
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
@ -225,7 +212,6 @@ def feature_loss(
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
@ -242,7 +228,6 @@ def discriminator_loss(
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs:

View File

@ -86,9 +86,7 @@ def mel_spectrogram(
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
@ -96,9 +94,7 @@ def mel_spectrogram(
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(
y.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
y,
@ -150,17 +146,13 @@ def get_dataset_filelist(a):
with open(a.input_training_file, "r", encoding="utf-8") as fi:
training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first training file: {training_files[0]}")
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first validation file: {validation_files[0]}")
@ -171,9 +163,7 @@ def get_dataset_filelist(a):
for x in fi.read().split("\n")
if len(x) > 0
]
print(
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
)
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
list_unseen_validation_files.append(unseen_validation_files)
return training_files, validation_files, list_unseen_validation_files
@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
print("[INFO] checking dataset integrity...")
for i in tqdm(range(len(self.audio_files))):
assert os.path.exists(
self.audio_files[i]
), f"{self.audio_files[i]} not found"
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
def __getitem__(
self, index: int
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
try:
filename = self.audio_files[index]
@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
# Obtain randomized audio chunk
if source_sampling_rate != self.sampling_rate:
# Adjust segment size to crop if the source sr is different
target_segment_size = math.ceil(
self.segment_size
* (source_sampling_rate / self.sampling_rate)
)
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
else:
target_segment_size = self.segment_size
# Compute upper bound index for the random chunk
random_chunk_upper_bound = max(
0, audio.shape[0] - target_segment_size
)
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
# Crop or pad audio to obtain random chunk with target_segment_size
if audio.shape[0] >= target_segment_size:
@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
else:
# For fine-tuning, assert that the waveform is in the defined sampling_rate
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
assert (
source_sampling_rate == self.sampling_rate
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
assert source_sampling_rate == self.sampling_rate, (
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
)
# Cast ndarray to torch tensor
audio = torch.FloatTensor(audio)
@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[
:,
mel_start
* self.hop_size : (mel_start + frames_per_seg)
* self.hop_size,
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
]
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
mel = torch.nn.functional.pad(
mel, (0, frames_per_seg - mel.size(2)), "constant"
)
audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
mel_loss = mel_spectrogram(
@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
# Shape sanity checks
assert (
audio.shape[1] == mel.shape[2] * self.hop_size
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), (
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
if self.fine_tuning:
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
else:
print(
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
)
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
return self[random.randrange(len(self))]
def __len__(self):

View File

@ -3,6 +3,7 @@
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations.Snake cuda vs. torch
fused_anti_alias_activation = activation1d.Activation1d(
activation=Snake(10), fused=True
).cuda()
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(
activation=Snake(10), fused=False
).cuda()
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()

View File

@ -3,6 +3,7 @@
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations, Snake CUDA vs. Torch
fused_anti_alias_activation = activation1d.Activation1d(
activation=SnakeBeta(10), fused=True
).cuda()
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(
activation=SnakeBeta(10), fused=False
).cuda()
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()
@ -57,7 +54,6 @@ def test_anti_alias_activation():
)
if __name__ == "__main__":
from alias_free_activation.cuda import load

View File

@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
def get_mel(x, h):
return mel_spectrogram(
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
)
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
def load_checkpoint(filepath, device):
@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test script to check CUDA kernel correctness."
)
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
parser.add_argument(
"--checkpoint_file",
type=str,
@ -109,9 +105,7 @@ if __name__ == "__main__":
diff += test_result.mean(dim=-1).item()
diff /= num_sample
if (
diff <= 2e-3
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
print(
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
@ -175,8 +169,8 @@ if __name__ == "__main__":
audio_second = audio_length_total / h.sampling_rate
khz_original = audio_length_total / toc_total_original / 1000
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
# Print results
print(

View File

@ -77,24 +77,18 @@ def train(rank, a, h):
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
print(
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
# Variable name is kept as "mrd" for backward compatibility & minimal code change
mrd = MultiBandDiscriminator(h).to(device)
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
print(
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
else: # Fallback to original MRD in BigVGAN-v1
mrd = MultiResolutionDiscriminator(h).to(device)
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
if h.get("use_multiscale_melloss", False):
print(
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
)
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
sampling_rate=h.sampling_rate
) # NOTE: accepts waveform as input
@ -114,9 +108,7 @@ def train(rank, a, h):
if os.path.isdir(a.checkpoint_path):
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
cp_g = scan_checkpoint(
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
)
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
cp_do = scan_checkpoint(
a.checkpoint_path,
prefix="do_",
@ -143,9 +135,7 @@ def train(rank, a, h):
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(
itertools.chain(mrd.parameters(), mpd.parameters()),
h.learning_rate,
@ -156,12 +146,8 @@ def train(rank, a, h):
optim_g.load_state_dict(state_dict_do["optim_g"])
optim_d.load_state_dict(state_dict_do["optim_d"])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
# Define training and validation datasets
@ -169,9 +155,7 @@ def train(rank, a, h):
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
Example: trained on LibriTTS, validate on VCTK
"""
training_filelist, validation_filelist, list_unseen_validation_filelist = (
get_dataset_filelist(a)
)
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(
training_filelist,
@ -324,33 +308,26 @@ def train(rank, a, h):
h.fmax_for_loss,
)
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
if (
"nonspeech" not in mode
): # Skips if the name of dataset (in mode string) contains "nonspeech"
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
# Resample to 16000 for pesq
y_16k = pesq_resampler(y)
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
y_g_hat_int_16k = (
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
)
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
# MRSTFT calculation
min_t = min(y.size(-1), y_g_hat.size(-1))
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
# Log audio and figures to Tensorboard
if j % a.eval_subsample == 0: # Subsample every nth from validation set
if steps >= 0:
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
if (
a.save_audio
): # Also save audio to disk if --save_audio is set to True
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y[0],
os.path.join(
@ -373,9 +350,7 @@ def train(rank, a, h):
steps,
h.sampling_rate,
)
if (
a.save_audio
): # Also save audio to disk if --save_audio is set to True
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y_g_hat[0, 0],
os.path.join(
@ -487,15 +462,11 @@ def train(rank, a, h):
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
y_df_hat_r, y_df_hat_g
)
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
y_ds_hat_r, y_ds_hat_g
)
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
@ -505,17 +476,11 @@ def train(rank, a, h):
# Whether to freeze D for initial training steps
if steps >= a.freeze_step:
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
mpd.parameters(), clip_grad_norm
)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
mrd.parameters(), clip_grad_norm
)
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
optim_d.step()
else:
print(
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
)
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
grad_norm_mpd = 0.0
grad_norm_mrd = 0.0
@ -523,9 +488,7 @@ def train(rank, a, h):
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
lambda_melloss = h.get(
"lambda_melloss", 45.0
) # Defaults to 45 in BigVGAN-v1 if not set
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
else: # Uses mel <y_mel, y_g_hat_mel> for loss
@ -542,27 +505,19 @@ def train(rank, a, h):
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
if steps >= a.freeze_step:
loss_gen_all = (
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
else:
print(
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
)
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
loss_gen_all = loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(
generator.parameters(), clip_grad_norm
)
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
mel_error = (
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to stdout
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
print(
f"Steps: {steps:d}, "
f"Gen Loss Total: {loss_gen_all:4.3f}, "
@ -577,11 +532,7 @@ def train(rank, a, h):
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
save_checkpoint(
checkpoint_path,
{
"generator": (
generator.module if h.num_gpus > 1 else generator
).state_dict()
},
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
)
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
save_checkpoint(
@ -598,9 +549,7 @@ def train(rank, a, h):
# Tensorboard summary logging
if steps % a.summary_interval == 0:
mel_error = (
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to tensorboard
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
@ -612,12 +561,8 @@ def train(rank, a, h):
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
sw.add_scalar(
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
)
sw.add_scalar(
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
)
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
sw.add_scalar("training/epoch", epoch + 1, steps)
# Validation
@ -660,9 +605,7 @@ def train(rank, a, h):
scheduler_d.step()
if rank == 0:
print(
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
)
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
def main():
@ -674,12 +617,8 @@ def main():
parser.add_argument("--input_wavs_dir", default="LibriTTS")
parser.add_argument("--input_mels_dir", default="ft_dataset")
parser.add_argument(
"--input_training_file", default="tests/LibriTTS/train-full.txt"
)
parser.add_argument(
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
)
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
parser.add_argument(
"--list_input_unseen_wavs_dir",

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,9 @@
import os
import sys
import threading
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -19,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
punctuation = set(['!', '?', '', ',', '.', '-'])
punctuation = set(["!", "?", "", ",", ".", "-"])
def get_first(text:str) -> str:
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2:
return texts
result = []
@ -39,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(text) > 0:
if len(result) == 0:
result.append(text)
else:
@ -47,28 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
return result
class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.bert_lock = threading.RLock()
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
print(f'############ {i18n("切分文本")} ############')
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f'############ {i18n("提取文本Bert特征")} ############')
print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text=="":
if phones is None or norm_text == "":
continue
res={
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
@ -76,11 +74,11 @@ class TextPreprocessor:
result.append(res)
return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n")
if len(text) == 0:
return []
if (text[0] not in splits and len(get_first(text)) < 4):
if text[0] not in splits and len(get_first(text)) < 4:
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
@ -96,18 +94,18 @@ class TextPreprocessor:
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
if len(text.strip()) == 0:
continue
if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if (text[-1] not in splits): text += "" if lang != "en" else "."
if text[-1] not in splits:
text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if (len(text) > 510):
if len(text) > 510:
texts.extend(split_big_text(text))
else:
texts.append(text)
@ -116,10 +114,12 @@ class TextPreprocessor:
print(texts)
return texts
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock:
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
# language = language.replace("all_","")
@ -127,17 +127,17 @@ class TextPreprocessor:
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"zh",version)
return self.get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"yue",version)
return self.get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = torch.zeros(
@ -145,8 +145,8 @@ class TextPreprocessor:
dtype=torch.float32,
).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
textlist = []
langlist = []
if language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
@ -179,15 +179,14 @@ class TextPreprocessor:
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text,language,version,final=True)
return self.get_phones_and_bert("." + text, language, version, final=True)
return phones, bert, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
@ -202,14 +201,14 @@ class TextPreprocessor:
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"):
language = language.replace("all_","")
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
language=language.replace("all_","")
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
language = language.replace("all_", "")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else:
@ -220,10 +219,9 @@ class TextPreprocessor:
return feature
def filter_text(self,texts):
_text=[]
if all(text in [None, " ", "\n",""] for text in texts):
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
@ -232,9 +230,8 @@ class TextPreprocessor:
_text.append(text)
return _text
def replace_consecutive_punctuation(self,text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
def replace_consecutive_punctuation(self, text):
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result

View File

@ -1,40 +1,56 @@
import re
from typing import Callable
punctuation = set(['!', '?', '', ',', '.', '-'," "])
punctuation = set(["!", "?", "", ",", ".", "-", " "])
METHODS = dict()
def get_method(name:str)->Callable:
def get_method(name: str) -> Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def get_method_names()->list:
def get_method_names() -> list:
return list(METHODS.keys())
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
}
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split('([' + punctuation + '])', text)
segments = re.split("([" + punctuation + "])", text)
# 初始化结果列表和当前片段
result = []
current_segment = ''
current_segment = ""
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
@ -51,7 +67,6 @@ def split_big_text(text, max_len=510):
return result
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
@ -90,7 +105,7 @@ def cut1(inp):
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
@ -123,6 +138,7 @@ def cut2(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
@ -131,26 +147,28 @@ def cut3(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
#按英文句号.切
# 按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
inp = inp.strip("\n")
punds = {',', '.', ';', '?', '!', '', '', '', '', '', ';', '', ''}
punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mergeitems = []
items = []
for i, char in enumerate(inp):
if char in punds:
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char)
else:
items.append(char)
@ -166,8 +184,6 @@ def cut5(inp):
return "\n".join(opt)
if __name__ == '__main__':
if __name__ == "__main__":
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

View File

@ -1,6 +1,13 @@
import os
import sys
now_dir = os.getcwd()
sys.path.insert(0, now_dir)
from text.g2pw import G2PWPinyin
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
v_to_u=False,
neutral_tone_with_five=True,
)

View File

@ -32,6 +32,7 @@ default_config = {
"EOS": 1024,
}
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"]
config["model"]["dropout"] = float(config["model"]["dropout"])
@ -40,6 +41,7 @@ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
t2s_model = t2s_model.eval()
return t2s_model
@torch.jit.script
def logits_to_probs(
logits,
@ -56,39 +58,35 @@ def logits_to_probs(
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
@torch.jit.script
def multinomial_sample_one_no_sync(probs_sort):
# Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@torch.jit.script
def sample(
logits,
@ -99,15 +97,20 @@ def sample(
repetition_penalty: float = 1.0,
):
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty
logits=logits,
previous_tokens=previous_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
@torch.jit.script
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@ -157,6 +160,7 @@ class DictToAttrRecursive(dict):
except KeyError:
raise AttributeError(f"Attribute {item} not found")
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
@ -170,6 +174,7 @@ class T2SMLP:
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
@ -205,7 +210,7 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
if padding_mask is None:
return x
@ -214,7 +219,7 @@ class T2SBlock:
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
@ -231,22 +236,20 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
if padding_mask is not None:
for i in range(batch_size):
# mask = padding_mask[i,:,0]
if self.false.device!= padding_mask.device:
if self.false.device != padding_mask.device:
self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
x_item = x[i,idx,:].unsqueeze(0)
attn_item = attn[i,idx,:].unsqueeze(0)
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
x_item = x[i, idx, :].unsqueeze(0)
attn_item = attn[i, idx, :].unsqueeze(0)
x_item = x_item + attn_item
x_item = F.layer_norm(
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm(
x_item,
@ -255,13 +258,11 @@ class T2SBlock:
self.norm_b2,
self.norm_eps2,
)
x[i,idx,:] = x_item.squeeze(0)
x[i, idx, :] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask)
else:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
@ -272,7 +273,7 @@ class T2SBlock:
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
@ -288,14 +289,12 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
@ -306,37 +305,35 @@ class T2SBlock:
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
self.num_blocks : int = num_blocks
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
k_cache : list[torch.Tensor] = []
v_cache : list[torch.Tensor] = []
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
k_cache: list[torch.Tensor] = []
v_cache: list[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
k_cache: list[torch.Tensor],
v_cache: list[torch.Tensor]):
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
# dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path)
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
@ -347,7 +344,7 @@ class VitsModel(nn.Module):
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
**self.hps.model,
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
@ -359,12 +356,13 @@ class VitsModel(nn.Module):
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
center=False,
)
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
class T2SModel(nn.Module):
def __init__(self,raw_t2s:Text2SemanticLightningModule):
def __init__(self, raw_t2s: Text2SemanticLightningModule):
super(T2SModel, self).__init__()
self.model_dim = raw_t2s.model.model_dim
self.embedding_dim = raw_t2s.model.embedding_dim
@ -373,7 +371,7 @@ class T2SModel(nn.Module):
self.vocab_size = raw_t2s.model.vocab_size
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
# self.p_dropout = float(raw_t2s.model.p_dropout)
self.EOS:int = int(raw_t2s.model.EOS)
self.EOS: int = int(raw_t2s.model.EOS)
self.norm_first = raw_t2s.model.norm_first
assert self.EOS == self.vocab_size - 1
self.hz = 50
@ -392,12 +390,7 @@ class T2SModel(nn.Module):
for i in range(self.num_layers):
layer = h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
block = T2SBlock(
self.num_head,
@ -412,7 +405,7 @@ class T2SModel(nn.Module):
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
layer.norm2.eps,
)
blocks.append(block)
@ -426,19 +419,26 @@ class T2SModel(nn.Module):
self.top_k = int(raw_t2s.config["inference"]["top_k"])
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor,top_k:LongTensor):
def forward(
self,
prompts: LongTensor,
ref_seq: LongTensor,
text_seq: LongTensor,
ref_bert: torch.Tensor,
text_bert: torch.Tensor,
top_k: LongTensor,
):
bert = torch.cat([ref_bert.T, text_bert.T], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
x = self.ar_text_embedding(all_phoneme_ids)
x = x + self.bert_proj(bert.transpose(1, 2))
x:torch.Tensor = self.ar_text_position(x)
x: torch.Tensor = self.ar_text_position(x)
early_stop_num = self.early_stop_num
#[1,N,512] [1,N]
# [1,N,512] [1,N]
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
y = prompts
# x_example = x[:,:,0] * 0.0
@ -464,11 +464,13 @@ class T2SModel(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.num_head, -1, -1)
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
idx = 0
top_k = int(top_k)
@ -480,17 +482,19 @@ class T2SModel(nn.Module):
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1)
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
stop = False
# for idx in range(1, 50):
for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
if(idx<11):###至少预测出10个token不然不给停止0.4s
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
@ -507,20 +511,22 @@ class T2SModel(nn.Module):
break
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
y[0,-1] = 0
y[0, -1] = 0
return y[:, -idx:].unsqueeze(0)
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
@torch.jit.script
def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
@ -529,39 +535,45 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
# [sum(word2ph), 1024]
return phone_level_feature
class MyBertModel(torch.nn.Module):
def __init__(self, bert_model):
super(MyBertModel, self).__init__()
self.bert = bert_model
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
return build_phone_level_feature(res, word2ph)
class SSLModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.ssl = cnhubert.get_model().model
def forward(self, ref_audio_16k)-> torch.Tensor:
def forward(self, ref_audio_16k) -> torch.Tensor:
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
return ssl_content
class ExportSSLModel(torch.nn.Module):
def __init__(self,ssl:SSLModel):
def __init__(self, ssl: SSLModel):
super().__init__()
self.ssl = ssl
def forward(self, ref_audio:torch.Tensor):
def forward(self, ref_audio: torch.Tensor):
return self.ssl(ref_audio)
@torch.jit.export
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
audio = resamplex(ref_audio,src_sr,dst_sr).float()
def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
audio = resamplex(ref_audio, src_sr, dst_sr).float()
return audio
def export_bert(output_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path)
@ -569,33 +581,34 @@ def export_bert(output_path):
ref_bert_inputs = tokenizer(text, return_tensors="pt")
word2ph = []
for c in text:
if c in ['','','','',",",".","?"]:
if c in ["", "", "", "", ",", ".", "?"]:
word2ph.append(1)
else:
word2ph.append(2)
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int()
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
my_bert_model = MyBertModel(bert_model)
ref_bert_inputs = {
'input_ids': ref_bert_inputs['input_ids'],
'attention_mask': ref_bert_inputs['attention_mask'],
'token_type_ids': ref_bert_inputs['token_type_ids'],
'word2ph': ref_bert_inputs['word2ph']
"input_ids": ref_bert_inputs["input_ids"],
"attention_mask": ref_bert_inputs["attention_mask"],
"token_type_ids": ref_bert_inputs["token_type_ids"],
"word2ph": ref_bert_inputs["word2ph"],
}
torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0)
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
output_path = os.path.join(output_path, "bert_model.pt")
my_bert_model.save(output_path)
print('#### exported bert ####')
print("#### exported bert ####")
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
if not os.path.exists(output_path):
os.makedirs(output_path)
print(f"目录已创建: {output_path}")
@ -605,21 +618,22 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
if export_bert_and_ssl:
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path)
print('#### exported ssl ####')
print("#### exported ssl ####")
export_bert(output_path)
else:
s = ExportSSLModel(ssl)
print(f"device: {device}")
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T.to(text_seq.device)
@ -633,18 +647,18 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print('#### get_raw_t2s_model ####')
print("#### get_raw_t2s_model ####")
print(raw_t2s.config)
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
t2s = torch.jit.script(t2s_m).to(device)
print('#### script t2s_m ####')
print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS(t2s,vits).to(device)
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio,16000,32000).to(device)
ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
@ -657,32 +671,28 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
with torch.no_grad():
gpt_sovits_export = torch.jit.trace(
gpt_sovits,
example_inputs=(
ssl_content,
ref_audio_sr,
ref_seq,
text_seq,
ref_bert,
text_bert,
top_k))
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
)
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
print('#### exported gpt_sovits ####')
print("#### exported gpt_sovits ####")
@torch.jit.script
def parse_audio(ref_audio):
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()#.to(ref_audio.device)
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device)
return ref_audio_16k,ref_audio_sr
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
return ref_audio_16k, ref_audio_sr
@torch.jit.script
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float()
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
class GPT_SoVITS(nn.Module):
def __init__(self, t2s:T2SModel,vits:VitsModel):
def __init__(self, t2s: T2SModel, vits: VitsModel):
super().__init__()
self.t2s = t2s
self.vits = vits
@ -709,12 +719,11 @@ class GPT_SoVITS(nn.Module):
def test():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
parser.add_argument('--output_path', required=True, help="Path to the output directory")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args()
gpt_path = args.gpt_model
@ -725,7 +734,7 @@ def test():
tokenizer = AutoTokenizer.from_pretrained(bert_path)
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
# bert = MyBertModel(bert_model)
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
# dict_s1 = torch.load(gpt_path, map_location="cuda")
# raw_t2s = get_raw_t2s_model(dict_s1)
@ -739,78 +748,79 @@ def test():
# ssl = ExportSSLModel(SSLModel()).to('cuda')
# ssl.eval()
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda')
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
# gpt_sovits = GPT_SoVITS(t2s,vits)
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda')
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
ref_seq = torch.LongTensor([ref_seq_id])
ref_bert = ref_bert_T.T.to(ref_seq.device)
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2')
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
test_bert = tokenizer(text, return_tensors="pt")
word2ph = []
for c in text:
if c in ['','','','',"?",",","."]:
if c in ["", "", "", "", "?", ",", "."]:
word2ph.append(1)
else:
word2ph.append(2)
test_bert['word2ph'] = torch.Tensor(word2ph).int()
test_bert["word2ph"] = torch.Tensor(word2ph).int()
test_bert = my_bert(
test_bert['input_ids'].to('cuda'),
test_bert['attention_mask'].to('cuda'),
test_bert['token_type_ids'].to('cuda'),
test_bert['word2ph'].to('cuda')
test_bert["input_ids"].to("cuda"),
test_bert["attention_mask"].to("cuda"),
test_bert["token_type_ids"].to("cuda"),
test_bert["word2ph"].to("cuda"),
)
text_seq = torch.LongTensor([text_seq_id])
text_bert = text_bert_T.T.to(text_seq.device)
print('text_bert:',text_bert.shape,text_bert)
print('test_bert:',test_bert.shape,test_bert)
print(torch.allclose(text_bert.to('cuda'),test_bert))
print("text_bert:", text_bert.shape, text_bert)
print("test_bert:", test_bert.shape, test_bert)
print(torch.allclose(text_bert.to("cuda"), test_bert))
print('text_seq:',text_seq.shape)
print('text_bert:',text_bert.shape,text_bert.type())
print("text_seq:", text_seq.shape)
print("text_bert:", text_bert.shape, text_bert.type())
#[1,N]
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda')
print('ref_audio:',ref_audio.shape)
# [1,N]
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
print("ref_audio:", ref_audio.shape)
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
print('start ssl')
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
print("start ssl")
ssl_content = ssl(ref_audio)
print('start gpt_sovits:')
print('ssl_content:',ssl_content.shape)
print('ref_audio_sr:',ref_audio_sr.shape)
print('ref_seq:',ref_seq.shape)
ref_seq=ref_seq.to('cuda')
print('text_seq:',text_seq.shape)
text_seq=text_seq.to('cuda')
print('ref_bert:',ref_bert.shape)
ref_bert=ref_bert.to('cuda')
print('text_bert:',text_bert.shape)
text_bert=text_bert.to('cuda')
print("start gpt_sovits:")
print("ssl_content:", ssl_content.shape)
print("ref_audio_sr:", ref_audio_sr.shape)
print("ref_seq:", ref_seq.shape)
ref_seq = ref_seq.to("cuda")
print("text_seq:", text_seq.shape)
text_seq = text_seq.to("cuda")
print("ref_bert:", ref_bert.shape)
ref_bert = ref_bert.to("cuda")
print("text_bert:", text_bert.shape)
text_bert = text_bert.to("cuda")
top_k = torch.LongTensor([5]).to('cuda')
top_k = torch.LongTensor([5]).to("cuda")
with torch.no_grad():
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
print('start write wav')
print("start write wav")
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
import text
import json
def export_symbel(version='v2'):
if version=='v1':
def export_symbel(version="v2"):
if version == "v1":
symbols = text._symbol_to_id_v1
with open("onnx/symbols_v1.json", "w") as file:
json.dump(symbols, file, indent=4)
@ -819,15 +829,16 @@ def export_symbel(version='v2'):
with open("onnx/symbols_v2.json", "w") as file:
json.dump(symbols, file, indent=4)
def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
parser.add_argument('--output_path', required=True, help="Path to the output directory")
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model")
parser.add_argument('--device', help="Device to use")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument("--output_path", required=True, help="Path to the output directory")
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
parser.add_argument("--device", help="Device to use")
args = parser.parse_args()
export(
@ -840,9 +851,11 @@ def main():
export_bert_and_ssl=args.export_common_model,
)
import inference_webui
if __name__ == "__main__":
inference_webui.is_half=False
inference_webui.dtype=torch.float32
inference_webui.is_half = False
inference_webui.dtype = torch.float32
main()
# test()

View File

@ -32,7 +32,6 @@ now_dir = os.getcwd()
class MelSpectrgram(torch.nn.Module):
def __init__(
self,
dtype,
@ -48,14 +47,12 @@ class MelSpectrgram(torch.nn.Module):
):
super().__init__()
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
self.n_fft:int = n_fft
self.hop_size:int = hop_size
self.win_size:int = win_size
self.center:bool = center
self.n_fft: int = n_fft
self.hop_size: int = hop_size
self.win_size: int = win_size
self.center: bool = center
def forward(self, y):
y = torch.nn.functional.pad(
@ -172,9 +169,7 @@ class ExportCFM(torch.nn.Module):
):
T_min = fea_ref.size(2)
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
cfm_res = self.cfm(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps
)
cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:]
fea_ref = fea_todo_chunk[:, :, -T_min:]
@ -198,6 +193,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
spec_min = -12
spec_max = 2
@torch.jit.script
def norm_spec(x):
spec_min = -12
@ -212,7 +208,6 @@ def denorm_spec(x):
class ExportGPTSovitsHalf(torch.nn.Module):
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
super().__init__()
self.hps = hps
@ -231,15 +226,15 @@ class ExportGPTSovitsHalf(torch.nn.Module):
center=False,
)
# self.dtype = dtype
self.filter_length:int = hps.data.filter_length
self.sampling_rate:int = hps.data.sampling_rate
self.hop_length:int = hps.data.hop_length
self.win_length:int = hps.data.win_length
self.filter_length: int = hps.data.filter_length
self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length
def forward(
self,
ssl_content,
ref_audio_32k:torch.FloatTensor,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0,
phoneme_ids1,
bert1,
@ -255,18 +250,14 @@ class ExportGPTSovitsHalf(torch.nn.Module):
center=False,
).to(ssl_content.dtype)
codes = self.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0)
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
pred_semantic = self.t2s_m(
prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
ge = self.vq_model.create_ge(refer)
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@ -293,6 +284,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
return fea_ref, fea_todo, mel2
class GPTSoVITSV3(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, bigvgan):
super().__init__()
@ -303,9 +295,9 @@ class GPTSoVITSV3(torch.nn.Module):
def forward(
self,
ssl_content,
ref_audio_32k:torch.FloatTensor,
phoneme_ids0:torch.LongTensor,
phoneme_ids1:torch.LongTensor,
ref_audio_32k: torch.FloatTensor,
phoneme_ids0: torch.LongTensor,
phoneme_ids1: torch.LongTensor,
bert1,
bert2,
top_k: torch.LongTensor,
@ -313,7 +305,9 @@ class GPTSoVITSV3(torch.nn.Module):
):
# current_time = datetime.now()
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 934 - fea_ref.shape[2]
wav_gen_list = []
idx = 0
@ -331,7 +325,13 @@ class GPTSoVITSV3(torch.nn.Module):
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
complete_len = chunk_len - fea_todo_chunk.shape[-1]
if complete_len != 0:
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype)], 2)
fea_todo_chunk = torch.cat(
[
fea_todo_chunk,
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
],
2,
)
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
idx += chunk_len
@ -343,13 +343,13 @@ class GPTSoVITSV3(torch.nn.Module):
wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length]
def init_bigvgan():
global bigvgan_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x"
% (now_dir,),
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False,
) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
@ -467,10 +467,7 @@ def export_cfm(
cfm = e_cfm.cfm
B, T = mu.size(0), mu.size(1)
x = (
torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype)
* temperature
)
x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
print("x:", x.shape, x.dtype)
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
@ -565,11 +562,7 @@ def export():
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = sovits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device)
@ -626,10 +619,7 @@ def export():
"create_ge": refer,
}
trace_vq_model = torch.jit.trace_module(
sovits.vq_model, inputs, optimize=True
)
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
trace_vq_model.save("onnx/ad/vq_model.pt")
print(fea_ref.shape, fea_ref.dtype, ge.shape)
@ -714,9 +704,7 @@ def export():
idx += chunk_len
cfm_res, fea_ref, mel2 = export_cfm_(
fea_ref, fea_todo_chunk, mel2, sample_steps
)
cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
cfm_resss.append(cfm_res)
continue
@ -726,9 +714,7 @@ def export():
with torch.inference_mode():
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
bigvgan_model_ = torch.jit.trace(
bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)
)
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
wav_gen = bigvgan_model(cmf_res)
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
@ -748,7 +734,6 @@ def test_export(
bigvgan,
output,
):
# hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0
@ -773,13 +758,9 @@ def test_export(
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
phones1, bert1, norm_text1 = get_phones_and_bert(
@ -799,8 +780,18 @@ def test_export(
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("start inference %s", current_time)
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
print(
ssl_content.shape,
ref_audio_32k.shape,
phoneme_ids0.shape,
phoneme_ids1.shape,
bert1.shape,
bert2.shape,
top_k.shape,
)
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
)
chunk_len = 934 - fea_ref.shape[2]
print(fea_ref.shape, fea_todo.shape, mel2.shape)
@ -812,7 +803,6 @@ def test_export(
wav_gen_length = fea_todo.shape[2] * 256
while 1:
current_time = datetime.now()
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
@ -861,7 +851,6 @@ def test_export1(
gpt_sovits_v3,
output,
):
# hps = sovits.hps
ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0
@ -886,14 +875,10 @@ def test_export1(
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
ref_audio_32k,_ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
phones1, bert1, norm_text1 = get_phones_and_bert(
@ -913,11 +898,19 @@ def test_export1(
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("start inference %s", current_time)
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
print(
ssl_content.shape,
ref_audio_32k.shape,
phoneme_ids0.shape,
phoneme_ids1.shape,
bert1.shape,
bert2.shape,
top_k.shape,
)
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
wav_gen = torch.cat([wav_gen,zero_wav_torch],0)
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
audio = wav_gen.cpu().detach().numpy()
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@ -929,7 +922,6 @@ import time
def test_():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
# cfm = ExportCFM(sovits.cfm)
@ -959,10 +951,7 @@ def test_():
# t2s_m.top_k = 15
logger.info("t2s_m ok")
vq_model: torch.jit.ScriptModule = torch.jit.load(
"onnx/ad/vq_model.pt", map_location=device
)
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
# vq_model = torch.jit.optimize_for_inference(vq_model)
# vq_model = vq_model.half().to(device)
vq_model.eval()
@ -1020,8 +1009,9 @@ def test_():
# "out2.wav",
# )
def test_export_gpt_sovits_v3():
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt",map_location=device)
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
# test_export1(
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
# gpt_sovits_v3,

View File

@ -27,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
from module.commons import sequence_mask
class TextEmbedding(nn.Module):
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
@ -129,8 +130,8 @@ class DiT(nn.Module):
return ckpt_forward
def forward(#x, prompt_x, x_lens, t, style,cond
self,#d is channel,n is T
def forward( # x, prompt_x, x_lens, t, style,cond
self, # d is channel,n is T
x0: float["b n d"], # nosied input audio # noqa: F722
cond0: float["b n d"], # masked cond audio # noqa: F722
x_lens,
@ -142,13 +143,11 @@ class DiT(nn.Module):
drop_audio_cond=False, # cfg for cond audio
drop_text=False, # cfg for text
# mask: bool["b n"] | None = None, # noqa: F722
):
x=x0.transpose(2,1)
cond=cond0.transpose(2,1)
text=text0.transpose(2,1)
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
x = x0.transpose(2, 1)
cond = cond0.transpose(2, 1)
text = text0.transpose(2, 1)
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
@ -157,8 +156,8 @@ class DiT(nn.Module):
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
dt = self.d_embed(dt_base_bootstrap)
t+=dt
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
t += dt
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@ -391,6 +391,7 @@ class Attention(nn.Module):
# Attention processor
# from torch.nn.attention import SDPBackend
# torch.backends.cuda.enable_flash_sdp(True)
class AttnProcessor:
@ -545,6 +546,7 @@ class JointAttnProcessor:
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()

View File

@ -1,6 +1,3 @@
from . import cnhubert, whisper_enc
content_module_map = {
'cnhubert': cnhubert,
'whisper': whisper_enc
}
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}

View File

@ -1,10 +1,11 @@
import torch
import os
from transformers import logging as tf_logging
tf_logging.set_verbosity_error()
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import (
@ -19,21 +20,19 @@ cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self, base_path:str=None):
def __init__(self, base_path: str = None):
super().__init__()
if base_path is None:
base_path = cnhubert_base_path
if os.path.exists(base_path):...
else:raise FileNotFoundError(base_path)
if os.path.exists(base_path):
...
else:
raise FileNotFoundError(base_path)
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_path, local_files_only=True
)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
def forward(self, x):
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats

View File

@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
:1, :feature_len, :
].transpose(1, 2)
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
return feature

View File

@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
i18n = I18nAuto()
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
def synthesize(
GPT_model_path,
SoVITS_model_path,
ref_audio_path,
ref_text_path,
ref_language,
target_text_path,
target_language,
output_path,
):
# Read reference text
with open(ref_text_path, 'r', encoding='utf-8') as file:
with open(ref_text_path, "r", encoding="utf-8") as file:
ref_text = file.read()
# Read target text
with open(target_text_path, 'r', encoding='utf-8') as file:
with open(target_text_path, "r", encoding="utf-8") as file:
target_text = file.read()
# Change model weights
@ -21,11 +31,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
change_sovits_weights(sovits_path=SoVITS_model_path)
# Synthesize audio
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
synthesis_result = get_tts_wav(
ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=i18n(ref_language),
text=target_text,
text_language=i18n(target_language), top_p=1, temperature=1)
text_language=i18n(target_language),
top_p=1,
temperature=1,
)
result_list = list(synthesis_result)
@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
print(f"Audio saved to {output_wav_path}")
def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
parser.add_argument('--target_text', required=True, help="Path to the target text file")
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
parser.add_argument('--output_path', required=True, help="Path to the output directory")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument(
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
)
parser.add_argument("--target_text", required=True, help="Path to the target text file")
parser.add_argument(
"--target_language",
required=True,
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
help="Language of the target text",
)
parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args()
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
synthesize(
args.gpt_model,
args.sovits_model,
args.ref_audio,
args.ref_text,
args.ref_language,
args.target_text,
args.target_language,
args.output_path,
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
import soundfile as sf
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle('GPT-SoVITS GUI')
self.setWindowTitle("GPT-SoVITS GUI")
self.setGeometry(800, 450, 950, 850)
self.setStyleSheet("""
@ -65,7 +66,8 @@ class GPTSoVITSGUI(QMainWindow):
license_text = (
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
)
license_label = QLabel(license_text)
license_label.setWordWrap(True)
@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
self.output_text = QTextEdit()
self.output_text.setReadOnly(True)
self.add_drag_drop_events([
self.add_drag_drop_events(
[
self.GPT_model_input,
self.SoVITS_model_input,
self.ref_audio_input,
self.ref_text_input,
self.target_text_input,
self.output_input,
])
]
)
self.synthesize_button = QPushButton("合成")
self.synthesize_button.clicked.connect(self.synthesize)
@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
def upload_ref_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path:
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
self.ref_text_input.setText(content)
def upload_target_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path:
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
self.target_text_input.setText(content)
@ -284,11 +288,13 @@ class GPTSoVITSGUI(QMainWindow):
change_sovits_weights(sovits_path=SoVITS_model_path)
self.SoVITS_Path = SoVITS_model_path
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
synthesis_result = get_tts_wav(
ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=language_combobox,
text=target_text,
text_language=target_language_combobox)
text_language=target_language_combobox,
)
result_list = list(synthesis_result)
@ -303,7 +309,7 @@ class GPTSoVITSGUI(QMainWindow):
self.output_text.append("处理结果:\n" + result)
if __name__ == '__main__':
if __name__ == "__main__":
app = QApplication(sys.argv)
mainWin = GPTSoVITSGUI()
mainWin.show()

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,19 @@
'''
"""
按中英混合识别
按日英混合识别
多语种启动切分识别语种
全部按中文识别
全部按英文识别
全部按日文识别
'''
"""
import random
import os
import re
import logging
import json
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
@ -27,8 +29,10 @@ import torch
try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...
analytics.version_check = lambda: None
except:
...
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
@ -43,15 +47,15 @@ gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2")
version = os.environ.get("version", "v2")
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
@ -68,30 +72,30 @@ else:
# device = "cpu"
dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja", # 全部按日文识别
i18n("中英混合"): "zh", # 按中英混合识别####不变
i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
}
dict_language_v2 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("粤语"): "all_yue",#全部按中文识别
i18n("韩文"): "all_ko",#全部按韩文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("粤英混合"): "yue",#按粤英混合识别####不变
i18n("韩英混合"): "ko",#按韩英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja", # 全部按日文识别
i18n("粤语"): "all_yue", # 全部按中文识别
i18n("韩文"): "all_ko", # 全部按韩文识别
i18n("中英混合"): "zh", # 按中英混合识别####不变
i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
}
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
cut_method = {
i18n("不切"):"cut0",
i18n("不切"): "cut0",
i18n("凑四句一切"): "cut1",
i18n("凑50字一切"): "cut2",
i18n("按中文句号。切"): "cut3",
@ -118,22 +122,33 @@ gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
version = tts_config.version
def inference(text, text_lang,
def inference(
text,
text_lang,
ref_audio_path,
aux_ref_audio_paths,
prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty, sample_steps, super_sampling,
):
prompt_lang,
top_k,
top_p,
temperature,
text_split_method,
batch_size,
speed_factor,
ref_text_free,
split_bucket,
fragment_interval,
seed,
keep_random,
parallel_infer,
repetition_penalty,
sample_steps,
super_sampling,
):
seed = -1 if keep_random else seed
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
inputs={
inputs = {
"text": text,
"text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path,
@ -144,12 +159,12 @@ def inference(text, text_lang,
"top_p": top_p,
"temperature": temperature,
"text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
"fragment_interval":fragment_interval,
"seed":actual_seed,
"batch_size": int(batch_size),
"speed_factor": float(speed_factor),
"split_bucket": split_bucket,
"return_fragment": False,
"fragment_interval": fragment_interval,
"seed": actual_seed,
"parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty,
"sample_steps": int(sample_steps),
@ -159,11 +174,12 @@ def inference(text, text_lang,
for item in tts_pipeline.run(inputs):
yield item, actual_seed
except NO_PROMPT_ERROR:
gr.Warning(i18n('V3不支持无参考文本模式请填写参考文本'))
gr.Warning(i18n("V3不支持无参考文本模式请填写参考文本"))
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
parts = re.split("(\d+)", s)
# 将数字部分转换为整数,非数字部分保持不变
parts = [int(part) if part.isdigit() else part for part in parts]
return parts
@ -171,52 +187,67 @@ def custom_sort_key(s):
def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
"choices": sorted(GPT_names, key=custom_sort_key),
"__type__": "update",
}
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
_ =[[],[]]
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
pretrained_sovits_name = [
"GPT_SoVITS/pretrained_models/s2G488k.pth",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
path_sovits_v3,
]
pretrained_gpt_name = [
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
]
_ = [[], []]
for i in range(3):
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name,pretrained_sovits_name = _
if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name, pretrained_sovits_name = _
if os.path.exists("./weight.json"):
pass
else:
with open("./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
with open("./weight.json", "w", encoding="utf-8") as file:
json.dump({"GPT": {}, "SoVITS": {}}, file)
with open("./weight.json", 'r', encoding="utf-8") as file:
with open("./weight.json", "r", encoding="utf-8") as file:
weight_data = file.read()
weight_data=json.loads(weight_data)
gpt_path = os.environ.get(
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
sovits_path = os.environ.get(
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
if isinstance(gpt_path,list):
weight_data = json.loads(weight_data)
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
if isinstance(gpt_path, list):
gpt_path = gpt_path[0]
if isinstance(sovits_path,list):
if isinstance(sovits_path, list):
sovits_path = sovits_path[0]
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
for path in SoVITS_weight_root + GPT_weight_root:
os.makedirs(path, exist_ok=True)
SoVITS_weight_root=["SoVITS_weights","SoVITS_weights_v2","SoVITS_weights_v3"]
GPT_weight_root=["GPT_weights","GPT_weights_v2","GPT_weights_v3"]
for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True)
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names = [i for i in pretrained_sovits_name]
for path in SoVITS_weight_root:
for name in os.listdir(path):
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
if name.endswith(".pth"):
SoVITS_names.append("%s/%s" % (path, name))
GPT_names = [i for i in pretrained_gpt_name]
for path in GPT_weight_root:
for name in os.listdir(path):
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
if name.endswith(".ckpt"):
GPT_names.append("%s/%s" % (path, name))
return SoVITS_names, GPT_names
@ -224,72 +255,110 @@ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global version, dict_language
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 and not os.path.exists(path_sovits_v3):
info= path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
tts_pipeline.init_vits_weights(sovits_path)
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
dict_language = dict_language_v1 if tts_pipeline.configs.version == "v1" else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
prompt_text_update, prompt_language_update = (
{"__type__": "update"},
{"__type__": "update", "value": prompt_language},
)
else:
prompt_text_update = {'__type__':'update', 'value':''}
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
prompt_text_update = {"__type__": "update", "value": ""}
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
if text_language in list(dict_language.keys()):
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
else:
text_update = {'__type__':'update', 'value':''}
text_language_update = {'__type__':'update', 'value':i18n("中文")}
if model_version=="v3":
visible_sample_steps=True
visible_inp_refs=False
text_update = {"__type__": "update", "value": ""}
text_language_update = {"__type__": "update", "value": i18n("中文")}
if model_version == "v3":
visible_sample_steps = True
visible_inp_refs = False
else:
visible_sample_steps=False
visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
visible_sample_steps = False
visible_inp_refs = True
yield (
{"__type__": "update", "choices": list(dict_language.keys())},
{"__type__": "update", "choices": list(dict_language.keys())},
prompt_text_update,
prompt_language_update,
text_update,
text_language_update,
{"__type__": "update", "visible": visible_sample_steps},
{"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False},
{"__type__": "update", "visible": True if model_version == "v3" else False},
)
with open("./weight.json") as f:
data = f.read()
data = json.loads(data)
data["SoVITS"][version] = sovits_path
with open("./weight.json", "w") as f:
f.write(json.dumps(data))
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
data["SoVITS"][version]=sovits_path
with open("./weight.json","w")as f:f.write(json.dumps(data))
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
+ "<br>"
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
)
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
GPT_dropdown = gr.Dropdown(
label=i18n("GPT模型列表"),
choices=sorted(GPT_names, key=custom_sort_key),
value=gpt_path,
interactive=True,
)
SoVITS_dropdown = gr.Dropdown(
label=i18n("SoVITS模型列表"),
choices=sorted(SoVITS_names, key=custom_sort_key),
value=sovits_path,
interactive=True,
)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
with gr.Row():
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"), type="filepath")
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple")
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"), file_count="multiple")
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
with gr.Row():
prompt_language = gr.Dropdown(
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"<br>"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。"))
ref_text_free = gr.Checkbox(
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
value=False,
interactive=True,
show_label=True,
)
gr.Markdown(
i18n("使用无参考文本模式时建议使用微调的GPT")
+ "<br>"
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
)
with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
@ -298,42 +367,66 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Column():
with gr.Row():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
batch_size = gr.Slider(
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
)
sample_steps = gr.Radio(
label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True
)
with gr.Row():
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
fragment_interval = gr.Slider(
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
)
speed_factor = gr.Slider(
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
)
with gr.Row():
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
with gr.Row():
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
temperature = gr.Slider(
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
)
repetition_penalty = gr.Slider(
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
)
with gr.Column():
with gr.Row():
how_to_cut = gr.Dropdown(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
choices=[
i18n("不切"),
i18n("凑四句一切"),
i18n("凑50字一切"),
i18n("按中文句号。切"),
i18n("按英文句号.切"),
i18n("按标点符号切"),
],
value=i18n("凑四句一切"),
interactive=True, scale=1
interactive=True,
scale=1,
)
super_sampling = gr.Checkbox(
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
)
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
with gr.Row():
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
value=True,
interactive=True,
show_label=True,
)
with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1)
seed = gr.Number(label=i18n("随机种子"), value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
output = gr.Audio(label=i18n("输出的语音"))
@ -341,40 +434,67 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
inference_button.click(
inference,
[
text,text_language, inp_ref, inp_refs,
prompt_text, prompt_language,
top_k, top_p, temperature,
how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty, sample_steps, super_sampling,
text,
text_language,
inp_ref,
inp_refs,
prompt_text,
prompt_language,
top_k,
top_p,
temperature,
how_to_cut,
batch_size,
speed_factor,
ref_text_free,
split_bucket,
fragment_interval,
seed,
keep_random,
parallel_infer,
repetition_penalty,
sample_steps,
super_sampling,
],
[output, seed],
)
stop_infer.click(tts_pipeline.stop, [], [])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
SoVITS_dropdown.change(
change_sovits_weights,
[SoVITS_dropdown, prompt_language, text_language],
[prompt_language, text_language, prompt_text, prompt_language, text, text_language],
)
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
gr.Markdown(
value=i18n(
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
)
)
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
choices=[
i18n("不切"),
i18n("凑四句一切"),
i18n("凑50字一切"),
i18n("按中文句号。切"),
i18n("按英文句号.切"),
i18n("按标点符号切"),
],
value=i18n("凑四句一切"),
interactive=True,
)
cut_text= gr.Button(i18n("切分"), variant="primary")
cut_text = gr.Button(i18n("切分"), variant="primary")
def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]:
if len(text_inp.strip()) == 0 or text_inp == []:
return ""
method = get_method(cut_method[how_to_cut])
return method(text_inp)
@ -383,8 +503,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
if __name__ == '__main__':
app.queue().launch(#concurrency_count=511, max_size=1022
if __name__ == "__main__":
app.queue().launch( # concurrency_count=511, max_size=1022
server_name="0.0.0.0",
inbrowser=True,
share=is_share,

View File

@ -18,7 +18,7 @@ class Encoder(nn.Module):
p_dropout=0.0,
window_size=4,
isflow=False,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@ -56,9 +56,7 @@ class Encoder(nn.Module):
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
if isflow:
cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
@ -74,9 +72,7 @@ class Encoder(nn.Module):
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(
x, g_l, torch.IntTensor([self.hidden_channels])
)
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
@ -99,7 +95,7 @@ class Decoder(nn.Module):
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@ -131,9 +127,7 @@ class Decoder(nn.Module):
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
@ -153,9 +147,7 @@ class Decoder(nn.Module):
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
@ -211,14 +203,8 @@ class MultiHeadAttention(nn.Module):
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
@ -247,46 +233,28 @@ class MultiHeadAttention(nn.Module):
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
@ -320,9 +288,7 @@ class MultiHeadAttention(nn.Module):
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
@ -336,14 +302,10 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
def _absolute_position_to_relative_position(self, x):
@ -353,9 +315,7 @@ class MultiHeadAttention(nn.Module):
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
@ -537,9 +497,7 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
def weight_norm_modules(module, name="weight", dim=0):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module.weight_norm()
return module
else:
@ -547,9 +505,7 @@ def weight_norm_modules(module, name="weight", dim=0):
def remove_weight_norm_modules(module, name="weight"):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module.remove_weight_norm()
else:
remove_weight_norm(module, name)
@ -567,7 +523,7 @@ class FFT(nn.Module):
proximal_bias=False,
proximal_init=True,
isflow=False,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@ -579,9 +535,7 @@ class FFT(nn.Module):
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
if isflow:
cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
@ -622,18 +576,14 @@ class FFT(nn.Module):
if g is not None:
g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
x = x * x_mask
for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(
x, g_l, torch.IntTensor([self.hidden_channels])
)
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)

View File

@ -7,6 +7,7 @@ from module import commons
from typing import Optional
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
@ -43,7 +44,7 @@ class Encoder(nn.Module):
p_dropout=0.0,
window_size=4,
isflow=True,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@ -65,13 +66,9 @@ class Encoder(nn.Module):
if self.gin_channels != 0:
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = (
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
)
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
logging.debug(self.gin_channels, self.cond_layer_idx)
assert (
self.cond_layer_idx < self.n_layers
), "cond_layer_idx should be less than n_layers"
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
@ -121,7 +118,9 @@ class Encoder(nn.Module):
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
):
y = attn_layers(x, x, attn_mask)
y = self.drop(y)
x = norm_layers_1(x + y)
@ -170,14 +169,8 @@ class MultiHeadAttention(nn.Module):
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
@ -187,7 +180,7 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
@ -198,7 +191,7 @@ class MultiHeadAttention(nn.Module):
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, _ = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
@ -224,7 +217,7 @@ class MultiHeadAttention(nn.Module):
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
output = output.transpose(2, 3).contiguous().view(b, d, -1)
return output, p_attn
def _matmul_with_relative_values(self, x, y):
@ -248,19 +241,17 @@ class MultiHeadAttention(nn.Module):
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
slice_end_position = slice_start_position + 2 * length - 1
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
@ -274,14 +265,10 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
def _absolute_position_to_relative_position(self, x):
@ -291,9 +278,7 @@ class MultiHeadAttention(nn.Module):
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
@ -395,12 +380,6 @@ class MRTE(nn.Module):
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x

View File

@ -28,9 +28,7 @@ def intersperse(lst, item):
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
return kl
@ -67,9 +65,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)

View File

@ -30,6 +30,7 @@
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
from einops import rearrange, repeat
@ -121,9 +122,7 @@ class EuclideanCodebook(nn.Module):
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
uniform_init if not kmeans_init else torch.zeros
)
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
@ -151,9 +150,7 @@ class EuclideanCodebook(nn.Module):
# broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
@ -174,11 +171,7 @@ class EuclideanCodebook(nn.Module):
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices
return embed_ind
@ -222,8 +215,7 @@ class EuclideanCodebook(nn.Module):
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
@ -264,12 +256,8 @@ class VectorQuantization(nn.Module):
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
self.epsilon = epsilon
self.commitment_weight = commitment_weight
@ -330,13 +318,9 @@ class ResidualVectorQuantization(nn.Module):
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
def forward(
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
quantized_out = 0.0
residual = x
@ -359,9 +343,7 @@ class ResidualVectorQuantization(nn.Module):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)

View File

@ -5,11 +5,14 @@ import torch
import torch.utils.data
from tqdm import tqdm
from module.mel_processing import spectrogram_torch,spec_to_mel_torch
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
from text import cleaned_text_to_sequence
import torch.nn.functional as F
from tools.my_utils import load_audio
version = os.environ.get('version',None)
version = os.environ.get("version", None)
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
@ -34,7 +37,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for line in lines:
tmp = line.split("\t")
if (len(tmp) != 4):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
@ -42,7 +45,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if (leng < min_num):
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
@ -67,7 +70,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
@ -102,7 +105,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]):
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
@ -120,8 +123,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
return spec, audio_norm
@ -137,12 +141,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
return len(self.audiopaths_sid_text)
def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first", ssl.shape, wav.shape)
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
len_mel = mel.shape[1]
if self.val:
reference_mel = mel[:, :len_mel // 3]
reference_mel = mel[:, : len_mel // 3]
return reference_mel, ssl, wav, mel
dir = random.randint(0, 1)
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
@ -150,20 +153,29 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
if dir == 0:
reference_mel = mel[:, :sep_point]
ssl = ssl[:, :, sep_point:]
wav2 = wav[:, sep_point * self.hop_length:]
wav2 = wav[:, sep_point * self.hop_length :]
mel = mel[:, sep_point:]
else:
reference_mel = mel[:, sep_point:]
ssl = ssl[:, :, :sep_point]
wav2 = wav[:, :sep_point * self.hop_length]
wav2 = wav[:, : sep_point * self.hop_length]
mel = mel[:, :sep_point]
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
ssl.shape,
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
"""
class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
@ -175,9 +187,7 @@ class TextAudioSpeakerCollate():
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@ -205,22 +215,24 @@ class TextAudioSpeakerCollate():
row = batch[ids_sorted_decreasing[i]]
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
text = row[3]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
@ -244,7 +256,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
for line in lines:
tmp = line.split("\t")
if (len(tmp) != 4):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
@ -252,7 +264,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if (leng < min_num):
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
@ -277,7 +289,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
@ -304,15 +316,16 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
self.spec_min=-12
self.spec_max=2
self.spec_min = -12
self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1024
self.hop_length_mel = 256
self.n_mel_channels = 100
self.sampling_rate_mel = 24000
self.mel_fmin = 0
self.mel_fmax = None
self.filter_length_mel=self.win_length_mel=1024
self.hop_length_mel=256
self.n_mel_channels=100
self.sampling_rate_mel=24000
self.mel_fmin=0
self.mel_fmax=None
def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
@ -323,7 +336,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]):
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
@ -338,25 +351,35 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
return (ssl, spec, mel, text)
def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio=torch.FloatTensor(audio_array)#/32768
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24=torch.FloatTensor(audio_array24)#/32768
audio_array24 = load_audio(
filename, 24000
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24 = torch.FloatTensor(audio_array24) # /32768
audio_norm24 = audio24
audio_norm24 = audio_norm24.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
spec1 = spectrogram_torch(
audio_norm24,
self.filter_length_mel,
self.sampling_rate_mel,
self.hop_length_mel,
self.win_length_mel,
center=False,
)
mel = spec_to_mel_torch(
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
)
mel = torch.squeeze(mel, 0)
mel=self.norm_spec(mel)
mel = self.norm_spec(mel)
# print(1111111,spec.shape,mel.shape)
return spec, mel
@ -370,9 +393,10 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
def __len__(self):
return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV3():
""" Zero-pads model inputs and targets
"""
class TextAudioSpeakerCollateV3:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
@ -383,12 +407,10 @@ class TextAudioSpeakerCollateV3():
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
#ssl, spec, wav,mel, text
# ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
#(ssl, spec,mel, text)
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
# (ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
@ -402,7 +424,7 @@ class TextAudioSpeakerCollateV3():
# max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[3].size(0) for x in batch])
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
@ -426,11 +448,11 @@ class TextAudioSpeakerCollateV3():
row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
# wav = row[2]
@ -438,15 +460,17 @@ class TextAudioSpeakerCollateV3():
# wav_lengths[i] = wav.size(1)
mel = row[2]
mel_padded[i, :, :mel.size(1)] = mel
mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1)
text = row[3]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
@ -470,7 +494,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
for line in lines:
tmp = line.split("\t")
if (len(tmp) != 4):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
@ -478,7 +502,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if (leng < min_num):
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
@ -503,7 +527,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
@ -530,15 +554,16 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
self.spec_min=-12
self.spec_max=2
self.spec_min = -12
self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1024
self.hop_length_mel = 256
self.n_mel_channels = 100
self.sampling_rate_mel = 24000
self.mel_fmin = 0
self.mel_fmax = None
self.filter_length_mel=self.win_length_mel=1024
self.hop_length_mel=256
self.n_mel_channels=100
self.sampling_rate_mel=24000
self.mel_fmin=0
self.mel_fmax=None
def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
@ -546,10 +571,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids)
try:
spec, mel,wav = self.get_audio("%s/%s" % (self.path5, audiopath))
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]):
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
@ -564,27 +589,37 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
return (ssl, spec, wav, mel, text)
def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio=torch.FloatTensor(audio_array)#/32768
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24=torch.FloatTensor(audio_array24)#/32768
audio_array24 = load_audio(
filename, 24000
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24 = torch.FloatTensor(audio_array24) # /32768
audio_norm24 = audio24
audio_norm24 = audio_norm24.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
spec1 = spectrogram_torch(
audio_norm24,
self.filter_length_mel,
self.sampling_rate_mel,
self.hop_length_mel,
self.win_length_mel,
center=False,
)
mel = spec_to_mel_torch(
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
)
mel = torch.squeeze(mel, 0)
mel=self.norm_spec(mel)
mel = self.norm_spec(mel)
# print(1111111,spec.shape,mel.shape)
return spec, mel,audio_norm
return spec, mel, audio_norm
def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
@ -596,9 +631,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
def __len__(self):
return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV3b():
""" Zero-pads model inputs and targets
"""
class TextAudioSpeakerCollateV3b:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
@ -609,12 +645,10 @@ class TextAudioSpeakerCollateV3b():
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
#ssl, spec, wav,mel, text
# ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
#(ssl, spec,mel, text)
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
# (ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
@ -627,7 +661,7 @@ class TextAudioSpeakerCollateV3b():
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[4].size(0) for x in batch])
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
@ -651,28 +685,40 @@ class TextAudioSpeakerCollateV3b():
row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
mel = row[3]
mel_padded[i, :, :mel.size(1)] = mel
mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1)
text = row[4]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return (
ssl_padded,
spec_padded,
mel_padded,
ssl_lengths,
spec_lengths,
text_padded,
text_lengths,
wav_padded,
wav_lengths,
mel_lengths,
)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
Maintain similar input lengths in a batch.
@ -736,12 +782,12 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
num_samples_bucket = self.num_samples_per_bucket[i]
rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
ids_bucket = ids_bucket[self.rank::self.num_replicas]
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
batches.append(batch)
if self.shuffle:

View File

@ -65,8 +65,6 @@ def mle_loss(z, m, logs, logdet, mask):
torch.exp(-2 * logs) * ((z - m) ** 2)
) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum(
torch.ones_like(z) * mask
) # averaging across batch, channel and time axes
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l

View File

@ -47,9 +47,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
@ -79,20 +77,14 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec
def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
@ -103,16 +95,10 @@ def mel_spectrogram_torch(
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),

View File

@ -1,4 +1,5 @@
import warnings
warnings.filterwarnings("ignore")
import math
@ -15,6 +16,7 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
@ -46,29 +48,21 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -89,10 +83,7 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -100,13 +91,8 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
@ -115,18 +101,12 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
@ -135,9 +115,7 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
@ -147,13 +125,9 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -188,7 +162,7 @@ class TextEncoder(nn.Module):
kernel_size,
p_dropout,
latent_channels=192,
version = "v2",
version="v2",
):
super().__init__()
self.out_channels = out_channels
@ -235,26 +209,22 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1,test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze(
commons.sequence_mask(text_lengths, text.size(1)), 1
).to(y.dtype)
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
if test == 1:
text[:, :] = 0
text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask)
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if(speed!=1):
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
@ -358,9 +328,7 @@ class PosteriorEncoder(nn.Module):
def forward(self, x, x_lengths, g=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
@ -370,14 +338,9 @@ class PosteriorEncoder(nn.Module):
class Encoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -392,7 +355,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths, g=None):
if(g!=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
@ -400,6 +363,7 @@ class Encoder(nn.Module):
stats = self.proj(x) * x_mask
return stats, x_mask
class WNEncoder(nn.Module):
def __init__(
self,
@ -432,9 +396,7 @@ class WNEncoder(nn.Module):
self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask
@ -457,9 +419,7 @@ class Generator(torch.nn.Module):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
@ -479,9 +439,7 @@ class Generator(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -634,9 +592,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
@ -736,10 +692,7 @@ class Quantizer(torch.nn.Module):
super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList(
[
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
)
self.n_code_groups = n_code_groups
self.embed_dim = embed_dim
@ -757,9 +710,7 @@ class Quantizer(torch.nn.Module):
z_q.append(_z_q)
min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
@ -799,13 +750,9 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q
@ -818,9 +765,7 @@ class CodePredictor(nn.Module):
x = x + g
x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
target = codes[1:].transpose(0, 1)
if not infer:
logits = logits.reshape(-1, self.dims)
@ -868,8 +813,8 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version = "v2",
**kwargs
version="v2",
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
@ -900,7 +845,7 @@ class SynthesizerTrn(nn.Module):
n_layers,
kernel_size,
p_dropout,
version = version,
version=version,
)
self.dec = Generator(
inter_channels,
@ -921,12 +866,10 @@ class SynthesizerTrn(nn.Module):
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
# self.version=os.environ.get("version","v1")
if(self.version=="v1"):
if self.version == "v1":
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
@ -943,13 +886,11 @@ class SynthesizerTrn(nn.Module):
self.freeze_quantizer = freeze_quantizer
def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
if(self.version=="v1"):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
@ -957,24 +898,16 @@ class SynthesizerTrn(nn.Module):
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(
z, y_lengths, self.segment_size
)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=ge)
return (
o,
@ -987,24 +920,18 @@ class SynthesizerTrn(nn.Module):
)
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
if(self.version=="v1"):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge, test=test
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -1013,39 +940,34 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
def get_ge(refer):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
if (self.version == "v1"):
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge
if(type(refer)==list):
ges=[]
if type(refer) == list:
ges = []
for _refer in refer:
ge=get_ge(_refer)
ge = get_ge(_refer)
ges.append(ge)
ge=torch.stack(ges,0).mean(0)
ge = torch.stack(ges, 0).mean(0)
else:
ge=get_ge(refer)
ge = get_ge(refer)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge,speed
)
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -1057,11 +979,10 @@ class SynthesizerTrn(nn.Module):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
class CFM(torch.nn.Module):
def __init__(
self,
in_channels,dit
):
def __init__(self, in_channels, dit):
super().__init__()
self.sigma_min = 1e-6
@ -1075,41 +996,54 @@ class CFM(torch.nn.Module):
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
"""Forward diffusion"""
B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype) * temperature
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0
mu=mu.transpose(2,1)
mu = mu.transpose(2, 1)
t = 0
d = 1 / n_timesteps
for j in range(n_timesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
if inference_cfg_rate>1e-5:
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
v_pred = self.estimator(
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
).transpose(2, 1)
if inference_cfg_rate > 1e-5:
neg = self.estimator(
x,
prompt_x,
x_lens,
t_tensor,
d_tensor,
mu,
use_grad_ckpt=False,
drop_audio_cond=True,
drop_text=True,
).transpose(2, 1)
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
x = x + d * v_pred
t = t + d
x[:, :, :prompt_len] = 0
return x
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
b, _, t = x1.shape
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
x0 = torch.randn_like(x1,device=mu.device)
x0 = torch.randn_like(x1, device=mu.device)
vt = x1 - x0
xt = x0 + t[:, None, None] * vt
dt = torch.zeros_like(t,device=mu.device)
dt = torch.zeros_like(t, device=mu.device)
prompt = torch.zeros_like(x1)
for i in range(b):
prompt[i, :, :prompt_lens[i]] = x1[i, :, :prompt_lens[i]]
xt[i, :, :prompt_lens[i]] = 0
gailv=0.3# if ttime()>1736250488 else 0.1
prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]]
xt[i, :, : prompt_lens[i]] = 0
gailv = 0.3 # if ttime()>1736250488 else 0.1
if random.random() < gailv:
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
d = 1/torch.pow(2, base)
d = 1 / torch.pow(2, base)
d_input = d.clone()
d_input[d_input < 1e-2] = 0
# with torch.no_grad():
@ -1117,29 +1051,32 @@ class CFM(torch.nn.Module):
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
x_mid = xt + d[:, None, None] * v_pred_1
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
vt = (v_pred_1 + v_pred_2) / 2
vt = vt.detach()
dt = 2*d
dt = 2 * d
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1)
loss = 0
for i in range(b):
loss += self.criterion(vt_pred[i, :, prompt_lens[i]:x_lens[i]], vt[i, :, prompt_lens[i]:x_lens[i]])
loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]])
loss /= b
return loss
def set_no_grad(net_g):
for name, param in net_g.named_parameters():
param.requires_grad=False
param.requires_grad = False
class SynthesizerTrnV3(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
@ -1161,8 +1098,8 @@ class SynthesizerTrnV3(nn.Module):
semantic_frame_rate=None,
freeze_quantizer=None,
version="v3",
**kwargs):
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@ -1183,110 +1120,111 @@ class SynthesizerTrnV3(nn.Module):
self.gin_channels = gin_channels
self.version = version
self.model_dim=512
self.model_dim = 512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
# gin_channels=gin_channels)
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.freeze_quantizer=freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
if self.freeze_quantizer==True:
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer
inter_channels2 = 512
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
self.cfm = CFM(
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
if self.freeze_quantizer == True:
set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer)
set_no_grad(self.enc_p)
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
def forward(
self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
): # ssl_lengths no need now
with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()#
self.ssl_proj.eval() #
self.quantizer.eval()
self.enc_p.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
fea, y_mask_ = self.wns1(
fea, mel_lengths, ge
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
B = ssl.shape[0]
prompt_len_max = mel_lengths * 2 / 3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
minn = min(mel.shape[-1], fea.shape[-1])
mel = mel[:, :, :minn]
fea = fea[:, :, :minn]
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
return cfm_loss
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None,speed=1):
def decode_encp(self, codes, text, refer, ge=None, speed=1):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
if ge == None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
if speed==1:
sizee=int(codes.size(2)*2.5*1.5)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
if speed == 1:
sizee = int(codes.size(2) * 2.5 * 1.5)
else:
sizee=int(codes.size(2)*2.5*1.5/speed)+1
sizee = int(codes.size(2) * 2.5 * 1.5 / speed) + 1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge,speed)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge
return fea, ge
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)
return codes.transpose(0, 1)
class SynthesizerTrnV3b(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
@ -1307,8 +1245,8 @@ class SynthesizerTrnV3b(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs):
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@ -1328,47 +1266,52 @@ class SynthesizerTrnV3b(nn.Module):
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.model_dim=512
self.model_dim = 512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
gin_channels=gin_channels)
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.freeze_quantizer=freeze_quantizer
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
inter_channels2 = 512
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
self.cfm = CFM(
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
with autocast(enabled=False):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
# ge=None
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
@ -1377,51 +1320,59 @@ class SynthesizerTrnV3b(nn.Module):
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
learned_mel = self.linear_mel(fea)
B=ssl.shape[0]
prompt_len_max = mel_lengths*2/3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)#
minn=min(mel.shape[-1],fea.shape[-1])
mel=mel[:,:,:minn]
fea=fea[:,:,:minn]
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)#fea==cond,y_lengths==target_mel_lengths#ge not need
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
B = ssl.shape[0]
prompt_len_max = mel_lengths * 2 / 3
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) #
minn = min(mel.shape[-1], fea.shape[-1])
mel = mel[:, :, :minn]
fea = fea[:, :, :minn]
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
return (
commit_loss,
cfm_loss,
F.mse_loss(learned_mel, mel),
o,
ids_slice,
y_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
quantized,
)
@torch.no_grad()
def decode_encp(self, codes,text, refer,ge=None):
def decode_encp(self, codes, text, refer, ge=None):
# print(2333333,refer.shape)
# ge=None
if(ge==None):
if ge == None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2)*2)]).to(codes.device)
y_lengths1 = torch.LongTensor([int(codes.size(2)*2.5*1.5)]).to(codes.device)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea,ge
return fea, ge
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)
return codes.transpose(0, 1)

View File

@ -14,6 +14,7 @@ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
@ -42,29 +43,21 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -85,10 +78,7 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -96,13 +86,8 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
@ -111,18 +96,12 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
@ -131,9 +110,7 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
@ -143,13 +120,9 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -232,7 +205,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge, speed=1):
y_mask = torch.ones_like(y[:1,:1,:])
y_mask = torch.ones_like(y[:1, :1, :])
y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
@ -244,8 +217,8 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if(speed!=1):
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
@ -331,9 +304,7 @@ class PosteriorEncoder(nn.Module):
def forward(self, x, x_lengths, g=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
@ -343,14 +314,9 @@ class PosteriorEncoder(nn.Module):
class Encoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -365,7 +331,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths, g=None):
if(g!=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
@ -373,6 +339,7 @@ class Encoder(nn.Module):
stats = self.proj(x) * x_mask
return stats, x_mask
class WNEncoder(nn.Module):
def __init__(
self,
@ -405,9 +372,7 @@ class WNEncoder(nn.Module):
self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask
@ -430,9 +395,7 @@ class Generator(torch.nn.Module):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
@ -452,9 +415,7 @@ class Generator(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -463,7 +424,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g:Optional[torch.Tensor]=None):
def forward(self, x, g: Optional[torch.Tensor] = None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
@ -607,9 +568,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
@ -709,10 +668,7 @@ class Quantizer(torch.nn.Module):
super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList(
[
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
)
self.n_code_groups = n_code_groups
self.embed_dim = embed_dim
@ -730,9 +686,7 @@ class Quantizer(torch.nn.Module):
z_q.append(_z_q)
min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
@ -772,13 +726,9 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q
@ -791,9 +741,7 @@ class CodePredictor(nn.Module):
x = x + g
x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
target = codes[1:].transpose(0, 1)
if not infer:
logits = logits.reshape(-1, self.dims)
@ -842,7 +790,7 @@ class SynthesizerTrn(nn.Module):
semantic_frame_rate=None,
freeze_quantizer=None,
version="v2",
**kwargs
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
@ -894,9 +842,7 @@ class SynthesizerTrn(nn.Module):
# 16,
# gin_channels=gin_channels,
# )
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
# self.version=os.environ.get("version","v1")
if self.version == "v1":
@ -921,9 +867,9 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"):
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
refer_mask = torch.ones_like(refer[:1, :1, :])
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
@ -933,9 +879,7 @@ class SynthesizerTrn(nn.Module):
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, text, ge, speed
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
@ -949,11 +893,9 @@ class SynthesizerTrn(nn.Module):
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
class CFM(torch.nn.Module):
def __init__(
self,
in_channels,dit
):
def __init__(self, in_channels, dit):
super().__init__()
# self.sigma_min = 1e-6
@ -963,27 +905,34 @@ class CFM(torch.nn.Module):
# self.criterion = torch.nn.MSELoss()
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0):
def forward(
self,
mu: torch.Tensor,
x_lens: torch.LongTensor,
prompt: torch.Tensor,
n_timesteps: torch.LongTensor,
temperature: float = 1.0,
):
"""Forward diffusion"""
B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device,dtype=mu.dtype)
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
ntimesteps = int(n_timesteps)
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x,dtype=mu.dtype)
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0.0
mu=mu.transpose(2,1)
t = torch.tensor(0.0,dtype=x.dtype,device=x.device)
d = torch.tensor(1.0/ntimesteps,dtype=x.dtype,device=x.device)
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
mu = mu.transpose(2, 1)
t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
d = torch.tensor(1.0 / ntimesteps, dtype=x.dtype, device=x.device)
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
for j in range(ntimesteps):
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu).transpose(2, 1)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu).transpose(2, 1)
# if inference_cfg_rate>1e-5:
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
@ -995,24 +944,28 @@ class CFM(torch.nn.Module):
def set_no_grad(net_g):
for name, param in net_g.named_parameters():
param.requires_grad=False
param.requires_grad = False
@torch.jit.script_if_tracing
def compile_codes_length(codes):
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
return y_lengths1 * 2.5 * 1.5
@torch.jit.script_if_tracing
def compile_ref_length(refer):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
return refer_lengths
class SynthesizerTrnV3(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
@ -1034,8 +987,8 @@ class SynthesizerTrnV3(nn.Module):
semantic_frame_rate=None,
freeze_quantizer=None,
version="v3",
**kwargs):
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@ -1056,41 +1009,38 @@ class SynthesizerTrnV3(nn.Module):
self.gin_channels = gin_channels
self.version = version
self.model_dim=512
self.model_dim = 512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
# gin_channels=gin_channels)
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
freeze_quantizer
inter_channels2=512
self.bridge=nn.Sequential(
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
nn.LeakyReLU()
)
self.wns1=Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8,gin_channels=gin_channels)
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
if freeze_quantizer==True:
inter_channels2 = 512
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
self.cfm = CFM(
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
if freeze_quantizer == True:
set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer)
set_no_grad(self.enc_p)
@ -1098,19 +1048,18 @@ class SynthesizerTrnV3(nn.Module):
def create_ge(self, refer):
refer_lengths = compile_ref_length(refer)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:,:704] * refer_mask, refer_mask)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge
def forward(self, codes, text,ge,speed=1):
y_lengths1=compile_codes_length(codes)
def forward(self, codes, text, ge, speed=1):
y_lengths1 = compile_codes_length(codes)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest")##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge,speed)
fea=self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest")##BCT
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea
@ -1118,4 +1067,4 @@ class SynthesizerTrnV3(nn.Module):
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)
return codes.transpose(0, 1)

View File

@ -52,11 +52,7 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
@ -156,9 +152,7 @@ class WN(torch.nn.Module):
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
@ -479,9 +473,7 @@ class ConvFlow(nn.Module):
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
@ -495,9 +487,7 @@ class ConvFlow(nn.Module):
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
@ -616,9 +606,7 @@ class MultiHeadAttention(nn.Module):
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_model, 0.5), dropout=dropout
)
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
self.fc = nn.Linear(n_head * d_v, d_model)
self.dropout = nn.Dropout(dropout)
@ -649,9 +637,7 @@ class MultiHeadAttention(nn.Module):
output, attn = self.attention(q, k, v, mask=slf_mask)
output = output.view(n_head, sz_b, len_x, d_v)
output = (
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
) # b x lq x (n*dv)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
output = self.fc(output)
@ -741,9 +727,7 @@ class MelStyleEncoder(nn.Module):
if mask is not None:
mask = (mask.int() == 0).squeeze(1)
max_len = x.shape[1]
slf_attn_mask = (
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
)
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
# spectral
x = self.spectral(x)
@ -785,9 +769,7 @@ class MelStyleEncoderVAE(nn.Module):
mu = self.fc1(enc_out)
logvar = self.fc2(enc_out)
posterior = D.Normal(mu, torch.exp(logvar))
kl_divergence = D.kl_divergence(
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
)
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
loss_kl = kl_divergence.mean()
z = posterior.rsample()
@ -825,9 +807,7 @@ class ActNorm(nn.Module):
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
device=x.device, dtype=x.dtype
)
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
x_len = torch.sum(x_mask, [1, 2])
if not self.initialized:
self.initialize(x, x_mask)
@ -856,9 +836,7 @@ class ActNorm(nn.Module):
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
)
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
self.bias.data.copy_(bias_init)
@ -873,9 +851,7 @@ class InvConvNear(nn.Module):
self.n_split = n_split
self.no_jacobian = no_jacobian
w_init = torch.linalg.qr(
torch.FloatTensor(self.n_split, self.n_split).normal_()
)[0]
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init)
@ -890,11 +866,7 @@ class InvConvNear(nn.Module):
x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
x = (
x.permute(0, 1, 3, 2, 4)
.contiguous()
.view(b, self.n_split, c // self.n_split, t)
)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
if reverse:
if hasattr(self, "weight_inv"):

View File

@ -31,32 +31,15 @@ class MRTE(nn.Module):
text_enc = self.text_pre(text * text_mask)
if test != None:
if test == 0:
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
elif test == 1:
x = ssl_enc + ge
elif test == 2:
x = (
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
else:
raise ValueError("test should be 0,1,2")
else:
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x
@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()

View File

@ -87,14 +87,10 @@ class ResidualVectorQuantizer(nn.Module):
raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.

View File

@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
**spline_kwargs,
)
return outputs, logabsdet
@ -175,8 +175,7 @@ def rational_quadratic_spline(
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
@ -190,12 +189,9 @@ def rational_quadratic_spline(
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator

View File

@ -1,22 +1,22 @@
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch
import torchaudio
from torch import nn
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
import soundfile
import os
import json
import os
import soundfile
from text import cleaned_text_to_sequence
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@ -101,22 +101,22 @@ class T2SModel(nn.Module):
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
self.first_stage_decoder = self.t2s_model.first_stage_decoder
self.stage_decoder = self.t2s_model.stage_decoder
#self.t2s_model = torch.jit.script(self.t2s_model)
# self.t2s_model = torch.jit.script(self.t2s_model)
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
early_stop_num = self.t2s_model.early_stop_num
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
prefix_len = prompts.shape[1]
#[1,N,512] [1,N]
# [1,N,512] [1,N]
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False
for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
@ -130,13 +130,11 @@ class T2SModel(nn.Module):
return y[:, -idx:].unsqueeze(0)
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
if dynamo:
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_encoder_export_output = torch.onnx.dynamo_export(
self.onnx_encoder,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
export_options=export_options
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
)
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
return
@ -148,13 +146,13 @@ class T2SModel(nn.Module):
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
output_names=["x", "prompts"],
dynamic_axes={
"ref_seq": {1 : "ref_length"},
"text_seq": {1 : "text_length"},
"ref_bert": {0 : "ref_length"},
"text_bert": {0 : "text_length"},
"ssl_content": {2 : "ssl_length"},
"ref_seq": {1: "ref_length"},
"text_seq": {1: "text_length"},
"ref_bert": {0: "ref_length"},
"text_bert": {0: "text_length"},
"ssl_content": {2: "ssl_length"},
},
opset_version=16
opset_version=16,
)
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
@ -165,11 +163,11 @@ class T2SModel(nn.Module):
input_names=["x", "prompts"],
output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={
"x": {1 : "x_length"},
"prompts": {1 : "prompts_length"},
"x": {1: "x_length"},
"prompts": {1: "prompts_length"},
},
verbose=False,
opset_version=16
opset_version=16,
)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
@ -180,23 +178,23 @@ class T2SModel(nn.Module):
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
dynamic_axes={
"iy": {1 : "iy_length"},
"ik": {1 : "ik_length"},
"iv": {1 : "iv_length"},
"iy_emb": {1 : "iy_emb_length"},
"ix_example": {1 : "ix_example_length"},
"iy": {1: "iy_length"},
"ik": {1: "ik_length"},
"iv": {1: "iv_length"},
"iy_emb": {1: "iy_emb_length"},
"ix_example": {1: "ix_example_length"},
},
verbose=False,
opset_version=16
opset_version=16,
)
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path, map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
@ -207,7 +205,7 @@ class VitsModel(nn.Module):
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
**self.hps.model,
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
@ -219,7 +217,7 @@ class VitsModel(nn.Module):
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
center=False,
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
@ -235,12 +233,16 @@ class GptSoVits(nn.Module):
audio = self.vits(text_seq, pred_semantic, ref_audio)
if debug:
import onnxruntime
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
audio1 = sess.run(None, {
"text_seq" : text_seq.detach().cpu().numpy(),
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
"ref_audio" : ref_audio.detach().cpu().numpy()
})
audio1 = sess.run(
None,
{
"text_seq": text_seq.detach().cpu().numpy(),
"pred_semantic": pred_semantic.detach().cpu().numpy(),
"ref_audio": ref_audio.detach().cpu().numpy(),
},
)
return audio, audio1
return audio
@ -254,12 +256,12 @@ class GptSoVits(nn.Module):
input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["audio"],
dynamic_axes={
"text_seq": {1 : "text_length"},
"pred_semantic": {2 : "pred_length"},
"ref_audio": {1 : "audio_length"},
"text_seq": {1: "text_length"},
"pred_semantic": {2: "pred_length"},
"ref_audio": {1: "audio_length"},
},
opset_version=17,
verbose=False
verbose=False,
)
@ -277,14 +279,67 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
ref_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"n",
"i2",
"h",
"ao3",
",",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
text_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
try:
os.mkdir(f"onnx/{project_name}")
@ -325,8 +380,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
if __name__ == "__main__":

View File

@ -12,8 +12,9 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
version = os.environ.get('version', None)
version = os.environ.get("version", None)
import traceback
import os.path
from text.cleaner import clean_text
@ -33,13 +34,13 @@ from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
@ -53,8 +54,10 @@ if os.path.exists(txt_path) == False:
# device = "mps"
else:
device = "cpu"
if os.path.exists(bert_pretrained_dir):...
else:raise FileNotFoundError(bert_pretrained_dir)
if os.path.exists(bert_pretrained_dir):
...
else:
raise FileNotFoundError(bert_pretrained_dir)
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True:
@ -83,12 +86,10 @@ if os.path.exists(txt_path) == False:
def process(data, res):
for name, text, lan in data:
try:
name=clean_path(name)
name = clean_path(name)
name = os.path.basename(name)
print(name)
phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan, version
)
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("", ","), lan, version)
path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph)
@ -128,9 +129,7 @@ if os.path.exists(txt_path) == False:
wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"])
if language in language_v1_to_language_v2.keys():
todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except:

View File

@ -2,26 +2,30 @@
import sys
import os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert
opt_dir= os.environ.get("opt_dir")
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import numpy as np
from scipy.io import wavfile
import librosa
now_dir = os.getcwd()
sys.path.append(now_dir)
from tools.my_utils import load_audio,clean_path
from tools.my_utils import load_audio, clean_path
# from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path
@ -36,90 +40,95 @@ from tools.my_utils import load_audio,clean_path
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir="%s/4-cnhubert"%(opt_dir)
wav32dir="%s/5-wav32k"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(hubert_dir,exist_ok=True)
os.makedirs(wav32dir,exist_ok=True)
maxx=0.95
alpha=0.5
hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
if torch.cuda.is_available():
device = "cuda:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
model=cnhubert.get_model()
model = cnhubert.get_model()
# is_half=False
if(is_half==True):
model=model.half().to(device)
if is_half == True:
model = model.half().to(device)
else:
model = model.to(device)
nan_fails=[]
def name2go(wav_name,wav_path):
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
if(os.path.exists(hubert_path)):return
nan_fails = []
def name2go(wav_name, wav_path):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if os.path.exists(hubert_path):
return
tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2:
print("%s-filtered,%s" % (wav_name, tmp_max))
return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
tmp_audio = librosa.resample(
tmp_audio32b, orig_sr=32000, target_sr=16000
)#不是重采样问题
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True):
tensor_wav16=tensor_wav16.half().to(device)
if is_half == True:
tensor_wav16 = tensor_wav16.half().to(device)
else:
tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append((wav_name,wav_path))
print("nan filtered:%s"%wav_name)
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
nan_fails.append((wav_name, wav_path))
print("nan filtered:%s" % wav_name)
return
wavfile.write(
"%s/%s"%(wav32dir,wav_name),
"%s/%s" % (wav32dir, wav_name),
32000,
tmp_audio32.astype("int16"),
)
my_save(ssl,hubert_path)
my_save(ssl, hubert_path)
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]:
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=clean_path(wav_name)
if (inp_wav_dir != "" and inp_wav_dir != None):
wav_name = clean_path(wav_name)
if inp_wav_dir != "" and inp_wav_dir != None:
wav_name = os.path.basename(wav_name)
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
else:
wav_path=wav_name
wav_path = wav_name
wav_name = os.path.basename(wav_name)
name2go(wav_name,wav_path)
name2go(wav_name, wav_path)
except:
print(line,traceback.format_exc())
print(line, traceback.format_exc())
if(len(nan_fails)>0 and is_half==True):
is_half=False
model=model.float()
if len(nan_fails) > 0 and is_half == True:
is_half = False
model = model.float()
for wav in nan_fails:
try:
name2go(wav[0],wav[1])
name2go(wav[0], wav[1])
except:
print(wav_name,traceback.format_exc())
print(wav_name, traceback.format_exc())

View File

@ -10,8 +10,10 @@ opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path")
if os.path.exists(pretrained_s2G):...
else:raise FileNotFoundError(pretrained_s2G)
if os.path.exists(pretrained_s2G):
...
else:
raise FileNotFoundError(pretrained_s2G)
# version=os.environ.get("version","v2")
size = os.path.getsize(pretrained_s2G)
if size < 82978 * 1024:
@ -25,6 +27,7 @@ elif size < 700 * 1024 * 1024:
else:
version = "v3"
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import sys
@ -33,11 +36,13 @@ now_dir = os.getcwd()
sys.path.append(now_dir)
import logging
import utils
if version!="v3":
if version != "v3":
from module.models import SynthesizerTrn
else:
from module.models import SynthesizerTrnV3 as SynthesizerTrn
from tools.my_utils import clean_path
logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G
@ -66,7 +71,7 @@ if os.path.exists(semantic_path) == False:
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
version=version,
**hps.model
**hps.model,
)
if is_half == True:
vq_model = vq_model.half().to(device)
@ -103,7 +108,7 @@ if os.path.exists(semantic_path) == False:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=clean_path(wav_name)
wav_name = clean_path(wav_name)
wav_name = os.path.basename(wav_name)
# name2go(name,lines1)
name2go(wav_name, lines1)

View File

@ -8,31 +8,37 @@ from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
'''
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
"""
00:v1
01:v2
02:v3
03:v3lora
'''
"""
from io import BytesIO
def my_save2(fea,path):
def my_save2(fea, path):
bio = BytesIO()
torch.save(fea, bio)
bio.seek(0)
data = bio.getvalue()
data = b'03' + data[2:]###temp for v3lora only, todo
with open(path, "wb") as f: f.write(data)
data = b"03" + data[2:] ###temp for v3lora only, todo
with open(path, "wb") as f:
f.write(data)
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
try:
opt = OrderedDict()
opt["weight"] = {}
@ -43,7 +49,7 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank:
opt["lora_rank"]=lora_rank
opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
@ -51,41 +57,48 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
except:
return traceback.format_exc()
head2version={
b'00':["v1","v1",False],
b'01':["v2","v2",False],
b'02':["v2","v3",False],
b'03':["v2","v3",True],
head2version = {
b"00": ["v1", "v1", False],
b"01": ["v2", "v2", False],
b"02": ["v2", "v3", False],
b"03": ["v2", "v3", True],
}
hash_pretrained_dict={
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
}
import hashlib
def get_hash_from_file(sovits_path):
with open(sovits_path,"rb")as f:data=f.read(8192)
with open(sovits_path, "rb") as f:
data = f.read(8192)
hash_md5 = hashlib.md5()
hash_md5.update(data)
return hash_md5.hexdigest()
def get_sovits_version_from_path_fast(sovits_path):
###1-if it is pretrained sovits models, by hash
hash=get_hash_from_file(sovits_path)
hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash]
###2-new weights or old weights, by head
with open(sovits_path,"rb")as f:version=f.read(2)
if version!=b"PK":
with open(sovits_path, "rb") as f:
version = f.read(2)
if version != b"PK":
return head2version[version]
###3-old weights, by file size
if_lora_v3=False
size=os.path.getsize(sovits_path)
'''
if_lora_v3 = False
size = os.path.getsize(sovits_path)
"""
v1weights:about 82942KB
half thr:82978KB
v2weights:about 83014KB
v3weights:about 750MB
'''
"""
if size < 82978 * 1024:
model_version = version = "v1"
elif size < 700 * 1024 * 1024:
@ -93,15 +106,16 @@ def get_sovits_version_from_path_fast(sovits_path):
else:
version = "v2"
model_version = "v3"
return version,model_version,if_lora_v3
return version, model_version, if_lora_v3
def load_sovits_new(sovits_path):
f=open(sovits_path,"rb")
meta=f.read(2)
if meta!="PK":
data = b'PK' + f.read()
f = open(sovits_path, "rb")
meta = f.read(2)
if meta != "PK":
data = b"PK" + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path,map_location="cpu", weights_only=False)
return torch.load(sovits_path, map_location="cpu", weights_only=False)

View File

@ -5,25 +5,24 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
import platform
from pathlib import Path
import torch
import platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict
from AR.utils import get_newest_ckpt
from process_ckpt import my_save
@ -35,7 +34,7 @@ class my_model_ckpt(ModelCheckpoint):
if_save_every_weights,
half_weights_save_dir,
exp_name,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self.if_save_latest = if_save_latest
@ -48,10 +47,7 @@ class my_model_ckpt(ModelCheckpoint):
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
if self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if (
self._every_n_epochs >= 1
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
if (
self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
@ -73,7 +69,7 @@ class my_model_ckpt(ModelCheckpoint):
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
# torch.save(
# print(os.environ)
if(os.environ.get("LOCAL_RANK","0")=="0"):
if os.environ.get("LOCAL_RANK", "0") == "0":
my_save(
to_save_od,
"%s/%s-e%s.ckpt"
@ -110,7 +106,7 @@ def main(args):
dirpath=ckpt_dir,
)
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
os.environ["MASTER_ADDR"]="localhost"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["USE_LIBUV"] = "0"
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
@ -121,9 +117,9 @@ def main(args):
devices=-1 if torch.cuda.is_available() else 1,
benchmark=False,
fast_dev_run=False,
strategy = DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
) if torch.cuda.is_available() else "auto",
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
if torch.cuda.is_available()
else "auto",
precision=config["train"]["precision"],
logger=logger,
num_sanity_val_steps=0,
@ -131,9 +127,7 @@ def main(args):
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
)
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir
)
model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,

View File

@ -1,37 +1,41 @@
import warnings
warnings.filterwarnings("ignore")
import utils
import os
import utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from module import commons
from module.data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler,
TextAudioSpeakerCollate,
TextAudioSpeakerLoader,
)
from module.models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from module.models import (
MultiPeriodDiscriminator,
SynthesizerTrn,
)
from process_ckpt import savee
torch.backends.cudnn.benchmark = False
@ -47,7 +51,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
@ -75,7 +78,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
@ -129,19 +132,27 @@ def run(rank, n_gpus, hps):
# batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn(
net_g = (
SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
).cuda(rank)
if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
)
for name, param in net_g.named_parameters():
if not param.requires_grad:
print(name, "not requires_grad")
@ -194,7 +205,7 @@ def run(rank, n_gpus, hps):
try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
net_d,
optim_d,
) # D多半加载没事
@ -202,11 +213,11 @@ def run(rank, n_gpus, hps):
logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
net_g,
optim_g,
)
epoch_str+=1
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
@ -214,37 +225,55 @@ def run(rank, n_gpus, hps):
# traceback.print_exc()
epoch_str = 1
global_step = 0
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G,
print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
) if torch.cuda.is_available() else net_g.load_state_dict(
)
if torch.cuda.is_available()
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
),
) ##测试不加载优化器
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
if (
hps.train.pretrained_s2D != ""
and hps.train.pretrained_s2D != None
and os.path.exists(hps.train.pretrained_s2D)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
print("loaded pretrained %s" % hps.train.pretrained_s2D,
print(
"loaded pretrained %s" % hps.train.pretrained_s2D,
net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
) if torch.cuda.is_available() else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
)
if torch.cuda.is_available()
else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
),
)
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
optim_g,
gamma=hps.train.lr_decay,
last_epoch=-1,
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
optim_d,
gamma=hps.train.lr_decay,
last_epoch=-1,
)
for _ in range(epoch_str):
scheduler_g.step()
@ -286,9 +315,7 @@ def run(rank, n_gpus, hps):
print("training done")
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@ -312,17 +339,38 @@ def train_and_evaluate(
text_lengths,
) in enumerate(tqdm(train_loader)):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
y, y_lengths = (
y.cuda(
rank,
non_blocking=True,
),
y_lengths.cuda(
rank,
non_blocking=True,
),
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@ -351,9 +399,7 @@ def train_and_evaluate(
hps.data.mel_fmin,
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
@ -365,15 +411,14 @@ def train_and_evaluate(
hps.data.mel_fmax,
)
y = commons.slice_segments(
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
y_d_hat_r,
y_d_hat_g,
)
loss_disc_all = loss_disc
optim_d.zero_grad()
@ -406,7 +451,8 @@ def train_and_evaluate(
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
epoch,
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
@ -430,25 +476,37 @@ def train_and_evaluate(
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict=None
try:###Some people installed the wrong version of matplotlib.
image_dict = None
try: ###Some people installed the wrong version of matplotlib.
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
y_mel[0].data.cpu().numpy(),
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
y_hat_mel[0].data.cpu().numpy(),
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
mel[0].data.cpu().numpy(),
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy()
stats_ssl[0].data.cpu().numpy(),
),
}
except:pass
if image_dict:utils.summarize(writer=writer,global_step=global_step,images=image_dict,scalars=scalar_dict,)
else:utils.summarize(writer=writer,global_step=global_step,scalars=scalar_dict,)
except:
pass
if image_dict:
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)
else:
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
@ -458,7 +516,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(global_step),
),
)
utils.save_checkpoint(
@ -467,7 +526,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(global_step),
),
)
else:
@ -477,7 +537,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(233333333333),
),
)
utils.save_checkpoint(
@ -486,7 +547,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(233333333333),
),
)
if rank == 0 and hps.train.if_save_every_weights == True:
@ -541,10 +603,24 @@ def evaluate(hps, generator, eval_loader, writer_eval):
ssl = ssl.to(device)
text, text_lengths = text.to(device), text_lengths.to(device)
for test in [0, 1]:
y_hat, mask, *_ = generator.module.infer(
ssl, spec, spec_lengths, text, text_lengths, test=test
) if torch.cuda.is_available() else generator.infer(
ssl, spec, spec_lengths, text, text_lengths, test=test
y_hat, mask, *_ = (
generator.module.infer(
ssl,
spec,
spec_lengths,
text,
text_lengths,
test=test,
)
if torch.cuda.is_available()
else generator.infer(
ssl,
spec,
spec_lengths,
text,
text_lengths,
test=test,
)
)
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
@ -569,19 +645,19 @@ def evaluate(hps, generator, eval_loader, writer_eval):
image_dict.update(
{
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy()
)
y_hat_mel[0].cpu().numpy(),
),
}
)
audio_dict.update(
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
{
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
},
)
image_dict.update(
{
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy()
)
}
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
},
)
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})

View File

@ -1,29 +1,37 @@
import warnings
warnings.filterwarnings("ignore")
import utils
import os
import utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from module import commons
from module.data_utils import (
DistributedBucketSampler,
)
from module.data_utils import (
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
)
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
@ -43,7 +51,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
@ -71,7 +78,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
@ -125,17 +132,21 @@ def run(rank, n_gpus, hps):
# batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn(
net_g = (
SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
).cuda(rank)
if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
)
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
# for name, param in net_g.named_parameters():
@ -143,7 +154,7 @@ def run(rank, n_gpus, hps):
# print(name, "not requires_grad")
optim_g = torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
@ -171,11 +182,11 @@ def run(rank, n_gpus, hps):
# logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
net_g,
optim_g,
)
epoch_str+=1
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
@ -183,17 +194,24 @@ def run(rank, n_gpus, hps):
# traceback.print_exc()
epoch_str = 1
global_step = 0
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G,
print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
) if torch.cuda.is_available() else net_g.load_state_dict(
)
if torch.cuda.is_available()
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
),
) ##测试不加载优化器
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
# if rank == 0:
@ -209,9 +227,7 @@ def run(rank, n_gpus, hps):
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1
# )
@ -221,7 +237,7 @@ def run(rank, n_gpus, hps):
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None
net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
@ -257,7 +273,16 @@ def run(rank, n_gpus, hps):
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
rank,
epoch,
hps,
nets,
optims,
schedulers,
scaler,
loaders,
logger,
writers,
):
net_g, net_d = nets
optim_g, optim_d = optims
@ -281,19 +306,33 @@ def train_and_evaluate(
# text,
# text_lengths,
# ) in enumerate(tqdm(train_loader)):
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
tqdm(train_loader)
):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
rank, non_blocking=True
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@ -304,8 +343,18 @@ def train_and_evaluate(
text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
loss_gen_all=cfm_loss
cfm_loss = net_g(
ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
@ -315,12 +364,15 @@ def train_and_evaluate(
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
lr = optim_g.param_groups[0]["lr"]
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch,
100. * batch_idx / len(train_loader)))
100.0 * batch_idx / len(train_loader),
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
@ -334,7 +386,8 @@ def train_and_evaluate(
writer=writer,
global_step=global_step,
# images=image_dict,
scalars=scalar_dict)
scalars=scalar_dict,
)
# if global_step % hps.train.eval_interval == 0:
# # evaluate(hps, net_g, eval_loader, writer_eval)
@ -344,7 +397,6 @@ def train_and_evaluate(
# # if keep_ckpts > 0:
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
@ -354,7 +406,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(global_step),
),
)
# utils.save_checkpoint(
@ -373,7 +426,8 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(233333333333),
),
)
# utils.save_checkpoint(

View File

@ -1,35 +1,45 @@
import warnings
warnings.filterwarnings("ignore")
import utils
import os
import utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from collections import OrderedDict as od
from random import randint
from module import commons
from peft import LoraConfig, get_peft_model
from module.data_utils import (
DistributedBucketSampler,
)
from module.data_utils import (
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
)
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
)
from peft import LoraConfig, get_peft_model
from process_ckpt import savee
from collections import OrderedDict as od
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
@ -43,7 +53,6 @@ device = "cpu" # cuda以外的设备等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
@ -62,7 +71,7 @@ def main():
def run(rank, n_gpus, hps):
global global_step,no_grad_names,save_root,lora_rank
global global_step, no_grad_names, save_root, lora_rank
if rank == 0:
logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps)
@ -71,7 +80,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
@ -119,21 +128,24 @@ def run(rank, n_gpus, hps):
persistent_workers=True,
prefetch_factor=4,
)
save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
os.makedirs(save_root,exist_ok=True)
lora_rank=int(hps.train.lora_rank)
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
os.makedirs(save_root, exist_ok=True)
lora_rank = int(hps.train.lora_rank)
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
def get_model(hps):return SynthesizerTrn(
def get_model(hps):
return SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
def get_optim(net_g):
return torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
@ -141,61 +153,66 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas,
eps=hps.train.eps,
)
def model2cuda(net_g,rank):
def model2cuda(net_g, rank):
if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
return net_g
try:# 如果能加载自动resume
try: # 如果能加载自动resume
net_g = get_model(hps)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
optim_g=get_optim(net_g)
net_g = model2cuda(net_g, rank)
optim_g = get_optim(net_g)
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(save_root, "G_*.pth"),
net_g,
optim_g,
)
epoch_str+=1
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
except: # 如果首次不能加载加载pretrain
# traceback.print_exc()
epoch_str = 1
global_step = 0
net_g = get_model(hps)
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print("loaded pretrained %s" % hps.train.pretrained_s2G,
print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
),
)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
net_g = model2cuda(net_g, rank)
optim_g = get_optim(net_g)
no_grad_names=set()
no_grad_names = set()
for name, param in net_g.named_parameters():
if not param.requires_grad:
no_grad_names.add(name.replace("module.",""))
no_grad_names.add(name.replace("module.", ""))
# print(name, "not requires_grad")
# print(no_grad_names)
# os._exit(233333)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
for _ in range(epoch_str):
scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None
print("start training from epoch %s"%epoch_str)
net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
@ -227,9 +244,8 @@ def run(rank, n_gpus, hps):
scheduler_g.step()
print("training done")
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@ -241,18 +257,32 @@ def train_and_evaluate(
global global_step
net_g.train()
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
tqdm(train_loader)
):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
rank, non_blocking=True
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
@ -262,8 +292,18 @@ def train_and_evaluate(
text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
loss_gen_all=cfm_loss
cfm_loss = net_g(
ssl,
spec,
mel,
ssl_lengths,
spec_lengths,
text,
text_lengths,
mel_lengths,
use_grad_ckpt=hps.train.grad_ckpt,
)
loss_gen_all = cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
@ -273,18 +313,17 @@ def train_and_evaluate(
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
lr = optim_g.param_groups[0]["lr"]
losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict)
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
@ -294,9 +333,7 @@ def train_and_evaluate(
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(global_step)
),
os.path.join(save_root, "G_{}.pth".format(global_step)),
)
else:
utils.save_checkpoint(
@ -304,21 +341,19 @@ def train_and_evaluate(
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(233333333333)
),
os.path.join(save_root, "G_{}.pth".format(233333333333)),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
sim_ckpt=od()
sim_ckpt = od()
for key in ckpt:
# if "cfm"not in key:
# print(key)
if key not in no_grad_names:
sim_ckpt[key]=ckpt[key].half().cpu()
sim_ckpt[key] = ckpt[key].half().cpu()
logger.info(
"saving ckpt %s_e%s:%s"
% (
@ -326,10 +361,11 @@ def train_and_evaluate(
epoch,
savee(
sim_ckpt,
hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
epoch,
global_step,
hps,lora_rank=lora_rank
hps,
lora_rank=lora_rank,
),
)
)

View File

@ -3,19 +3,25 @@ import re
# jieba静音
import jieba
jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置
from pathlib import Path
import fast_langdetect
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
fast_langdetect.infer.LangDetectConfig(
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
)
)
from split_lang import LangSplitter
def full_en(text):
pattern = r'^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
pattern = r"^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
return bool(re.match(pattern, text))
@ -34,7 +40,7 @@ def full_cjk(text):
(0x2EBF0, 0x2EE5D), # CJK Extension H
]
pattern = r'[0-9、-〜。!?.!?… ]+$'
pattern = r"[0-9、-〜。!?.!?… ]+$"
cjk_text = ""
for char in text:
@ -45,7 +51,7 @@ def full_cjk(text):
return cjk_text
def split_jako(tag_lang,item):
def split_jako(tag_lang, item):
if tag_lang == "ja":
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
else:
@ -53,28 +59,28 @@ def split_jako(tag_lang,item):
lang_list: list[dict] = []
tag = 0
for match in re.finditer(pattern, item['text']):
for match in re.finditer(pattern, item["text"]):
if match.start() > tag:
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
tag = match.end()
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
if tag < len(item['text']):
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
if tag < len(item["text"]):
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
return lang_list
def merge_lang(lang_list, item):
if lang_list and item['lang'] == lang_list[-1]['lang']:
lang_list[-1]['text'] += item['text']
if lang_list and item["lang"] == lang_list[-1]["lang"]:
lang_list[-1]["text"] += item["text"]
else:
lang_list.append(item)
return lang_list
class LangSegmenter():
class LangSegmenter:
# 默认过滤器, 基于gsv目前四种语言
DEFAULT_LANG_MAP = {
"zh": "zh",
@ -87,7 +93,6 @@ class LangSegmenter():
"en": "en",
}
def getTexts(text):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
substr = lang_splitter.split_by_lang(text=text)
@ -95,18 +100,18 @@ class LangSegmenter():
lang_list: list[dict] = []
for _, item in enumerate(substr):
dict_item = {'lang':item.lang,'text':item.text}
dict_item = {"lang": item.lang, "text": item.text}
# 处理短英文被识别为其他语言的问题
if full_en(dict_item['text']):
dict_item['lang'] = 'en'
lang_list = merge_lang(lang_list,dict_item)
if full_en(dict_item["text"]):
dict_item["lang"] = "en"
lang_list = merge_lang(lang_list, dict_item)
continue
# 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = []
if dict_item['lang'] != 'ja':
ja_list = split_jako('ja',dict_item)
if dict_item["lang"] != "ja":
ja_list = split_jako("ja", dict_item)
if not ja_list:
ja_list.append(dict_item)
@ -115,8 +120,8 @@ class LangSegmenter():
ko_list: list[dict] = []
temp_list: list[dict] = []
for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != 'ko':
ko_list = split_jako('ko',ko_item)
if ko_item["lang"] != "ko":
ko_list = split_jako("ko", ko_item)
if ko_list:
temp_list.extend(ko_list)
@ -126,26 +131,26 @@ class LangSegmenter():
# 未存在非日韩文夹日韩文
if len(temp_list) == 1:
# 未知语言检查是否为CJK
if dict_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text'])
if dict_item["lang"] == "x":
cjk_text = full_cjk(dict_item["text"])
if cjk_text:
dict_item = {'lang':'zh','text':cjk_text}
lang_list = merge_lang(lang_list,dict_item)
dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list, dict_item)
continue
else:
lang_list = merge_lang(lang_list,dict_item)
lang_list = merge_lang(lang_list, dict_item)
continue
# 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK
if temp_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text'])
if temp_item["lang"] == "x":
cjk_text = full_cjk(dict_item["text"])
if cjk_text:
dict_item = {'lang':'zh','text':cjk_text}
lang_list = merge_lang(lang_list,dict_item)
dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list, dict_item)
else:
lang_list = merge_lang(lang_list,temp_item)
lang_list = merge_lang(lang_list, temp_item)
return lang_list
@ -155,4 +160,3 @@ if __name__ == "__main__":
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
print(LangSegmenter.getTexts(text))

View File

@ -10,18 +10,19 @@ from text import symbols2 as symbols_v2
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
def cleaned_text_to_sequence(cleaned_text, version=None):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
'''
if version is None:version=os.environ.get('version', 'v2')
"""
if version is None:
version = os.environ.get("version", "v2")
if version == "v1":
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
else:
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
return phones

View File

@ -98,9 +98,7 @@ def replace_punctuation(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
return replaced_text
@ -114,7 +112,9 @@ def text_normalize(text):
return dest_text
punctuation_set=set(punctuation)
punctuation_set = set(punctuation)
def jyuping_to_initials_finals_tones(jyuping_syllables):
initials_finals = []
tones = []
@ -159,12 +159,14 @@ def jyuping_to_initials_finals_tones(jyuping_syllables):
assert len(initials_finals) == len(tones)
###魔改为辅音+带音调的元音
phones=[]
for a,b in zip(initials_finals,tones):
if(b not in [-1,0]):###防止粤语和普通话重合开头加Y如果是标点不加。
todo="%s%s"%(a,b)
else:todo=a
if(todo not in punctuation_set):todo="Y%s"%todo
phones = []
for a, b in zip(initials_finals, tones):
if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y如果是标点不加。
todo = "%s%s" % (a, b)
else:
todo = a
if todo not in punctuation_set:
todo = "Y%s" % todo
phones.append(todo)
# return initials_finals, tones, word2ph

View File

@ -18,6 +18,7 @@ pinyin_to_symbol_map = {
import jieba_fast
import logging
jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg
@ -37,7 +38,7 @@ rep_map = {
"/": ",",
"": "-",
"~": "",
"":"",
"": "",
}
tone_modifier = ToneSandhi()
@ -49,9 +50,7 @@ def replace_punctuation(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
return replaced_text
@ -62,17 +61,15 @@ def replace_punctuation_with_en(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
)
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
return replaced_text
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result
@ -87,9 +84,7 @@ def _get_initials_finals(word):
initials = []
finals = []
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for c, v in zip(orig_initials, orig_finals):
initials.append(c)
finals.append(v)

View File

@ -19,17 +19,24 @@ pinyin_to_symbol_map = {
import jieba_fast
import logging
jieba_fast.setLogLevel(logging.CRITICAL)
import jieba_fast.posseg as psg
# is_g2pw_str = os.environ.get("is_g2pw", "True")##默认开启
# is_g2pw = False#True if is_g2pw_str.lower() == 'true' else False
is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
is_g2pw = True # True if is_g2pw_str.lower() == 'true' else False
if is_g2pw:
# print("当前使用g2pw进行拼音推理")
from text.g2pw import G2PWPinyin, correct_pronunciation
parent_directory = os.path.dirname(current_file_path)
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)
g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
model_source=os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
v_to_u=False,
neutral_tone_with_five=True,
)
rep_map = {
"": ",",
@ -46,7 +53,7 @@ rep_map = {
"/": ",",
"": "-",
"~": "",
"":"",
"": "",
}
tone_modifier = ToneSandhi()
@ -58,9 +65,7 @@ def replace_punctuation(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
return replaced_text
@ -77,9 +82,7 @@ def _get_initials_finals(word):
finals = []
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for c, v in zip(orig_initials, orig_finals):
initials.append(c)
@ -87,31 +90,66 @@ def _get_initials_finals(word):
return initials, finals
must_erhua = {
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
}
must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"}
not_erhua = {
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
"狗儿", "少儿"
"虐儿",
"为儿",
"护儿",
"瞒儿",
"救儿",
"替儿",
"有儿",
"一儿",
"我儿",
"俺儿",
"妻儿",
"拐儿",
"聋儿",
"乞儿",
"患儿",
"幼儿",
"孤儿",
"婴儿",
"婴幼儿",
"连体儿",
"脑瘫儿",
"流浪儿",
"体弱儿",
"混血儿",
"蜜雪儿",
"舫儿",
"祖儿",
"美儿",
"应采儿",
"可儿",
"侄儿",
"孙儿",
"侄孙儿",
"女儿",
"男儿",
"红孩儿",
"花儿",
"虫儿",
"马儿",
"鸟儿",
"猪儿",
"猫儿",
"狗儿",
"少儿",
}
def _merge_erhua(initials: list[str],
finals: list[str],
word: str,
pos: str) -> list[list[str]]:
def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> list[list[str]]:
"""
Do erhub.
"""
# fix er1
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn == 'er1':
finals[i] = 'er2'
if i == len(finals) - 1 and word[i] == "" and phn == "er1":
finals[i] = "er2"
# 发音
if word not in must_erhua and (word in not_erhua or
pos in {"a", "j", "nr"}):
if word not in must_erhua and (word in not_erhua or pos in {"a", "j", "nr"}):
return initials, finals
# "……" 等情况直接返回
@ -124,9 +162,13 @@ def _merge_erhua(initials: list[str],
new_initials = []
new_finals = []
for i, phn in enumerate(finals):
if i == len(finals) - 1 and word[i] == "" and phn in {
"er2", "er5"
} and word[-2:] not in not_erhua and new_finals:
if (
i == len(finals) - 1
and word[i] == ""
and phn in {"er2", "er5"}
and word[-2:] not in not_erhua
and new_finals
):
phn = "er" + new_finals[-1][-1]
new_initials.append(initials[i])
@ -160,7 +202,7 @@ def _g2p(segments):
# assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, [])
finals = sum(finals, [])
print("pypinyin结果",initials,finals)
print("pypinyin结果", initials, finals)
else:
# g2pw采用整句推理
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
@ -171,19 +213,19 @@ def _g2p(segments):
sub_finals = []
now_word_length = pre_word_length + len(word)
if pos == 'eng':
if pos == "eng":
pre_word_length = now_word_length
continue
word_pinyins = pinyins[pre_word_length:now_word_length]
# 多音字消歧
word_pinyins = correct_pronunciation(word,word_pinyins)
word_pinyins = correct_pronunciation(word, word_pinyins)
for pinyin in word_pinyins:
if pinyin[0].isalpha():
sub_initials.append(to_initials(pinyin))
sub_finals.append(to_finals_tone3(pinyin,neutral_tone_with_five=True))
sub_finals.append(to_finals_tone3(pinyin, neutral_tone_with_five=True))
else:
sub_initials.append(pinyin)
sub_finals.append(pinyin)
@ -259,18 +301,18 @@ def replace_punctuation_with_en(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
)
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
return replaced_text
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result
def text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
@ -283,6 +325,7 @@ def text_normalize(text):
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
# 不排除英文的文本格式化
def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization

View File

@ -19,55 +19,57 @@ special = [
def clean_text(text, language, version=None):
if version is None:version=os.environ.get('version', 'v2')
if version is None:
version = os.environ.get("version", "v2")
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
if(language not in language_module_map):
language="en"
text=" "
if language not in language_module_map:
language = "en"
text = " "
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol, version)
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
if hasattr(language_module,"text_normalize"):
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
if hasattr(language_module, "text_normalize"):
norm_text = language_module.text_normalize(text)
else:
norm_text=text
if language == "zh" or language=="yue":##########
norm_text = text
if language == "zh" or language == "yue": ##########
phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph)
elif language == "en":
phones = language_module.g2p(norm_text)
if len(phones) < 4:
phones = [','] + phones
phones = [","] + phones
word2ph = None
else:
phones = language_module.g2p(norm_text)
word2ph = None
phones = ['UNK' if ph not in symbols else ph for ph in phones]
phones = ["UNK" if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text
def clean_special(text, language, special_s, target_symbol, version=None):
if version is None:version=os.environ.get('version', 'v2')
if version is None:
version = os.environ.get("version", "v2")
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean","yue":"cantonese"}
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
"""
特殊静音段sp符号处理
"""
text = text.replace(special_s, ",")
language_module = __import__("text."+language_module_map[language],fromlist=[language_module_map[language]])
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
norm_text = language_module.text_normalize(text)
phones = language_module.g2p(norm_text)
new_ph = []
@ -81,8 +83,9 @@ def clean_special(text, language, special_s, target_symbol, version=None):
def text_to_sequence(text, language, version=None):
version = os.environ.get('version',version)
if version is None:version='v2'
version = os.environ.get("version", version)
if version is None:
version = "v2"
phones = clean_text(text)
return cleaned_text_to_sequence(phones, version)

View File

@ -9,17 +9,17 @@ import unicodedata
# 后缀计量单位替换表
measurement_map = {
"m": ["meter", "meters"],
'km': ["kilometer", "kilometers"],
"km": ["kilometer", "kilometers"],
"km/h": ["kilometer per hour", "kilometers per hour"],
"ft": ["feet", "feet"],
"L": ["liter", "liters"],
"tbsp": ["tablespoon", "tablespoons"],
'tsp': ["teaspoon", "teaspoons"],
"tsp": ["teaspoon", "teaspoons"],
"h": ["hour", "hours"],
"min": ["minute", "minutes"],
"s": ["second", "seconds"],
"°C": ["degree celsius", "degrees celsius"],
"°F": ["degree fahrenheit", "degrees fahrenheit"]
"°F": ["degree fahrenheit", "degrees fahrenheit"],
}
@ -27,37 +27,38 @@ measurement_map = {
_inflect = inflect.engine()
# 转化数字序数词
_ordinal_number_re = re.compile(r'\b([0-9]+)\. ')
_ordinal_number_re = re.compile(r"\b([0-9]+)\. ")
# 我听说好像对于数字正则识别其实用 \d 会好一点
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
# 时间识别
_time_re = re.compile(r'\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b')
_time_re = re.compile(r"\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b")
# 后缀计量单位识别
_measurement_re = re.compile(r'\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b')
_measurement_re = re.compile(r"\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b")
# 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ )
_pounds_re_start = re.compile(r'£([0-9\.\,]*[0-9]+)')
_pounds_re_end = re.compile(r'([0-9\.\,]*[0-9]+)£')
_pounds_re_start = re.compile(r"£([0-9\.\,]*[0-9]+)")
_pounds_re_end = re.compile(r"([0-9\.\,]*[0-9]+)£")
# 前后 $ 识别
_dollars_re_start = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_dollars_re_end = re.compile(r'([(0-9\.\,]*[0-9]+)\$')
_dollars_re_start = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_dollars_re_end = re.compile(r"([(0-9\.\,]*[0-9]+)\$")
# 小数的识别
_decimal_number_re = re.compile(r'([0-9]+\.\s*[0-9]+)')
_decimal_number_re = re.compile(r"([0-9]+\.\s*[0-9]+)")
# 分数识别 (形式 "3/4" )
_fraction_re = re.compile(r'([0-9]+/[0-9]+)')
_fraction_re = re.compile(r"([0-9]+/[0-9]+)")
# 序数词识别
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
# 数字处理
_number_re = re.compile(r'[0-9]+')
_number_re = re.compile(r"[0-9]+")
def _convert_ordinal(m):
"""
@ -70,8 +71,10 @@ def _convert_ordinal(m):
ordinal = _inflect.ordinal(m.group(1))
return ordinal + ", "
def _remove_commas(m):
return m.group(1).replace(',', '')
return m.group(1).replace(",", "")
def _expand_time(m):
"""
@ -82,12 +85,12 @@ def _expand_time(m):
output: "one o'clock p.m. / four o'clock am. / one thirty p.m."
"""
hours, minutes = map(int, m.group(1, 2))
period = 'a.m.' if hours < 12 else 'p.m.'
period = "a.m." if hours < 12 else "p.m."
if hours > 12:
hours -= 12
hour_word = _inflect.number_to_words(hours)
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ''
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ""
if minutes == 0:
return f"{hour_word} o'clock {period}"
@ -103,7 +106,7 @@ def _expand_measurement(m):
sign = m.group(3)
ptr = 1
# 想不到怎么方便的取数字又懒得改正则1.2 反正也是复数读法,干脆直接去掉 "."
num = int(m.group(1).replace(sign, '').replace(".",''))
num = int(m.group(1).replace(sign, "").replace(".", ""))
decimal_part = m.group(2)
# 上面判断的漏洞,比如 0.1 的情况,在这里排除了
if decimal_part == None and num == 1:
@ -116,23 +119,24 @@ def _expand_pounds(m):
没找到特别规范的说明和美元的处理一样其实可以把两个合并在一起
"""
match = m.group(1)
parts = match.split('.')
parts = match.split(".")
if len(parts) > 2:
return match + ' pounds' # Unexpected format
return match + " pounds" # Unexpected format
pounds = int(parts[0]) if parts[0] else 0
pence = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
pence = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
if pounds and pence:
pound_unit = 'pound' if pounds == 1 else 'pounds'
penny_unit = 'penny' if pence == 1 else 'pence'
return '%s %s and %s %s' % (pounds, pound_unit, pence, penny_unit)
pound_unit = "pound" if pounds == 1 else "pounds"
penny_unit = "penny" if pence == 1 else "pence"
return "%s %s and %s %s" % (pounds, pound_unit, pence, penny_unit)
elif pounds:
pound_unit = 'pound' if pounds == 1 else 'pounds'
return '%s %s' % (pounds, pound_unit)
pound_unit = "pound" if pounds == 1 else "pounds"
return "%s %s" % (pounds, pound_unit)
elif pence:
penny_unit = 'penny' if pence == 1 else 'pence'
return '%s %s' % (pence, penny_unit)
penny_unit = "penny" if pence == 1 else "pence"
return "%s %s" % (pence, penny_unit)
else:
return 'zero pounds'
return "zero pounds"
def _expand_dollars(m):
"""
@ -142,23 +146,24 @@ def _expand_dollars(m):
output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents"
"""
match = m.group(1)
parts = match.split('.')
parts = match.split(".")
if len(parts) > 2:
return match + ' dollars' # Unexpected format
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
cents = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s and %s %s' % (dollars, dollar_unit, cents, cent_unit)
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s and %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return 'zero dollars'
return "zero dollars"
# 小数的处理
def _expand_decimal_number(m):
@ -168,11 +173,11 @@ def _expand_decimal_number(m):
output: "thirteen point two three four"
"""
match = m.group(1)
parts = match.split('.')
parts = match.split(".")
words = []
# 遍历字符串中的每个字符
for char in parts[1]:
if char == '.':
if char == ".":
words.append("point")
else:
words.append(char)
@ -196,39 +201,41 @@ def _expend_fraction(m):
| 3/2 | three halves |
"""
match = m.group(0)
numerator, denominator = map(int, match.split('/'))
numerator, denominator = map(int, match.split("/"))
numerator_part = _inflect.number_to_words(numerator)
if denominator == 2:
if numerator == 1:
denominator_part = 'half'
denominator_part = "half"
else:
denominator_part = 'halves'
denominator_part = "halves"
elif denominator == 1:
return f'{numerator_part}'
return f"{numerator_part}"
else:
denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator))
if numerator > 1:
denominator_part += 's'
denominator_part += "s"
return f"{numerator_part} {denominator_part}"
return f'{numerator_part} {denominator_part}'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
return "two thousand"
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword='')
return _inflect.number_to_words(num, andword="")
def normalize(text):
@ -238,7 +245,7 @@ def normalize(text):
"""
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
text = re.sub(r'(?<!\d)-|-(?!\d)', ' minus ', text)
text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_time_re, _expand_time, text)
text = re.sub(_measurement_re, _expand_measurement, text)
@ -251,19 +258,20 @@ def normalize(text):
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
text = ''.join(char for char in unicodedata.normalize('NFD', text)
if unicodedata.category(char) != 'Mn') # Strip accents
text = "".join(
char for char in unicodedata.normalize("NFD", text) if unicodedata.category(char) != "Mn"
) # Strip accents
text = re.sub("%", " percent", text)
text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
text = re.sub(r"(?i)i\.e\.", "that is", text)
text = re.sub(r"(?i)e\.g\.", "for example", text)
# 增加纯大写单词拆分
text = re.sub(r'(?<!^)(?<![\s])([A-Z])', r' \1', text)
text = re.sub(r"(?<!^)(?<![\s])([A-Z])", r" \1", text)
return text
if __name__ == '__main__':
if __name__ == "__main__":
# 我觉得其实可以把切分结果展示出来只读或者修改不影响传给TTS的实际text
# 然后让用户确认后再输入给 TTS可以让用户检查自己有没有不标准的输入
print(normalize("1. test ordinal number 1st"))

View File

@ -11,6 +11,7 @@ from text.symbols2 import symbols
from builtins import str as unicode
from text.en_normalization.expend import normalize
from nltk.tokenize import TweetTokenizer
word_tokenize = TweetTokenizer().tokenize
from nltk import pos_tag
@ -121,9 +122,9 @@ def replace_phs(phs):
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}\s])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}\s])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result
@ -182,6 +183,7 @@ def read_dict_new():
return g2p_dict
def hot_reload_hot(g2p_dict):
with open(CMU_DICT_HOT_PATH) as f:
line = f.readline()
@ -258,9 +260,12 @@ class en_G2p(G2p):
del self.cmu[word.lower()]
# 修正多音字
self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP')
self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ')
self.homograph2features["read"] = (["R", "IY1", "D"], ["R", "EH1", "D"], "VBP")
self.homograph2features["complex"] = (
["K", "AH0", "M", "P", "L", "EH1", "K", "S"],
["K", "AA1", "M", "P", "L", "EH0", "K", "S"],
"JJ",
)
def __call__(self, text):
# tokenization
@ -279,7 +284,7 @@ class en_G2p(G2p):
elif len(word) == 1:
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写
if o_word == "A":
pron = ['EY1']
pron = ["EY1"]
else:
pron = self.cmu[word][0]
# g2p_en 原版多音字处理
@ -288,7 +293,7 @@ class en_G2p(G2p):
if pos.startswith(pos1):
pron = pron1
# pos1比pos长仅出现在read
elif len(pos) < len(pos1) and pos == pos1[:len(pos)]:
elif len(pos) < len(pos1) and pos == pos1[: len(pos)]:
pron = pron1
else:
pron = pron2
@ -301,7 +306,6 @@ class en_G2p(G2p):
return prons[:-1]
def qryword(self, o_word):
word = o_word.lower()
@ -319,7 +323,7 @@ class en_G2p(G2p):
for w in word:
# 单读 A 发音修正, 此处不存在大写的情况
if w == "a":
phones.extend(['EY1'])
phones.extend(["EY1"])
elif not w.isalpha():
phones.extend([w])
else:
@ -330,23 +334,23 @@ class en_G2p(G2p):
if re.match(r"^([a-z]+)('s)$", word):
phones = self.qryword(word[:-2])[:]
# P T K F TH HH 无声辅音结尾 's 发 ['S']
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
phones.extend(['S'])
if phones[-1] in ["P", "T", "K", "F", "TH", "HH"]:
phones.extend(["S"])
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']:
phones.extend(['AH0', 'Z'])
elif phones[-1] in ["S", "Z", "SH", "ZH", "CH", "JH"]:
phones.extend(["AH0", "Z"])
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
else:
phones.extend(['Z'])
phones.extend(["Z"])
return phones
# 尝试进行分词,应对复合词
comps = wordsegment.segment(word.lower())
# 无法分词的送回去预测
if len(comps)==1:
if len(comps) == 1:
return self.predict(word)
# 可以分词的递归处理

View File

@ -15,6 +15,7 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
from typing import Dict
from typing import List
from typing import Tuple
@ -23,21 +24,24 @@ import numpy as np
from .utils import tokenize_and_map
ANCHOR_CHAR = ''
ANCHOR_CHAR = ""
def prepare_onnx_input(tokenizer,
def prepare_onnx_input(
tokenizer,
labels: List[str],
char2phonemes: Dict[str, List[int]],
chars: List[str],
texts: List[str],
query_ids: List[int],
use_mask: bool=False,
window_size: int=None,
max_len: int=512) -> Dict[str, np.array]:
use_mask: bool = False,
window_size: int = None,
max_len: int = 512,
) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
window_size=window_size, texts=texts, query_ids=query_ids)
window_size=window_size, texts=texts, query_ids=query_ids
)
input_ids = []
token_type_ids = []
attention_masks = []
@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(
tokenizer=tokenizer, text=text)
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=tokens,
text2token=text2token,
token2text=token2text)
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_id = list(
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int))
attention_mask = list(np.ones((len(processed_tokens), ), dtype=int))
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
phoneme_mask = (
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
)
char_id = chars.index(query_char)
position_id = text2token[
query_id] + 1 # [CLS] token locate at first place
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids).astype(np.int64),
'token_type_ids': np.array(token_type_ids).astype(np.int64),
'attention_masks': np.array(attention_masks).astype(np.int64),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids).astype(np.int64),
'position_ids': np.array(position_ids).astype(np.int64),
"input_ids": np.array(input_ids).astype(np.int64),
"token_type_ids": np.array(token_type_ids).astype(np.int64),
"attention_masks": np.array(attention_masks).astype(np.int64),
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
"char_ids": np.array(char_ids).astype(np.int64),
"position_ids": np.array(position_ids).astype(np.int64),
}
return outputs
def _truncate_texts(window_size: int, texts: List[str],
query_ids: List[int]) -> Tuple[List[str], List[int]]:
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
return truncated_texts, truncated_query_ids
def _truncate(max_len: int,
text: str,
query_id: int,
tokens: List[str],
text2token: List[int],
token2text: List[Tuple[int]]):
def _truncate(
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
@ -137,14 +131,16 @@ def _truncate(max_len: int,
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (text[start:end], query_id - start, tokens[token_start:token_end], [
i - token_start if i is not None else None
for i in text2token[start:end]
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
return (
text[start:end],
query_id - start,
tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
)
def get_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
return labels, char2phonemes

View File

@ -17,17 +17,25 @@ PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
class G2PWPinyin(Pinyin):
def __init__(self, model_dir='G2PWModel/', model_source=None,
def __init__(
self,
model_dir="G2PWModel/",
model_source=None,
enable_non_tradional_chinese=True,
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
v_to_u=False,
neutral_tone_with_five=False,
tone_sandhi=False,
**kwargs,
):
self._g2pw = G2PWOnnxConverter(
model_dir=model_dir,
style='pinyin',
style="pinyin",
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
self._converter = Converter(
self._g2pw, v_to_u=v_to_u,
self._g2pw,
v_to_u=v_to_u,
neutral_tone_with_five=neutral_tone_with_five,
tone_sandhi=tone_sandhi,
)
@ -37,31 +45,25 @@ class G2PWPinyin(Pinyin):
class Converter(UltimateConverter):
def __init__(self, g2pw_instance, v_to_u=False,
neutral_tone_with_five=False,
tone_sandhi=False, **kwargs):
def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
super(Converter, self).__init__(
v_to_u=v_to_u,
neutral_tone_with_five=neutral_tone_with_five,
tone_sandhi=tone_sandhi, **kwargs)
v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs
)
self._g2pw = g2pw_instance
def convert(self, words, style, heteronym, errors, strict, **kwargs):
pys = []
if RE_HANS.match(words):
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
errors=errors, strict=strict)
pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict)
post_data = self.post_pinyin(words, heteronym, pys)
if post_data is not None:
pys = post_data
pys = self.convert_styles(
pys, words, style, heteronym, errors, strict)
pys = self.convert_styles(pys, words, style, heteronym, errors, strict)
else:
py = self.handle_nopinyin(words, style=style, errors=errors,
heteronym=heteronym, strict=strict)
py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict)
if py:
pys.extend(py)
@ -73,13 +75,11 @@ class Converter(UltimateConverter):
g2pw_pinyin = self._g2pw(han)
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
return super(Converter, self).convert(
han, Style.TONE, heteronym, errors, strict, **kwargs)
return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs)
for i, item in enumerate(g2pw_pinyin[0]):
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
py = super(Converter, self).convert(
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs)
pinyins.extend(py)
else:
pinyins.append([to_tone(item)])
@ -104,7 +104,7 @@ def _remove_dup_and_empty(lst_list):
if lst:
new_lst_list.append(lst)
else:
new_lst_list.append([''])
new_lst_list.append([""])
return new_lst_list
@ -127,17 +127,17 @@ def get_dict():
def read_dict():
polyphonic_dict = {}
with open(PP_DICT_PATH,encoding="utf-8") as f:
with open(PP_DICT_PATH, encoding="utf-8") as f:
line = f.readline()
while line:
key, value_str = line.split(':')
key, value_str = line.split(":")
value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value
line = f.readline()
with open(PP_FIX_DICT_PATH,encoding="utf-8") as f:
with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
line = f.readline()
while line:
key, value_str = line.split(':')
key, value_str = line.split(":")
value = eval(value_str.strip())
polyphonic_dict[key.strip()] = value
line = f.readline()

View File

@ -2,6 +2,7 @@
# This code is modified from https://github.com/GitYCC/g2pW
import warnings
warnings.filterwarnings("ignore")
import json
import os
@ -14,6 +15,7 @@ from typing import Tuple
import numpy as np
import onnxruntime
onnxruntime.set_default_logger_severity(3)
from opencc import OpenCC
from transformers import AutoTokenizer
@ -26,21 +28,23 @@ from .dataset import prepare_onnx_input
from .utils import load_config
from ..zh_normalization.char_convert import tranditional_to_simplified
model_version = '1.1'
model_version = "1.1"
def predict(session, onnx_input: Dict[str, Any],
labels: List[str]) -> Tuple[List[str], List[float]]:
def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]:
all_preds = []
all_confidences = []
probs = session.run([], {
"input_ids": onnx_input['input_ids'],
"token_type_ids": onnx_input['token_type_ids'],
"attention_mask": onnx_input['attention_masks'],
"phoneme_mask": onnx_input['phoneme_masks'],
"char_ids": onnx_input['char_ids'],
"position_ids": onnx_input['position_ids']
})[0]
probs = session.run(
[],
{
"input_ids": onnx_input["input_ids"],
"token_type_ids": onnx_input["token_type_ids"],
"attention_mask": onnx_input["attention_masks"],
"phoneme_mask": onnx_input["phoneme_masks"],
"char_ids": onnx_input["char_ids"],
"position_ids": onnx_input["position_ids"],
},
)[0]
preds = np.argmax(probs, axis=1).tolist()
max_probs = []
@ -52,17 +56,17 @@ def predict(session, onnx_input: Dict[str, Any],
return all_preds, all_confidences
def download_and_decompress(model_dir: str='G2PWModel/'):
def download_and_decompress(model_dir: str = "G2PWModel/"):
if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir)
zip_dir = os.path.join(parent_directory,"G2PWModel_1.1.zip")
extract_dir = os.path.join(parent_directory,"G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory,"G2PWModel")
zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
print("Downloading g2pw model...")
modelscope_url = "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
with requests.get(modelscope_url, stream=True) as r:
r.raise_for_status()
with open(zip_dir, 'wb') as f:
with open(zip_dir, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
@ -75,12 +79,15 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
return model_dir
class G2PWOnnxConverter:
def __init__(self,
model_dir: str='G2PWModel/',
style: str='bopomofo',
model_source: str=None,
enable_non_tradional_chinese: bool=False):
def __init__(
self,
model_dir: str = "G2PWModel/",
style: str = "bopomofo",
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
uncompress_path = download_and_decompress(model_dir)
sess_options = onnxruntime.SessionOptions()
@ -88,41 +95,59 @@ class G2PWOnnxConverter:
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
try:
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
except:
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
self.config = load_config(
config_path=os.path.join(uncompress_path, 'config.py'),
use_default=True)
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
polyphonic_chars_path = os.path.join(uncompress_path,
'POLYPHONIC_CHARS.txt')
monophonic_chars_path = os.path.join(uncompress_path,
'MONOPHONIC_CHARS.txt')
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
self.polyphonic_chars = [
line.split('\t')
for line in open(polyphonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
]
self.non_polyphonic = {
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', ''
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
}
self.non_monophonic = {'', ''}
self.non_monophonic = {"", ""}
self.monophonic_chars = [
line.split('\t')
for line in open(monophonic_chars_path, encoding='utf-8').read()
.strip().split('\n')
line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n")
]
self.labels, self.char2phonemes = get_char_phoneme_labels(
polyphonic_chars=self.polyphonic_chars
) if self.config.use_char_phoneme else get_phoneme_labels(
polyphonic_chars=self.polyphonic_chars)
self.labels, self.char2phonemes = (
get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
if self.config.use_char_phoneme
else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
)
self.chars = sorted(list(self.char2phonemes.keys()))
@ -131,41 +156,29 @@ class G2PWOnnxConverter:
if char in self.polyphonic_chars_new:
self.polyphonic_chars_new.remove(char)
self.monophonic_chars_dict = {
char: phoneme
for char, phoneme in self.monophonic_chars
}
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
for char in self.non_monophonic:
if char in self.monophonic_chars_dict:
self.monophonic_chars_dict.pop(char)
self.pos_tags = [
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
]
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
with open(
os.path.join(uncompress_path,
'bopomofo_to_pinyin_wo_tune_dict.json'),
'r',
encoding='utf-8') as fr:
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin,
"bopomofo": lambda x: x,
"pinyin": self._convert_bopomofo_to_pinyin,
}[style]
with open(
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
'r',
encoding='utf-8') as fr:
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC('s2tw')
self.cc = OpenCC("s2tw")
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
assert tone in '12345'
assert tone in "12345"
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
@ -185,8 +198,7 @@ class G2PWOnnxConverter:
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(
sentences=sentences)
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
@ -199,14 +211,12 @@ class G2PWOnnxConverter:
texts=texts,
query_ids=query_ids,
use_mask=self.config.use_mask,
window_size=None)
window_size=None,
)
preds, confidences = predict(
session=self.session_g2pW,
onnx_input=onnx_input,
labels=self.labels)
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]
preds = [pred.split(" ")[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
@ -214,15 +224,12 @@ class G2PWOnnxConverter:
return results
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent)
pypinyin_result = pinyin(
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in self.polyphonic_chars_new:
@ -230,8 +237,7 @@ class G2PWOnnxConverter:
query_ids.append(i)
sent_ids.append(sent_id)
elif char in self.monophonic_chars_dict:
partial_result[i] = self.style_convert_func(
self.monophonic_chars_dict[char])
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])

View File

@ -15,6 +15,7 @@
Credits
This code is modified from https://github.com/GitYCC/g2pW
"""
import os
import re
@ -24,14 +25,14 @@ def wordize_and_map(text: str):
index_map_from_text_to_word = []
index_map_from_word_to_text = []
while len(text) > 0:
match_space = re.match(r'^ +', text)
match_space = re.match(r"^ +", text)
if match_space:
space_str = match_space.group(0)
index_map_from_text_to_word += [None] * len(space_str)
text = text[len(space_str):]
text = text[len(space_str) :]
continue
match_en = re.match(r'^[a-zA-Z0-9]+', text)
match_en = re.match(r"^[a-zA-Z0-9]+", text)
if match_en:
en_word = match_en.group(0)
@ -42,7 +43,7 @@ def wordize_and_map(text: str):
index_map_from_text_to_word += [len(words)] * len(en_word)
words.append(en_word)
text = text[len(en_word):]
text = text[len(en_word) :]
else:
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + 1
@ -63,15 +64,14 @@ def tokenize_and_map(tokenizer, text: str):
for word, (word_start, word_end) in zip(words, word2text):
word_tokens = tokenizer.tokenize(word)
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
if len(word_tokens) == 0 or word_tokens == ["[UNK]"]:
index_map_from_token_to_text.append((word_start, word_end))
tokens.append('[UNK]')
tokens.append("[UNK]")
else:
current_word_start = word_start
for word_token in word_tokens:
word_token_len = len(re.sub(r'^##', '', word_token))
index_map_from_token_to_text.append(
(current_word_start, current_word_start + word_token_len))
word_token_len = len(re.sub(r"^##", "", word_token))
index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len))
current_word_start = current_word_start + word_token_len
tokens.append(word_token)
@ -85,53 +85,51 @@ def tokenize_and_map(tokenizer, text: str):
def _load_config(config_path: os.PathLike):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
spec = importlib.util.spec_from_file_location("__init__", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config
default_config_dict = {
'manual_seed': 1313,
'model_source': 'bert-base-chinese',
'window_size': 32,
'num_workers': 2,
'use_mask': True,
'use_char_phoneme': False,
'use_conditional': True,
'param_conditional': {
'affect_location': 'softmax',
'bias': True,
'char-linear': True,
'pos-linear': False,
'char+pos-second': True,
'char+pos-second_lowrank': False,
'lowrank_size': 0,
'char+pos-second_fm': False,
'fm_size': 0,
'fix_mode': None,
'count_json': 'train.count.json'
"manual_seed": 1313,
"model_source": "bert-base-chinese",
"window_size": 32,
"num_workers": 2,
"use_mask": True,
"use_char_phoneme": False,
"use_conditional": True,
"param_conditional": {
"affect_location": "softmax",
"bias": True,
"char-linear": True,
"pos-linear": False,
"char+pos-second": True,
"char+pos-second_lowrank": False,
"lowrank_size": 0,
"char+pos-second_fm": False,
"fm_size": 0,
"fix_mode": None,
"count_json": "train.count.json",
},
'lr': 5e-5,
'val_interval': 200,
'num_iter': 10000,
'use_focal': False,
'param_focal': {
'alpha': 0.0,
'gamma': 0.7
"lr": 5e-5,
"val_interval": 200,
"num_iter": 10000,
"use_focal": False,
"param_focal": {"alpha": 0.0, "gamma": 0.7},
"use_pos": True,
"param_pos ": {
"weight": 0.1,
"pos_joint_training": True,
"train_pos_path": "train.pos",
"valid_pos_path": "dev.pos",
"test_pos_path": "test.pos",
},
'use_pos': True,
'param_pos ': {
'weight': 0.1,
'pos_joint_training': True,
'train_pos_path': 'train.pos',
'valid_pos_path': 'dev.pos',
'test_pos_path': 'test.pos'
}
}
def load_config(config_path: os.PathLike, use_default: bool=False):
def load_config(config_path: os.PathLike, use_default: bool = False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():

View File

@ -2,43 +2,51 @@
import re
import os
import hashlib
try:
import pyopenjtalk
current_file_path = os.path.dirname(__file__)
# 防止win下无法读取模型
if os.name == 'nt':
if os.name == "nt":
python_dir = os.getcwd()
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR,python_dir))
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)):
if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper():
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
else:
import shutil
if not os.path.exists('TEMP'):
os.mkdir('TEMP')
if not os.path.exists("TEMP"):
os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ja")):
os.mkdir(os.path.join("TEMP", "ja"))
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
shutil.copytree(pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), os.path.join("TEMP", "ja", "open_jtalk_dic"), )
shutil.copytree(
pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"),
os.path.join("TEMP", "ja", "open_jtalk_dic"),
)
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
current_file_path = os.path.join(os.path.relpath(current_file_path,python_dir))
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)):
if current_file_path[: len(python_dir)].upper() == python_dir.upper():
current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
else:
if not os.path.exists('TEMP'):
os.mkdir('TEMP')
if not os.path.exists("TEMP"):
os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ja")):
os.mkdir(os.path.join("TEMP", "ja"))
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
shutil.copyfile(os.path.join(current_file_path, "ja_userdic", "userdict.csv"),os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"))
shutil.copyfile(
os.path.join(current_file_path, "ja_userdic", "userdict.csv"),
os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"),
)
current_file_path = os.path.join("TEMP", "ja")
def get_hash(fp: str) -> str:
hash_md5 = hashlib.md5()
with open(fp, "rb") as f:
@ -51,9 +59,12 @@ try:
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
# 如果没有用户词典就生成一个如果有就检查md5如果不一样就重新生成
if os.path.exists(USERDIC_CSV_PATH):
if not os.path.exists(USERDIC_BIN_PATH) or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r",encoding='utf-8').read():
if (
not os.path.exists(USERDIC_BIN_PATH)
or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read()
):
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
with open(USERDIC_HASH_PATH, "w", encoding='utf-8') as f:
with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f:
f.write(get_hash(USERDIC_CSV_PATH))
if os.path.exists(USERDIC_BIN_PATH):
@ -61,11 +72,13 @@ try:
except Exception:
# print(e)
import pyopenjtalk
# failed to load user dictionary, ignore.
pass
from text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
@ -123,9 +136,9 @@ def post_replace_ph(ph):
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result
@ -152,7 +165,7 @@ def preprocess_jap(text, with_prosody=False):
text += p.split(" ")
if i < len(marks):
if marks[i] == " ":# 防止意外的UNK
if marks[i] == " ": # 防止意外的UNK
continue
text += [marks[i].replace(" ", "")]
return text
@ -165,6 +178,7 @@ def text_normalize(text):
text = replace_consecutive_punctuation(text)
return text
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
@ -241,6 +255,7 @@ def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
return phones
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
def _numeric_feature_by_regex(regex, s):
match = re.search(regex, s)
@ -248,6 +263,7 @@ def _numeric_feature_by_regex(regex, s):
return -50
return int(match.group(1))
def g2p(norm_text, with_prosody=True):
phones = preprocess_jap(norm_text, with_prosody)
phones = [post_replace_ph(i) for i in phones]

View File

@ -9,39 +9,43 @@ import importlib
import os
# 防止win下无法读取模型
if os.name == 'nt':
if os.name == "nt":
class win_G2p(G2p):
def check_mecab(self):
super().check_mecab()
spam_spec = importlib.util.find_spec("eunjeon")
non_found = spam_spec is None
if non_found:
print('you have to install eunjeon. install it...')
print("you have to install eunjeon. install it...")
else:
installpath = spam_spec.submodule_search_locations[0]
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
import sys
from eunjeon import Mecab as _Mecab
class Mecab(_Mecab):
def get_dicpath(installpath):
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
import shutil
python_dir = os.getcwd()
if (installpath[:len(python_dir)].upper() == python_dir.upper()):
dicpath = os.path.join(os.path.relpath(installpath,python_dir),'data','mecabrc')
else:
if not os.path.exists('TEMP'):
os.mkdir('TEMP')
if not os.path.exists(os.path.join('TEMP', 'ko')):
os.mkdir(os.path.join('TEMP', 'ko'))
if os.path.exists(os.path.join('TEMP', 'ko', 'ko_dict')):
shutil.rmtree(os.path.join('TEMP', 'ko', 'ko_dict'))
shutil.copytree(os.path.join(installpath, 'data'), os.path.join('TEMP', 'ko', 'ko_dict'))
dicpath = os.path.join('TEMP', 'ko', 'ko_dict', 'mecabrc')
python_dir = os.getcwd()
if installpath[: len(python_dir)].upper() == python_dir.upper():
dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc")
else:
dicpath=os.path.abspath(os.path.join(installpath, 'data/mecabrc'))
if not os.path.exists("TEMP"):
os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ko")):
os.mkdir(os.path.join("TEMP", "ko"))
if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")):
shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict"))
shutil.copytree(
os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict")
)
dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc")
else:
dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc"))
return dicpath
def __init__(self, dicpath=get_dicpath(installpath)):
@ -55,10 +59,14 @@ if os.name == 'nt':
from text.symbols2 import symbols
# This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
_korean_classifiers = (
"군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통"
)
# List of (hangul, hangul divided) pairs:
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
_hangul_divided = [
(re.compile("%s" % x[0]), x[1])
for x in [
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
# ('ㄵ', 'ㄴㅈ'),
# ('ㄶ', 'ㄴㅎ'),
@ -70,79 +78,86 @@ _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
# ('ㄿ', 'ㄹㅍ'),
# ('ㅀ', 'ㄹㅎ'),
# ('ㅄ', 'ㅂㅅ'),
('', 'ㅗㅏ'),
('', 'ㅗㅐ'),
('', 'ㅗㅣ'),
('', 'ㅜㅓ'),
('', 'ㅜㅔ'),
('', 'ㅜㅣ'),
('', 'ㅡㅣ'),
('', 'ㅣㅏ'),
('', 'ㅣㅐ'),
('', 'ㅣㅓ'),
('', 'ㅣㅔ'),
('', 'ㅣㅗ'),
('', 'ㅣㅜ')
]]
("", "ㅗㅏ"),
("", "ㅗㅐ"),
("", "ㅗㅣ"),
("", "ㅜㅓ"),
("", "ㅜㅔ"),
("", "ㅜㅣ"),
("", "ㅡㅣ"),
("", "ㅣㅏ"),
("", "ㅣㅐ"),
("", "ㅣㅓ"),
("", "ㅣㅔ"),
("", "ㅣㅗ"),
("", "ㅣㅜ"),
]
]
# List of (Latin alphabet, hangul) pairs:
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('a', '에이'),
('b', ''),
('c', ''),
('d', ''),
('e', ''),
('f', '에프'),
('g', ''),
('h', '에이치'),
('i', '아이'),
('j', '제이'),
('k', '케이'),
('l', ''),
('m', ''),
('n', ''),
('o', ''),
('p', ''),
('q', ''),
('r', '아르'),
('s', '에스'),
('t', ''),
('u', ''),
('v', '브이'),
('w', '더블유'),
('x', '엑스'),
('y', '와이'),
('z', '제트')
]]
_latin_to_hangul = [
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
for x in [
("a", "에이"),
("b", ""),
("c", ""),
("d", ""),
("e", ""),
("f", "에프"),
("g", ""),
("h", "에이치"),
("i", "아이"),
("j", "제이"),
("k", "케이"),
("l", ""),
("m", ""),
("n", ""),
("o", ""),
("p", ""),
("q", ""),
("r", "아르"),
("s", "에스"),
("t", ""),
("u", ""),
("v", "브이"),
("w", "더블유"),
("x", "엑스"),
("y", "와이"),
("z", "제트"),
]
]
# List of (ipa, lazy ipa) pairs:
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('t͡ɕ','ʧ'),
('d͡ʑ','ʥ'),
('ɲ','n^'),
('ɕ','ʃ'),
('ʷ','w'),
('ɭ','l`'),
('ʎ','ɾ'),
('ɣ','ŋ'),
('ɰ','ɯ'),
('ʝ','j'),
('ʌ','ə'),
('ɡ','g'),
('\u031a','#'),
('\u0348','='),
('\u031e',''),
('\u0320',''),
('\u0339','')
]]
_ipa_to_lazy_ipa = [
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
for x in [
("t͡ɕ", "ʧ"),
("d͡ʑ", "ʥ"),
("ɲ", "n^"),
("ɕ", "ʃ"),
("ʷ", "w"),
("ɭ", "l`"),
("ʎ", "ɾ"),
("ɣ", "ŋ"),
("ɰ", "ɯ"),
("ʝ", "j"),
("ʌ", "ə"),
("ɡ", "g"),
("\u031a", "#"),
("\u0348", "="),
("\u031e", ""),
("\u0320", ""),
("\u0339", ""),
]
]
def fix_g2pk2_error(text):
new_text = ""
i = 0
while i < len(text) - 4:
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == '':
new_text += text[i:i+3] + ' ' + ''
if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "":
new_text += text[i : i + 3] + " " + ""
i += 5
else:
new_text += text[i]
@ -166,20 +181,20 @@ def divide_hangul(text):
def hangul_number(num, sino=True):
'''Reference https://github.com/Kyubyong/g2pK'''
num = re.sub(',', '', num)
"""Reference https://github.com/Kyubyong/g2pK"""
num = re.sub(",", "", num)
if num == '0':
return ''
if not sino and num == '20':
return '스무'
if num == "0":
return ""
if not sino and num == "20":
return "스무"
digits = '123456789'
names = '일이삼사오육칠팔구'
digits = "123456789"
names = "일이삼사오육칠팔구"
digit2name = {d: n for d, n in zip(digits, names)}
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉"
decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔"
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
@ -188,75 +203,75 @@ def hangul_number(num, sino=True):
i = len(num) - i - 1
if sino:
if i == 0:
name = digit2name.get(digit, '')
name = digit2name.get(digit, "")
elif i == 1:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
name = digit2name.get(digit, "") + ""
name = name.replace("일십", "")
else:
if i == 0:
name = digit2mod.get(digit, '')
name = digit2mod.get(digit, "")
elif i == 1:
name = digit2dec.get(digit, '')
if digit == '0':
name = digit2dec.get(digit, "")
if digit == "0":
if i % 4 == 0:
last_three = spelledout[-min(3, len(spelledout)):]
if ''.join(last_three) == '':
spelledout.append('')
last_three = spelledout[-min(3, len(spelledout)) :]
if "".join(last_three) == "":
spelledout.append("")
continue
else:
spelledout.append('')
spelledout.append("")
continue
if i == 2:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
name = digit2name.get(digit, "") + ""
name = name.replace("일백", "")
elif i == 3:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
name = digit2name.get(digit, "") + ""
name = name.replace("일천", "")
elif i == 4:
name = digit2name.get(digit, '') + ''
name = name.replace('일만', '')
name = digit2name.get(digit, "") + ""
name = name.replace("일만", "")
elif i == 5:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
name = digit2name.get(digit, "") + ""
name = name.replace("일십", "")
elif i == 6:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
name = digit2name.get(digit, "") + ""
name = name.replace("일백", "")
elif i == 7:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
name = digit2name.get(digit, "") + ""
name = name.replace("일천", "")
elif i == 8:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
elif i == 9:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
elif i == 10:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
elif i == 11:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
elif i == 12:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
elif i == 13:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
elif i == 14:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
elif i == 15:
name = digit2name.get(digit, '') + ''
name = digit2name.get(digit, "") + ""
spelledout.append(name)
return ''.join(elem for elem in spelledout)
return "".join(elem for elem in spelledout)
def number_to_hangul(text):
'''Reference https://github.com/Kyubyong/g2pK'''
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
"""Reference https://github.com/Kyubyong/g2pK"""
tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text))
for token in tokens:
num, classifier = token
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
spelledout = hangul_number(num, sino=False)
else:
spelledout = hangul_number(num, sino=True)
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}")
# digit by digit for remaining digits
digits = '0123456789'
names = '영일이삼사오육칠팔구'
digits = "0123456789"
names = "영일이삼사오육칠팔구"
for d, n in zip(digits, names):
text = text.replace(d, n)
return text
@ -265,19 +280,23 @@ def number_to_hangul(text):
def korean_to_lazy_ipa(text):
text = latin_to_hangul(text)
text = number_to_hangul(text)
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text)
for regex, replacement in _ipa_to_lazy_ipa:
text = re.sub(regex, replacement, text)
return text
_g2p=G2p()
_g2p = G2p()
def korean_to_ipa(text):
text = latin_to_hangul(text)
text = number_to_hangul(text)
text = _g2p(text)
text = fix_g2pk2_error(text)
text = korean_to_lazy_ipa(text)
return text.replace('ʧ','').replace('ʥ','')
return text.replace("ʧ", "").replace("ʥ", "")
def post_replace_ph(ph):
rep_map = {
@ -301,12 +320,13 @@ def post_replace_ph(ph):
ph = ""
return ph
def g2p(text):
text = latin_to_hangul(text)
text = _g2p(text)
text = divide_hangul(text)
text = fix_g2pk2_error(text)
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
text = re.sub(r"([\u3131-\u3163])$", r"\1.", text)
# text = "".join([post_replace_ph(i) for i in text])
text = [post_replace_ph(i) for i in text]
return text

View File

@ -1,4 +1,3 @@
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-")

View File

@ -1,4 +1,3 @@
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-")
@ -395,24 +394,404 @@ arpa = {
"SH",
}
ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停'
ko_symbols = "ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停"
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
yue_symbols={'Yeot3', 'Yip1', 'Yyu3', 'Yeng4', 'Yut5', 'Yaan5', 'Ym5', 'Yaan6', 'Yang1', 'Yun4', 'Yon2', 'Yui5', 'Yun2', 'Yat3', 'Ye', 'Yeot1', 'Yoeng5', 'Yoek2', 'Yam2', 'Yeon6', 'Yu6', 'Yiu3', 'Yaang6', 'Yp5', 'Yai4', 'Yoek4', 'Yit6', 'Yam5', 'Yoeng6', 'Yg1', 'Yk3', 'Yoe4', 'Yam3', 'Yc', 'Yyu4', 'Yyut1', 'Yiu4', 'Ying3', 'Yip3', 'Yaap3', 'Yau3', 'Yan4', 'Yau1', 'Yap4', 'Yk6', 'Yok3', 'Yai1', 'Yeot6', 'Yan2', 'Yoek6', 'Yt1', 'Yoi1', 'Yit5', 'Yn4', 'Yaau3', 'Yau4', 'Yuk6', 'Ys', 'Yuk', 'Yin6', 'Yung6', 'Ya', 'You', 'Yaai5', 'Yau5', 'Yoi3', 'Yaak3', 'Yaat3', 'Ying2', 'Yok5', 'Yeng2', 'Yyut3', 'Yam1', 'Yip5', 'You1', 'Yam6', 'Yaa5', 'Yi6', 'Yek4', 'Yyu2', 'Yuk5', 'Yaam1', 'Yang2', 'Yai', 'Yiu6', 'Yin4', 'Yok4', 'Yot3', 'Yui2', 'Yeoi5', 'Yyun6', 'Yyu5', 'Yoi5', 'Yeot2', 'Yim4', 'Yeoi2', 'Yaan1', 'Yang6', 'Yong1', 'Yaang4', 'Yung5', 'Yeon1', 'Yin2', 'Ya3', 'Yaang3', 'Yg', 'Yk2', 'Yaau5', 'Yut1', 'Yt5', 'Yip4', 'Yung4', 'Yj', 'Yong3', 'Ya1', 'Yg6', 'Yaau6', 'Yit3', 'Yun3', 'Ying1', 'Yn2', 'Yg4', 'Yl', 'Yp3', 'Yn3', 'Yak1', 'Yang5', 'Yoe6', 'You2', 'Yap2', 'Yak2', 'Yt3', 'Yot5', 'Yim2', 'Yi1', 'Yn6', 'Yaat5', 'Yaam3', 'Yoek5', 'Ye3', 'Yeon4', 'Yaa2', 'Yu3', 'Yim6', 'Ym', 'Yoe3', 'Yaai2', 'Ym2', 'Ya6', 'Yeng6', 'Yik4', 'Yot4', 'Yaai4', 'Yyun3', 'Yu1', 'Yoeng1', 'Yaap2', 'Yuk3', 'Yoek3', 'Yeng5', 'Yeoi1', 'Yiu2', 'Yok1', 'Yo1', 'Yoek1', 'Yoeng2', 'Yeon5', 'Yiu1', 'Yoeng4', 'Yuk2', 'Yat4', 'Yg5', 'Yut4', 'Yan6', 'Yin3', 'Yaa6', 'Yap1', 'Yg2', 'Yoe5', 'Yt4', 'Ya5', 'Yo4', 'Yyu1', 'Yak3', 'Yeon2', 'Yong4', 'Ym1', 'Ye2', 'Yaang5', 'Yoi2', 'Yeng3', 'Yn', 'Yyut4', 'Yau', 'Yaak2', 'Yaan4', 'Yek2', 'Yin1', 'Yi5', 'Yoe2', 'Yei5', 'Yaat6', 'Yak5', 'Yp6', 'Yok6', 'Yei2', 'Yaap1', 'Yyut5', 'Yi4', 'Yim1', 'Yk5', 'Ye4', 'Yok2', 'Yaam6', 'Yat2', 'Yon6', 'Yei3', 'Yyu6', 'Yeot5', 'Yk4', 'Yai6', 'Yd', 'Yg3', 'Yei6', 'Yau2', 'Yok', 'Yau6', 'Yung3', 'Yim5', 'Yut6', 'Yit1', 'Yon3', 'Yat1', 'Yaam2', 'Yyut2', 'Yui6', 'Yt2', 'Yek6', 'Yt', 'Ye6', 'Yang3', 'Ying6', 'Yaau1', 'Yeon3', 'Yng', 'Yh', 'Yang4', 'Ying5', 'Yaap6', 'Yoeng3', 'Yyun4', 'You3', 'Yan5', 'Yat5', 'Yot1', 'Yun1', 'Yi3', 'Yaa1', 'Yaap4', 'You6', 'Yaang2', 'Yaap5', 'Yaa3', 'Yaak6', 'Yeng1', 'Yaak1', 'Yo5', 'Yoi4', 'Yam4', 'Yik1', 'Ye1', 'Yai5', 'Yung1', 'Yp2', 'Yui4', 'Yaak4', 'Yung2', 'Yak4', 'Yaat4', 'Yeoi4', 'Yut2', 'Yin5', 'Yaau4', 'Yap6', 'Yb', 'Yaam4', 'Yw', 'Yut3', 'Yong2', 'Yt6', 'Yaai6', 'Yap5', 'Yik5', 'Yun6', 'Yaam5', 'Yun5', 'Yik3', 'Ya2', 'Yyut6', 'Yon4', 'Yk1', 'Yit4', 'Yak6', 'Yaan2', 'Yuk1', 'Yai2', 'Yik2', 'Yaat2', 'Yo3', 'Ykw', 'Yn5', 'Yaa', 'Ye5', 'Yu4', 'Yei1', 'Yai3', 'Yyun5', 'Yip2', 'Yaau2', 'Yiu5', 'Ym4', 'Yeoi6', 'Yk', 'Ym6', 'Yoe1', 'Yeoi3', 'Yon', 'Yuk4', 'Yaai3', 'Yaa4', 'Yot6', 'Yaang1', 'Yei4', 'Yek1', 'Yo', 'Yp', 'Yo6', 'Yp4', 'Yan3', 'Yoi', 'Yap3', 'Yek3', 'Yim3', 'Yz', 'Yot2', 'Yoi6', 'Yit2', 'Yu5', 'Yaan3', 'Yan1', 'Yon5', 'Yp1', 'Yong5', 'Ygw', 'Yak', 'Yat6', 'Ying4', 'Yu2', 'Yf', 'Ya4', 'Yon1', 'You4', 'Yik6', 'Yui1', 'Yaat1', 'Yeot4', 'Yi2', 'Yaai1', 'Yek5', 'Ym3', 'Yong6', 'You5', 'Yyun1', 'Yn1', 'Yo2', 'Yip6', 'Yui3', 'Yaak5', 'Yyun2'}
yue_symbols = {
"Yeot3",
"Yip1",
"Yyu3",
"Yeng4",
"Yut5",
"Yaan5",
"Ym5",
"Yaan6",
"Yang1",
"Yun4",
"Yon2",
"Yui5",
"Yun2",
"Yat3",
"Ye",
"Yeot1",
"Yoeng5",
"Yoek2",
"Yam2",
"Yeon6",
"Yu6",
"Yiu3",
"Yaang6",
"Yp5",
"Yai4",
"Yoek4",
"Yit6",
"Yam5",
"Yoeng6",
"Yg1",
"Yk3",
"Yoe4",
"Yam3",
"Yc",
"Yyu4",
"Yyut1",
"Yiu4",
"Ying3",
"Yip3",
"Yaap3",
"Yau3",
"Yan4",
"Yau1",
"Yap4",
"Yk6",
"Yok3",
"Yai1",
"Yeot6",
"Yan2",
"Yoek6",
"Yt1",
"Yoi1",
"Yit5",
"Yn4",
"Yaau3",
"Yau4",
"Yuk6",
"Ys",
"Yuk",
"Yin6",
"Yung6",
"Ya",
"You",
"Yaai5",
"Yau5",
"Yoi3",
"Yaak3",
"Yaat3",
"Ying2",
"Yok5",
"Yeng2",
"Yyut3",
"Yam1",
"Yip5",
"You1",
"Yam6",
"Yaa5",
"Yi6",
"Yek4",
"Yyu2",
"Yuk5",
"Yaam1",
"Yang2",
"Yai",
"Yiu6",
"Yin4",
"Yok4",
"Yot3",
"Yui2",
"Yeoi5",
"Yyun6",
"Yyu5",
"Yoi5",
"Yeot2",
"Yim4",
"Yeoi2",
"Yaan1",
"Yang6",
"Yong1",
"Yaang4",
"Yung5",
"Yeon1",
"Yin2",
"Ya3",
"Yaang3",
"Yg",
"Yk2",
"Yaau5",
"Yut1",
"Yt5",
"Yip4",
"Yung4",
"Yj",
"Yong3",
"Ya1",
"Yg6",
"Yaau6",
"Yit3",
"Yun3",
"Ying1",
"Yn2",
"Yg4",
"Yl",
"Yp3",
"Yn3",
"Yak1",
"Yang5",
"Yoe6",
"You2",
"Yap2",
"Yak2",
"Yt3",
"Yot5",
"Yim2",
"Yi1",
"Yn6",
"Yaat5",
"Yaam3",
"Yoek5",
"Ye3",
"Yeon4",
"Yaa2",
"Yu3",
"Yim6",
"Ym",
"Yoe3",
"Yaai2",
"Ym2",
"Ya6",
"Yeng6",
"Yik4",
"Yot4",
"Yaai4",
"Yyun3",
"Yu1",
"Yoeng1",
"Yaap2",
"Yuk3",
"Yoek3",
"Yeng5",
"Yeoi1",
"Yiu2",
"Yok1",
"Yo1",
"Yoek1",
"Yoeng2",
"Yeon5",
"Yiu1",
"Yoeng4",
"Yuk2",
"Yat4",
"Yg5",
"Yut4",
"Yan6",
"Yin3",
"Yaa6",
"Yap1",
"Yg2",
"Yoe5",
"Yt4",
"Ya5",
"Yo4",
"Yyu1",
"Yak3",
"Yeon2",
"Yong4",
"Ym1",
"Ye2",
"Yaang5",
"Yoi2",
"Yeng3",
"Yn",
"Yyut4",
"Yau",
"Yaak2",
"Yaan4",
"Yek2",
"Yin1",
"Yi5",
"Yoe2",
"Yei5",
"Yaat6",
"Yak5",
"Yp6",
"Yok6",
"Yei2",
"Yaap1",
"Yyut5",
"Yi4",
"Yim1",
"Yk5",
"Ye4",
"Yok2",
"Yaam6",
"Yat2",
"Yon6",
"Yei3",
"Yyu6",
"Yeot5",
"Yk4",
"Yai6",
"Yd",
"Yg3",
"Yei6",
"Yau2",
"Yok",
"Yau6",
"Yung3",
"Yim5",
"Yut6",
"Yit1",
"Yon3",
"Yat1",
"Yaam2",
"Yyut2",
"Yui6",
"Yt2",
"Yek6",
"Yt",
"Ye6",
"Yang3",
"Ying6",
"Yaau1",
"Yeon3",
"Yng",
"Yh",
"Yang4",
"Ying5",
"Yaap6",
"Yoeng3",
"Yyun4",
"You3",
"Yan5",
"Yat5",
"Yot1",
"Yun1",
"Yi3",
"Yaa1",
"Yaap4",
"You6",
"Yaang2",
"Yaap5",
"Yaa3",
"Yaak6",
"Yeng1",
"Yaak1",
"Yo5",
"Yoi4",
"Yam4",
"Yik1",
"Ye1",
"Yai5",
"Yung1",
"Yp2",
"Yui4",
"Yaak4",
"Yung2",
"Yak4",
"Yaat4",
"Yeoi4",
"Yut2",
"Yin5",
"Yaau4",
"Yap6",
"Yb",
"Yaam4",
"Yw",
"Yut3",
"Yong2",
"Yt6",
"Yaai6",
"Yap5",
"Yik5",
"Yun6",
"Yaam5",
"Yun5",
"Yik3",
"Ya2",
"Yyut6",
"Yon4",
"Yk1",
"Yit4",
"Yak6",
"Yaan2",
"Yuk1",
"Yai2",
"Yik2",
"Yaat2",
"Yo3",
"Ykw",
"Yn5",
"Yaa",
"Ye5",
"Yu4",
"Yei1",
"Yai3",
"Yyun5",
"Yip2",
"Yaau2",
"Yiu5",
"Ym4",
"Yeoi6",
"Yk",
"Ym6",
"Yoe1",
"Yeoi3",
"Yon",
"Yuk4",
"Yaai3",
"Yaa4",
"Yot6",
"Yaang1",
"Yei4",
"Yek1",
"Yo",
"Yp",
"Yo6",
"Yp4",
"Yan3",
"Yoi",
"Yap3",
"Yek3",
"Yim3",
"Yz",
"Yot2",
"Yoi6",
"Yit2",
"Yu5",
"Yaan3",
"Yan1",
"Yon5",
"Yp1",
"Yong5",
"Ygw",
"Yak",
"Yat6",
"Ying4",
"Yu2",
"Yf",
"Ya4",
"Yon1",
"You4",
"Yik6",
"Yui1",
"Yaat1",
"Yeot4",
"Yi2",
"Yaai1",
"Yek5",
"Ym3",
"Yong6",
"You5",
"Yyun1",
"Yn1",
"Yo2",
"Yip6",
"Yui3",
"Yaak5",
"Yyun2",
}
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
symbols = sorted(set(symbols))
# print(len(symbols))
symbols+=["[","]"]##日文新增上升下降调型
symbols+=sorted(list(ko_symbols))
symbols+=sorted(list(yue_symbols))##新加的yue统一摆在后头#已查过开头加Y后没有重复韩文显然不会重复
symbols += ["[", "]"] ##日文新增上升下降调型
symbols += sorted(list(ko_symbols))
symbols += sorted(list(yue_symbols)) ##新加的yue统一摆在后头#已查过开头加Y后没有重复韩文显然不会重复
# print(len(symbols))
if __name__ == "__main__":
print(len(symbols))
'''
"""
粤语
732-353=379
韩文+粤语
732-322=410
'''
"""

View File

@ -510,12 +510,7 @@ class ToneSandhi:
# e.g. 走了, 看着, 去过
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
finals[-1] = finals[-1][:-1] + "5"
elif (
len(word) > 1
and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"} and word not in self.must_not_neural_tone_words:
finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
@ -525,25 +520,18 @@ class ToneSandhi:
finals[-1] = finals[-1][:-1] + "5"
# 个做量词
elif (
ge_idx >= 1
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
ge_idx >= 1 and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
) or word == "":
finals[ge_idx] = finals[ge_idx][:-1] + "5"
else:
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word)
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list):
# conventional neural in Chinese
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, [])
return finals
@ -561,9 +549,7 @@ class ToneSandhi:
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零
if word.find("") != -1 and all(
[item.isnumeric() for item in word if item != ""]
):
if word.find("") != -1 and all([item.isnumeric() for item in word if item != ""]):
return finals
# "一" between reduplication words shold be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
@ -697,13 +683,10 @@ class ToneSandhi:
return new_seg
# the first and the second words are all_tone_three
def _merge_continuous_three_tones(
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
def _merge_continuous_three_tones(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
@ -715,10 +698,7 @@ class ToneSandhi:
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
@ -732,13 +712,10 @@ class ToneSandhi:
return len(word) == 2 and word[0] == word[1]
# the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2(
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
def _merge_continuous_three_tones_2(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
@ -750,10 +727,7 @@ class ToneSandhi:
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:

File diff suppressed because one or more lines are too long

View File

@ -21,25 +21,29 @@ from .num import verbalize_digit
def _time_num2str(num_string: str) -> str:
"""A special case for verbalizing number in time."""
result = num2str(num_string.lstrip('0'))
if num_string.startswith('0'):
result = DIGITS['0'] + result
result = num2str(num_string.lstrip("0"))
if num_string.startswith("0"):
result = DIGITS["0"] + result
return result
# 时刻表达式
RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
r':([0-5][0-9])'
r'(:([0-5][0-9]))?')
RE_TIME = re.compile(
r"([0-1]?[0-9]|2[0-3])"
r":([0-5][0-9])"
r"(:([0-5][0-9]))?"
)
# 时间范围如8:30-12:30
RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])'
r':([0-5][0-9])'
r'(:([0-5][0-9]))?'
r'(~|-)'
r'([0-1]?[0-9]|2[0-3])'
r':([0-5][0-9])'
r'(:([0-5][0-9]))?')
RE_TIME_RANGE = re.compile(
r"([0-1]?[0-9]|2[0-3])"
r":([0-5][0-9])"
r"(:([0-5][0-9]))?"
r"(~|-)"
r"([0-1]?[0-9]|2[0-3])"
r":([0-5][0-9])"
r"(:([0-5][0-9]))?"
)
def replace_time(match) -> str:
@ -62,31 +66,33 @@ def replace_time(match) -> str:
second_2 = match.group(9)
result = f"{num2str(hour)}"
if minute.lstrip('0'):
if minute.lstrip("0"):
if int(minute) == 30:
result += ""
else:
result += f"{_time_num2str(minute)}"
if second and second.lstrip('0'):
if second and second.lstrip("0"):
result += f"{_time_num2str(second)}"
if is_range:
result += ""
result += f"{num2str(hour_2)}"
if minute_2.lstrip('0'):
if minute_2.lstrip("0"):
if int(minute) == 30:
result += ""
else:
result += f"{_time_num2str(minute_2)}"
if second_2 and second_2.lstrip('0'):
if second_2 and second_2.lstrip("0"):
result += f"{_time_num2str(second_2)}"
return result
RE_DATE = re.compile(r'(\d{4}|\d{2})年'
r'((0?[1-9]|1[0-2])月)?'
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
RE_DATE = re.compile(
r"(\d{4}|\d{2})年"
r"((0?[1-9]|1[0-2])月)?"
r"(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?"
)
def replace_date(match) -> str:
@ -110,8 +116,7 @@ def replace_date(match) -> str:
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
RE_DATE2 = re.compile(
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
RE_DATE2 = re.compile(r"(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])")
def replace_date2(match) -> str:

View File

@ -18,10 +18,7 @@ from pypinyin.constants import SUPPORT_UCS4
# 全角半角转换
# 英文字符全角 -> 半角映射表 (num: 52)
F2H_ASCII_LETTERS = {
ord(char) + 65248: ord(char)
for char in string.ascii_letters
}
F2H_ASCII_LETTERS = {ord(char) + 65248: ord(char) for char in string.ascii_letters}
# 英文字符半角 -> 全角映射表
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
@ -37,26 +34,29 @@ F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
# 空格 (num: 1)
F2H_SPACE = {'\u3000': ' '}
H2F_SPACE = {' ': '\u3000'}
F2H_SPACE = {"\u3000": " "}
H2F_SPACE = {" ": "\u3000"}
# 非"有拼音的汉字"的字符串可用于NSW提取
if SUPPORT_UCS4:
RE_NSW = re.compile(r'(?:[^'
r'\u3007' #
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
r'])+')
RE_NSW = re.compile(
r"(?:[^"
r"\u3007" #
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
r"\U00020000-\U0002A6DF" # CJK扩展B:[20000-2A6DF]
r"\U0002A703-\U0002B73F" # CJK扩展C:[2A700-2B73F]
r"\U0002B740-\U0002B81D" # CJK扩展D:[2B740-2B81D]
r"\U0002F80A-\U0002FA1F" # CJK兼容扩展:[2F800-2FA1F]
r"])+"
)
else:
RE_NSW = re.compile( # pragma: no cover
r'(?:[^'
r'\u3007' #
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
r'])+')
r"(?:[^"
r"\u3007" #
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
r"])+"
)

View File

@ -15,23 +15,26 @@
Rules to verbalize numbers into Chinese characters.
https://zh.wikipedia.org/wiki/中文数字#現代中文
"""
import re
from collections import OrderedDict
from typing import List
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
UNITS = OrderedDict({
1: '',
2: '',
3: '',
4: '',
8: '亿',
})
DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")}
UNITS = OrderedDict(
{
1: "",
2: "",
3: "",
4: "",
8: "亿",
}
)
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)"
# 分数表达式
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)")
def replace_frac(match) -> str:
@ -52,7 +55,7 @@ def replace_frac(match) -> str:
# 百分数表达式
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%")
def replace_percentage(match) -> str:
@ -72,7 +75,7 @@ def replace_percentage(match) -> str:
# 整数表达式
# 带负号的整数 -10
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
RE_INTEGER = re.compile(r"(-)" r"(\d+)")
def replace_negative_num(match) -> str:
@ -92,7 +95,7 @@ def replace_negative_num(match) -> str:
# 编号-无符号整形
# 00078
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
RE_DEFAULT_NUM = re.compile(r"\d{3}\d*")
def replace_default_num(match):
@ -110,15 +113,11 @@ def replace_default_num(match):
# RE_ASMD = re.compile(
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
RE_ASMD = re.compile(
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
)
asmd_map = {"+": "", "-": "", "×": "", "÷": "", "=": "等于"}
asmd_map = {
'+': '',
'-': '',
'×': '',
'÷': '',
'=': '等于'
}
def replace_asmd(match) -> str:
"""
@ -132,24 +131,25 @@ def replace_asmd(match) -> str:
# 次方专项
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+")
power_map = {
'': '0',
'¹': '1',
'²': '2',
'³': '3',
'': '4',
'': '5',
'': '6',
'': '7',
'': '8',
'': '9',
'ˣ': 'x',
'ʸ': 'y',
'': 'n'
"": "0",
"¹": "1",
"²": "2",
"³": "3",
"": "4",
"": "5",
"": "6",
"": "7",
"": "8",
"": "9",
"ˣ": "x",
"ʸ": "y",
"": "n",
}
def replace_power(match) -> str:
"""
Args:
@ -166,10 +166,10 @@ def replace_power(match) -> str:
# 数字表达式
# 纯小数
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))")
# 正整数 + 量词
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))")
def replace_positive_quantifier(match) -> str:
@ -220,7 +220,9 @@ RE_RANGE = re.compile(
[-~] # 匹配范围分隔符
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
""", re.VERBOSE)
""",
re.VERBOSE,
)
def replace_range(match) -> str:
@ -239,7 +241,9 @@ def replace_range(match) -> str:
# ~至表达式
RE_TO_RANGE = re.compile(
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)"
)
def replace_to_range(match) -> str:
"""
@ -248,71 +252,66 @@ def replace_to_range(match) -> str:
Returns:
str
"""
result = match.group(0).replace('~', '')
result = match.group(0).replace("~", "")
return result
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
stripped = value_string.lstrip('0')
def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
stripped = value_string.lstrip("0")
if len(stripped) == 0:
return []
elif len(stripped) == 1:
if use_zero and len(stripped) < len(value_string):
return [DIGITS['0'], DIGITS[stripped]]
return [DIGITS["0"], DIGITS[stripped]]
else:
return [DIGITS[stripped]]
else:
largest_unit = next(
power for power in reversed(UNITS.keys()) if power < len(stripped))
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
first_part = value_string[:-largest_unit]
second_part = value_string[-largest_unit:]
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
second_part)
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
def verbalize_cardinal(value_string: str) -> str:
if not value_string:
return ''
return ""
# 000 -> '零' , 0 -> '零'
value_string = value_string.lstrip('0')
value_string = value_string.lstrip("0")
if len(value_string) == 0:
return DIGITS['0']
return DIGITS["0"]
result_symbols = _get_value(value_string)
# verbalized number starting with '一十*' is abbreviated as `十*`
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
'1'] and result_symbols[1] == UNITS[1]:
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS["1"] and result_symbols[1] == UNITS[1]:
result_symbols = result_symbols[1:]
return ''.join(result_symbols)
return "".join(result_symbols)
def verbalize_digit(value_string: str, alt_one=False) -> str:
result_symbols = [DIGITS[digit] for digit in value_string]
result = ''.join(result_symbols)
result = "".join(result_symbols)
if alt_one:
result = result.replace("", "")
return result
def num2str(value_string: str) -> str:
integer_decimal = value_string.split('.')
integer_decimal = value_string.split(".")
if len(integer_decimal) == 1:
integer = integer_decimal[0]
decimal = ''
decimal = ""
elif len(integer_decimal) == 2:
integer, decimal = integer_decimal
else:
raise ValueError(
f"The value string: '${value_string}' has more than one point in it."
)
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
result = verbalize_cardinal(integer)
decimal = decimal.rstrip('0')
decimal = decimal.rstrip("0")
if decimal:
# '.22' is verbalized as '零点二二'
# '3.20' is verbalized as '三点二
result = result if result else ""
result += '' + verbalize_digit(decimal)
result += "" + verbalize_digit(decimal)
return result

View File

@ -21,10 +21,8 @@ from .num import verbalize_digit
# 移动139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
# 联通130、131、132、156、155、186、185、176
# 电信133、153、189、180、181、177
RE_MOBILE_PHONE = re.compile(
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
RE_TELEPHONE = re.compile(
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
RE_MOBILE_PHONE = re.compile(r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
RE_TELEPHONE = re.compile(r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
# 全国统一的号码400开头
RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
@ -32,14 +30,12 @@ RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
def phone2str(phone_string: str, mobile=True) -> str:
if mobile:
sp_parts = phone_string.strip('+').split()
result = ''.join(
[verbalize_digit(part, alt_one=True) for part in sp_parts])
sp_parts = phone_string.strip("+").split()
result = "".join([verbalize_digit(part, alt_one=True) for part in sp_parts])
return result
else:
sil_parts = phone_string.split('-')
result = ''.join(
[verbalize_digit(part, alt_one=True) for part in sil_parts])
sil_parts = phone_string.split("-")
result = "".join([verbalize_digit(part, alt_one=True) for part in sil_parts])
return result

View File

@ -17,7 +17,7 @@ from .num import num2str
# 温度表达式,温度会影响负号的读法
# -3°C 零下三度
RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
RE_TEMPERATURE = re.compile(r"(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)")
measure_dict = {
"cm2": "平方厘米",
"cm²": "平方厘米",
@ -35,7 +35,7 @@ measure_dict = {
"ml": "毫升",
"m": "",
"mm": "毫米",
"s": ""
"s": "",
}

Some files were not shown because too many files have changed in this diff Show More