more code refactor

This commit is contained in:
Blaise 2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@ -16,7 +16,7 @@ __all__ = [
"DistributedBucketSampler",
]
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]):
@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
sort batches
"""
def __init__(self,
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int]=None,
rank: Optional[int]=None,
shuffle: bool=True,
seed: int=0,
drop_last: bool=False,
batch_size: int=32) -> None:
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 32,
) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(
self.
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) /
self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
len(self.dataset) / self.num_replicas
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]):
id_with_lengths.sort(key=lambda x: x[1])
return id_with_lengths
def make_buckets(self, bucket_width: float=2.0):
def make_buckets(self, bucket_width: float = 2.0):
buckets = []
cur = []
max_sec = bucket_width
@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size:(b + 1) *
grouped_batch_size] for b in range(n_batch)
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
shuffle(batches)
indices = list(itertools.chain(*batches))
@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size /
len(indices)))[:padding_size]
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

View File

@ -6,14 +6,21 @@ from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
def __init__(
self,
config,
train_semantic_path,
train_phoneme_path,
dev_semantic_path=None,
dev_phoneme_path=None,
):
super().__init__()
self.config = config
self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config['data']['num_workers']
self.num_workers = self.config["data"]["num_workers"]
def prepare_data(self):
pass
@ -22,8 +29,9 @@ class Text2SemanticDataModule(LightningDataModule):
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config['data']['max_sec'],
pad_val=self.config['data']['pad_val'])
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path,
@ -33,9 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size = self.config['train']['batch_size']
sampler = DistributedBucketSampler(
self._train_dataset, batch_size=batch_size)
batch_size = self.config["train"]["batch_size"]
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,
batch_size=batch_size,
@ -43,7 +50,7 @@ class Text2SemanticDataModule(LightningDataModule):
collate_fn=self._train_dataset.collate,
num_workers=self.num_workers,
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
def val_dataloader(self):
@ -52,9 +59,9 @@ class Text2SemanticDataModule(LightningDataModule):
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers,12),
num_workers=max(self.num_workers, 12),
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
# 这个会使用到嘛?
@ -63,4 +70,5 @@ class Text2SemanticDataModule(LightningDataModule):
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate)
collate_fn=self._train_dataset.collate,
)

View File

@ -1,21 +1,24 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback,os
import traceback, os
from typing import Dict
from typing import List
import numpy as np
import pandas as pd
import torch,json
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from text import cleaned_text_to_sequence
# from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0]
ndim = seq.ndim
@ -28,18 +31,20 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
ndim - axis - 1)
padded_seq = np.pad(
seq, padding, mode='constant', constant_values=pad_value)
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
return batch
class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training."""
def __init__(self,
def __init__(
self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
@ -48,24 +53,30 @@ class Text2SemanticDataset(Dataset):
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25) -> None:
max_ps_ratio: int = 25,
) -> None:
super().__init__()
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
)
# get dict
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.basename(phoneme_path)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
assert os.path.exists(self.path6)
self.phoneme_data={}
with open(self.path2,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp=line.split("\t")
if(len(tmp)!=4):continue
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
# pad for semantic tokens
@ -74,7 +85,7 @@ class Text2SemanticDataset(Dataset):
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
# self.hz=int(data[:-2])#
self.hz=int(os.environ.get("hz","25hz")[:-2])
self.hz = int(os.environ.get("hz", "25hz")[:-2])
# max seconds of semantic token
self.max_sec = max_sec
@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset):
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self):
semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys())
@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data['item_name'][i]
item_name = self.semantic_data["item_name"][i]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data['semantic_audio'][i]
semantic_str = self.semantic_data["semantic_audio"][i]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1
continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
try:
phoneme_ids = cleaned_text_to_sequence(phoneme)
@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@ -160,16 +176,16 @@ class Text2SemanticDataset(Dataset):
idx += 1
self.item_names.append(item_name)
min_num=100#20直接不补#30补了也不存ckpt
leng =len(self.semantic_phoneme)
if(leng<min_num):
tmp1=self.semantic_phoneme
tmp2=self.item_names
self.semantic_phoneme=[]
self.item_names=[]
for _ in range(max(2,int(min_num/leng))):
self.semantic_phoneme+=tmp1
self.item_names+=tmp2
min_num = 100 # 20直接不补#30补了也不存ckpt
leng = len(self.semantic_phoneme)
if leng < min_num:
tmp1 = self.semantic_phoneme
tmp2 = self.item_names
self.semantic_phoneme = []
self.item_names = []
for _ in range(max(2, int(min_num / leng))):
self.semantic_phoneme += tmp1
self.item_names += tmp2
if num_not_in > 0:
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset):
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
)
'''
"""
there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463
'''
"""
# 345410 for LibriTTS
print("dataset.__len__():", self.__len__())
@ -204,22 +220,24 @@ class Text2SemanticDataset(Dataset):
# semantic tokens target
semantic_ids_len = len(semantic_ids)
flag=0
flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name)
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
else:flag=1
if(flag==1):
if os.path.exists(path_bert) == True:
bert_feature = torch.load(path_bert, map_location="cpu")
else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature=None
bert_feature = None
else:
assert bert_feature.shape[-1] == len(phoneme_ids)
return {
'idx': idx,
'phoneme_ids': phoneme_ids,
'phoneme_ids_len': phoneme_ids_len,
'semantic_ids': semantic_ids,
'semantic_ids_len': semantic_ids_len,
'bert_feature': bert_feature,
"idx": idx,
"phoneme_ids": phoneme_ids,
"phoneme_ids_len": phoneme_ids_len,
"semantic_ids": semantic_ids,
"semantic_ids_len": semantic_ids_len,
"bert_feature": bert_feature,
}
def get_sample_length(self, idx: int):
@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset):
semantic_ids_lens: List[int] = []
# return
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
@ -256,9 +273,9 @@ class Text2SemanticDataset(Dataset):
bert_padded.zero_()
for idx, item in enumerate(examples):
bert = item['bert_feature']
if(bert!=None):
bert_padded[idx, :, :bert.shape[-1]] = bert
bert = item["bert_feature"]
if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert
return {
# List[int]
@ -276,20 +293,20 @@ class Text2SemanticDataset(Dataset):
}
if __name__ == '__main__':
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
if __name__ == "__main__":
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset(
phoneme_path=root_dir + 'phoneme_train.npy',
semantic_path=root_dir + 'semantic_train.tsv')
phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False)
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
)
for i, batch in enumerate(dataloader):
if(i%1000==0):print(i)
if i % 1000 == 0:
print(i)
# if i == 0:
# print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],

View File

@ -1,5 +1,6 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os,sys
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict
@ -12,29 +13,35 @@ from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir,is_train=True):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1=config.get("pretrained_s1")
if(pretrained_s1 and is_train):
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
)
if is_train:
self.automatic_optimization = False
self.save_hyperparameters()
self.eval_dir = output_dir / 'eval'
self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
loss, acc = self.model.forward(
batch['phoneme_ids'], batch['phoneme_ids_len'],
batch['semantic_ids'], batch['semantic_ids_len'],
batch['bert_feature'])
batch["phoneme_ids"],
batch["phoneme_ids_len"],
batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
)
self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0:
opt.step()
@ -47,22 +54,27 @@ class Text2SemanticLightningModule(LightningModule):
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
self.log(
"lr",
scheduler.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
self.log(
f"top_{self.top_k}_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
def validation_step(self, batch: Dict, batch_idx: int):
return
def validation_step(self, batch: Dict, batch_idx: int):return
# # get loss
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'],
@ -100,10 +112,9 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append([
name_param_pair[0]
for name_param_pair in self.model.named_parameters()
])
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,
@ -111,18 +122,19 @@ class Text2SemanticLightningModule(LightningModule):
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=1000, )
clipping_update_period=1000,
)
return {
"optimizer": lm_opt,
"lr_scheduler": {
"scheduler":
WarmupCosineLRSchedule(
"scheduler": WarmupCosineLRSchedule(
lm_opt,
init_lr=self.config['optimizer']['lr_init'],
peak_lr=self.config['optimizer']['lr'],
end_lr=self.config['optimizer']['lr_end'],
warmup_steps=self.config['optimizer']['warmup_steps'],
total_steps=self.config['optimizer']['decay_steps'])
}
init_lr=self.config["optimizer"]["lr_init"],
peak_lr=self.config["optimizer"]["lr"],
end_lr=self.config["optimizer"]["lr_end"],
warmup_steps=self.config["optimizer"]["warmup_steps"],
total_steps=self.config["optimizer"]["decay_steps"],
)
},
}

View File

@ -3,7 +3,12 @@ import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask
from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
@ -22,35 +27,39 @@ default_config = {
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024
"EOS": 1024,
}
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config['model']["hidden_dim"]
self.embedding_dim = config['model']["embedding_dim"]
self.num_head = config['model']["head"]
self.num_layers = config['model']["n_layer"]
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config['model']["vocab_size"]
self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
self.p_dropout = config['model']["dropout"]
self.EOS = config['model']["EOS"]
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout)
self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder(
TransformerEncoderLayer(
@ -59,28 +68,30 @@ class Text2SemanticDecoder(nn.Module):
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first, ),
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None, )
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(
self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS, )
ignore_index=self.EOS,
)
def forward(self, x, x_lens, y, y_lens, bert_feature):
'''
"""
x: phoneme_ids
y: semantic_ids
'''
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens)
@ -102,18 +113,23 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True, )
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1, ),
diagonal=1,
),
(x_len, 0),
value=False, )
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len))
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
@ -122,26 +138,28 @@ class Text2SemanticDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask, )
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction='sum')
loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(self,
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int=-100,
early_stop_num: int=-1,
temperature: float=1.0):
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
@ -159,35 +177,37 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True, )
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False, )
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask, )
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature)
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len
) > early_stop_num:
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction')
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
@ -198,23 +218,24 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(
y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(self,
x,#####全部文本token
def infer_panel(
self,
x, #####全部文本token
x_lens,
prompts,####参考音频token
prompts, ####参考音频token
bert_feature,
top_k: int=-100,
early_stop_num: int=-1,
temperature: float=1.0):
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
@ -224,75 +245,81 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache={
"all_stage":self.num_layers,
"k":[None]*self.num_layers,###根据配置自己手写
"v":[None]*self.num_layers,
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb":None,##只需要对最新的samples求emb再拼历史的就行
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer":1,
"stage":0
"first_infer": 1,
"stage": 0,
}
for idx in tqdm(range(1500)):
if(cache["first_infer"]==1):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
cache["y_emb"]=y_emb
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
if(cache["first_infer"]==1):
if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1)
else:
xy_pos=y_pos[:,-1:]
xy_pos = y_pos[:, -1:]
y_len = y_pos.shape[1]
###以下3个不做缓存
if (cache["first_infer"] == 1):
if cache["first_infer"] == 1:
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True, )
y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False, )
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
else:
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
# pdb.set_trace()
###缓存重头戏
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,cache=cache )
logits = self.ar_predict_layer(xy_dec[:, -1])##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len
) > early_stop_num:
samples = sample(
logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction')
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
cache["first_infer"]=0
return y,idx
cache["first_infer"] = 0
return y, idx

View File

@ -2,6 +2,7 @@
import torch
import torch.nn.functional as F
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
@ -9,7 +10,7 @@ def sequence_mask(length, max_length=None):
return x.unsqueeze(0) < length.unsqueeze(1)
def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
@ -38,11 +39,9 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1):
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
@ -53,16 +52,14 @@ def top_k_top_p_filtering(logits,
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep),
logits.size(-1)) # Safety check
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
@ -70,13 +67,13 @@ def top_k_top_p_filtering(logits,
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
@ -100,6 +97,8 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
from typing import Optional, Tuple
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
@ -115,7 +114,7 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
previous_tokens=previous_tokens.squeeze()
previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
@ -159,4 +158,3 @@ def sample(
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

View File

@ -13,7 +13,9 @@ from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward=multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
@ -89,53 +91,58 @@ class MultiheadAttention(Module):
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None, ) -> None:
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (self.kdim == embed_dim and
self.vdim == embed_dim)
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs))
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs))
torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs))
torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs))
torch.empty(3 * embed_dim, **factory_kwargs)
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters()
else:
@ -143,7 +150,8 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
@ -156,7 +164,8 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
@ -194,10 +203,11 @@ class MultiheadAttention(Module):
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor]=None,
need_weights: bool=True,
attn_mask: Optional[Tensor]=None,
average_attn_weights: bool=True,cache=None
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -251,23 +261,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask):
key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (self.in_proj_bias is not None and
query.dtype != self.in_proj_bias.dtype):
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (self.in_proj_weight is not None and
query.dtype != self.in_proj_weight.dtype):
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
@ -288,29 +301,41 @@ class MultiheadAttention(Module):
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input")
"key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (query, key, value, self.in_proj_weight,
self.in_proj_bias, self.out_proj.weight,
self.out_proj.bias, )
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args]):
why_not_fast_path = (
"some Tensor argument is neither CUDA nor CPU")
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]):
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
@ -322,17 +347,21 @@ class MultiheadAttention(Module):
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask
if key_padding_mask is not None else attn_mask,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1 if key_padding_mask is not None else 0
if attn_mask is not None else None, )
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}")
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
@ -343,9 +372,7 @@ class MultiheadAttention(Module):
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [
x.transpose(1, 0) for x in (query, key, value)
]
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
@ -370,7 +397,9 @@ class MultiheadAttention(Module):
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,cache=cache )
average_attn_weights=average_attn_weights,
cache=cache,
)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
@ -390,7 +419,9 @@ class MultiheadAttention(Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,cache=cache )
average_attn_weights=average_attn_weights,
cache=cache,
)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:

View File

@ -10,7 +10,8 @@ class TokenEmbedding(nn.Module):
self,
embedding_dim: int,
vocab_size: int,
dropout: float=0.0, ):
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
@ -24,7 +25,7 @@ class TokenEmbedding(nn.Module):
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index:index + 1]
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
x = self.word_embeddings(x)
@ -36,9 +37,10 @@ class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float=0.0,
scale: bool=False,
alpha: bool=False, ):
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
@ -59,13 +61,14 @@ class SinePositionalEmbedding(nn.Module):
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(
0, x.size(1), dtype=torch.float32).unsqueeze(1)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.embedding_dim))
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
@ -74,5 +77,5 @@ class SinePositionalEmbedding(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)

View File

@ -12,14 +12,16 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
"""
def __init__(self,
def __init__(
self,
optimizer,
init_lr,
peak_lr,
end_lr,
warmup_steps=10000,
total_steps=400000,
current_step=0):
current_step=0,
):
self.init_lr = init_lr
self.peak_lr = peak_lr
self.end_lr = end_lr
@ -33,10 +35,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
self._last_lr = [self.lr]
def set_lr(self, lr):
self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
for g in self.optimizer.param_groups:
# g['lr'] = lr
g['lr'] = self.end_lr###锁定用线性
g["lr"] = self.end_lr ###锁定用线性
def step(self):
if self._current_step < self.warmup_steps:
@ -47,7 +49,8 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
else:
decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps)
self.total_steps - self.warmup_steps
)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
@ -55,25 +58,19 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
self.set_lr(lr)
self.lr = lr
self._current_step += 1
return self.lr
if __name__ == '__main__':
if __name__ == "__main__":
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0)
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
)
lrs = []
for i in range(25000):
s.step()

View File

@ -1,9 +1,16 @@
from torch.nn.functional import *
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
# import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched(
query: Tensor,
key: Tensor,
@ -29,7 +36,8 @@ def multi_head_attention_forward_patched(
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,cache=None
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -105,7 +113,17 @@ def multi_head_attention_forward_patched(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward,
@ -134,10 +152,13 @@ def multi_head_attention_forward_patched(
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
average_attn_weights=average_attn_weights,cache=cache
average_attn_weights=average_attn_weights,
cache=cache,
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
@ -159,7 +180,7 @@ def multi_head_attention_forward_patched(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
target_type=query.dtype,
)
if is_causal and attn_mask is None:
@ -184,56 +205,81 @@ def multi_head_attention_forward_patched(
check_other=False,
)
if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no
# longer causal.
is_causal = False
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
assert (
q_proj_weight is not None
), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
if(cache!=None):
if(cache["first_infer"]==1):
cache["k"][cache["stage"]]=k
q, k, v = _in_projection(
query,
key,
value,
q_proj_weight,
k_proj_weight,
v_proj_weight,
b_q,
b_k,
b_v,
)
if cache != None:
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
# print(0,cache["k"].shape)
cache["v"][cache["stage"]]=v
else:###12个layer每个都要留自己的cache_kv
cache["v"][cache["stage"]] = v
else: ###12个layer每个都要留自己的cache_kv
# print(1,cache["k"].shape)
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
cache["k"][cache["stage"]] = torch.cat(
[cache["k"][cache["stage"]], k], 0
) ##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
# print(2, cache["k"].shape)
src_len = cache["k"][cache["stage"]].shape[0]
k=cache["k"][cache["stage"]]
v=cache["v"][cache["stage"]]
k = cache["k"][cache["stage"]]
v = cache["v"][cache["stage"]]
# if attn_mask is not None:
# attn_mask=attn_mask[-1:,]
# print(attn_mask.shape,attn_mask)
@ -255,14 +301,20 @@ def multi_head_attention_forward_patched(
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
@ -286,26 +338,34 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
@ -316,10 +376,15 @@ def multi_head_attention_forward_patched(
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
assert key_padding_mask.shape == (
bsz,
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
else:
@ -337,10 +402,14 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
@ -349,7 +418,9 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@ -377,8 +448,12 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)
) + torch.rand_like(deriv)
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@ -75,7 +76,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
(d, ) = ctx.saved_tensors
(d,) = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.043637
ceil = 1.2
@ -100,7 +101,8 @@ class ActivationBalancerFunction(torch.autograd.Function):
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int, ) -> Tensor:
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
@ -125,7 +127,12 @@ class ActivationBalancerFunction(torch.autograd.Function):
scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return (x_grad - neg_delta_grad, None, None, None, )
return (
x_grad - neg_delta_grad,
None,
None,
None,
)
def _compute_scale_factor(
@ -134,7 +141,8 @@ def _compute_scale_factor(
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -145,12 +153,13 @@ def _compute_scale_factor(
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = (
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor)
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor)
min=0, max=max_factor
)
return below_threshold - above_threshold
@ -161,7 +170,8 @@ def _compute_sign_factor(
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -171,18 +181,18 @@ def _compute_sign_factor(
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - proportion_positive) *
(gain_factor / min_positive)).clamp_(
min=0, max=max_factor)
factor1 = (
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((proportion_positive - max_positive) *
(gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor)
factor2 = (
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
).clamp_(min=0, max=max_factor)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
@ -233,14 +243,15 @@ class ActivationBalancer(torch.nn.Module):
self,
num_channels: int,
channel_dim: int,
min_positive: float=0.05,
max_positive: float=0.95,
max_factor: float=0.04,
sign_gain_factor: float=0.01,
scale_gain_factor: float=0.02,
min_abs: float=0.2,
max_abs: float=100.0,
min_prob: float=0.1, ):
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.04,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2,
max_abs: float = 100.0,
min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module):
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
torch.jit.is_tracing()):
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
return _no_op(x)
count = self.cpu_count
@ -276,7 +286,7 @@ class ActivationBalancer(torch.nn.Module):
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5**(1 + (count / 4000.0)))
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
if random.random() < prob:
sign_gain_factor = 0.5
@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module):
self.min_positive,
self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
else:
sign_factor = None
@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module):
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
return ActivationBalancerFunction.apply(
x,
scale_factor,
sign_factor,
self.channel_dim, )
self.channel_dim,
)
else:
return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0,
min_prob=0.25) -> nn.Sequential:
def BalancedDoubleSwish(
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
return nn.Sequential(
balancer,
DoubleSwish(), )
DoubleSwish(),
)

View File

@ -28,24 +28,26 @@ class LayerNorm(nn.Module):
def __init__(
self,
normalized_shape: _shape_t,
eps: float=1e-5,
elementwise_affine: bool=True,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None, ) -> None:
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape, ) # type: ignore[assignment]
self.normalized_shape = tuple(
normalized_shape) # type: ignore[arg-type]
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@ -57,36 +59,43 @@ class LayerNorm(nn.Module):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (F.layer_norm(
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps, ), embedding, )
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight,
self.bias, self.eps)
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__))
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float=1e-5,
eps: float = 1e-5,
device=None,
dtype=None, ) -> None:
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
@ -123,9 +132,11 @@ class TransformerEncoder(nn.Module):
def forward(
self,
src: Tensor,
mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,
return_layer_states: bool=False,cache=None ) -> Tensor:
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module):
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
layer_states.append(output[0])
if self.norm is not None:
@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module):
output = src
for mod in self.layers:
output = mod(output,
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None:
output = self.norm(output)
@ -171,40 +187,44 @@ class TransformerEncoderLayer(nn.Module):
self,
d_model: int,
nhead: int,
dim_feedforward: int=2048,
dropout: float=0.1,
activation: Union[str, Callable[[Tensor], Tensor]]=F.relu,
batch_first: bool=False,
norm_first: bool=False,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module=nn.Linear,
linear2_self_attention_cls: nn.Module=nn.Linear,
linear1_feedforward_cls: nn.Module=nn.Linear,
linear2_feedforward_cls: nn.Module=nn.Linear,
layer_norm_cls: nn.Module=LayerNorm,
layer_norm_eps: float=1e-5,
adaptive_layer_norm=False, ) -> None:
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead)
# import os
# os._exit(2333333)
self.self_attn = MultiheadAttention(
d_model,#512 16
d_model, # 512 16
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs, )
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward,
**factory_kwargs)
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model,
**factory_kwargs)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module):
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@ -251,8 +269,10 @@ class TransformerEncoderLayer(nn.Module):
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor:
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask):
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module):
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,cache=cache )
src_key_padding_mask,
cache=cache,
)
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache),
stage_embedding, )
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
@ -298,9 +322,11 @@ class TransformerEncoderLayer(nn.Module):
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor:
key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask)
#torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os
# os._exit(23333)
x = self.self_attn(
@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module):
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,cache=cache )[0]
need_weights=False,
cache=cache,
)[0]
return self.dropout1(x)
# feed forward block
@ -328,20 +356,23 @@ class AdaptiveLayerNorm(nn.Module):
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor:
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return weight * self.norm(input) + bias
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

View File

@ -27,46 +27,44 @@ class GruutPhonemizer:
"": "",
"": "",
"«": "«",
"»": "»"
"»": "»",
}
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text)
return text.strip()
def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes:
return ''
if word.phonemes[0] in ['', '|']:
return ""
if word.phonemes[0] in ["", "|"]:
return word.text.strip()
phonemes = ''.join(word.phonemes)
phonemes = "".join(word.phonemes)
# remove modifier characters ˈˌː with regex
phonemes = re.sub(r'[ˈˌː͡]', '', phonemes)
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
return phonemes.strip()
def phonemize(self, text: str, espeak: bool=False) -> str:
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [
sent
for sent in self._phonemizer(
text_to_phonemize, lang="en-us", espeak=espeak)
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
return ' '.join(words)
return " ".join(words)
def transform(self, phonemes):
# convert phonemes to ids
# dictionary is in symbols.py
return [
self.symbol_to_id[p] for p in phonemes
if p in self.symbol_to_id.keys()
]
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
if __name__ == "__main__":

View File

@ -1,7 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
PAD = '_'
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")

View File

@ -11,22 +11,24 @@ def load_yaml_config(path):
def save_config_to_yaml(config, path):
assert path.endswith('.yaml')
with open(path, 'w') as f:
assert path.endswith(".yaml")
with open(path, "w") as f:
f.write(yaml.dump(config))
f.close()
def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args)
if not name.startswith('_'))
with open(path, 'a') as args_file:
args_file.write('==> torch version: {}\n'.format(torch.__version__))
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
args_file.write('==> Cmd:\n')
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write('\n==> args:\n')
args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()):
args_file.write(' %s: %s\n' % (str(k), str(v)))
args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close()

View File

@ -11,23 +11,30 @@ logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import (
Wav2Vec2FeatureExtractor,
HubertModel,
Wav2Vec2Model,
)
import utils
import torch.nn as nn
cnhubert_base_path=None
cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self):
super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path
)
def forward(self, x):
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats
# class CNHubertLarge(nn.Module):
# def __init__(self):
# super().__init__()
@ -59,12 +66,12 @@ class CNHubert(nn.Module):
# return feats
def get_model():
model = CNHubert()
model.eval()
return model
# def get_large_model():
# model = CNHubertLarge()
# model.eval()
@ -80,18 +87,18 @@ def get_model():
# model.eval()
# return model
def get_content(hmodel, wav_16k_tensor):
with torch.no_grad():
feats = hmodel(wav_16k_tensor)
return feats.transpose(1,2)
return feats.transpose(1, 2)
if __name__ == '__main__':
if __name__ == "__main__":
model = get_model()
src_path = "/Users/Shared/原音频2.wav"
wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
model = model
wav_16k_tensor = wav_16k_tensor
feats = get_content(model,wav_16k_tensor)
feats = get_content(model, wav_16k_tensor)
print(feats.shape)

View File

@ -3,13 +3,15 @@ import torch
def get_model():
import whisper
model = whisper.load_model("small", device='cpu')
model = whisper.load_model("small", device="cpu")
return model.encoder
def get_content(model=None, wav_16k_tensor=None):
from whisper import log_mel_spectrogram, pad_or_trim
dev = next(model.parameters()).device
mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
# if torch.cuda.is_available():
@ -17,6 +19,7 @@ def get_content(model=None, wav_16k_tensor=None):
feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1,2)
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
:1, :feature_len, :
].transpose(1, 2)
return feature

View File

@ -4,11 +4,22 @@ from torch import nn
from torch.nn import functional as F
from module import commons
from module. modules import LayerNorm
from module.modules import LayerNorm
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,isflow=False, **kwargs):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
isflow=False,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
@ -24,15 +35,34 @@ class Encoder(nn.Module):
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name='weight')
cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
def forward(self, x, x_mask, g=None):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
@ -43,11 +73,10 @@ class Encoder(nn.Module):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(
x,
g_l,
torch.IntTensor([self.hidden_channels]))
x, g_l, torch.IntTensor([self.hidden_channels])
)
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
@ -60,7 +89,18 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
@ -79,11 +119,33 @@ class Decoder(nn.Module):
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
@ -91,7 +153,9 @@ class Decoder(nn.Module):
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
@ -111,7 +175,18 @@ class Decoder(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % n_heads == 0
@ -136,8 +211,14 @@ class MultiHeadAttention(nn.Module):
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
@ -166,28 +247,46 @@ class MultiHeadAttention(nn.Module):
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
@ -217,10 +316,13 @@ class MultiHeadAttention(nn.Module):
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
@ -234,10 +336,14 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final
def _absolute_position_to_relative_position(self, x):
@ -247,11 +353,13 @@ class MultiHeadAttention(nn.Module):
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
@ -267,7 +375,16 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -329,27 +446,43 @@ class Depthwise_Separable_Conv1D(nn.Module):
padding=0,
dilation=1,
bias=True,
padding_mode='zeros', # TODO: refine this type
padding_mode="zeros", # TODO: refine this type
device=None,
dtype=None
dtype=None,
):
super().__init__()
self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
groups=in_channels, stride=stride, padding=padding, dilation=dilation, bias=bias,
padding_mode=padding_mode, device=device, dtype=dtype)
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias,
device=device, dtype=dtype)
self.depth_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
self.point_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input):
return self.point_conv(self.depth_conv(input))
def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name='weight')
self.point_conv = weight_norm(self.point_conv, name='weight')
self.depth_conv = weight_norm(self.depth_conv, name="weight")
self.point_conv = weight_norm(self.point_conv, name="weight")
def remove_weight_norm(self):
self.depth_conv = remove_weight_norm(self.depth_conv, name='weight')
self.point_conv = remove_weight_norm(self.point_conv, name='weight')
self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
self.point_conv = remove_weight_norm(self.point_conv, name="weight")
class Depthwise_Separable_TransposeConv1D(nn.Module):
@ -363,48 +496,79 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
output_padding=0,
bias=True,
dilation=1,
padding_mode='zeros', # TODO: refine this type
padding_mode="zeros", # TODO: refine this type
device=None,
dtype=None
dtype=None,
):
super().__init__()
self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
groups=in_channels, stride=stride, output_padding=output_padding,
padding=padding, dilation=dilation, bias=bias, padding_mode=padding_mode,
device=device, dtype=dtype)
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias,
device=device, dtype=dtype)
self.depth_conv = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
output_padding=output_padding,
padding=padding,
dilation=dilation,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
self.point_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input):
return self.point_conv(self.depth_conv(input))
def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name='weight')
self.point_conv = weight_norm(self.point_conv, name='weight')
self.depth_conv = weight_norm(self.depth_conv, name="weight")
self.point_conv = weight_norm(self.point_conv, name="weight")
def remove_weight_norm(self):
remove_weight_norm(self.depth_conv, name='weight')
remove_weight_norm(self.point_conv, name='weight')
remove_weight_norm(self.depth_conv, name="weight")
remove_weight_norm(self.point_conv, name="weight")
def weight_norm_modules(module, name='weight', dim=0):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
def weight_norm_modules(module, name="weight", dim=0):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
module.weight_norm()
return module
else:
return weight_norm(module, name, dim)
def remove_weight_norm_modules(module, name='weight'):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
def remove_weight_norm_modules(module, name="weight"):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
module.remove_weight_norm()
else:
remove_weight_norm(module, name)
class FFT(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
proximal_bias=False, proximal_init=True, isflow = False, **kwargs):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers=1,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
isflow=False,
**kwargs
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
@ -415,9 +579,11 @@ class FFT(nn.Module):
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name='weight')
cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
@ -426,14 +592,29 @@ class FFT(nn.Module):
self.norm_layers_1 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias,
proximal_init=proximal_init))
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g = None):
def forward(self, x, x_mask, g=None):
"""
x: decoder input
h: encoder output
@ -441,17 +622,18 @@ class FFT(nn.Module):
if g is not None:
g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
x = x * x_mask
for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(
x,
g_l,
torch.IntTensor([self.hidden_channels]))
x, g_l, torch.IntTensor([self.hidden_channels])
)
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
@ -463,9 +645,9 @@ class FFT(nn.Module):
return x
class TransformerCouplingLayer(nn.Module):
def __init__(self,
def __init__(
self,
channels,
hidden_channels,
kernel_size,
@ -475,7 +657,7 @@ class TransformerCouplingLayer(nn.Module):
filter_channels=0,
mean_only=False,
wn_sharing_parameter=None,
gin_channels = 0
gin_channels=0,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
@ -487,18 +669,31 @@ class TransformerCouplingLayer(nn.Module):
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
self.enc = (
Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
isflow=True,
gin_channels=gin_channels,
)
if wn_sharing_parameter is None
else wn_sharing_parameter
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
@ -506,7 +701,7 @@ class TransformerCouplingLayer(nn.Module):
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask

View File

@ -1,7 +1,5 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
@ -12,7 +10,7 @@ def init_weights(m, mean=0.0, std=0.01):
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
@ -30,7 +28,9 @@ def intersperse(lst, item):
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
@ -64,15 +64,15 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
return ret, ids_str
def get_timing_signal_1d(
length, channels, min_timescale=1.0, max_timescale=1.0e4):
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(num_timescales - 1))
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
@ -139,7 +139,7 @@ def generate_path(duration, mask):
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
@ -157,7 +157,7 @@ def clip_grad_value_(parameters, clip_value, norm_type=2):
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
@ -170,7 +170,7 @@ def squeeze(x, x_mask=None, n_sqz=2):
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
if x_mask is not None:
x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
else:
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
return x_sqz * x_mask, x_mask

View File

@ -76,10 +76,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
print("kmeans start ... ")
for _ in tqdm(range(num_iters)):
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs ** 2).sum(dim=-1)
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
@ -110,6 +108,7 @@ class EuclideanCodebook(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
@ -122,7 +121,9 @@ class EuclideanCodebook(nn.Module):
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
uniform_init if not kmeans_init else torch.zeros
)
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
@ -147,7 +148,7 @@ class EuclideanCodebook(nn.Module):
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
#broadcast_tensors(self.buffers())
# broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
@ -165,7 +166,7 @@ class EuclideanCodebook(nn.Module):
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
#broadcast_tensors(self.buffers())
# broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
@ -246,6 +247,7 @@ class VectorQuantization(nn.Module):
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
@ -256,22 +258,31 @@ class VectorQuantization(nn.Module):
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.,
commitment_weight: float = 1.0,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
self.project_in = (
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self._codebook = EuclideanCodebook(
dim=_codebook_dim,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size
@property
@ -316,13 +327,16 @@ class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
def forward(
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
quantized_out = 0.0
residual = x
@ -345,7 +359,9 @@ class ResidualVectorQuantization(nn.Module):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor:
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
@ -358,7 +374,7 @@ class ResidualVectorQuantization(nn.Module):
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor:
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[st + i]

View File

@ -1,6 +1,6 @@
import time,logging
import time, logging
import os
import random,traceback
import random, traceback
import numpy as np
import torch
import torch.utils.data
@ -16,9 +16,11 @@ import torch
import requests
from scipy.io import wavfile
from io import BytesIO
# from config import exp_dir
from my_utils import load_audio
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
@ -27,30 +29,31 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
def __init__(self, hparams, val=False):
exp_dir=hparams.exp_dir
self.path2="%s/2-name2text.txt"%exp_dir
self.path4="%s/4-cnhubert"%exp_dir
self.path5="%s/5-wav32k"%exp_dir
exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
self.path5 = "%s/5-wav32k" % exp_dir
assert os.path.exists(self.path2)
assert os.path.exists(self.path4)
assert os.path.exists(self.path5)
names4=set([name[:-3]for name in list(os.listdir(self.path4))])#去除.pt后缀
names5=set(os.listdir(self.path5))
self.phoneme_data={}
with open(self.path2,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5))
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp=line.split("\t")
if(len(tmp)!=4):continue
self.phoneme_data[tmp[0]]=[tmp[1]]
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text=list(set(self.phoneme_data)&names4&names5)
tmp=self.audiopaths_sid_text
leng=len(tmp)
min_num=100
if(leng<min_num):
self.audiopaths_sid_text=[]
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
self.max_wav_value = hparams.max_wav_value
@ -74,15 +77,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
skipped_phone += 1
continue
size=os.path.getsize("%s/%s"%(self.path5,audiopath))
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
duration = size / self.sampling_rate / 2
if (54 > duration > 0.6 or self.val):
if 54 > duration > 0.6 or self.val:
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
lengths.append(size // (2 * self.hop_length))
else:
@ -90,7 +93,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
continue
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
print("total left: ", len(audiopaths_sid_text_new))
assert len(audiopaths_sid_text_new)>1#至少能凑够batch size这里todo
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
@ -98,30 +101,41 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids)
try:
spec, wav = self.get_audio("%s/%s"%(self.path5,audiopath))
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt"%(self.path4,audiopath),map_location="cpu")
if(ssl.shape[-1]!=spec.shape[-1]):
typee=ssl.dtype
ssl=F.pad(ssl.float(),(0,1),mode="replicate").to(typee)
ssl.requires_grad=False
ssl = torch.load(
"%s/%s.pt" % (self.path4, audiopath), map_location="cpu"
)
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
except:
traceback.print_exc()
spec = torch.zeros(1025, 100)
wav = torch.zeros(1, 100*self.hop_length)
ssl=torch.zeros(1,768,100)
text=text[-1:]
wav = torch.zeros(1, 100 * self.hop_length)
ssl = torch.zeros(1, 768, 100)
text = text[-1:]
print("load audio or ssl error!!!!!!", audiopath)
# print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
return (ssl, spec, wav, text)
def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio_array = load_audio(
filename, self.sampling_rate
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
# print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
audio=torch.FloatTensor(audio_array)#/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,self.sampling_rate, self.hop_length, self.win_length,center=False)
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
return spec, audio_norm
@ -137,33 +151,45 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
return len(self.audiopaths_sid_text)
def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1]- wav.shape[-1]//self.hop_length) < 3, ("first", ssl.shape, wav.shape)
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first",
ssl.shape,
wav.shape,
)
len_mel = mel.shape[1]
if self.val:
reference_mel = mel[:, :len_mel//3]
reference_mel = mel[:, : len_mel // 3]
return reference_mel, ssl, wav, mel
dir = random.randint(0, 1)
sep_point = random.randint(int(len_mel//3), int(len_mel//3*2))
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
if dir == 0:
reference_mel = mel[:, :sep_point]
ssl = ssl[:, :, sep_point:]
wav2 = wav[:, sep_point*self.hop_length:]
wav2 = wav[:, sep_point * self.hop_length :]
mel = mel[:, sep_point:]
else:
reference_mel = mel[:, sep_point:]
ssl = ssl[:, :, :sep_point]
wav2 = wav[:, :sep_point*self.hop_length]
wav2 = wav[:, : sep_point * self.hop_length]
mel = mel[:, :sep_point]
assert abs(ssl.shape[-1]- wav2.shape[-1]//self.hop_length) < 3, (ssl.shape, wav.shape,wav2.shape, mel.shape, sep_point,self.hop_length, sep_point*self.hop_length, dir)
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape,
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
"""
class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
@ -176,8 +202,8 @@ class TextAudioSpeakerCollate():
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@ -205,23 +231,31 @@ class TextAudioSpeakerCollate():
row = batch[ids_sorted_decreasing[i]]
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
text = row[3]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
return (
ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
)
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
@ -234,7 +268,15 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
def __init__(
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths
# print(233333333333333,self.lengths,dir(dataset))
@ -263,7 +305,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
@ -289,14 +333,23 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample
ids_bucket = ids_bucket[self.rank::self.num_replicas]
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch)
if self.shuffle:

View File

@ -22,9 +22,9 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1-dr)**2)
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
@ -36,7 +36,7 @@ def generator_loss(disc_outputs):
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1-dg)**2)
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
@ -55,14 +55,19 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def mle_loss(z, m, logs, logdet, mask):
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term
l = torch.sum(logs) + 0.5 * torch.sum(
torch.exp(-2 * logs) * ((z - m) ** 2)
) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
l = l / torch.sum(
torch.ones_like(z) * mask
) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l

View File

@ -49,21 +49,37 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
@ -71,37 +87,63 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + '_' + str(spec.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device
wnsize_dtype_device = str(win_size) + '_' + dtype_device
dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

View File

@ -16,8 +16,17 @@ from module.quantize import ResidualVectorQuantizer
from text import symbols
from torch.cuda.amp import autocast
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
@ -31,21 +40,29 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -66,7 +83,10 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -74,8 +94,13 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
@ -84,12 +109,18 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
@ -98,7 +129,9 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
@ -108,9 +141,13 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -135,7 +172,8 @@ class DurationPredictor(nn.Module):
class TextEncoder(nn.Module):
def __init__(self,
def __init__(
self,
out_channels,
hidden_channels,
filter_channels,
@ -143,7 +181,8 @@ class TextEncoder(nn.Module):
n_layers,
kernel_size,
p_dropout,
latent_channels=192):
latent_channels=192,
):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
@ -160,17 +199,14 @@ class TextEncoder(nn.Module):
hidden_channels,
filter_channels,
n_heads,
n_layers//2,
n_layers // 2,
kernel_size,
p_dropout)
p_dropout,
)
self.encoder_text = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
@ -179,21 +215,25 @@ class TextEncoder(nn.Module):
hidden_channels,
filter_channels,
n_heads,
n_layers//2,
n_layers // 2,
kernel_size,
p_dropout)
p_dropout,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
if test == 1 :
text_mask = torch.unsqueeze(
commons.sequence_mask(text_lengths, text.size(1)), 1
).to(y.dtype)
if test == 1:
text[:, :] = 0
text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask)
@ -208,9 +248,9 @@ class TextEncoder(nn.Module):
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0,1)
def decode_latent(self, codes, y_mask, refer,refer_mask, ge):
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
@ -224,15 +264,18 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module):
def __init__(self,
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
gin_channels=0,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
@ -245,8 +288,16 @@ class ResidualCouplingBlock(nn.Module):
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
gin_channels=gin_channels, mean_only=True))
modules.ResidualCouplingLayer(
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
@ -260,14 +311,16 @@ class ResidualCouplingBlock(nn.Module):
class PosteriorEncoder(nn.Module):
def __init__(self,
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -278,13 +331,21 @@ class PosteriorEncoder(nn.Module):
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
if(g!=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
@ -294,14 +355,16 @@ class PosteriorEncoder(nn.Module):
class WNEncoder(nn.Module):
def __init__(self,
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -312,11 +375,20 @@ class WNEncoder(nn.Module):
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask
@ -325,24 +397,45 @@ class WNEncoder(nn.Module):
class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2)))
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -373,7 +466,7 @@ class Generator(torch.nn.Module):
return x
def remove_weight_norm(self):
print('Removing weight norm...')
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
@ -386,13 +479,55 @@ class DiscriminatorP(torch.nn.Module):
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
@ -421,14 +556,16 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
@ -451,7 +588,9 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
@ -469,31 +608,40 @@ class MultiPeriodDiscriminator(torch.nn.Module):
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class ReferenceEncoder(nn.Module):
'''
"""
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
'''
"""
def __init__(self, spec_channels, gin_channels=0):
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [weight_norm(nn.Conv2d(in_channels=filters[i],
convs = [
weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1))) for i in range(K)]
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(input_size=ref_enc_filters[-1] * out_channels,
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True)
batch_first=True,
)
self.proj = nn.Linear(128, gin_channels)
def forward(self, inputs):
@ -527,23 +675,31 @@ class Quantizer_module(torch.nn.Module):
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
def forward(self, x):
d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) - 2 * torch.matmul(x, self.embedding.weight.T)
d = (
torch.sum(x**2, 1, keepdim=True)
+ torch.sum(self.embedding.weight**2, 1)
- 2 * torch.matmul(x, self.embedding.weight.T)
)
min_indicies = torch.argmin(d, 1)
z_q = self.embedding(min_indicies)
return z_q, min_indicies
class Quantizer(torch.nn.Module):
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList([
Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)
])
self.quantizer_modules = nn.ModuleList(
[
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
)
self.n_code_groups = n_code_groups
self.embed_dim = embed_dim
def forward(self, xin):
#B, C, T
# B, C, T
B, C, T = xin.shape
xin = xin.transpose(1, 2)
x = xin.reshape(-1, self.embed_dim)
@ -553,28 +709,31 @@ class Quantizer(torch.nn.Module):
for _x, m in zip(x, self.quantizer_modules):
_z_q, _min_indicies = m(_x)
z_q.append(_z_q)
min_indicies.append(_min_indicies) #B * T,
min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
return z_q, loss, codes.transpose(1, 2)
def embed(self, x):
#idx: N, 4, T
x=x.transpose(1, 2)
# idx: N, 4, T
x = x.transpose(1, 2)
x = torch.split(x, 1, 2)
ret = []
for q, embed in zip(x, self.quantizer_modules):
q = embed.embedding(q.squeeze(-1))
ret.append(q)
ret = torch.cat(ret, -1)
return ret.transpose(1, 2) #N, C, T
return ret.transpose(1, 2) # N, C, T
class CodePredictor(nn.Module):
def __init__(self,
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
@ -583,7 +742,7 @@ class CodePredictor(nn.Module):
p_dropout,
n_q=8,
dims=1024,
ssl_dim=768
ssl_dim=768,
):
super().__init__()
self.hidden_channels = hidden_channels
@ -594,19 +753,18 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.out_proj = nn.Conv1d(hidden_channels, (n_q-1) * dims, 1)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q
self.dims = dims
def forward(self, x, x_mask, refer, codes, infer=False):
x = x.detach()
x = self.vq_proj(x * x_mask) * x_mask
@ -614,7 +772,9 @@ class CodePredictor(nn.Module):
x = x + g
x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
target = codes[1:].transpose(0, 1)
if not infer:
logits = logits.reshape(-1, self.dims)
@ -626,22 +786,22 @@ class CodePredictor(nn.Module):
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
print('Top-10 Accuracy:', top3_acc, "%")
print("Top-10 Accuracy:", top3_acc, "%")
pred_codes = torch.argmax(logits, dim=-1)
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
print('Top-1 Accuracy:', acc, "%")
print("Top-1 Accuracy:", acc, "%")
return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
@ -662,8 +822,8 @@ class SynthesizerTrn(nn.Module):
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs):
**kwargs
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@ -691,28 +851,44 @@ class SynthesizerTrn(nn.Module):
n_heads,
n_layers,
kernel_size,
p_dropout)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
p_dropout,
)
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
self.ref_enc = modules.MelStyleEncoder(
spec_channels, style_vector_dim=gin_channels
)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
if freeze_quantizer:
self.ssl_proj.requires_grad_(False)
self.quantizer.requires_grad_(False)
@ -721,56 +897,85 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.mrte.requires_grad_(False)
def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False):
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
z_slice, ids_slice = commons.rand_slice_segments(
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=ge)
return o, commit_loss, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
return (
o,
commit_loss,
ids_slice,
y_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
quantized,
)
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask)
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge, test=test
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o,y_mask, (z, z_p, m_p, logs_p)
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes,text, refer, noise_scale=0.5):
def decode(self, codes, text, refer, noise_scale=0.5):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask)
y_lengths = torch.LongTensor([codes.size(2)*2]).to(codes.device)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -781,4 +986,4 @@ class SynthesizerTrn(nn.Module):
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)
return codes.transpose(0, 1)

View File

@ -32,7 +32,15 @@ class LayerNorm(nn.Module):
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
@ -44,13 +52,22 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers-1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
@ -70,7 +87,8 @@ class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
@ -83,11 +101,18 @@ class DDSConv(nn.Module):
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size ** i
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
groups=channels, dilation=dilation, padding=padding
))
self.convs_sep.append(
nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
@ -108,11 +133,19 @@ class DDSConv(nn.Module):
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
self.hidden_channels =hidden_channels
self.kernel_size = kernel_size,
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
@ -123,15 +156,22 @@ class WN(torch.nn.Module):
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate ** i
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
in_layer = torch.nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
@ -141,7 +181,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
@ -155,21 +195,18 @@ class WN(torch.nn.Module):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(
x_in,
g_l,
n_channels_tensor)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:]
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:,self.hidden_channels:,:]
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
@ -186,24 +223,76 @@ class WN(torch.nn.Module):
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
@ -231,12 +320,30 @@ class ResBlock1(torch.nn.Module):
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
@ -280,14 +387,14 @@ class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels,1))
self.logs = nn.Parameter(torch.zeros(channels,1))
self.m = nn.Parameter(torch.zeros(channels, 1))
self.logs = nn.Parameter(torch.zeros(channels, 1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1,2])
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
@ -295,7 +402,8 @@ class ElementwiseAffine(nn.Module):
class ResidualCouplingLayer(nn.Module):
def __init__(self,
def __init__(
self,
channels,
hidden_channels,
kernel_size,
@ -303,7 +411,8 @@ class ResidualCouplingLayer(nn.Module):
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False):
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
@ -315,18 +424,25 @@ class ResidualCouplingLayer(nn.Module):
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=p_dropout,
gin_channels=gin_channels,
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
logs = torch.zeros_like(m)
@ -334,7 +450,7 @@ class ResidualCouplingLayer(nn.Module):
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
@ -343,7 +459,15 @@ class ResidualCouplingLayer(nn.Module):
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
n_layers,
num_bins=10,
tail_bound=5.0,
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
@ -354,13 +478,15 @@ class ConvFlow(nn.Module):
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
@ -368,30 +494,33 @@ class ConvFlow(nn.Module):
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins:]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails='linear',
tail_bound=self.tail_bound
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1,2])
logdet = torch.sum(logabsdet * x_mask, [1, 2])
if not reverse:
return x, logdet
else:
return x
class LinearNorm(nn.Module):
def __init__(self,
def __init__(
self,
in_channels,
out_channels,
bias=True,
@ -417,10 +546,10 @@ class Mish(nn.Module):
class Conv1dGLU(nn.Module):
'''
"""
Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
'''
"""
def __init__(self, in_channels, out_channels, kernel_size, dropout):
super(Conv1dGLU, self).__init__()
@ -438,7 +567,8 @@ class Conv1dGLU(nn.Module):
class ConvNorm(nn.Module):
def __init__(self,
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
@ -451,16 +581,18 @@ class ConvNorm(nn.Module):
super(ConvNorm, self).__init__()
if padding is None:
assert (kernel_size % 2 == 1)
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels,
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
bias=bias,
)
if spectral_norm:
self.conv = nn.utils.spectral_norm(self.conv)
@ -471,9 +603,9 @@ class ConvNorm(nn.Module):
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
"""Multi-Head Attention module"""
def __init__(self, n_head, d_model, d_k, d_v, dropout=0., spectral_norm=False):
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
super().__init__()
self.n_head = n_head
@ -484,7 +616,9 @@ class MultiHeadAttention(nn.Module):
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_model, 0.5), dropout=dropout
)
self.fc = nn.Linear(n_head * d_v, d_model)
self.dropout = nn.Dropout(dropout)
@ -504,12 +638,9 @@ class MultiHeadAttention(nn.Module):
q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1,
len_x, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1,
len_x, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1,
len_x, d_v) # (n*b) x lv x dv
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
if mask is not None:
slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
@ -518,8 +649,9 @@ class MultiHeadAttention(nn.Module):
output, attn = self.attention(q, k, v, mask=slf_mask)
output = output.view(n_head, sz_b, len_x, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(
sz_b, len_x, -1) # b x lq x (n*dv)
output = (
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
) # b x lq x (n*dv)
output = self.fc(output)
@ -528,7 +660,7 @@ class MultiHeadAttention(nn.Module):
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
"""Scaled Dot-Product Attention"""
def __init__(self, temperature, dropout):
super().__init__()
@ -551,14 +683,17 @@ class ScaledDotProductAttention(nn.Module):
class MelStyleEncoder(nn.Module):
''' MelStyleEncoder '''
"""MelStyleEncoder"""
def __init__(self, n_mel_channels=80,
def __init__(
self,
n_mel_channels=80,
style_hidden=128,
style_vector_dim=256,
style_kernel_size=5,
style_head=2,
dropout=0.1):
dropout=0.1,
):
super(MelStyleEncoder, self).__init__()
self.in_dim = n_mel_channels
self.hidden_dim = style_hidden
@ -573,7 +708,7 @@ class MelStyleEncoder(nn.Module):
nn.Dropout(self.dropout),
LinearNorm(self.hidden_dim, self.hidden_dim),
Mish(),
nn.Dropout(self.dropout)
nn.Dropout(self.dropout),
)
self.temporal = nn.Sequential(
@ -581,9 +716,13 @@ class MelStyleEncoder(nn.Module):
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
)
self.slf_attn = MultiHeadAttention(self.n_head, self.hidden_dim,
self.hidden_dim // self.n_head, self.hidden_dim // self.n_head,
self.dropout)
self.slf_attn = MultiHeadAttention(
self.n_head,
self.hidden_dim,
self.hidden_dim // self.n_head,
self.hidden_dim // self.n_head,
self.dropout,
)
self.fc = LinearNorm(self.hidden_dim, self.out_dim)
@ -598,11 +737,13 @@ class MelStyleEncoder(nn.Module):
return out
def forward(self, x, mask=None):
x = x.transpose(1,2)
x = x.transpose(1, 2)
if mask is not None:
mask = (mask.int()==0).squeeze(1)
mask = (mask.int() == 0).squeeze(1)
max_len = x.shape[1]
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
slf_attn_mask = (
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
)
# spectral
x = self.spectral(x)
@ -644,7 +785,9 @@ class MelStyleEncoderVAE(nn.Module):
mu = self.fc1(enc_out)
logvar = self.fc2(enc_out)
posterior = D.Normal(mu, torch.exp(logvar))
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
kl_divergence = D.kl_divergence(
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
)
loss_kl = kl_divergence.mean()
z = posterior.rsample()
@ -656,11 +799,12 @@ class MelStyleEncoderVAE(nn.Module):
if manual_latent is None:
if random_sample:
dev = next(self.parameters()).device
posterior = D.Normal(torch.zeros(1, self.z_latent_dim, device=dev),
torch.ones(1, self.z_latent_dim, device=dev))
posterior = D.Normal(
torch.zeros(1, self.z_latent_dim, device=dev),
torch.ones(1, self.z_latent_dim, device=dev),
)
z = posterior.rsample()
else:
enc_out = self.ref_encoder(inputs.transpose(1, 2))
mu = self.fc1(enc_out)
z = mu
@ -681,7 +825,9 @@ class ActNorm(nn.Module):
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
device=x.device, dtype=x.dtype
)
x_len = torch.sum(x_mask, [1, 2])
if not self.initialized:
self.initialize(x, x_mask)
@ -707,10 +853,12 @@ class ActNorm(nn.Module):
denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m ** 2)
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
bias_init = (
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
)
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
self.bias.data.copy_(bias_init)
@ -720,19 +868,21 @@ class ActNorm(nn.Module):
class InvConvNear(nn.Module):
def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
super().__init__()
assert (n_split % 2 == 0)
assert n_split % 2 == 0
self.channels = channels
self.n_split = n_split
self.no_jacobian = no_jacobian
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
w_init = torch.linalg.qr(
torch.FloatTensor(self.n_split, self.n_split).normal_()
)[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init)
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
b, c, t = x.size()
assert (c % self.n_split == 0)
assert c % self.n_split == 0
if x_mask is None:
x_mask = 1
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
@ -740,7 +890,11 @@ class InvConvNear(nn.Module):
x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
x = (
x.permute(0, 1, 3, 2, 4)
.contiguous()
.view(b, self.n_split, c // self.n_split, t)
)
if reverse:
if hasattr(self, "weight_inv"):

View File

@ -5,46 +5,74 @@ from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention
class MRTE(nn.Module):
def __init__(self,
def __init__(
self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer = 2
ge_layer=2,
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size,hidden_size,n_heads)
self.c_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size,out_channels, 1)
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
if(ge==None):ge=0
if ge == None:
ge = 0
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
if test != None:
if test == 0:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
elif test == 1:
x = ssl_enc + ge
elif test ==2:
x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
elif test == 2:
x = (
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
else:
raise ValueError("test should be 0,1,2")
else:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask)
return x
class SpeakerEncoder(torch.nn.Module):
def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256):
def __init__(
self,
mel_n_channels=80,
model_num_layers=2,
model_hidden_size=256,
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.lstm = nn.LSTM(
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
@ -56,13 +84,15 @@ class SpeakerEncoder(torch.nn.Module):
class MELEncoder(nn.Module):
def __init__(self,
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers):
n_layers,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -86,8 +116,8 @@ class MELEncoder(nn.Module):
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
self.hidden_channels =hidden_channels
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
@ -96,10 +126,15 @@ class WN(torch.nn.Module):
self.res_skip_layers = torch.nn.ModuleList()
for i in range(n_layers):
dilation = dilation_rate ** i
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer)
@ -110,7 +145,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name='weight')
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x):
@ -120,15 +155,13 @@ class WN(torch.nn.Module):
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
acts = fused_add_tanh_sigmoid_multiply(
x_in,
n_channels_tensor)
acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:]
x = (x + res_acts)
output = output + res_skip_acts[:,self.hidden_channels:,:]
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = x + res_acts
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output
@ -149,12 +182,11 @@ def fused_add_tanh_sigmoid_multiply(input, n_channels):
return acts
if __name__ == '__main__':
content_enc = torch.randn(3,192,100)
content_mask = torch.ones(3,1,100)
ref_mel = torch.randn(3,128,30)
ref_mask = torch.ones(3,1,30)
if __name__ == "__main__":
content_enc = torch.randn(3, 192, 100)
content_mask = torch.ones(3, 1, 100)
ref_mel = torch.randn(3, 128, 30)
ref_mask = torch.ones(3, 1, 30)
model = MRTE()
out = model(content_enc,content_mask,ref_mel,ref_mask)
out = model(content_enc, content_mask, ref_mel, ref_mask)
print(out.shape)

View File

@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module):
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult:
def forward(
self,
x: torch.Tensor,
n_q: tp.Optional[int] = None,
layers: tp.Optional[list] = None,
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module):
"""
n_q = n_q if n_q else self.n_q
if layers and max(layers) >= n_q:
raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.')
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.

View File

@ -9,26 +9,24 @@ DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(inputs,
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {
'tails': tails,
'tail_bound': tail_bound
}
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
@ -46,29 +44,28 @@ def piecewise_rational_quadratic_transform(inputs,
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(
inputs[..., None] >= bin_locations,
dim=-1
) - 1
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(inputs,
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails='linear',
tail_bound=1.,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == 'linear':
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
@ -77,45 +74,57 @@ def unconstrained_rational_quadratic_spline(inputs,
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError('{} tails are not implemented.'.format(tails))
raise RuntimeError("{} tails are not implemented.".format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(inputs,
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0., right=1., bottom=0., top=1.,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError('Input to a transform is not within its domain')
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError('Minimal bin width too large for the number of bins')
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError('Minimal bin height too large for the number of bins')
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
@ -126,7 +135,7 @@ def rational_quadratic_spline(inputs,
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
@ -150,15 +159,13 @@ def rational_quadratic_spline(inputs,
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (((inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta)
+ input_heights * (input_delta - input_derivatives)))
b = (input_heights * input_derivatives
- (inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta))
c = - input_delta * (inputs - input_cumheights)
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
@ -167,11 +174,15 @@ def rational_quadratic_spline(inputs,
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2))
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
@ -179,15 +190,20 @@ def rational_quadratic_spline(inputs,
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2)
+ input_derivatives * theta_one_minus_theta)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2))
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

View File

@ -1,50 +1,81 @@
import os,torch,sys
import os, torch, sys
from subprocess import Popen
now_dir = os.getcwd()
sys.path.append(now_dir)
from config import text_path,wav_dir,n_card,n_process_per_card,exp_name,n_parts,exp_dir
os.makedirs("%s/logs_s1"%exp_dir,exist_ok=True)
os.makedirs("%s/logs_s2"%exp_dir,exist_ok=True)
from config import (
text_path,
wav_dir,
n_card,
exp_name,
n_parts,
exp_dir,
)
os.makedirs("%s/logs_s1" % exp_dir, exist_ok=True)
os.makedirs("%s/logs_s2" % exp_dir, exist_ok=True)
##############step1
ps=[]
ps = []
for i_part in range(n_parts):
cmd="python prepare/1-get-text.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card)
cmd = "python prepare/1-get-text.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd)
p = Popen(cmd, shell=True)
ps.append(p)
for p in ps:
p.wait()
opt=[]
opt = []
for i_part in range(n_parts):
txt_path = "%s/2-name2text-%s.txt" % (exp_dir, i_part)
with open(txt_path,"r")as f:
opt+=f.read().strip("\n").split("\n")
with open(txt_path, "r") as f:
opt += f.read().strip("\n").split("\n")
os.remove(txt_path)
with open("%s/2-name2text.txt"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n")
with open("%s/2-name2text.txt" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")
############step2
ps=[]
ps = []
for i_part in range(n_parts):
cmd="python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card)
cmd = "python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd)
p = Popen(cmd, shell=True)
ps.append(p)
for p in ps:
p.wait()
#############step3
ps=[]
ps = []
for i_part in range(n_parts):
cmd="python prepare/3-get-semantic.py %s %s %s %s %s"%(text_path,exp_name,i_part,n_parts,i_part%n_card)
cmd = "python prepare/3-get-semantic.py %s %s %s %s %s" % (
text_path,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd)
p = Popen(cmd, shell=True)
ps.append(p)
for p in ps:
p.wait()
opt=["item_name semantic_audio"]
opt = ["item_name semantic_audio"]
for i_part in range(n_parts):
semantic_path = "%s/6-name2semantic-%s.tsv" % (exp_dir, i_part)
with open(semantic_path,"r")as f:
opt+=f.read().strip("\n").split("\n")
with open(semantic_path, "r") as f:
opt += f.read().strip("\n").split("\n")
os.remove(semantic_path)
with open("%s/6-name2semantic.tsv"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n")
with open("%s/6-name2semantic.tsv" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")

View File

@ -2,16 +2,16 @@
import os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir= os.environ.get("opt_dir")
bert_pretrained_dir= os.environ.get("bert_pretrained_dir")
is_half=eval(os.environ.get("is_half","True"))
import sys,numpy as np,traceback,pdb
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
is_half = eval(os.environ.get("is_half", "True"))
import sys, numpy as np, traceback, pdb
import os.path
from glob import glob
from tqdm import tqdm
@ -31,25 +31,29 @@ import numpy as np
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
if(os.path.exists(txt_path)==False):
bert_dir="%s/3-bert"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(bert_dir,exist_ok=True)
device="cuda:0"
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(bert_dir, exist_ok=True)
device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model=AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if (is_half == True):
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
@ -67,51 +71,55 @@ if(os.path.exists(txt_path)==False):
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def process(data,res):
for name,text,lan in data:
def process(data, res):
for name, text, lan in data:
try:
name=os.path.basename(name)
phones, word2ph, norm_text=clean_text(text.replace("%", '-').replace('', ','),lan)
path_bert="%s/%s.pt"%(bert_dir,name)
if (os.path.exists(path_bert) == False and lan == "zh"):
name = os.path.basename(name)
phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan
)
path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones)
# torch.save(bert_feature, path_bert)
my_save(bert_feature, path_bert)
phones = " ".join(phones)
# res.append([name,phones])
res.append([name,phones, word2ph, norm_text])
res.append([name, phones, word2ph, norm_text])
except:
print(name, text, traceback.format_exc())
todo=[]
res=[]
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
todo = []
res = []
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
language_v1_to_language_v2={
"ZH":"zh",
"zh":"zh",
"JP":"ja",
"jp":"ja",
"JA":"ja",
"ja":"ja",
"EN":"en",
"en":"en",
"En":"en",
language_v1_to_language_v2 = {
"ZH": "zh",
"zh": "zh",
"JP": "ja",
"jp": "ja",
"JA": "ja",
"ja": "ja",
"EN": "en",
"en": "en",
"En": "en",
}
for line in lines[int(i_part)::int(all_parts)]:
for line in lines[int(i_part) :: int(all_parts)]:
try:
wav_name,spk_name,language,text=line.split("|")
wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"])
todo.append([wav_name,text,language_v1_to_language_v2.get(language,language)])
todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
except:
print(line,traceback.format_exc())
process(todo,res)
opt=[]
for name,phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s"%(name,phones, word2ph, norm_text))
with open(txt_path,"w",encoding="utf8")as f:
f.write("\n".join(opt)+"\n")
print(line, traceback.format_exc())
process(todo, res)
opt = []
for name, phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
with open(txt_path, "w", encoding="utf8") as f:
f.write("\n".join(opt) + "\n")

View File

@ -1,20 +1,23 @@
# -*- coding: utf-8 -*-
import sys,os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
from feature_extractor import cnhubert
opt_dir= os.environ.get("opt_dir")
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
is_half=eval(os.environ.get("is_half","True"))
import sys, os
import pdb,traceback,numpy as np,logging
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
is_half = eval(os.environ.get("is_half", "True"))
import pdb, traceback, numpy as np, logging
from scipy.io import wavfile
import librosa,torch
import librosa, torch
now_dir = os.getcwd()
sys.path.append(now_dir)
from my_utils import load_audio
@ -32,63 +35,75 @@ from my_utils import load_audio
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
hubert_dir="%s/4-cnhubert"%(opt_dir)
wav32dir="%s/5-wav32k"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(hubert_dir,exist_ok=True)
os.makedirs(wav32dir,exist_ok=True)
maxx=0.95
alpha=0.5
device="cuda:0"
model=cnhubert.get_model()
if(is_half==True):
model=model.half().to(device)
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
device = "cuda:0"
model = cnhubert.get_model()
if is_half == True:
model = model.half().to(device)
else:
model = model.to(device)
def name2go(wav_name):
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
if(os.path.exists(hubert_path)):return
wav_path="%s/%s"%(inp_wav_dir,wav_name)
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if os.path.exists(hubert_path):
return
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2:
print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
tmp_audio = librosa.resample(
tmp_audio32, orig_sr=32000, target_sr=16000
)
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + (
(1 - alpha) * 32768
) * tmp_audio
tmp_audio = librosa.resample(tmp_audio32, orig_sr=32000, target_sr=16000)
tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True):
tensor_wav16=tensor_wav16.half().to(device)
if is_half == True:
tensor_wav16 = tensor_wav16.half().to(device)
else:
tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:return
ssl = (
model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"]
.transpose(1, 2)
.cpu()
) # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
return
wavfile.write(
"%s/%s"%(wav32dir,wav_name),
"%s/%s" % (wav32dir, wav_name),
32000,
tmp_audio32.astype("int16"),
)
# torch.save(ssl,hubert_path )
my_save(ssl,hubert_path )
my_save(ssl, hubert_path)
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]:
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=os.path.basename(wav_name)
wav_name = os.path.basename(wav_name)
name2go(wav_name)
except:
print(line,traceback.format_exc())
print(line, traceback.format_exc())

View File

@ -1,24 +1,27 @@
import os
inp_text= os.environ.get("inp_text")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir= os.environ.get("opt_dir")
pretrained_s2G= os.environ.get("pretrained_s2G")
s2config_path= os.environ.get("s2config_path")
is_half=eval(os.environ.get("is_half","True"))
import math,traceback
inp_text = os.environ.get("inp_text")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path")
is_half = eval(os.environ.get("is_half", "True"))
import math, traceback
import multiprocessing
import sys,pdb
import sys, pdb
now_dir = os.getcwd()
sys.path.append(now_dir)
from random import shuffle
import torch.multiprocessing as mp
from glob import glob
from tqdm import tqdm
import logging,librosa,utils,torch
import logging, librosa, utils, torch
from module.models import SynthesizerTrn
logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G
@ -30,52 +33,58 @@ logging.getLogger("numba").setLevel(logging.WARNING)
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
hubert_dir="%s/4-cnhubert"%(opt_dir)
semantic_path="%s/6-name2semantic-%s.tsv"%(opt_dir,i_part)
if(os.path.exists(semantic_path)==False):
os.makedirs(opt_dir,exist_ok=True)
hubert_dir = "%s/4-cnhubert" % (opt_dir)
semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
if os.path.exists(semantic_path) == False:
os.makedirs(opt_dir, exist_ok=True)
device="cuda:0"
device = "cuda:0"
hps = utils.get_hparams_from_file(s2config_path)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
if(is_half==True):
vq_model=vq_model.half().to(device)
**hps.model
)
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
# utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
print(vq_model.load_state_dict(torch.load(pretrained_s2G,map_location="cpu")["weight"], strict=False))
print(
vq_model.load_state_dict(
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
)
)
def name2go(wav_name,lines):
def name2go(wav_name, lines):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if(os.path.exists(hubert_path)==False):return
if os.path.exists(hubert_path) == False:
return
ssl_content = torch.load(hubert_path, map_location="cpu")
if(is_half==True):
ssl_content=ssl_content.half().to(device)
if is_half == True:
ssl_content = ssl_content.half().to(device)
else:
ssl_content = ssl_content.to(device)
codes = vq_model.extract_latent(ssl_content)
semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
lines.append("%s\t%s"%(wav_name,semantic))
lines.append("%s\t%s" % (wav_name, semantic))
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
lines1=[]
for line in lines[int(i_part)::int(all_parts)]:
lines1 = []
for line in lines[int(i_part) :: int(all_parts)]:
# print(line)
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=os.path.basename(wav_name)
wav_name = os.path.basename(wav_name)
# name2go(name,lines1)
name2go(wav_name,lines1)
name2go(wav_name, lines1)
except:
print(line,traceback.format_exc())
with open(semantic_path,"w",encoding="utf8")as f:f.write("\n".join(lines1))
print(line, traceback.format_exc())
with open(semantic_path, "w", encoding="utf8") as f:
f.write("\n".join(lines1))

View File

@ -6,49 +6,56 @@ import cn2an
from pypinyin import lazy_pinyin, Style
import sys
sys.path.append("/data/docker/liujing04/gpt-vits/gpt-vits-master")
from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi
current_file_path = os.path.dirname(__file__)
pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in
open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()}
pinyin_to_symbol_map = {
line.split("\t")[0]: line.strip().split("\t")[1]
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba.posseg as psg
rep_map = {
'': ',',
'': ',',
'': ',',
'': '.',
'': '!',
'': '?',
'\n': '.',
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
'': ",",
'...': '',
'$': '.',
'/': ',',
'': "-"
"": ",",
"...": "",
"$": ".",
"/": ",",
"": "-",
}
tone_modifier = ToneSandhi()
def replace_punctuation(text):
text = text.replace("", "").replace("","")
pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys()))
text = text.replace("", "").replace("", "")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def g2p(text):
pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip()!='']
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
phones, word2ph = _g2p(sentences)
return phones, word2ph
@ -56,10 +63,10 @@ def g2p(text):
def _get_initials_finals(word):
initials = []
finals = []
orig_initials = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals):
initials.append(c)
finals.append(v)
@ -72,17 +79,16 @@ def _g2p(segments):
for seg in segments:
pinyins = []
# Replace all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg)
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
initials = []
finals = []
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
for word, pos in seg_cut:
if pos == 'eng':
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos,
sub_finals)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
initials.append(sub_initials)
finals.append(sub_finals)
@ -91,7 +97,7 @@ def _g2p(segments):
finals = sum(finals, [])
#
for c, v in zip(initials, finals):
raw_pinyin = c+v
raw_pinyin = c + v
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c == v:
@ -102,40 +108,40 @@ def _g2p(segments):
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c+v_without_tone
assert tone in '12345'
pinyin = c + v_without_tone
assert tone in "12345"
if c:
# 多音节
v_rep_map = {
"uei": 'ui',
'iou': 'iu',
'uen': 'un',
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c+v_rep_map[v_without_tone]
pinyin = c + v_rep_map[v_without_tone]
else:
# 单音节
pinyin_rep_map = {
'ing': 'ying',
'i': 'yi',
'in': 'yin',
'u': 'wu',
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
'v': 'yu',
'e': 'e',
'i': 'y',
'u': 'w',
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]]+pinyin[1:]
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(' ')
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
@ -144,9 +150,8 @@ def _g2p(segments):
return phones_list, word2ph
def text_normalize(text):
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
numbers = re.findall(r"\d+(?:\.?\d+)?", text)
for number in numbers:
text = text.replace(number, cn2an.an2cn(number), 1)
text = replace_punctuation(text)
@ -154,7 +159,7 @@ def text_normalize(text):
return text
if __name__ == '__main__':
if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?"
text = "你好"

View File

@ -1,29 +1,27 @@
from text import chinese, japanese, cleaned_text_to_sequence, symbols, english
language_module_map = {
'zh': chinese,
"ja": japanese,
'en': english
}
language_module_map = {"zh": chinese, "ja": japanese, "en": english}
special = [
('%', 'zh', "SP"),
('', 'zh', "SP2"),
('^', 'zh', "SP3"),
("%", "zh", "SP"),
("", "zh", "SP2"),
("^", "zh", "SP3"),
# ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
]
def clean_text(text, language):
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol)
language_module = language_module_map[language]
norm_text = language_module.text_normalize(text)
if(language=="zh"):
if language == "zh":
phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph)
else:
phones = language_module.g2p(norm_text)
word2ph=None
word2ph = None
for ph in phones:
assert ph in symbols
@ -41,17 +39,17 @@ def clean_special(text, language, special_s, target_symbol):
new_ph = []
for ph in phones:
assert ph in symbols
if ph == ',':
if ph == ",":
new_ph.append(target_symbol)
else:
new_ph.append(ph)
return new_ph
def text_to_sequence(text, language):
phones = clean_text(text)
return cleaned_text_to_sequence(phones)
if __name__ == '__main__':
print(clean_text("你好%啊啊啊额、还是到付红四方。", 'zh'))
if __name__ == "__main__":
print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh"))

View File

@ -8,20 +8,87 @@ from string import punctuation
from text import symbols
current_file_path = os.path.dirname(__file__)
CMU_DICT_PATH = os.path.join(current_file_path, 'cmudict.rep')
CACHE_PATH = os.path.join(current_file_path, 'cmudict_cache.pickle')
CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
_g2p = G2p()
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'}
arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
def replace_phs(phs):
rep_map = {
';': ',',
':': ',',
'\'': '-',
'"': '-'
}
rep_map = {";": ",", ":": ",", "'": "-", '"': "-"}
phs_new = []
for ph in phs:
if ph in symbols:
@ -29,9 +96,10 @@ def replace_phs(phs):
elif ph in rep_map.keys():
phs_new.append(rep_map[ph])
else:
print('ph not in symbols: ', ph)
print("ph not in symbols: ", ph)
return phs_new
def read_dict():
g2p_dict = {}
start_line = 49
@ -41,13 +109,13 @@ def read_dict():
while line:
if line_index >= start_line:
line = line.strip()
word_split = line.split(' ')
word_split = line.split(" ")
word = word_split[0]
syllable_split = word_split[1].split(' - ')
syllable_split = word_split[1].split(" - ")
g2p_dict[word] = []
for syllable in syllable_split:
phone_split = syllable.split(' ')
phone_split = syllable.split(" ")
g2p_dict[word].append(phone_split)
line_index = line_index + 1
@ -57,13 +125,13 @@ def read_dict():
def cache_dict(g2p_dict, file_path):
with open(file_path, 'wb') as pickle_file:
with open(file_path, "wb") as pickle_file:
pickle.dump(g2p_dict, pickle_file)
def get_dict():
if os.path.exists(CACHE_PATH):
with open(CACHE_PATH, 'rb') as pickle_file:
with open(CACHE_PATH, "rb") as pickle_file:
g2p_dict = pickle.load(pickle_file)
else:
g2p_dict = read_dict()
@ -71,6 +139,7 @@ def get_dict():
return g2p_dict
eng_dict = get_dict()
@ -78,8 +147,8 @@ def text_normalize(text):
# todo: eng text normalize
return text.replace(";", ",")
def g2p(text):
def g2p(text):
phones = []
words = re.split(r"([,;.\-\?\!\s+])", text)
for w in words:
@ -97,6 +166,7 @@ def g2p(text):
return replace_phs(phones)
if __name__ == "__main__":
# print(get_dict())
print(g2p("hello"))

View File

@ -8,57 +8,63 @@ from text import symbols
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile(
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# List of (symbol, Japanese) pairs for marks:
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
('', 'パーセント')
]]
_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("", "パーセント")]]
# List of (consonant, sokuon) pairs:
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
(r'Q([↑↓]*[kg])', r'k#\1'),
(r'Q([↑↓]*[tdjʧ])', r't#\1'),
(r'Q([↑↓]*[sʃ])', r's\1'),
(r'Q([↑↓]*[pb])', r'p#\1')
]]
_real_sokuon = [
(re.compile("%s" % x[0]), x[1])
for x in [
(r"Q([↑↓]*[kg])", r"k#\1"),
(r"Q([↑↓]*[tdjʧ])", r"t#\1"),
(r"Q([↑↓]*[sʃ])", r"s\1"),
(r"Q([↑↓]*[pb])", r"p#\1"),
]
]
# List of (consonant, hatsuon) pairs:
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
(r'N([↑↓]*[pbm])', r'm\1'),
(r'N([↑↓]*[ʧʥj])', r'n^\1'),
(r'N([↑↓]*[tdn])', r'n\1'),
(r'N([↑↓]*[kg])', r'ŋ\1')
]]
_real_hatsuon = [
(re.compile("%s" % x[0]), x[1])
for x in [
(r"N([↑↓]*[pbm])", r"m\1"),
(r"N([↑↓]*[ʧʥj])", r"n^\1"),
(r"N([↑↓]*[tdn])", r"n\1"),
(r"N([↑↓]*[kg])", r"ŋ\1"),
]
]
def post_replace_ph(ph):
rep_map = {
'': ',',
'': ',',
'': ',',
'': '.',
'': '!',
'': '?',
'\n': '.',
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
'': ",",
'...': ''
"": ",",
"...": "",
}
if ph in rep_map.keys():
ph = rep_map[ph]
if ph in symbols:
return ph
if ph not in symbols:
ph = 'UNK'
ph = "UNK"
return ph
def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text)
@ -66,7 +72,7 @@ def symbols_to_japanese(text):
def preprocess_jap(text):
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
"""Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
text = symbols_to_japanese(text)
sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text)
@ -77,13 +83,15 @@ def preprocess_jap(text):
text += p.split(" ")
if i < len(marks):
text += [marks[i].replace(' ', '')]
text += [marks[i].replace(" ", "")]
return text
def text_normalize(text):
# todo: jap text normalize
return text
def g2p(norm_text):
phones = preprocess_jap(norm_text)
phones = [post_replace_ph(i) for i in phones]
@ -91,7 +99,7 @@ def g2p(norm_text):
return phones
if __name__ == '__main__':
if __name__ == "__main__":
for line in open("../../../Downloads/transcript_utf8.txt").readlines():
text = line.split(":")[1]
phones = g2p(text)

View File

@ -1,24 +1,397 @@
import os
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ['!', '?', '', ",", "."]#@是SP停顿
punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-")
pu_symbols = punctuation + ["SP", 'SP2', 'SP3', "UNK"]
pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
pad = '_'
pad = "_"
c = ['AA', 'EE', 'OO', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'w', 'x', 'y', 'z', 'zh']
v = ['E1', 'En1', 'a1', 'ai1', 'an1', 'ang1', 'ao1', 'e1', 'ei1', 'en1', 'eng1', 'er1', 'i1', 'i01', 'ia1', 'ian1', 'iang1', 'iao1', 'ie1', 'in1', 'ing1', 'iong1', 'ir1', 'iu1', 'o1', 'ong1', 'ou1', 'u1', 'ua1', 'uai1', 'uan1', 'uang1', 'ui1', 'un1', 'uo1', 'v1', 'van1', 've1', 'vn1', 'E2', 'En2', 'a2', 'ai2', 'an2', 'ang2', 'ao2', 'e2', 'ei2', 'en2', 'eng2', 'er2', 'i2', 'i02', 'ia2', 'ian2', 'iang2', 'iao2', 'ie2', 'in2', 'ing2', 'iong2', 'ir2', 'iu2', 'o2', 'ong2', 'ou2', 'u2', 'ua2', 'uai2', 'uan2', 'uang2', 'ui2', 'un2', 'uo2', 'v2', 'van2', 've2', 'vn2', 'E3', 'En3', 'a3', 'ai3', 'an3', 'ang3', 'ao3', 'e3', 'ei3', 'en3', 'eng3', 'er3', 'i3', 'i03', 'ia3', 'ian3', 'iang3', 'iao3', 'ie3', 'in3', 'ing3', 'iong3', 'ir3', 'iu3', 'o3', 'ong3', 'ou3', 'u3', 'ua3', 'uai3', 'uan3', 'uang3', 'ui3', 'un3', 'uo3', 'v3', 'van3', 've3', 'vn3', 'E4', 'En4', 'a4', 'ai4', 'an4', 'ang4', 'ao4', 'e4', 'ei4', 'en4', 'eng4', 'er4', 'i4', 'i04', 'ia4', 'ian4', 'iang4', 'iao4', 'ie4', 'in4', 'ing4', 'iong4', 'ir4', 'iu4', 'o4', 'ong4', 'ou4', 'u4', 'ua4', 'uai4', 'uan4', 'uang4', 'ui4', 'un4', 'uo4', 'v4', 'van4', 've4', 'vn4', 'E5', 'En5', 'a5', 'ai5', 'an5', 'ang5', 'ao5', 'e5', 'ei5', 'en5', 'eng5', 'er5', 'i5', 'i05', 'ia5', 'ian5', 'iang5', 'iao5', 'ie5', 'in5', 'ing5', 'iong5', 'ir5', 'iu5', 'o5', 'ong5', 'ou5', 'u5', 'ua5', 'uai5', 'uan5', 'uang5', 'ui5', 'un5', 'uo5', 'v5', 'van5', 've5', 'vn5']
c = [
"AA",
"EE",
"OO",
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"w",
"x",
"y",
"z",
"zh",
]
v = [
"E1",
"En1",
"a1",
"ai1",
"an1",
"ang1",
"ao1",
"e1",
"ei1",
"en1",
"eng1",
"er1",
"i1",
"i01",
"ia1",
"ian1",
"iang1",
"iao1",
"ie1",
"in1",
"ing1",
"iong1",
"ir1",
"iu1",
"o1",
"ong1",
"ou1",
"u1",
"ua1",
"uai1",
"uan1",
"uang1",
"ui1",
"un1",
"uo1",
"v1",
"van1",
"ve1",
"vn1",
"E2",
"En2",
"a2",
"ai2",
"an2",
"ang2",
"ao2",
"e2",
"ei2",
"en2",
"eng2",
"er2",
"i2",
"i02",
"ia2",
"ian2",
"iang2",
"iao2",
"ie2",
"in2",
"ing2",
"iong2",
"ir2",
"iu2",
"o2",
"ong2",
"ou2",
"u2",
"ua2",
"uai2",
"uan2",
"uang2",
"ui2",
"un2",
"uo2",
"v2",
"van2",
"ve2",
"vn2",
"E3",
"En3",
"a3",
"ai3",
"an3",
"ang3",
"ao3",
"e3",
"ei3",
"en3",
"eng3",
"er3",
"i3",
"i03",
"ia3",
"ian3",
"iang3",
"iao3",
"ie3",
"in3",
"ing3",
"iong3",
"ir3",
"iu3",
"o3",
"ong3",
"ou3",
"u3",
"ua3",
"uai3",
"uan3",
"uang3",
"ui3",
"un3",
"uo3",
"v3",
"van3",
"ve3",
"vn3",
"E4",
"En4",
"a4",
"ai4",
"an4",
"ang4",
"ao4",
"e4",
"ei4",
"en4",
"eng4",
"er4",
"i4",
"i04",
"ia4",
"ian4",
"iang4",
"iao4",
"ie4",
"in4",
"ing4",
"iong4",
"ir4",
"iu4",
"o4",
"ong4",
"ou4",
"u4",
"ua4",
"uai4",
"uan4",
"uang4",
"ui4",
"un4",
"uo4",
"v4",
"van4",
"ve4",
"vn4",
"E5",
"En5",
"a5",
"ai5",
"an5",
"ang5",
"ao5",
"e5",
"ei5",
"en5",
"eng5",
"er5",
"i5",
"i05",
"ia5",
"ian5",
"iang5",
"iao5",
"ie5",
"in5",
"ing5",
"iong5",
"ir5",
"iu5",
"o5",
"ong5",
"ou5",
"u5",
"ua5",
"uai5",
"uan5",
"uang5",
"ui5",
"un5",
"uo5",
"v5",
"van5",
"ve5",
"vn5",
]
v_without_tone = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn']
v_without_tone = [
"E",
"En",
"a",
"ai",
"an",
"ang",
"ao",
"e",
"ei",
"en",
"eng",
"er",
"i",
"i0",
"ia",
"ian",
"iang",
"iao",
"ie",
"in",
"ing",
"iong",
"ir",
"iu",
"o",
"ong",
"ou",
"u",
"ua",
"uai",
"uan",
"uang",
"ui",
"un",
"uo",
"v",
"van",
"ve",
"vn",
]
# japanese
ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky',
'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'v', 'w', 'y', 'z']
ja_symbols = [
"I",
"N",
"U",
"a",
"b",
"by",
"ch",
"cl",
"d",
"dy",
"e",
"f",
"g",
"gy",
"h",
"hy",
"i",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"p",
"py",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"u",
"v",
"w",
"y",
"z",
]
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'}
arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
symbols = sorted(set(symbols))
if __name__ == '__main__':
if __name__ == "__main__":
print(len(symbols))

View File

@ -19,51 +19,442 @@ from pypinyin import lazy_pinyin
from pypinyin import Style
class ToneSandhi():
class ToneSandhi:
def __init__(self):
self.must_neural_tone_words = {
'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝',
'难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊',
'里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去',
'软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号',
'认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当',
'蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻',
'舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂',
'胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆',
'老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂',
'精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿',
'窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台',
'码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算',
'白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨',
'琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快',
'爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜',
'溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔',
'棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事',
'木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾',
'收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼',
'抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实',
'扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头',
'念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼',
'干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数',
'屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气',
'实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈',
'姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方',
'大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴',
'嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦',
'咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝',
'叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹',
'功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息',
'凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤',
'佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家',
'交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故',
'不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨',
'父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅',
'幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱',
'凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱',
'扫把', '惦记'
"麻烦",
"麻利",
"鸳鸯",
"高粱",
"骨头",
"骆驼",
"马虎",
"首饰",
"馒头",
"馄饨",
"风筝",
"难为",
"队伍",
"阔气",
"闺女",
"门道",
"锄头",
"铺盖",
"铃铛",
"铁匠",
"钥匙",
"里脊",
"里头",
"部分",
"那么",
"道士",
"造化",
"迷糊",
"连累",
"这么",
"这个",
"运气",
"过去",
"软和",
"转悠",
"踏实",
"跳蚤",
"跟头",
"趔趄",
"财主",
"豆腐",
"讲究",
"记性",
"记号",
"认识",
"规矩",
"见识",
"裁缝",
"补丁",
"衣裳",
"衣服",
"衙门",
"街坊",
"行李",
"行当",
"蛤蟆",
"蘑菇",
"薄荷",
"葫芦",
"葡萄",
"萝卜",
"荸荠",
"苗条",
"苗头",
"苍蝇",
"芝麻",
"舒服",
"舒坦",
"舌头",
"自在",
"膏药",
"脾气",
"脑袋",
"脊梁",
"能耐",
"胳膊",
"胭脂",
"胡萝",
"胡琴",
"胡同",
"聪明",
"耽误",
"耽搁",
"耷拉",
"耳朵",
"老爷",
"老实",
"老婆",
"老头",
"老太",
"翻腾",
"罗嗦",
"罐头",
"编辑",
"结实",
"红火",
"累赘",
"糨糊",
"糊涂",
"精神",
"粮食",
"簸箕",
"篱笆",
"算计",
"算盘",
"答应",
"笤帚",
"笑语",
"笑话",
"窟窿",
"窝囊",
"窗户",
"稳当",
"稀罕",
"称呼",
"秧歌",
"秀气",
"秀才",
"福气",
"祖宗",
"砚台",
"码头",
"石榴",
"石头",
"石匠",
"知识",
"眼睛",
"眯缝",
"眨巴",
"眉毛",
"相声",
"盘算",
"白净",
"痢疾",
"痛快",
"疟疾",
"疙瘩",
"疏忽",
"畜生",
"生意",
"甘蔗",
"琵琶",
"琢磨",
"琉璃",
"玻璃",
"玫瑰",
"玄乎",
"狐狸",
"状元",
"特务",
"牲口",
"牙碜",
"牌楼",
"爽快",
"爱人",
"热闹",
"烧饼",
"烟筒",
"烂糊",
"点心",
"炊帚",
"灯笼",
"火候",
"漂亮",
"滑溜",
"溜达",
"温和",
"清楚",
"消息",
"浪头",
"活泼",
"比方",
"正经",
"欺负",
"模糊",
"槟榔",
"棺材",
"棒槌",
"棉花",
"核桃",
"栅栏",
"柴火",
"架势",
"枕头",
"枇杷",
"机灵",
"本事",
"木头",
"木匠",
"朋友",
"月饼",
"月亮",
"暖和",
"明白",
"时候",
"新鲜",
"故事",
"收拾",
"收成",
"提防",
"挖苦",
"挑剔",
"指甲",
"指头",
"拾掇",
"拳头",
"拨弄",
"招牌",
"招呼",
"抬举",
"护士",
"折腾",
"扫帚",
"打量",
"打算",
"打点",
"打扮",
"打听",
"打发",
"扎实",
"扁担",
"戒指",
"懒得",
"意识",
"意思",
"情形",
"悟性",
"怪物",
"思量",
"怎么",
"念头",
"念叨",
"快活",
"忙活",
"志气",
"心思",
"得罪",
"张罗",
"弟兄",
"开通",
"应酬",
"庄稼",
"干事",
"帮手",
"帐篷",
"希罕",
"师父",
"师傅",
"巴结",
"巴掌",
"差事",
"工夫",
"岁数",
"屁股",
"尾巴",
"少爷",
"小气",
"小伙",
"将就",
"对头",
"对付",
"寡妇",
"家伙",
"客气",
"实在",
"官司",
"学问",
"学生",
"字号",
"嫁妆",
"媳妇",
"媒人",
"婆家",
"娘家",
"委屈",
"姑娘",
"姐夫",
"妯娌",
"妥当",
"妖精",
"奴才",
"女婿",
"头发",
"太阳",
"大爷",
"大方",
"大意",
"大夫",
"多少",
"多么",
"外甥",
"壮实",
"地道",
"地方",
"在乎",
"困难",
"嘴巴",
"嘱咐",
"嘟囔",
"嘀咕",
"喜欢",
"喇嘛",
"喇叭",
"商量",
"唾沫",
"哑巴",
"哈欠",
"哆嗦",
"咳嗽",
"和尚",
"告诉",
"告示",
"含糊",
"吓唬",
"后头",
"名字",
"名堂",
"合同",
"吆喝",
"叫唤",
"口袋",
"厚道",
"厉害",
"千斤",
"包袱",
"包涵",
"匀称",
"勤快",
"动静",
"动弹",
"功夫",
"力气",
"前头",
"刺猬",
"刺激",
"别扭",
"利落",
"利索",
"利害",
"分析",
"出息",
"凑合",
"凉快",
"冷战",
"冤枉",
"冒失",
"养活",
"关系",
"先生",
"兄弟",
"便宜",
"使唤",
"佩服",
"作坊",
"体面",
"位置",
"似的",
"伙计",
"休息",
"什么",
"人家",
"亲戚",
"亲家",
"交情",
"云彩",
"事情",
"买卖",
"主意",
"丫头",
"丧气",
"两口",
"东西",
"东家",
"世故",
"不由",
"不在",
"下水",
"下巴",
"上头",
"上司",
"丈夫",
"丈人",
"一辈",
"那个",
"菩萨",
"父亲",
"母亲",
"咕噜",
"邋遢",
"费用",
"冤家",
"甜头",
"介绍",
"荒唐",
"大人",
"泥鳅",
"幸福",
"熟悉",
"计划",
"扑腾",
"蜡烛",
"姥爷",
"照顾",
"喉咙",
"吉他",
"弄堂",
"蚂蚱",
"凤凰",
"拖沓",
"寒碜",
"糟蹋",
"倒腾",
"报复",
"逻辑",
"盘缠",
"喽啰",
"牢骚",
"咖喱",
"扫把",
"惦记",
}
self.must_not_neural_tone_words = {
"男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子", "人人", "虎虎"
"男子",
"女子",
"分子",
"原子",
"量子",
"莲子",
"石子",
"瓜子",
"电子",
"人人",
"虎虎",
}
self.punc = ":,;。?!“”‘’':,;.?!"
@ -72,14 +463,15 @@ class ToneSandhi():
# word: "家里"
# pos: "s"
# finals: ['ia1', 'i3']
def _neural_sandhi(self, word: str, pos: str,
finals: List[str]) -> List[str]:
def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
for j, item in enumerate(word):
if j - 1 >= 0 and item == word[j - 1] and pos[0] in {
"n", "v", "a"
} and word not in self.must_not_neural_tone_words:
if (
j - 1 >= 0
and item == word[j - 1]
and pos[0] in {"n", "v", "a"}
and word not in self.must_not_neural_tone_words
):
finals[j] = finals[j][:-1] + "5"
ge_idx = word.find("")
if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
@ -89,9 +481,12 @@ class ToneSandhi():
# e.g. 走了, 看着, 去过
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
finals[-1] = finals[-1][:-1] + "5"
elif len(word) > 1 and word[-1] in "们子" and pos in {
"r", "n"
} and word not in self.must_not_neural_tone_words:
elif (
len(word) > 1
and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
@ -100,21 +495,26 @@ class ToneSandhi():
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
finals[-1] = finals[-1][:-1] + "5"
# 个做量词
elif (ge_idx >= 1 and
(word[ge_idx - 1].isnumeric() or
word[ge_idx - 1] in "几有两半多各整每做是")) or word == '':
elif (
ge_idx >= 1
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
) or word == "":
finals[ge_idx] = finals[ge_idx][:-1] + "5"
else:
if word in self.must_neural_tone_words or word[
-2:] in self.must_neural_tone_words:
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word)
finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]]
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list):
# conventional neural in Chinese
if word in self.must_neural_tone_words or word[
-2:] in self.must_neural_tone_words:
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, [])
return finals
@ -126,15 +526,15 @@ class ToneSandhi():
else:
for i, char in enumerate(word):
# "不" before tone4 should be bu2, e.g. 不怕
if char == "" and i + 1 < len(word) and finals[i +
1][-1] == "4":
if char == "" and i + 1 < len(word) and finals[i + 1][-1] == "4":
finals[i] = finals[i][:-1] + "2"
return finals
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零
if word.find("") != -1 and all(
[item.isnumeric() for item in word if item != ""]):
[item.isnumeric() for item in word if item != ""]
):
return finals
# "一" between reduplication words shold be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
@ -161,10 +561,10 @@ class ToneSandhi():
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword):]
second_subword = word[len(first_subword) :]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[:-len(first_subword)]
second_subword = word[: -len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
@ -182,18 +582,19 @@ class ToneSandhi():
elif len(word_list[0]) == 1:
finals[1] = finals[1][:-1] + "2"
else:
finals_list = [
finals[:len(word_list[0])], finals[len(word_list[0]):]
]
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
if len(finals_list) == 2:
for i, sub in enumerate(finals_list):
# e.g. 所有/人
if self._all_tone_three(sub) and len(sub) == 2:
finals_list[i][0] = finals_list[i][0][:-1] + "2"
# e.g. 好/喜欢
elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \
finals_list[0][-1][-1] == "3":
elif (
i == 1
and not self._all_tone_three(sub)
and finals_list[i][0][-1] == "3"
and finals_list[0][-1][-1] == "3"
):
finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
finals = sum(finals_list, [])
# split idiom into two words who's length is 2
@ -222,7 +623,7 @@ class ToneSandhi():
new_seg.append((word, pos))
last_word = word[:]
if last_word == "":
new_seg.append((last_word, 'd'))
new_seg.append((last_word, "d"))
last_word = ""
return new_seg
@ -236,12 +637,21 @@ class ToneSandhi():
new_seg = []
# function 1
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and i + 1 < len(seg) and seg[i - 1][
0] == seg[i + 1][0] and seg[i - 1][1] == "v":
if (
i - 1 >= 0
and word == ""
and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v"
):
new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0]
else:
if i - 2 >= 0 and seg[i - 1][0] == "" and seg[i - 2][
0] == word and pos == "v":
if (
i - 2 >= 0
and seg[i - 1][0] == ""
and seg[i - 2][0] == word
and pos == "v"
):
continue
else:
new_seg.append([word, pos])
@ -257,22 +667,27 @@ class ToneSandhi():
# the first and the second words are all_tone_three
def _merge_continuous_three_tones(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and self._all_tone_three(
sub_finals_list[i - 1]) and self._all_tone_three(
sub_finals_list[i]) and not merge_last[i - 1]:
if (
i - 1 >= 0
and self._all_tone_three(sub_finals_list[i - 1])
and self._all_tone_three(sub_finals_list[i])
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len(
seg[i - 1][0]) + len(seg[i][0]) <= 3:
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
@ -287,21 +702,27 @@ class ToneSandhi():
# the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \
merge_last[i - 1]:
if (
i - 1 >= 0
and sub_finals_list[i - 1][-1][-1] == "3"
and sub_finals_list[i][0][-1] == "3"
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len(
seg[i - 1][0]) + len(seg[i][0]) <= 3:
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
@ -313,14 +734,13 @@ class ToneSandhi():
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and seg[i-1][0] != "#":
if i - 1 >= 0 and word == "" and seg[i - 1][0] != "#":
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
else:
new_seg.append([word, pos])
return new_seg
def _merge_reduplication(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if new_seg and word == new_seg[-1][0]:
@ -329,8 +749,7 @@ class ToneSandhi():
new_seg.append([word, pos])
return new_seg
def pre_merge_for_modify(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
seg = self._merge_bu(seg)
try:
seg = self._merge_yi(seg)
@ -349,8 +768,7 @@ class ToneSandhi():
seg = self._merge_er(seg)
return seg
def modified_tone(self, word: str, pos: str,
finals: List[str]) -> List[str]:
def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals)