Merge branch 'main' into main

This commit is contained in:
RVC-Boss 2024-01-17 15:51:17 +08:00 committed by GitHub
commit e205abb87d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 6179 additions and 3322 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
env
runtime

View File

@ -16,7 +16,7 @@ __all__ = [
"DistributedBucketSampler", "DistributedBucketSampler",
] ]
T_co = TypeVar('T_co', covariant=True) T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]): class DistributedBucketSampler(Sampler[T_co]):
@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
sort batches sort batches
""" """
def __init__(self, def __init__(
self,
dataset: Dataset, dataset: Dataset,
num_replicas: Optional[int] = None, num_replicas: Optional[int] = None,
rank: Optional[int] = None, rank: Optional[int] = None,
shuffle: bool = True, shuffle: bool = True,
seed: int = 0, seed: int = 0,
drop_last: bool = False, drop_last: bool = False,
batch_size: int=32) -> None: batch_size: int = 32,
) -> None:
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError( raise RuntimeError("Requires distributed package to be available")
"Requires distributed package to be available")
num_replicas = dist.get_world_size() num_replicas = dist.get_world_size()
if rank is None: if rank is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError( raise RuntimeError("Requires distributed package to be available")
"Requires distributed package to be available")
rank = dist.get_rank() rank = dist.get_rank()
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0: if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval" raise ValueError(
" [0, {}]".format(rank, num_replicas - 1)) "Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset self.dataset = dataset
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.rank = rank self.rank = rank
@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there # If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally. # is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len( if (
self. self.drop_last and len(self.dataset) % self.num_replicas != 0
dataset) % self.num_replicas != 0: # type: ignore[arg-type] ): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible. # Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when # This is to ensure each rank receives the same amount of data when
# using this Sampler. # using this Sampler.
self.num_samples = math.ceil( self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / (len(self.dataset) - self.num_replicas)
self.num_replicas # type: ignore[arg-type] / self.num_replicas # type: ignore[arg-type]
) )
else: else:
self.num_samples = math.ceil( self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type] len(self.dataset) / self.num_replicas
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle self.shuffle = shuffle
self.seed = seed self.seed = seed
@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
shuffled_bucket = list(itertools.chain(*shuffled_bucket)) shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [ batches = [
shuffled_bucket[b * grouped_batch_size:(b + 1) * shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
grouped_batch_size] for b in range(n_batch) for b in range(n_batch)
] ]
shuffle(batches) shuffle(batches)
indices = list(itertools.chain(*batches)) indices = list(itertools.chain(*batches))
@ -129,8 +132,9 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices): if padding_size <= len(indices):
indices += indices[:padding_size] indices += indices[:padding_size]
else: else:
indices += (indices * math.ceil(padding_size / indices += (indices * math.ceil(padding_size / len(indices)))[
len(indices)))[:padding_size] :padding_size
]
else: else:
# remove tail of data to make it evenly divisible. # remove tail of data to make it evenly divisible.
indices = indices[: self.total_size] indices = indices[: self.total_size]

View File

@ -6,14 +6,21 @@ from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule): class Text2SemanticDataModule(LightningDataModule):
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None): def __init__(
self,
config,
train_semantic_path,
train_phoneme_path,
dev_semantic_path=None,
dev_phoneme_path=None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.train_semantic_path = train_semantic_path self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config['data']['num_workers'] self.num_workers = self.config["data"]["num_workers"]
def prepare_data(self): def prepare_data(self):
pass pass
@ -22,8 +29,9 @@ class Text2SemanticDataModule(LightningDataModule):
self._train_dataset = Text2SemanticDataset( self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path, phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path, semantic_path=self.train_semantic_path,
max_sec=self.config['data']['max_sec'], max_sec=self.config["data"]["max_sec"],
pad_val=self.config['data']['pad_val']) pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset( # self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path, # phoneme_path=self.dev_phoneme_path,
@ -33,9 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val']) # pad_val=self.config['data']['pad_val'])
def train_dataloader(self): def train_dataloader(self):
batch_size = self.config['train']['batch_size'] batch_size = self.config["train"]["batch_size"]
sampler = DistributedBucketSampler( sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
self._train_dataset, batch_size=batch_size)
return DataLoader( return DataLoader(
self._train_dataset, self._train_dataset,
batch_size=batch_size, batch_size=batch_size,
@ -43,7 +50,7 @@ class Text2SemanticDataModule(LightningDataModule):
collate_fn=self._train_dataset.collate, collate_fn=self._train_dataset.collate,
num_workers=self.num_workers, num_workers=self.num_workers,
persistent_workers=True, persistent_workers=True,
prefetch_factor=16 prefetch_factor=16,
) )
def val_dataloader(self): def val_dataloader(self):
@ -54,7 +61,7 @@ class Text2SemanticDataModule(LightningDataModule):
collate_fn=self._train_dataset.collate, collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers, 12), num_workers=max(self.num_workers, 12),
persistent_workers=True, persistent_workers=True,
prefetch_factor=16 prefetch_factor=16,
) )
# 这个会使用到嘛? # 这个会使用到嘛?
@ -63,4 +70,5 @@ class Text2SemanticDataModule(LightningDataModule):
self._dev_dataset, self._dev_dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
collate_fn=self._train_dataset.collate) collate_fn=self._train_dataset.collate,
)

View File

@ -1,6 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
import pdb import pdb
import sys import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert") # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback, os import traceback, os
from typing import Dict from typing import Dict
@ -14,8 +15,10 @@ from torch.utils.data import Dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
# from config import exp_dir # from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0): def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0] seq = sequences[0]
ndim = seq.ndim ndim = seq.ndim
@ -28,18 +31,20 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = [] padded_sequences = []
for seq, length in zip(sequences, seq_lengths): for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * ( padding = (
ndim - axis - 1) [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
padded_seq = np.pad( )
seq, padding, mode='constant', constant_values=pad_value) padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq) padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences) batch = np.stack(padded_sequences)
return batch return batch
class Text2SemanticDataset(Dataset): class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training.""" """dataset class for text tokens to semantic model training."""
def __init__(self, def __init__(
self,
phoneme_path: str, phoneme_path: str,
semantic_path: str, semantic_path: str,
max_sample: int = None, max_sample: int = None,
@ -48,13 +53,18 @@ class Text2SemanticDataset(Dataset):
# min value of phoneme/sec # min value of phoneme/sec
min_ps_ratio: int = 3, min_ps_ratio: int = 3,
# max value of phoneme/sec # max value of phoneme/sec
max_ps_ratio: int = 25) -> None: max_ps_ratio: int = 25,
) -> None:
super().__init__() super().__init__()
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8") self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
)
# get dict # get dict
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir self.path3 = "%s/3-bert" % (
os.path.basename(phoneme_path)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2) assert os.path.exists(self.path2)
assert os.path.exists(self.path6) assert os.path.exists(self.path6)
@ -64,7 +74,8 @@ class Text2SemanticDataset(Dataset):
for line in lines: for line in lines:
tmp = line.split("\t") tmp = line.split("\t")
if(len(tmp)!=4):continue if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]] self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset):
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large") # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large") # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self): def init_batch(self):
semantic_data_len = len(self.semantic_data) semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys()) phoneme_data_len = len(self.phoneme_data.keys())
@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len): for i in range(semantic_data_len):
# 先依次遍历 # 先依次遍历
# get str # get str
item_name = self.semantic_data['item_name'][i] item_name = self.semantic_data["item_name"][i]
# print(self.phoneme_data) # print(self.phoneme_data)
try: try:
phoneme, word2ph, text = self.phoneme_data[item_name] phoneme, word2ph, text = self.phoneme_data[item_name]
@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1 num_not_in += 1
continue continue
semantic_str = self.semantic_data['semantic_audio'][i] semantic_str = self.semantic_data["semantic_audio"][i]
# get token list # get token list
semantic_ids = [int(idx) for idx in semantic_str.split(' ')] semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本 # 过滤掉太长的样本
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1 num_deleted_bigger += 1
continue continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理#### # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(' ') phoneme = phoneme.split(" ")
try: try:
phoneme_ids = cleaned_text_to_sequence(phoneme) phoneme_ids = cleaned_text_to_sequence(phoneme)
@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1 num_not_in += 1
continue continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行 # if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2改为恒定限制为semantic/2.5就行 if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1 num_deleted_ps += 1
continue continue
# if len(semantic_ids) > 1000:###########3 # if len(semantic_ids) > 1000:###########3
@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1 num_deleted_ps += 1
# print(item_name) # print(item_name)
continue continue
@ -162,7 +178,7 @@ class Text2SemanticDataset(Dataset):
min_num = 100 # 20直接不补#30补了也不存ckpt min_num = 100 # 20直接不补#30补了也不存ckpt
leng = len(self.semantic_phoneme) leng = len(self.semantic_phoneme)
if(leng<min_num): if leng < min_num:
tmp1 = self.semantic_phoneme tmp1 = self.semantic_phoneme
tmp2 = self.item_names tmp2 = self.item_names
self.semantic_phoneme = [] self.semantic_phoneme = []
@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset):
print( print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}" f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
) )
''' """
there are 31 semantic datas not in phoneme datas there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3 deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463 dataset.__len__(): 366463
''' """
# 345410 for LibriTTS # 345410 for LibriTTS
print("dataset.__len__():", self.__len__()) print("dataset.__len__():", self.__len__())
@ -206,20 +222,22 @@ class Text2SemanticDataset(Dataset):
flag = 0 flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name) path_bert = "%s/%s.pt" % (self.path3, item_name)
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu") if os.path.exists(path_bert) == True:
else:flag=1 bert_feature = torch.load(path_bert, map_location="cpu")
if(flag==1): else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32) # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature = None bert_feature = None
else: else:
assert bert_feature.shape[-1] == len(phoneme_ids) assert bert_feature.shape[-1] == len(phoneme_ids)
return { return {
'idx': idx, "idx": idx,
'phoneme_ids': phoneme_ids, "phoneme_ids": phoneme_ids,
'phoneme_ids_len': phoneme_ids_len, "phoneme_ids_len": phoneme_ids_len,
'semantic_ids': semantic_ids, "semantic_ids": semantic_ids,
'semantic_ids_len': semantic_ids_len, "semantic_ids_len": semantic_ids_len,
'bert_feature': bert_feature, "bert_feature": bert_feature,
} }
def get_sample_length(self, idx: int): def get_sample_length(self, idx: int):
@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset):
semantic_ids_lens: List[int] = [] semantic_ids_lens: List[int] = []
# return # return
for item in examples: for item in examples:
sample_index.append(item["idx"]) sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
@ -256,8 +273,8 @@ class Text2SemanticDataset(Dataset):
bert_padded.zero_() bert_padded.zero_()
for idx, item in enumerate(examples): for idx, item in enumerate(examples):
bert = item['bert_feature'] bert = item["bert_feature"]
if(bert!=None): if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert bert_padded[idx, :, : bert.shape[-1]] = bert
return { return {
@ -276,20 +293,20 @@ class Text2SemanticDataset(Dataset):
} }
if __name__ == '__main__': if __name__ == "__main__":
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/' root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset( dataset = Text2SemanticDataset(
phoneme_path=root_dir + 'phoneme_train.npy', phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + 'semantic_train.tsv') semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12 batch_size = 12
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
batch_size=batch_size, )
collate_fn=dataset.collate,
shuffle=False)
for i, batch in enumerate(dataloader): for i, batch in enumerate(dataloader):
if(i%1000==0):print(i) if i % 1000 == 0:
print(i)
# if i == 0: # if i == 0:
# print('batch["ids"]:', batch["ids"]) # print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"], # print('batch["phoneme_ids"]:', batch["phoneme_ids"],

View File

@ -1,5 +1,6 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os, sys import os, sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from typing import Dict from typing import Dict
@ -18,23 +19,29 @@ class Text2SemanticLightningModule(LightningModule):
self.top_k = 3 self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1 = config.get("pretrained_s1") pretrained_s1 = config.get("pretrained_s1")
if(pretrained_s1 and is_train): if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"])) print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
)
if is_train: if is_train:
self.automatic_optimization = False self.automatic_optimization = False
self.save_hyperparameters() self.save_hyperparameters()
self.eval_dir = output_dir / 'eval' self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True) self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int): def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers() opt = self.optimizers()
scheduler = self.lr_schedulers() scheduler = self.lr_schedulers()
loss, acc = self.model.forward( loss, acc = self.model.forward(
batch['phoneme_ids'], batch['phoneme_ids_len'], batch["phoneme_ids"],
batch['semantic_ids'], batch['semantic_ids_len'], batch["phoneme_ids_len"],
batch['bert_feature']) batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
)
self.manual_backward(loss) self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0: if batch_idx > 0 and batch_idx % 4 == 0:
opt.step() opt.step()
@ -47,22 +54,27 @@ class Text2SemanticLightningModule(LightningModule):
on_step=True, on_step=True,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True) sync_dist=True,
)
self.log( self.log(
"lr", "lr",
scheduler.get_last_lr()[0], scheduler.get_last_lr()[0],
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True) sync_dist=True,
)
self.log( self.log(
f"top_{self.top_k}_acc", f"top_{self.top_k}_acc",
acc, acc,
on_step=True, on_step=True,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True) sync_dist=True,
)
def validation_step(self, batch: Dict, batch_idx: int):
return
def validation_step(self, batch: Dict, batch_idx: int):return
# # get loss # # get loss
# loss, acc = self.model.forward( # loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'], # batch['phoneme_ids'], batch['phoneme_ids_len'],
@ -100,10 +112,9 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
model_parameters = self.model.parameters() model_parameters = self.model.parameters()
parameters_names = [] parameters_names = []
parameters_names.append([ parameters_names.append(
name_param_pair[0] [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
for name_param_pair in self.model.named_parameters() )
])
lm_opt = ScaledAdam( lm_opt = ScaledAdam(
model_parameters, model_parameters,
lr=0.01, lr=0.01,
@ -111,18 +122,19 @@ class Text2SemanticLightningModule(LightningModule):
clipping_scale=2.0, clipping_scale=2.0,
parameters_names=parameters_names, parameters_names=parameters_names,
show_dominant_parameters=False, show_dominant_parameters=False,
clipping_update_period=1000, ) clipping_update_period=1000,
)
return { return {
"optimizer": lm_opt, "optimizer": lm_opt,
"lr_scheduler": { "lr_scheduler": {
"scheduler": "scheduler": WarmupCosineLRSchedule(
WarmupCosineLRSchedule(
lm_opt, lm_opt,
init_lr=self.config['optimizer']['lr_init'], init_lr=self.config["optimizer"]["lr_init"],
peak_lr=self.config['optimizer']['lr'], peak_lr=self.config["optimizer"]["lr"],
end_lr=self.config['optimizer']['lr_end'], end_lr=self.config["optimizer"]["lr_end"],
warmup_steps=self.config['optimizer']['warmup_steps'], warmup_steps=self.config["optimizer"]["warmup_steps"],
total_steps=self.config['optimizer']['decay_steps']) total_steps=self.config["optimizer"]["decay_steps"],
} )
},
} }

View File

@ -3,7 +3,12 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from AR.models.utils import make_pad_mask from AR.models.utils import make_pad_mask
from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
)
from AR.modules.embedding import SinePositionalEmbedding from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm from AR.modules.transformer import LayerNorm
@ -22,35 +27,39 @@ default_config = {
"p_dropout": 0.0, "p_dropout": 0.0,
"vocab_size": 1024 + 1, "vocab_size": 1024 + 1,
"phoneme_vocab_size": 512, "phoneme_vocab_size": 512,
"EOS": 1024 "EOS": 1024,
} }
class Text2SemanticDecoder(nn.Module): class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3): def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__() super(Text2SemanticDecoder, self).__init__()
self.model_dim = config['model']["hidden_dim"] self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config['model']["embedding_dim"] self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config['model']["head"] self.num_head = config["model"]["head"]
self.num_layers = config['model']["n_layer"] self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first self.norm_first = norm_first
self.vocab_size = config['model']["vocab_size"] self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config['model']["phoneme_vocab_size"] self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config['model']["dropout"] self.p_dropout = config["model"]["dropout"]
self.EOS = config['model']["EOS"] self.EOS = config["model"]["EOS"]
self.norm_first = norm_first self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1 assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin # should be same as num of kmeans bin
# assert self.EOS == 1024 # assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim) self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding( self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
self.ar_text_position = SinePositionalEmbedding( self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True) self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding( self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout) self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding( self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True) self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder( self.h = TransformerEncoder(
TransformerEncoderLayer( TransformerEncoderLayer(
@ -59,26 +68,28 @@ class Text2SemanticDecoder(nn.Module):
dim_feedforward=self.model_dim * 4, dim_feedforward=self.model_dim * 4,
dropout=0.1, dropout=0.1,
batch_first=True, batch_first=True,
norm_first=norm_first, ), norm_first=norm_first,
),
num_layers=self.num_layers, num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None, ) norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear( self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.model_dim, self.vocab_size, bias=False) self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
self.ar_accuracy_metric = MulticlassAccuracy( self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size, self.vocab_size,
top_k=top_k, top_k=top_k,
average="micro", average="micro",
multidim_average="global", multidim_average="global",
ignore_index=self.EOS, ) ignore_index=self.EOS,
)
def forward(self, x, x_lens, y, y_lens, bert_feature): def forward(self, x, x_lens, y, y_lens, bert_feature):
''' """
x: phoneme_ids x: phoneme_ids
y: semantic_ids y: semantic_ids
''' """
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
@ -102,18 +113,23 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = F.pad( x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len), (0, y_len),
value=True, ) value=True,
)
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1, ), diagonal=1,
),
(x_len, 0), (x_len, 0),
value=False, ) value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len) _xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1) .expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)) .reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
@ -122,24 +138,26 @@ class Text2SemanticDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h( xy_dec, _ = self.h(
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, ) mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss # loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction='sum') loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item() acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(self, def infer(
self,
x, x,
x_lens, x_lens,
prompts, prompts,
bert_feature, bert_feature,
top_k: int = -100, top_k: int = -100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float=1.0): temperature: float = 1.0,
):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
@ -159,35 +177,37 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad( x_attn_mask_pad = F.pad(
x_attn_mask, x_attn_mask,
(0, y_len), (0, y_len),
value=True, ) value=True,
)
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), (x_len, 0),
value=False, ) value=False,
xy_attn_mask = torch.concat( )
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.h( xy_dec, _ = self.h(
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, ) mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling( samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature) logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
) > early_stop_num:
print("use early stop num:", early_stop_num) print("use early stop num:", early_stop_num)
stop = True stop = True
if torch.argmax( if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True stop = True
if stop: if stop:
if prompts.shape[1] == y.shape[1]: if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1) y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction') print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y # 本次生成的 semantic_ids 和之前的 y 构成新的 y
@ -198,21 +218,22 @@ class Text2SemanticDecoder(nn.Module):
return y return y
def pad_y_eos(self, y, y_mask_int, eos_id): def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad( targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y, (0, 1), value=0) + eos_id * F.pad( y_mask_int, (0, 1), value=1
y_mask_int, (0, 1), value=1) )
# 错位 # 错位
return targets[:, :-1], targets[:, 1:] return targets[:, :-1], targets[:, 1:]
def infer_panel(self, def infer_panel(
self,
x, #####全部文本token x, #####全部文本token
x_lens, x_lens,
prompts, ####参考音频token prompts, ####参考音频token
bert_feature, bert_feature,
top_k: int = -100, top_k: int = -100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float=1.0): temperature: float = 1.0,
):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
@ -233,62 +254,68 @@ class Text2SemanticDecoder(nn.Module):
# "logits":None,###原版就已经只对结尾求再拼接了,不用管 # "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits # "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer": 1, "first_infer": 1,
"stage":0 "stage": 0,
} }
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if(cache["first_infer"]==1): if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y)
else: else:
y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1) y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
cache["y_emb"] = y_emb cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型 # x 和逐渐增长的 y 一起输入给模型
if(cache["first_infer"]==1): if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
else: else:
xy_pos = y_pos[:, -1:] xy_pos = y_pos[:, -1:]
y_len = y_pos.shape[1] y_len = y_pos.shape[1]
###以下3个不做缓存 ###以下3个不做缓存
if (cache["first_infer"] == 1): if cache["first_infer"] == 1:
x_attn_mask_pad = F.pad( x_attn_mask_pad = F.pad(
x_attn_mask, x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y) (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True, ) value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu( torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), (x_len, 0),
value=False, ) value=False,
xy_attn_mask = torch.concat( )
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
else: else:
###最右边一列(是错的) ###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False # xy_attn_mask[:,-1]=False
###最下面一行(是对的) ###最下面一行(是对的)
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device) xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
# pdb.set_trace() # pdb.set_trace()
###缓存重头戏 ###缓存重头戏
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len) # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
xy_dec, _ = self.h( xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
(xy_pos, None), logits = self.ar_predict_layer(
mask=xy_attn_mask,cache=cache ) xy_dec[:, -1]
logits = self.ar_predict_layer(xy_dec[:, -1])##不用改如果用了cache的默认就是只有一帧取最后一帧一样的 ) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) samples = sample(
if early_stop_num != -1 and (y.shape[1] - prefix_len logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
) > early_stop_num: )[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num) print("use early stop num:", early_stop_num)
stop = True stop = True
if torch.argmax( if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True stop = True
if stop: if stop:
if prompts.shape[1] == y.shape[1]: if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1) y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction') print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y # 本次生成的 semantic_ids 和之前的 y 构成新的 y

View File

@ -2,6 +2,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
if max_length is None: if max_length is None:
max_length = length.max() max_length = length.max()
@ -38,11 +39,9 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(logits, def top_k_top_p_filtering(
top_k=0, logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
top_p=1.0, ):
filter_value=-float("Inf"),
min_tokens_to_keep=1):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size, vocabulary size) logits: logits distribution shape (batch size, vocabulary size)
@ -53,16 +52,14 @@ def top_k_top_p_filtering(logits,
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
if top_k > 0: if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p < 1.0: if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum( cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
@ -70,13 +67,13 @@ def top_k_top_p_filtering(logits,
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove) 1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
return logits return logits
@ -100,6 +97,8 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
from typing import Optional, Tuple from typing import Optional, Tuple
def multinomial_sample_one_no_sync( def multinomial_sample_one_no_sync(
probs_sort, probs_sort,
): # Does multinomial sampling without a cuda synchronization ): # Does multinomial sampling without a cuda synchronization
@ -159,4 +158,3 @@ def sample(
) )
idx_next = multinomial_sample_one_no_sync(probs) idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs return idx_next, probs

View File

@ -13,8 +13,10 @@ from torch.nn.parameter import Parameter
from torch.nn import functional as F from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched F.multi_head_attention_forward = multi_head_attention_forward_patched
class MultiheadAttention(Module): class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper: from different representation subspaces as described in the paper:
@ -89,53 +91,58 @@ class MultiheadAttention(Module):
linear1_cls=Linear, linear1_cls=Linear,
linear2_cls=Linear, linear2_cls=Linear,
device=None, device=None,
dtype=None, ) -> None: dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (self.kdim == embed_dim and self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.vdim == embed_dim)
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.batch_first = batch_first self.batch_first = batch_first
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
if add_bias_kv: if add_bias_kv:
self.bias_k = Parameter( self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
else: else:
self.bias_k = self.bias_v = None self.bias_k = self.bias_v = None
if linear1_cls == Linear: if linear1_cls == Linear:
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter( self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)) torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter( self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)) torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter( self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)) torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None) self.register_parameter("in_proj_weight", None)
else: else:
self.in_proj_weight = Parameter( self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None) self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None) self.register_parameter("v_proj_weight", None)
if bias: if bias:
self.in_proj_bias = Parameter( self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)) torch.empty(3 * embed_dim, **factory_kwargs)
)
else: else:
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear( self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs) embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters() self._reset_parameters()
else: else:
@ -143,7 +150,8 @@ class MultiheadAttention(Module):
raise NotImplementedError raise NotImplementedError
else: else:
self.in_proj_linear = linear1_cls( self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
@ -156,7 +164,8 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls( self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs) embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None: if self.bias_k is not None:
xavier_normal_(self.bias_k) xavier_normal_(self.bias_k)
@ -197,7 +206,8 @@ class MultiheadAttention(Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
average_attn_weights: bool=True,cache=None average_attn_weights: bool = True,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -251,23 +261,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None: if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype _kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point( if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask): key_padding_mask
):
raise AssertionError( raise AssertionError(
"only bool and floating types of key_padding_mask are supported" "only bool and floating types of key_padding_mask are supported"
) )
why_not_fast_path = "" why_not_fast_path = ""
if not is_batched: if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value: elif query is not key or key is not value:
# When lifting this restriction, don't forget to either # When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where # enforce that the dtypes all match or test cases where
# they don't! # they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (self.in_proj_bias is not None and elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
query.dtype != self.in_proj_bias.dtype):
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (self.in_proj_weight is not None and elif (
query.dtype != self.in_proj_weight.dtype): self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message. # this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training: elif self.training:
@ -288,29 +301,41 @@ class MultiheadAttention(Module):
why_not_fast_path = "attn_mask was not None" why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None: elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = ( why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input") "key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1: elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd" why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled(): elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled" why_not_fast_path = "autocast is enabled"
if not why_not_fast_path: if not why_not_fast_path:
tensor_args = (query, key, value, self.in_proj_weight, tensor_args = (
self.in_proj_bias, self.out_proj.weight, query,
self.out_proj.bias, ) key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support # We have to use list comprehensions below because TorchScript does not support
# generator expressions. # generator expressions.
if torch.overrides.has_torch_function(tensor_args): if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function" why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) elif not all(
for x in tensor_args]): [
why_not_fast_path = ( (x is None or x.is_cuda or "cpu" in str(x.device))
"some Tensor argument is neither CUDA nor CPU") for x in tensor_args
]
):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any( elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]): [x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = ( why_not_fast_path = (
"grad is enabled and at least one of query or the " "grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad") "input/output projection weights or biases requires_grad"
)
if not why_not_fast_path: if not why_not_fast_path:
return torch._native_multi_head_attention( return torch._native_multi_head_attention(
query, query,
@ -322,17 +347,21 @@ class MultiheadAttention(Module):
self.in_proj_bias, self.in_proj_bias,
self.out_proj.weight, self.out_proj.weight,
self.out_proj.bias, self.out_proj.bias,
key_padding_mask key_padding_mask if key_padding_mask is not None else attn_mask,
if key_padding_mask is not None else attn_mask,
need_weights, need_weights,
average_attn_weights, average_attn_weights,
1 if key_padding_mask is not None else 0 1
if attn_mask is not None else None, ) if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, ( assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. " "MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}") + f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched: if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property # make sure that the transpose op does not affect the "is" property
@ -343,9 +372,7 @@ class MultiheadAttention(Module):
query, key = [x.transpose(1, 0) for x in (query, key)] query, key = [x.transpose(1, 0) for x in (query, key)]
value = key value = key
else: else:
query, key, value = [ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
x.transpose(1, 0) for x in (query, key, value)
]
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward( attn_output, attn_output_weights = F.multi_head_attention_forward(
@ -370,7 +397,9 @@ class MultiheadAttention(Module):
q_proj_weight=self.q_proj_weight, q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,cache=cache ) average_attn_weights=average_attn_weights,
cache=cache,
)
else: else:
attn_output, attn_output_weights = F.multi_head_attention_forward( attn_output, attn_output_weights = F.multi_head_attention_forward(
query, query,
@ -390,7 +419,9 @@ class MultiheadAttention(Module):
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=need_weights, need_weights=need_weights,
attn_mask=attn_mask, attn_mask=attn_mask,
average_attn_weights=average_attn_weights,cache=cache ) average_attn_weights=average_attn_weights,
cache=cache,
)
if self.batch_first and is_batched: if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights return attn_output.transpose(1, 0), attn_output_weights
else: else:

View File

@ -10,7 +10,8 @@ class TokenEmbedding(nn.Module):
self, self,
embedding_dim: int, embedding_dim: int,
vocab_size: int, vocab_size: int,
dropout: float=0.0, ): dropout: float = 0.0,
):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -38,7 +39,8 @@ class SinePositionalEmbedding(nn.Module):
embedding_dim: int, embedding_dim: int,
dropout: float = 0.0, dropout: float = 0.0,
scale: bool = False, scale: bool = False,
alpha: bool=False, ): alpha: bool = False,
):
super().__init__() super().__init__()
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
@ -59,13 +61,14 @@ class SinePositionalEmbedding(nn.Module):
pe = torch.zeros(x.size(1), self.embedding_dim) pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse: if self.reverse:
position = torch.arange( position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else: else:
position = torch.arange( position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
-(math.log(10000.0) / self.embedding_dim)) * -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)

View File

@ -12,14 +12,16 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
""" """
def __init__(self, def __init__(
self,
optimizer, optimizer,
init_lr, init_lr,
peak_lr, peak_lr,
end_lr, end_lr,
warmup_steps=10000, warmup_steps=10000,
total_steps=400000, total_steps=400000,
current_step=0): current_step=0,
):
self.init_lr = init_lr self.init_lr = init_lr
self.peak_lr = peak_lr self.peak_lr = peak_lr
self.end_lr = end_lr self.end_lr = end_lr
@ -33,10 +35,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
self._last_lr = [self.lr] self._last_lr = [self.lr]
def set_lr(self, lr): def set_lr(self, lr):
self._last_lr = [g['lr'] for g in self.optimizer.param_groups] self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
for g in self.optimizer.param_groups: for g in self.optimizer.param_groups:
# g['lr'] = lr # g['lr'] = lr
g['lr'] = self.end_lr###锁定用线性 g["lr"] = self.end_lr ###锁定用线性
def step(self): def step(self):
if self._current_step < self.warmup_steps: if self._current_step < self.warmup_steps:
@ -47,7 +49,8 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
else: else:
decay_ratio = (self._current_step - self.warmup_steps) / ( decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps) self.total_steps - self.warmup_steps
)
if decay_ratio < 0.0 or decay_ratio > 1.0: if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError( raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings." "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
@ -62,18 +65,12 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
return self.lr return self.lr
if __name__ == "__main__":
if __name__ == '__main__':
m = nn.Linear(10, 10) m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4) opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule( s = WarmupCosineLRSchedule(
opt, opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
1e-6, )
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0)
lrs = [] lrs = []
for i in range(25000): for i in range(25000):
s.step() s.step()

View File

@ -1,9 +1,16 @@
from torch.nn.functional import * from torch.nn.functional import *
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
# import torch # import torch
# Tensor = torch.Tensor # Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union # from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched( def multi_head_attention_forward_patched(
query: Tensor, query: Tensor,
key: Tensor, key: Tensor,
@ -29,7 +36,8 @@ def multi_head_attention_forward_patched(
static_k: Optional[Tensor] = None, static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None, static_v: Optional[Tensor] = None,
average_attn_weights: bool = True, average_attn_weights: bool = True,
is_causal: bool = False,cache=None is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -105,7 +113,17 @@ def multi_head_attention_forward_patched(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
""" """
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if has_torch_function(tens_ops): if has_torch_function(tens_ops):
return handle_torch_function( return handle_torch_function(
multi_head_attention_forward, multi_head_attention_forward,
@ -134,10 +152,13 @@ def multi_head_attention_forward_patched(
v_proj_weight=v_proj_weight, v_proj_weight=v_proj_weight,
static_k=static_k, static_k=static_k,
static_v=static_v, static_v=static_v,
average_attn_weights=average_attn_weights,cache=cache average_attn_weights=average_attn_weights,
cache=cache,
) )
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the # is batched, run the computation and before returning squeeze the
@ -159,7 +180,7 @@ def multi_head_attention_forward_patched(
mask_name="key_padding_mask", mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask), other_type=_none_or_dtype(attn_mask),
other_name="attn_mask", other_name="attn_mask",
target_type=query.dtype target_type=query.dtype,
) )
if is_causal and attn_mask is None: if is_causal and attn_mask is None:
@ -184,51 +205,76 @@ def multi_head_attention_forward_patched(
check_other=False, check_other=False,
) )
if key_padding_mask is not None: if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it. # We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no # Turn off use of is_causal hint, as the merged mask is no
# longer causal. # longer causal.
is_causal = False is_causal = False
assert embed_dim == embed_dim_to_check, \ assert (
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor): if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing # embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc') head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else: else:
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight: if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used # allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \ assert (
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else: else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
# #
# compute in-projection # compute in-projection
# #
if not use_separate_proj_weight: if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else: else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" assert (
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" q_proj_weight is not None
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" ), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None: if in_proj_bias is None:
b_q = b_k = b_v = None b_q = b_k = b_v = None
else: else:
b_q, b_k, b_v = in_proj_bias.chunk(3) b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) q, k, v = _in_projection(
if(cache!=None): query,
if(cache["first_infer"]==1): key,
value,
q_proj_weight,
k_proj_weight,
v_proj_weight,
b_q,
b_k,
b_v,
)
if cache != None:
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k cache["k"][cache["stage"]] = k
# print(0,cache["k"].shape) # print(0,cache["k"].shape)
cache["v"][cache["stage"]] = v cache["v"][cache["stage"]] = v
else: ###12个layer每个都要留自己的cache_kv else: ###12个layer每个都要留自己的cache_kv
# print(1,cache["k"].shape) # print(1,cache["k"].shape)
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1但是proj的时候可能transpose了所以时序到0维了 cache["k"][cache["stage"]] = torch.cat(
[cache["k"][cache["stage"]], k], 0
) ##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0) cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
# print(2, cache["k"].shape) # print(2, cache["k"].shape)
src_len = cache["k"][cache["stage"]].shape[0] src_len = cache["k"][cache["stage"]].shape[0]
@ -255,14 +301,20 @@ def multi_head_attention_forward_patched(
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len) correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size: if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len) correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size: if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else: else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
# add bias along batch dimension (currently second) # add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None: if bias_k is not None and bias_v is not None:
@ -286,26 +338,34 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \ assert (
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim, \ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k k = static_k
if static_v is None: if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \ assert (
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim, \ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v v = static_v
# add zero attention along batch dimension (now first) # add zero attention along batch dimension (now first)
if add_zero_attn: if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim) zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) k = torch.cat(
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
if attn_mask is not None: if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1)) attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None: if key_padding_mask is not None:
@ -316,10 +376,15 @@ def multi_head_attention_forward_patched(
# merge key padding and attention masks # merge key padding and attention masks
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \ assert key_padding_mask.shape == (
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" bsz,
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ src_len,
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None: if attn_mask is None:
attn_mask = key_padding_mask attn_mask = key_padding_mask
else: else:
@ -337,10 +402,14 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape B, Nt, E = q.shape
q_scaled = q / math.sqrt(E) q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None: if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
else: else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = softmax(attn_output_weights, dim=-1)
@ -349,7 +418,9 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@ -377,8 +448,12 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim) k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) attn_output = scaled_dot_product_attention(
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving. # floors), should be expectation-preserving.
floor = -0.043637 floor = -0.043637
ceil = 1.2 ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor) d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
) + torch.rand_like(deriv) deriv
)
if __name__ == "__main__": if __name__ == "__main__":
# for self-testing only. # for self-testing only.
assert d_scaled.min() >= 0.0 assert d_scaled.min() >= 0.0
@ -100,7 +101,8 @@ class ActivationBalancerFunction(torch.autograd.Function):
x: Tensor, x: Tensor,
scale_factor: Tensor, scale_factor: Tensor,
sign_factor: Optional[Tensor], sign_factor: Optional[Tensor],
channel_dim: int, ) -> Tensor: channel_dim: int,
) -> Tensor:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
@ -125,7 +127,12 @@ class ActivationBalancerFunction(torch.autograd.Function):
scale_factor = scale_factor.unsqueeze(-1) scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor neg_delta_grad = x_grad.abs() * factor
return (x_grad - neg_delta_grad, None, None, None, ) return (
x_grad - neg_delta_grad,
None,
None,
None,
)
def _compute_scale_factor( def _compute_scale_factor(
@ -134,7 +141,8 @@ def _compute_scale_factor(
min_abs: float, min_abs: float,
max_abs: float, max_abs: float,
gain_factor: float, gain_factor: float,
max_factor: float, ) -> Tensor: max_factor: float,
) -> Tensor:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -145,12 +153,13 @@ def _compute_scale_factor(
else: else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs. # x_abs)_mean , min_abs.
below_threshold = ( below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( min=0, max=max_factor
min=0, max=max_factor) )
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor) min=0, max=max_factor
)
return below_threshold - above_threshold return below_threshold - above_threshold
@ -161,7 +170,8 @@ def _compute_sign_factor(
min_positive: float, min_positive: float,
max_positive: float, max_positive: float,
gain_factor: float, gain_factor: float,
max_factor: float, ) -> Tensor: max_factor: float,
) -> Tensor:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -171,18 +181,18 @@ def _compute_sign_factor(
else: else:
# 0 if proportion_positive >= min_positive, else can be # 0 if proportion_positive >= min_positive, else can be
# as large as max_factor. # as large as max_factor.
factor1 = ((min_positive - proportion_positive) * factor1 = (
(gain_factor / min_positive)).clamp_( (min_positive - proportion_positive) * (gain_factor / min_positive)
min=0, max=max_factor) ).clamp_(min=0, max=max_factor)
if max_positive == 1.0: if max_positive == 1.0:
factor2 = 0.0 factor2 = 0.0
else: else:
# 0 if self.proportion_positive <= max_positive, else can be # 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor. # as large as -max_factor.
factor2 = ((proportion_positive - max_positive) * factor2 = (
(gain_factor / (1.0 - max_positive))).clamp_( (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
min=0, max=max_factor) ).clamp_(min=0, max=max_factor)
sign_factor = factor1 - factor2 sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1: # require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float) assert not isinstance(sign_factor, float)
@ -240,7 +250,8 @@ class ActivationBalancer(torch.nn.Module):
scale_gain_factor: float = 0.02, scale_gain_factor: float = 0.02,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0, max_abs: float = 100.0,
min_prob: float=0.1, ): min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module):
self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
torch.jit.is_tracing()):
return _no_op(x) return _no_op(x)
count = self.cpu_count count = self.cpu_count
@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module):
self.min_positive, self.min_positive,
self.max_positive, self.max_positive,
gain_factor=self.sign_gain_factor / prob, gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor, ) max_factor=self.max_factor,
)
else: else:
sign_factor = None sign_factor = None
@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module):
min_abs=self.min_abs, min_abs=self.min_abs,
max_abs=self.max_abs, max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob, gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor, ) max_factor=self.max_factor,
)
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x,
scale_factor, scale_factor,
sign_factor, sign_factor,
self.channel_dim, ) self.channel_dim,
)
else: else:
return _no_op(x) return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, def BalancedDoubleSwish(
min_prob=0.25) -> nn.Sequential: d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
""" """
ActivationBalancer -> DoubleSwish ActivationBalancer -> DoubleSwish
""" """
balancer = ActivationBalancer( balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob) d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
return nn.Sequential( return nn.Sequential(
balancer, balancer,
DoubleSwish(), ) DoubleSwish(),
)

View File

@ -31,21 +31,23 @@ class LayerNorm(nn.Module):
eps: float = 1e-5, eps: float = 1e-5,
elementwise_affine: bool = True, elementwise_affine: bool = True,
device=None, device=None,
dtype=None, ) -> None: dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__() super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment # mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment] normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple( self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
normalized_shape) # type: ignore[arg-type]
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.elementwise_affine = elementwise_affine
if self.elementwise_affine: if self.elementwise_affine:
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)) torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter( self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)) torch.empty(self.normalized_shape, **factory_kwargs)
)
else: else:
self.register_parameter("weight", None) self.register_parameter("weight", None)
self.register_parameter("bias", None) self.register_parameter("bias", None)
@ -60,21 +62,27 @@ class LayerNorm(nn.Module):
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple): if isinstance(input, tuple):
input, embedding = input input, embedding = input
return (F.layer_norm( return (
F.layer_norm(
input, input,
self.normalized_shape, self.normalized_shape,
self.weight, self.weight,
self.bias, self.bias,
self.eps, ), embedding, ) self.eps,
),
embedding,
)
assert embedding is None assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight, return F.layer_norm(
self.bias, self.eps) input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return ( return (
"{normalized_shape}, eps={eps}, " "{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)) "elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module): class IdentityNorm(nn.Module):
@ -83,7 +91,8 @@ class IdentityNorm(nn.Module):
d_model: int, d_model: int,
eps: float = 1e-5, eps: float = 1e-5,
device=None, device=None,
dtype=None, ) -> None: dtype=None,
) -> None:
super(IdentityNorm, self).__init__() super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any = None) -> Tensor: def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
@ -125,7 +134,9 @@ class TransformerEncoder(nn.Module):
src: Tensor, src: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool=False,cache=None ) -> Tensor: return_layer_states: bool = False,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module):
output = mod( output = mod(
output, output,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache) src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
layer_states.append(output[0]) layer_states.append(output[0])
if self.norm is not None: if self.norm is not None:
@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module):
output = src output = src
for mod in self.layers: for mod in self.layers:
output = mod(output, output = mod(
output,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache) src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None: if self.norm is not None:
output = self.norm(output) output = self.norm(output)
@ -184,7 +200,8 @@ class TransformerEncoderLayer(nn.Module):
linear2_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm, layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5, layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False, ) -> None: adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead) # print(233333333333,d_model,nhead)
@ -197,14 +214,17 @@ class TransformerEncoderLayer(nn.Module):
batch_first=batch_first, batch_first=batch_first,
linear1_cls=linear1_self_attention_cls, linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls, linear2_cls=linear2_self_attention_cls,
**factory_kwargs, ) **factory_kwargs,
)
# Implementation of Feedforward model # Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, self.linear1 = linear1_feedforward_cls(
**factory_kwargs) d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, self.linear2 = linear2_feedforward_cls(
**factory_kwargs) dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout)
@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module):
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm: if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm( norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
d_model, eps=layer_norm_eps, **factory_kwargs)
else: else:
norm2 = layer_norm_cls( norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm: if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@ -252,7 +270,9 @@ class TransformerEncoderLayer(nn.Module):
self, self,
src: Tensor, src: Tensor,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor: src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
Args: Args:
@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None: if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype _skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point( if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask): src_key_padding_mask
):
raise AssertionError( raise AssertionError(
"only bool and floating types of key_padding_mask are supported" "only bool and floating types of key_padding_mask are supported"
) )
@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module):
x = x + self._sa_block( x = x + self._sa_block(
self.norm1(x, stage_embedding), self.norm1(x, stage_embedding),
src_mask, src_mask,
src_key_padding_mask,cache=cache ) src_key_padding_mask,
cache=cache,
)
x = x + self._ff_block(self.norm2(x, stage_embedding)) x = x + self._ff_block(self.norm2(x, stage_embedding))
else: else:
x = self.norm1( x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding, ) stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding) x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple: if is_src_tuple:
@ -298,7 +322,9 @@ class TransformerEncoderLayer(nn.Module):
self, self,
x: Tensor, x: Tensor,
attn_mask: Optional[Tensor], attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor: key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask) # print(x.shape,attn_mask.shape,key_padding_mask)
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None # torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os # import os
@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module):
x, x,
attn_mask=attn_mask, attn_mask=attn_mask,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=False,cache=cache )[0] need_weights=False,
cache=cache,
)[0]
return self.dropout1(x) return self.dropout1(x)
# feed forward block # feed forward block
@ -334,14 +362,17 @@ class AdaptiveLayerNorm(nn.Module):
weight, bias = torch.split( weight, bias = torch.split(
self.project_layer(embedding), self.project_layer(embedding),
split_size_or_sections=self.d_model, split_size_or_sections=self.d_model,
dim=-1, ) dim=-1,
)
return (weight * self.norm(input) + bias, embedding) return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split( weight, bias = torch.split(
self.project_layer(embedding), self.project_layer(embedding),
split_size_or_sections=self.d_model, split_size_or_sections=self.d_model,
dim=-1, ) dim=-1,
)
return weight * self.norm(input) + bias return weight * self.norm(input) + bias
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

View File

@ -27,46 +27,44 @@ class GruutPhonemizer:
"": "", "": "",
"": "", "": "",
"«": "«", "«": "«",
"»": "»" "»": "»",
} }
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])" self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
def _normalize_punctuation(self, text: str) -> str: def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text) text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text) text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text) text = regex.sub(r"\pZ+", r" ", text)
return text.strip() return text.strip()
def _convert_punctuation(self, word: Word) -> str: def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes: if not word.phonemes:
return '' return ""
if word.phonemes[0] in ['', '|']: if word.phonemes[0] in ["", "|"]:
return word.text.strip() return word.text.strip()
phonemes = ''.join(word.phonemes) phonemes = "".join(word.phonemes)
# remove modifier characters ˈˌː with regex # remove modifier characters ˈˌː with regex
phonemes = re.sub(r'[ˈˌː͡]', '', phonemes) phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
return phonemes.strip() return phonemes.strip()
def phonemize(self, text: str, espeak: bool = False) -> str: def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text) text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [ sents: List[Sentence] = [
sent sent
for sent in self._phonemizer( for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
text_to_phonemize, lang="en-us", espeak=espeak)
] ]
words: List[str] = [ words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents) self._convert_punctuation(word) for word in itertools.chain(*sents)
] ]
return ' '.join(words) return " ".join(words)
def transform(self, phonemes): def transform(self, phonemes):
# convert phonemes to ids # convert phonemes to ids
# dictionary is in symbols.py # dictionary is in symbols.py
return [ return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
self.symbol_to_id[p] for p in phonemes
if p in self.symbol_to_id.keys()
]
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,7 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
PAD = '_' PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” ' PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'" IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS) SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ") SPACE_ID = SYMBOLS.index(" ")

View File

@ -11,22 +11,24 @@ def load_yaml_config(path):
def save_config_to_yaml(config, path): def save_config_to_yaml(config, path):
assert path.endswith('.yaml') assert path.endswith(".yaml")
with open(path, 'w') as f: with open(path, "w") as f:
f.write(yaml.dump(config)) f.write(yaml.dump(config))
f.close() f.close()
def write_args(args, path): def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args) args_dict = dict(
if not name.startswith('_')) (name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
with open(path, 'a') as args_file: )
args_file.write('==> torch version: {}\n'.format(torch.__version__)) with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write( args_file.write(
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) "==> cudnn version: {}\n".format(torch.backends.cudnn.version())
args_file.write('==> Cmd:\n') )
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv)) args_file.write(str(sys.argv))
args_file.write('\n==> args:\n') args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()): for k, v in sorted(args_dict.items()):
args_file.write(' %s: %s\n' % (str(k), str(v))) args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close() args_file.close()

View File

@ -11,23 +11,30 @@ logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import ( from transformers import (
Wav2Vec2FeatureExtractor, Wav2Vec2FeatureExtractor,
HubertModel, HubertModel,
Wav2Vec2Model,
) )
import utils import utils
import torch.nn as nn import torch.nn as nn
cnhubert_base_path = None cnhubert_base_path = None
class CNHubert(nn.Module): class CNHubert(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path) self.model = HubertModel.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path) self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path
)
def forward(self, x): def forward(self, x):
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"] feats = self.model(input_values)["last_hidden_state"]
return feats return feats
# class CNHubertLarge(nn.Module): # class CNHubertLarge(nn.Module):
# def __init__(self): # def __init__(self):
# super().__init__() # super().__init__()
@ -59,12 +66,12 @@ class CNHubert(nn.Module):
# return feats # return feats
def get_model(): def get_model():
model = CNHubert() model = CNHubert()
model.eval() model.eval()
return model return model
# def get_large_model(): # def get_large_model():
# model = CNHubertLarge() # model = CNHubertLarge()
# model.eval() # model.eval()
@ -80,13 +87,14 @@ def get_model():
# model.eval() # model.eval()
# return model # return model
def get_content(hmodel, wav_16k_tensor): def get_content(hmodel, wav_16k_tensor):
with torch.no_grad(): with torch.no_grad():
feats = hmodel(wav_16k_tensor) feats = hmodel(wav_16k_tensor)
return feats.transpose(1, 2) return feats.transpose(1, 2)
if __name__ == '__main__': if __name__ == "__main__":
model = get_model() model = get_model()
src_path = "/Users/Shared/原音频2.wav" src_path = "/Users/Shared/原音频2.wav"
wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000) wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
@ -94,4 +102,3 @@ if __name__ == '__main__':
wav_16k_tensor = wav_16k_tensor wav_16k_tensor = wav_16k_tensor
feats = get_content(model, wav_16k_tensor) feats = get_content(model, wav_16k_tensor)
print(feats.shape) print(feats.shape)

View File

@ -3,13 +3,15 @@ import torch
def get_model(): def get_model():
import whisper import whisper
model = whisper.load_model("small", device='cpu')
model = whisper.load_model("small", device="cpu")
return model.encoder return model.encoder
def get_content(model=None, wav_16k_tensor=None): def get_content(model=None, wav_16k_tensor=None):
from whisper import log_mel_spectrogram, pad_or_trim from whisper import log_mel_spectrogram, pad_or_trim
dev = next(model.parameters()).device dev = next(model.parameters()).device
mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000] mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
# if torch.cuda.is_available(): # if torch.cuda.is_available():
@ -17,6 +19,7 @@ def get_content(model=None, wav_16k_tensor=None):
feature_len = mel.shape[-1] // 2 feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频" assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad(): with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1,2) feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
:1, :feature_len, :
].transpose(1, 2)
return feature return feature

View File

@ -1,34 +1,31 @@
import os import os
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
gpt_path = os.environ.get(
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
)
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth") sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base") cnhubert_base_path = os.environ.get(
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large") "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872) infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui) infer_ttswebui = int(infer_ttswebui)
if("_CUDA_VISIBLE_DEVICES"in os.environ): if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) is_half = eval(os.environ.get("is_half", "True"))
import gradio as gr import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
import sys,torch,numpy as np import numpy as np
from pathlib import Path import librosa,torch
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
# torch.backends.cuda.sdp_kernel("flash")
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
# torch.backends.cuda.enable_math_sdp(True)
from random import shuffle
from AR.utils import get_newest_ckpt
from glob import glob
from tqdm import tqdm
from feature_extractor import cnhubert from feature_extractor import cnhubert
cnhubert.cnhubert_base_path=cnhubert_base_path cnhubert.cnhubert_base_path=cnhubert_base_path
from io import BytesIO
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
from text.cleaner import text_to_sequence, clean_text from text.cleaner import clean_text
from time import time as ttime from time import time as ttime
from module.mel_processing import spectrogram_torch from module.mel_processing import spectrogram_torch
from my_utils import load_audio from my_utils import load_audio
@ -36,8 +33,12 @@ from my_utils import load_audio
device = "cuda" device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if(is_half==True):bert_model=bert_model.half().to(device) if is_half == True:
else:bert_model=bert_model.to(device) bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
# bert_model=bert_model.to(device) # bert_model=bert_model.to(device)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
@ -55,58 +56,95 @@ def get_bert_feature(text, word2ph):
# if(is_half==True):phone_level_feature=phone_level_feature.half() # if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T return phone_level_feature.T
n_semantic = 1024 n_semantic = 1024
dict_s2=torch.load(sovits_path,map_location="cpu") dict_s2=torch.load(sovits_path,map_location="cpu")
hps=dict_s2["config"] hps=dict_s2["config"]
class DictToAttrRecursive:
class DictToAttrRecursive(dict):
def __init__(self, input_dict): def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items(): for key, value in input_dict.items():
if isinstance(value, dict): if isinstance(value, dict):
# 如果值是字典,递归调用构造函数 value = DictToAttrRecursive(value)
setattr(self, key, DictToAttrRecursive(value)) self[key] = value
else:
setattr(self, key, value) setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
hps = DictToAttrRecursive(hps) hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz" hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu") dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"] config = dict_s1["config"]
ssl_model = cnhubert.get_model() ssl_model = cnhubert.get_model()
if(is_half==True):ssl_model=ssl_model.half().to(device) if is_half == True:
else:ssl_model=ssl_model.to(device) ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**hps.model) **hps.model
if(is_half==True):vq_model=vq_model.half().to(device) )
else:vq_model=vq_model.to(device) if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50 hz = 50
max_sec = config['data']['max_sec'] max_sec = config["data"]["max_sec"]
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False) t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"]) t2s_model.load_state_dict(dict_s1["weight"])
if(is_half==True):t2s_model=t2s_model.half() if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device) t2s_model = t2s_model.to(device)
t2s_model.eval() t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()]) total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6)) print("Number of parameter: %.2fM" % (total / 1e6))
def get_spepc(hps, filename): def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate)) audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio) audio = torch.FloatTensor(audio)
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False) spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec return spec
dict_language={
"中文":"zh", dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
"英文":"en",
"日文":"ja"
}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime() t0 = ttime()
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
@ -114,9 +152,15 @@ def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙 wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
wav16k = torch.from_numpy(wav16k) wav16k = torch.from_numpy(wav16k)
if(is_half==True):wav16k=wav16k.half().to(device) if is_half == True:
else:wav16k=wav16k.to(device) wav16k = wav16k.half().to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float() else:
wav16k = wav16k.to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content) codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
t1 = ttime() t1 = ttime()
@ -126,14 +170,24 @@ def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
phones1 = cleaned_text_to_sequence(phones1) phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n") texts = text.split("\n")
audio_opt = [] audio_opt = []
zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32) zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
for text in texts: for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language) phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2) phones2 = cleaned_text_to_sequence(phones2)
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device) if prompt_language == "zh":
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device) bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device) else:
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1) bert1 = torch.zeros(
(1024, len(phones1)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
if text_language == "zh":
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1) bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
@ -149,45 +203,79 @@ def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
prompt, prompt,
bert, bert,
# prompt_phone_len=ph_offset, # prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'], top_k=config["inference"]["top_k"],
early_stop_num=hz * max_sec) early_stop_num=hz * max_sec,
)
t3 = ttime() t3 = ttime()
# print(pred_semantic.shape,idx) # print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device) refer = get_spepc(hps, ref_wav_path) # .to(device)
if(is_half==True):refer=refer.half().to(device) if is_half == True:
else:refer=refer.to(device) refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分 audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav)
t4 = ttime() t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
} # 不考虑省略号
splits={"","","","",",",".","?","!","~",":","","","",}#不考虑省略号
def split(todo_text): def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "") todo_text = todo_text.replace("……", "").replace("——", "")
if (todo_text[-1] not in splits): todo_text += "" if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0 i_split_head = i_split_tail = 0
len_text = len(todo_text) len_text = len(todo_text)
todo_texts = [] todo_texts = []
while (1): while 1:
if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 if i_split_head >= len_text:
if (todo_text[i_split_head] in splits): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1 i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head]) todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head i_split_tail = i_split_head
else: else:
i_split_head += 1 i_split_head += 1
return todo_texts return todo_texts
def cut1(inp): def cut1(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
inps = split(inp) inps = split(inp)
split_idx = list(range(0, len(inps), 5)) split_idx = list(range(0, len(inps), 5))
split_idx[-1] = None split_idx[-1] = None
if(len(split_idx)>1): if len(split_idx) > 1:
opts = [] opts = []
for idx in range(len(split_idx) - 1): for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]])) opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
@ -195,61 +283,64 @@ def cut1(inp):
opts = [inp] opts = [inp]
return "\n".join(opts) return "\n".join(opts)
def cut2(inp): def cut2(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
inps = split(inp) inps = split(inp)
if(len(inps)<2):return [inp] if len(inps) < 2:
return [inp]
opts = [] opts = []
summ = 0 summ = 0
tmp_str = "" tmp_str = ""
for i in range(len(inps)): for i in range(len(inps)):
summ += len(inps[i]) summ += len(inps[i])
tmp_str += inps[i] tmp_str += inps[i]
if(summ>50): if summ > 50:
summ = 0 summ = 0
opts.append(tmp_str) opts.append(tmp_str)
tmp_str = "" tmp_str = ""
if(tmp_str!=""):opts.append(tmp_str) if tmp_str != "":
if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起 opts.append(tmp_str)
if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1] opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1] opts = opts[:-1]
return "\n".join(opts) return "\n".join(opts)
def cut3(inp): def cut3(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")]) return "\n".join(["%s" % item for item in inp.strip("").split("")])
with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown( gr.Markdown(
value= value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
) )
# with gr.Tabs(): # with gr.Tabs():
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")): # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
with gr.Group(): with gr.Group():
gr.Markdown( gr.Markdown(value="*请上传并填写参考信息")
value=
"*请上传并填写参考信息"
)
with gr.Row(): with gr.Row():
inp_ref = gr.Audio(label="请上传参考音频", type="filepath") inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
prompt_text = gr.Textbox(label="参考音频的文本", value="") prompt_text = gr.Textbox(label="参考音频的文本", value="")
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文") prompt_language = gr.Dropdown(
gr.Markdown( label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
value=
"*请填写需要合成的目标文本"
) )
gr.Markdown(value="*请填写需要合成的目标文本")
with gr.Row(): with gr.Row():
text = gr.Textbox(label="需要合成的文本", value="") text = gr.Textbox(label="需要合成的文本", value="")
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文") text_language = gr.Dropdown(
label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
)
inference_button = gr.Button("合成语音", variant="primary") inference_button = gr.Button("合成语音", variant="primary")
output = gr.Audio(label="输出的语音") output = gr.Audio(label="输出的语音")
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output]) inference_button.click(
get_tts_wav,
gr.Markdown( [inp_ref, prompt_text, prompt_language, text, text_language],
value= [output],
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
) )
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
with gr.Row(): with gr.Row():
text_inp = gr.Textbox(label="需要合成的切分前文本", value="") text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
button1 = gr.Button("凑五句一切", variant="primary") button1 = gr.Button("凑五句一切", variant="primary")
@ -259,10 +350,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
button1.click(cut1, [text_inp], [text_opt]) button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt]) button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt]) button3.click(cut3, [text_inp], [text_opt])
gr.Markdown( gr.Markdown(value="后续将支持混合语种编码文本输入。")
value=
"后续将支持混合语种编码文本输入。"
)
app.queue(concurrency_count=511, max_size=1022).launch( app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0", server_name="0.0.0.0",

View File

@ -8,7 +8,18 @@ from module. modules import LayerNorm
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,isflow=False, **kwargs): def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
window_size=4,
isflow=False,
**kwargs
):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
@ -24,15 +35,34 @@ class Encoder(nn.Module):
self.ffn_layers = nn.ModuleList() self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList() self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers): for i in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) self.attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
window_size=window_size,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
if isflow: if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1) cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name='weight') self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"] self.gin_channels = kwargs["gin_channels"]
def forward(self, x, x_mask, g=None): def forward(self, x, x_mask, g=None):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask x = x * x_mask
@ -45,9 +75,8 @@ class Encoder(nn.Module):
cond_offset = i * 2 * self.hidden_channels cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply( x = commons.fused_add_tanh_sigmoid_multiply(
x, x, g_l, torch.IntTensor([self.hidden_channels])
g_l, )
torch.IntTensor([self.hidden_channels]))
y = self.attn_layers[i](x, x, attn_mask) y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_1[i](x + y) x = self.norm_layers_1[i](x + y)
@ -60,7 +89,18 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
@ -79,11 +119,33 @@ class Decoder(nn.Module):
self.ffn_layers = nn.ModuleList() self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList() self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers): for i in range(self.n_layers):
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) self.self_attn_layers.append(
MultiHeadAttention(
hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels)) self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) self.ffn_layers.append(
FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask): def forward(self, x, x_mask, h, h_mask):
@ -91,7 +153,9 @@ class Decoder(nn.Module):
x: decoder input x: decoder input
h: encoder output h: encoder output
""" """
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for i in range(self.n_layers):
@ -111,7 +175,18 @@ class Decoder(nn.Module):
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): def __init__(
self,
channels,
out_channels,
n_heads,
p_dropout=0.0,
window_size=None,
heads_share=True,
block_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__() super().__init__()
assert channels % n_heads == 0 assert channels % n_heads == 0
@ -136,8 +211,14 @@ class MultiHeadAttention(nn.Module):
if window_size is not None: if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5 rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) self.emb_rel_k = nn.Parameter(
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
nn.init.xavier_uniform_(self.conv_q.weight) nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight) nn.init.xavier_uniform_(self.conv_k.weight)
@ -166,28 +247,46 @@ class MultiHeadAttention(nn.Module):
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None: if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention." assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
scores_local = self._relative_position_to_absolute_position(rel_logits) scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local scores = scores + scores_local
if self.proximal_bias: if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention." assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
if mask is not None: if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None: if self.block_length is not None:
assert t_s == t_t, "Local attention is only available for self-attention." assert (
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
scores = scores.masked_fill(block_mask == 0, -1e4) scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn) p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value) output = torch.matmul(p_attn, value)
if self.window_size is not None: if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn) relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) value_relative_embeddings = self._get_relative_embeddings(
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) self.emb_rel_v, t_s
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] )
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn return output, p_attn
def _matmul_with_relative_values(self, x, y): def _matmul_with_relative_values(self, x, y):
@ -217,10 +316,13 @@ class MultiHeadAttention(nn.Module):
if pad_length > 0: if pad_length > 0:
padded_relative_embeddings = F.pad( padded_relative_embeddings = F.pad(
relative_embeddings, relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
else: else:
padded_relative_embeddings = relative_embeddings padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
return used_relative_embeddings return used_relative_embeddings
def _relative_position_to_absolute_position(self, x): def _relative_position_to_absolute_position(self, x):
@ -234,10 +336,14 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1). # Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length]) x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
# Reshape and slice out the padded elements. # Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
return x_final return x_final
def _absolute_position_to_relative_position(self, x): def _absolute_position_to_relative_position(self, x):
@ -247,7 +353,9 @@ class MultiHeadAttention(nn.Module):
""" """
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# padd along column # padd along column
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape # add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
@ -267,7 +375,16 @@ class MultiHeadAttention(nn.Module):
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): def __init__(
self,
in_channels,
out_channels,
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
causal=False,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -329,27 +446,43 @@ class Depthwise_Separable_Conv1D(nn.Module):
padding=0, padding=0,
dilation=1, dilation=1,
bias=True, bias=True,
padding_mode='zeros', # TODO: refine this type padding_mode="zeros", # TODO: refine this type
device=None, device=None,
dtype=None dtype=None,
): ):
super().__init__() super().__init__()
self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, self.depth_conv = nn.Conv1d(
groups=in_channels, stride=stride, padding=padding, dilation=dilation, bias=bias, in_channels=in_channels,
padding_mode=padding_mode, device=device, dtype=dtype) out_channels=in_channels,
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, kernel_size=kernel_size,
device=device, dtype=dtype) groups=in_channels,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
self.point_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input): def forward(self, input):
return self.point_conv(self.depth_conv(input)) return self.point_conv(self.depth_conv(input))
def weight_norm(self): def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name='weight') self.depth_conv = weight_norm(self.depth_conv, name="weight")
self.point_conv = weight_norm(self.point_conv, name='weight') self.point_conv = weight_norm(self.point_conv, name="weight")
def remove_weight_norm(self): def remove_weight_norm(self):
self.depth_conv = remove_weight_norm(self.depth_conv, name='weight') self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
self.point_conv = remove_weight_norm(self.point_conv, name='weight') self.point_conv = remove_weight_norm(self.point_conv, name="weight")
class Depthwise_Separable_TransposeConv1D(nn.Module): class Depthwise_Separable_TransposeConv1D(nn.Module):
@ -363,48 +496,79 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
output_padding=0, output_padding=0,
bias=True, bias=True,
dilation=1, dilation=1,
padding_mode='zeros', # TODO: refine this type padding_mode="zeros", # TODO: refine this type
device=None, device=None,
dtype=None dtype=None,
): ):
super().__init__() super().__init__()
self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, self.depth_conv = nn.ConvTranspose1d(
groups=in_channels, stride=stride, output_padding=output_padding, in_channels=in_channels,
padding=padding, dilation=dilation, bias=bias, padding_mode=padding_mode, out_channels=in_channels,
device=device, dtype=dtype) kernel_size=kernel_size,
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, groups=in_channels,
device=device, dtype=dtype) stride=stride,
output_padding=output_padding,
padding=padding,
dilation=dilation,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
self.point_conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input): def forward(self, input):
return self.point_conv(self.depth_conv(input)) return self.point_conv(self.depth_conv(input))
def weight_norm(self): def weight_norm(self):
self.depth_conv = weight_norm(self.depth_conv, name='weight') self.depth_conv = weight_norm(self.depth_conv, name="weight")
self.point_conv = weight_norm(self.point_conv, name='weight') self.point_conv = weight_norm(self.point_conv, name="weight")
def remove_weight_norm(self): def remove_weight_norm(self):
remove_weight_norm(self.depth_conv, name='weight') remove_weight_norm(self.depth_conv, name="weight")
remove_weight_norm(self.point_conv, name='weight') remove_weight_norm(self.point_conv, name="weight")
def weight_norm_modules(module, name='weight', dim=0): def weight_norm_modules(module, name="weight", dim=0):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
module.weight_norm() module.weight_norm()
return module return module
else: else:
return weight_norm(module, name, dim) return weight_norm(module, name, dim)
def remove_weight_norm_modules(module, name='weight'): def remove_weight_norm_modules(module, name="weight"):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
module.remove_weight_norm() module.remove_weight_norm()
else: else:
remove_weight_norm(module, name) remove_weight_norm(module, name)
class FFT(nn.Module): class FFT(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., def __init__(
proximal_bias=False, proximal_init=True, isflow = False, **kwargs): self,
hidden_channels,
filter_channels,
n_heads,
n_layers=1,
kernel_size=1,
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
isflow=False,
**kwargs
):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
@ -415,9 +579,11 @@ class FFT(nn.Module):
self.proximal_bias = proximal_bias self.proximal_bias = proximal_bias
self.proximal_init = proximal_init self.proximal_init = proximal_init
if isflow: if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1) cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name='weight') self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"] self.gin_channels = kwargs["gin_channels"]
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList() self.self_attn_layers = nn.ModuleList()
@ -426,11 +592,26 @@ class FFT(nn.Module):
self.norm_layers_1 = nn.ModuleList() self.norm_layers_1 = nn.ModuleList()
for i in range(self.n_layers): for i in range(self.n_layers):
self.self_attn_layers.append( self.self_attn_layers.append(
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, MultiHeadAttention(
proximal_init=proximal_init)) hidden_channels,
hidden_channels,
n_heads,
p_dropout=p_dropout,
proximal_bias=proximal_bias,
proximal_init=proximal_init,
)
)
self.norm_layers_0.append(LayerNorm(hidden_channels)) self.norm_layers_0.append(LayerNorm(hidden_channels))
self.ffn_layers.append( self.ffn_layers.append(
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) FFN(
hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
p_dropout=p_dropout,
causal=True,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None): def forward(self, x, x_mask, g=None):
@ -441,7 +622,9 @@ class FFT(nn.Module):
if g is not None: if g is not None:
g = self.cond_layer(g) g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for i in range(self.n_layers):
if g is not None: if g is not None:
@ -449,9 +632,8 @@ class FFT(nn.Module):
cond_offset = i * 2 * self.hidden_channels cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply( x = commons.fused_add_tanh_sigmoid_multiply(
x, x, g_l, torch.IntTensor([self.hidden_channels])
g_l, )
torch.IntTensor([self.hidden_channels]))
y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_0[i](x + y) x = self.norm_layers_0[i](x + y)
@ -463,9 +645,9 @@ class FFT(nn.Module):
return x return x
class TransformerCouplingLayer(nn.Module): class TransformerCouplingLayer(nn.Module):
def __init__(self, def __init__(
self,
channels, channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
@ -475,7 +657,7 @@ class TransformerCouplingLayer(nn.Module):
filter_channels=0, filter_channels=0,
mean_only=False, mean_only=False,
wn_sharing_parameter=None, wn_sharing_parameter=None,
gin_channels = 0 gin_channels=0,
): ):
assert channels % 2 == 0, "channels should be divisible by 2" assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__() super().__init__()
@ -487,7 +669,20 @@ class TransformerCouplingLayer(nn.Module):
self.mean_only = mean_only self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter self.enc = (
Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
isflow=True,
gin_channels=gin_channels,
)
if wn_sharing_parameter is None
else wn_sharing_parameter
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_() self.post.weight.data.zero_()
self.post.bias.data.zero_() self.post.bias.data.zero_()

View File

@ -1,7 +1,5 @@
import math import math
import numpy as np
import torch import torch
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -30,7 +28,9 @@ def intersperse(lst, item):
def kl_divergence(m_p, logs_p, m_q, logs_q): def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)""" """KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5 kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl return kl
@ -64,15 +64,15 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
return ret, ids_str return ret, ids_str
def get_timing_signal_1d( def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float) position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2 num_timescales = channels // 2
log_timescale_increment = ( log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
math.log(float(max_timescale) / float(min_timescale)) / num_timescales - 1
(num_timescales - 1)) )
inv_timescales = min_timescale * torch.exp( inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2]) signal = F.pad(signal, [0, 0, 0, channels % 2])
@ -157,7 +157,7 @@ def clip_grad_value_(parameters, clip_value, norm_type=2):
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
if clip_value is not None: if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value) p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type) total_norm = total_norm ** (1.0 / norm_type)
return total_norm return total_norm

View File

@ -76,9 +76,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
print("kmeans start ... ") print("kmeans start ... ")
for _ in tqdm(range(num_iters)): for _ in tqdm(range(num_iters)):
diffs = rearrange(samples, "n d -> n () d") - rearrange( diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
means, "c d -> () c d"
)
dists = -(diffs**2).sum(dim=-1) dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices buckets = dists.max(dim=-1).indices
@ -110,6 +108,7 @@ class EuclideanCodebook(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch. randomly selected vector from the current batch.
""" """
def __init__( def __init__(
self, self,
dim: int, dim: int,
@ -122,7 +121,9 @@ class EuclideanCodebook(nn.Module):
): ):
super().__init__() super().__init__()
self.decay = decay self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
uniform_init if not kmeans_init else torch.zeros
)
embed = init_fn(codebook_size, dim) embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size self.codebook_size = codebook_size
@ -246,6 +247,7 @@ class VectorQuantization(nn.Module):
randomly selected vector from the current batch. randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss. commitment_weight (float): Weight for commitment loss.
""" """
def __init__( def __init__(
self, self,
dim: int, dim: int,
@ -256,22 +258,31 @@ class VectorQuantization(nn.Module):
kmeans_init: bool = True, kmeans_init: bool = True,
kmeans_iters: int = 50, kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2, threshold_ema_dead_code: int = 2,
commitment_weight: float = 1., commitment_weight: float = 1.0,
): ):
super().__init__() super().__init__()
_codebook_dim: int = default(codebook_dim, dim) _codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) self.project_in = (
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.epsilon = epsilon self.epsilon = epsilon
self.commitment_weight = commitment_weight self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, self._codebook = EuclideanCodebook(
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, dim=_codebook_dim,
decay=decay, epsilon=epsilon, codebook_size=codebook_size,
threshold_ema_dead_code=threshold_ema_dead_code) kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size self.codebook_size = codebook_size
@property @property
@ -316,13 +327,16 @@ class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation. """Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
""" """
def __init__(self, *, num_quantizers, **kwargs): def __init__(self, *, num_quantizers, **kwargs):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)] [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
) )
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None): def forward(
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
quantized_out = 0.0 quantized_out = 0.0
residual = x residual = x
@ -345,7 +359,9 @@ class ResidualVectorQuantization(nn.Module):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized return quantized_out, out_indices, out_losses, out_quantized
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor: def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
residual = x residual = x
all_indices = [] all_indices = []
n_q = n_q or len(self.layers) n_q = n_q or len(self.layers)

View File

@ -16,9 +16,11 @@ import torch
import requests import requests
from scipy.io import wavfile from scipy.io import wavfile
from io import BytesIO from io import BytesIO
# from config import exp_dir # from config import exp_dir
from my_utils import load_audio from my_utils import load_audio
class TextAudioSpeakerLoader(torch.utils.data.Dataset): class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
1) loads audio, speaker_id, text pairs 1) loads audio, speaker_id, text pairs
@ -42,14 +44,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for line in lines: for line in lines:
tmp = line.split("\t") tmp = line.split("\t")
if(len(tmp)!=4):continue if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]] self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text tmp = self.audiopaths_sid_text
leng = len(tmp) leng = len(tmp)
min_num = 100 min_num = 100
if(leng<min_num): if leng < min_num:
self.audiopaths_sid_text = [] self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))): for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp self.audiopaths_sid_text += tmp
@ -74,7 +77,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text): for audiopath in tqdm(self.audiopaths_sid_text):
try: try:
phoneme = self.phoneme_data[audiopath][0] phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ') phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme) phoneme_ids = cleaned_text_to_sequence(phoneme)
except Exception: except Exception:
print(f"{audiopath} not in self.phoneme_data !") print(f"{audiopath} not in self.phoneme_data !")
@ -82,7 +85,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
continue continue
size = os.path.getsize("%s/%s" % (self.path5, audiopath)) size = os.path.getsize("%s/%s" % (self.path5, audiopath))
duration = size / self.sampling_rate / 2 duration = size / self.sampling_rate / 2
if (54 > duration > 0.6 or self.val): if 54 > duration > 0.6 or self.val:
audiopaths_sid_text_new.append([audiopath, phoneme_ids]) audiopaths_sid_text_new.append([audiopath, phoneme_ids])
lengths.append(size // (2 * self.hop_length)) lengths.append(size // (2 * self.hop_length))
else: else:
@ -100,8 +103,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
try: try:
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath)) spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad(): with torch.no_grad():
ssl = torch.load("%s/%s.pt"%(self.path4,audiopath),map_location="cpu") ssl = torch.load(
if(ssl.shape[-1]!=spec.shape[-1]): "%s/%s.pt" % (self.path4, audiopath), map_location="cpu"
)
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False ssl.requires_grad = False
@ -116,12 +121,21 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
return (ssl, spec, wav, text) return (ssl, spec, wav, text)
def get_audio(self, filename): def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768 audio_array = load_audio(
filename, self.sampling_rate
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
# print(filename,audio_array.max(),audio_array.min(),audio_array.mean()) # print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
audio = torch.FloatTensor(audio_array) # /32768 audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,self.sampling_rate, self.hop_length, self.win_length,center=False) spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
return spec, audio_norm return spec, audio_norm
@ -137,7 +151,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
return len(self.audiopaths_sid_text) return len(self.audiopaths_sid_text)
def random_slice(self, ssl, wav, mel): def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1]- wav.shape[-1]//self.hop_length) < 3, ("first", ssl.shape, wav.shape) assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first",
ssl.shape,
wav.shape,
)
len_mel = mel.shape[1] len_mel = mel.shape[1]
if self.val: if self.val:
@ -157,13 +175,21 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
wav2 = wav[:, : sep_point * self.hop_length] wav2 = wav[:, : sep_point * self.hop_length]
mel = mel[:, :sep_point] mel = mel[:, :sep_point]
assert abs(ssl.shape[-1]- wav2.shape[-1]//self.hop_length) < 3, (ssl.shape, wav.shape,wav2.shape, mel.shape, sep_point,self.hop_length, sep_point*self.hop_length, dir) assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape,
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate(): class TextAudioSpeakerCollate:
""" Zero-pads model inputs and targets """Zero-pads model inputs and targets"""
"""
def __init__(self, return_ids=False): def __init__(self, return_ids=False):
self.return_ids = return_ids self.return_ids = return_ids
@ -176,8 +202,8 @@ class TextAudioSpeakerCollate():
""" """
# Right zero-pad all one-hot text sequences to max input length # Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort( _, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]), torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
dim=0, descending=True) )
max_ssl_len = max([x[0].size(2) for x in batch]) max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@ -220,8 +246,16 @@ class TextAudioSpeakerCollate():
text_padded[i, : text.size(0)] = text text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
return (
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
)
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
@ -234,7 +268,15 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
""" """
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): def __init__(
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths self.lengths = dataset.lengths
# print(233333333333333,self.lengths,dir(dataset)) # print(233333333333333,self.lengths,dir(dataset))
@ -263,7 +305,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
for i in range(len(buckets)): for i in range(len(buckets)):
len_bucket = len(buckets[i]) len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size total_batch_size = self.num_replicas * self.batch_size
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem) num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket return buckets, num_samples_per_bucket
@ -289,14 +333,23 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample # subsample
ids_bucket = ids_bucket[self.rank :: self.num_replicas] ids_bucket = ids_bucket[self.rank :: self.num_replicas]
# batching # batching
for j in range(len(ids_bucket) // self.batch_size): for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]] batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch) batches.append(batch)
if self.shuffle: if self.shuffle:

View File

@ -24,7 +24,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
dg = dg.float() dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2) r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2) g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss) loss += r_loss + g_loss
r_losses.append(r_loss.item()) r_losses.append(r_loss.item())
g_losses.append(g_loss.item()) g_losses.append(g_loss.item())
@ -55,14 +55,19 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
z_mask = z_mask.float() z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5 kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask) kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask) l = kl / torch.sum(z_mask)
return l return l
def mle_loss(z, m, logs, logdet, mask): def mle_loss(z, m, logs, logdet, mask):
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term l = torch.sum(logs) + 0.5 * torch.sum(
torch.exp(-2 * logs) * ((z - m) ** 2)
) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes l = l / torch.sum(
torch.ones_like(z) * mask
) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l return l

View File

@ -49,21 +49,37 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.: if torch.min(y) < -1.0:
print('min value is ', torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.: if torch.max(y) > 1.0:
print('max value is ', torch.max(y)) print("max value is ", torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], spec = torch.stft(
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec return spec
@ -71,37 +87,63 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis global mel_basis
dtype_device = str(spec.dtype) + '_' + str(spec.device) dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec) spec = spectral_normalize_torch(spec)
return spec return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): def mel_spectrogram_torch(
if torch.min(y) < -1.: y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
print('min value is ', torch.min(y)) ):
if torch.max(y) > 1.: if torch.min(y) < -1.0:
print('max value is ', torch.max(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window global mel_basis, hann_window
dtype_device = str(y.dtype) + '_' + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + '_' + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], spec = torch.stft(
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

View File

@ -16,8 +16,17 @@ from module.quantize import ResidualVectorQuantizer
from text import symbols from text import symbols
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): def __init__(
self,
in_channels,
filter_channels,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
):
super().__init__() super().__init__()
filter_channels = in_channels # it needs to be removed from future version. filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels self.in_channels = in_channels
@ -31,21 +40,29 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2)) self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows): for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList() self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2)) self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4): for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip()) self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1) self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -66,7 +83,10 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w) h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask) h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q z_q = e_q
for flow in self.post_flows: for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -74,8 +94,13 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1) z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) logdet_tot_q += torch.sum(
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0 logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask) z0, logdet = self.log_flow(z0, x_mask)
@ -84,12 +109,18 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows: for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse) z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # [b] return nll + logq # [b]
else: else:
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows: for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse) z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1) z0, z1 = torch.split(z, [1, 1], 1)
@ -98,7 +129,9 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -108,9 +141,13 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels) self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels) self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1) self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -135,7 +172,8 @@ class DurationPredictor(nn.Module):
class TextEncoder(nn.Module): class TextEncoder(nn.Module):
def __init__(self, def __init__(
self,
out_channels, out_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
@ -143,7 +181,8 @@ class TextEncoder(nn.Module):
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
latent_channels=192): latent_channels=192,
):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -162,15 +201,12 @@ class TextEncoder(nn.Module):
n_heads, n_heads,
n_layers // 2, n_layers // 2,
kernel_size, kernel_size,
p_dropout) p_dropout,
)
self.encoder_text = attentions.Encoder( self.encoder_text = attentions.Encoder(
hidden_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
filter_channels, )
n_heads,
n_layers,
kernel_size,
p_dropout)
self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE() self.mrte = MRTE()
@ -181,18 +217,22 @@ class TextEncoder(nn.Module):
n_heads, n_heads,
n_layers // 2, n_layers // 2,
kernel_size, kernel_size,
p_dropout) p_dropout,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, test=None): def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
y = self.ssl_proj(y * y_mask) * y_mask y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask) y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype) text_mask = torch.unsqueeze(
commons.sequence_mask(text_lengths, text.size(1)), 1
).to(y.dtype)
if test == 1: if test == 1:
text[:, :] = 0 text[:, :] = 0
text = self.text_embedding(text).transpose(1, 2) text = self.text_embedding(text).transpose(1, 2)
@ -209,8 +249,8 @@ class TextEncoder(nn.Module):
x = self.ssl_proj(x) x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x) quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1) return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer,refer_mask, ge):
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask y = self.vq_proj(quantized) * y_mask
@ -224,15 +264,18 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__(self, def __init__(
self,
channels, channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
n_flows=4, n_flows=4,
gin_channels=0): gin_channels=0,
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -245,8 +288,16 @@ class ResidualCouplingBlock(nn.Module):
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
for i in range(n_flows): for i in range(n_flows):
self.flows.append( self.flows.append(
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, modules.ResidualCouplingLayer(
gin_channels=gin_channels, mean_only=True)) channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False): def forward(self, x, x_mask, g=None, reverse=False):
@ -260,14 +311,16 @@ class ResidualCouplingBlock(nn.Module):
class PosteriorEncoder(nn.Module): class PosteriorEncoder(nn.Module):
def __init__(self, def __init__(
self,
in_channels, in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
gin_channels=0): gin_channels=0,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -278,13 +331,21 @@ class PosteriorEncoder(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
if(g!=None): if g != None:
g = g.detach() g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
@ -294,14 +355,16 @@ class PosteriorEncoder(nn.Module):
class WNEncoder(nn.Module): class WNEncoder(nn.Module):
def __init__(self, def __init__(
self,
in_channels, in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
gin_channels=0): gin_channels=0,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -312,11 +375,20 @@ class WNEncoder(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.norm = modules.LayerNorm(out_channels) self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask out = self.proj(x) * x_mask
@ -325,24 +397,45 @@ class WNEncoder(nn.Module):
class Generator(torch.nn.Module): class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, def __init__(
upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) self.conv_pre = Conv1d(
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm( self.ups.append(
ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), weight_norm(
k, u, padding=(k - u) // 2))) ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -373,7 +466,7 @@ class Generator(torch.nn.Module):
return x return x
def remove_weight_norm(self): def remove_weight_norm(self):
print('Removing weight norm...') print("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_weight_norm(l)
for l in self.resblocks: for l in self.resblocks:
@ -386,13 +479,55 @@ class DiscriminatorP(torch.nn.Module):
self.period = period self.period = period
self.use_spectral_norm = use_spectral_norm self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([ self.convs = nn.ModuleList(
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), [
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), norm_f(
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), Conv2d(
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 1,
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 32,
]) (kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x): def forward(self, x):
@ -421,14 +556,16 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__() super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([ self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)), norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]) ]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x): def forward(self, x):
@ -451,7 +588,9 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11] periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs) self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat): def forward(self, y, y_hat):
@ -469,31 +608,40 @@ class MultiPeriodDiscriminator(torch.nn.Module):
return y_d_rs, y_d_gs, fmap_rs, fmap_gs return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class ReferenceEncoder(nn.Module): class ReferenceEncoder(nn.Module):
''' """
inputs --- [N, Ty/r, n_mels*r] mels inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size] outputs --- [N, ref_enc_gru_size]
''' """
def __init__(self, spec_channels, gin_channels=0): def __init__(self, spec_channels, gin_channels=0):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128] ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters) K = len(ref_enc_filters)
filters = [1] + ref_enc_filters filters = [1] + ref_enc_filters
convs = [weight_norm(nn.Conv2d(in_channels=filters[i], convs = [
weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1], out_channels=filters[i + 1],
kernel_size=(3, 3), kernel_size=(3, 3),
stride=(2, 2), stride=(2, 2),
padding=(1, 1))) for i in range(K)] padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs) self.convs = nn.ModuleList(convs)
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(input_size=ref_enc_filters[-1] * out_channels, self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2, hidden_size=256 // 2,
batch_first=True) batch_first=True,
)
self.proj = nn.Linear(128, gin_channels) self.proj = nn.Linear(128, gin_channels)
def forward(self, inputs): def forward(self, inputs):
@ -527,18 +675,26 @@ class Quantizer_module(torch.nn.Module):
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
def forward(self, x): def forward(self, x):
d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) - 2 * torch.matmul(x, self.embedding.weight.T) d = (
torch.sum(x**2, 1, keepdim=True)
+ torch.sum(self.embedding.weight**2, 1)
- 2 * torch.matmul(x, self.embedding.weight.T)
)
min_indicies = torch.argmin(d, 1) min_indicies = torch.argmin(d, 1)
z_q = self.embedding(min_indicies) z_q = self.embedding(min_indicies)
return z_q, min_indicies return z_q, min_indicies
class Quantizer(torch.nn.Module): class Quantizer(torch.nn.Module):
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160): def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
super(Quantizer, self).__init__() super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0 assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList([ self.quantizer_modules = nn.ModuleList(
Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups) [
]) Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
)
self.n_code_groups = n_code_groups self.n_code_groups = n_code_groups
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -555,7 +711,9 @@ class Quantizer(torch.nn.Module):
z_q.append(_z_q) z_q.append(_z_q)
min_indicies.append(_min_indicies) # B * T, min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape) z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
z_q = xin + (z_q - xin).detach() z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2) z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
@ -574,7 +732,8 @@ class Quantizer(torch.nn.Module):
class CodePredictor(nn.Module): class CodePredictor(nn.Module):
def __init__(self, def __init__(
self,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
@ -583,7 +742,7 @@ class CodePredictor(nn.Module):
p_dropout, p_dropout,
n_q=8, n_q=8,
dims=1024, dims=1024,
ssl_dim=768 ssl_dim=768,
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -594,19 +753,18 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels) self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.encoder = attentions.Encoder( self.encoder = attentions.Encoder(
hidden_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
filter_channels, )
n_heads,
n_layers,
kernel_size,
p_dropout)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1) self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q self.n_q = n_q
self.dims = dims self.dims = dims
def forward(self, x, x_mask, refer, codes, infer=False): def forward(self, x, x_mask, refer, codes, infer=False):
x = x.detach() x = x.detach()
x = self.vq_proj(x * x_mask) * x_mask x = self.vq_proj(x * x_mask) * x_mask
@ -614,7 +772,9 @@ class CodePredictor(nn.Module):
x = x + g x = x + g
x = self.encoder(x * x_mask, x_mask) x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3) logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
target = codes[1:].transpose(0, 1) target = codes[1:].transpose(0, 1)
if not infer: if not infer:
logits = logits.reshape(-1, self.dims) logits = logits.reshape(-1, self.dims)
@ -626,22 +786,22 @@ class CodePredictor(nn.Module):
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1) correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item() top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
print('Top-10 Accuracy:', top3_acc, "%") print("Top-10 Accuracy:", top3_acc, "%")
pred_codes = torch.argmax(logits, dim=-1) pred_codes = torch.argmax(logits, dim=-1)
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item() acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
print('Top-1 Accuracy:', acc, "%") print("Top-1 Accuracy:", acc, "%")
return pred_codes.transpose(0, 1) return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module): class SynthesizerTrn(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
""" """
def __init__(self, def __init__(
self,
spec_channels, spec_channels,
segment_size, segment_size,
inter_channels, inter_channels,
@ -662,8 +822,8 @@ class SynthesizerTrn(nn.Module):
use_sdp=True, use_sdp=True,
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
**kwargs): **kwargs
):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
self.inter_channels = inter_channels self.inter_channels = inter_channels
@ -691,28 +851,44 @@ class SynthesizerTrn(nn.Module):
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout) p_dropout,
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, )
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) self.dec = Generator(
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, inter_channels,
gin_channels=gin_channels) resblock,
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) self.ref_enc = modules.MelStyleEncoder(
spec_channels, style_vector_dim=gin_channels
)
ssl_dim = 768 ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"] assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz': if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else: else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer( self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
dimension=ssl_dim,
n_q=1,
bins=1024
)
if freeze_quantizer: if freeze_quantizer:
self.ssl_proj.requires_grad_(False) self.ssl_proj.requires_grad_(False)
self.quantizer.requires_grad_(False) self.quantizer.requires_grad_(False)
@ -721,34 +897,58 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.mrte.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)
def forward(self, ssl, y, y_lengths, text, text_lengths): def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False): with autocast(enabled=False):
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge) z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) z_slice, ids_slice = commons.rand_slice_segments(
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=ge) o = self.dec(z_slice, g=ge)
return o, commit_loss, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized return (
o,
commit_loss,
ids_slice,
y_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
quantized,
)
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5): def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0]) quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge, test=test
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -756,21 +956,26 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge) o = self.dec((z * y_mask)[:, :, :], g=ge)
return o, y_mask, (z, z_p, m_p, logs_p) return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad() @torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5): def decode(self, codes, text, refer, noise_scale=0.5):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)

View File

@ -32,7 +32,15 @@ class LayerNorm(nn.Module):
class ConvReluNorm(nn.Module): class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): def __init__(
self,
in_channels,
hidden_channels,
out_channels,
kernel_size,
n_layers,
p_dropout,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -44,13 +52,22 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList() self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential( self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers - 1): for _ in range(n_layers - 1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) self.conv_layers.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2,
)
)
self.norm_layers.append(LayerNorm(hidden_channels)) self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_() self.proj.weight.data.zero_()
@ -70,7 +87,8 @@ class DDSConv(nn.Module):
""" """
Dialted and Depth-Separable Convolution Dialted and Depth-Separable Convolution
""" """
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -85,9 +103,16 @@ class DDSConv(nn.Module):
for i in range(n_layers): for i in range(n_layers):
dilation = kernel_size**i dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2 padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, self.convs_sep.append(
groups=channels, dilation=dilation, padding=padding nn.Conv1d(
)) channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
)
)
self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels)) self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels)) self.norms_2.append(LayerNorm(channels))
@ -108,11 +133,19 @@ class DDSConv(nn.Module):
class WN(torch.nn.Module): class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): def __init__(
self,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
p_dropout=0,
):
super(WN, self).__init__() super(WN, self).__init__()
assert(kernel_size % 2 == 1) assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.kernel_size = kernel_size, self.kernel_size = (kernel_size,)
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
self.n_layers = n_layers self.n_layers = n_layers
self.gin_channels = gin_channels self.gin_channels = gin_channels
@ -123,15 +156,22 @@ class WN(torch.nn.Module):
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
if gin_channels != 0: if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) cond_layer = torch.nn.Conv1d(
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') gin_channels, 2 * hidden_channels * n_layers, 1
)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers): for i in range(n_layers):
dilation = dilation_rate**i dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2) padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, in_layer = torch.nn.Conv1d(
dilation=dilation, padding=padding) hidden_channels,
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
# last one is not necessary # last one is not necessary
@ -141,7 +181,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs): def forward(self, x, x_mask, g=None, **kwargs):
@ -159,10 +199,7 @@ class WN(torch.nn.Module):
else: else:
g_l = torch.zeros_like(x_in) g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply( acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
x_in,
g_l,
n_channels_tensor)
acts = self.drop(acts) acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts) res_skip_acts = self.res_skip_layers[i](acts)
@ -186,24 +223,76 @@ class WN(torch.nn.Module):
class ResBlock1(torch.nn.Module): class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__() super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList([ self.convs1 = nn.ModuleList(
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], [
padding=get_padding(kernel_size, dilation[0]))), weight_norm(
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], Conv1d(
padding=get_padding(kernel_size, dilation[1]))), channels,
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], channels,
padding=get_padding(kernel_size, dilation[2]))) kernel_size,
]) 1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights) self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([ self.convs2 = nn.ModuleList(
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, [
padding=get_padding(kernel_size, 1))), weight_norm(
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, Conv1d(
padding=get_padding(kernel_size, 1))), channels,
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, channels,
padding=get_padding(kernel_size, 1))) kernel_size,
]) 1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights) self.convs2.apply(init_weights)
def forward(self, x, x_mask=None): def forward(self, x, x_mask=None):
@ -231,12 +320,30 @@ class ResBlock1(torch.nn.Module):
class ResBlock2(torch.nn.Module): class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)): def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__() super(ResBlock2, self).__init__()
self.convs = nn.ModuleList([ self.convs = nn.ModuleList(
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], [
padding=get_padding(kernel_size, dilation[0]))), weight_norm(
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], Conv1d(
padding=get_padding(kernel_size, dilation[1]))) channels,
]) channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights) self.convs.apply(init_weights)
def forward(self, x, x_mask=None): def forward(self, x, x_mask=None):
@ -295,7 +402,8 @@ class ElementwiseAffine(nn.Module):
class ResidualCouplingLayer(nn.Module): class ResidualCouplingLayer(nn.Module):
def __init__(self, def __init__(
self,
channels, channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
@ -303,7 +411,8 @@ class ResidualCouplingLayer(nn.Module):
n_layers, n_layers,
p_dropout=0, p_dropout=0,
gin_channels=0, gin_channels=0,
mean_only=False): mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2" assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -315,7 +424,14 @@ class ResidualCouplingLayer(nn.Module):
self.mean_only = mean_only self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) self.enc = WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=p_dropout,
gin_channels=gin_channels,
)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_() self.post.weight.data.zero_()
self.post.bias.data.zero_() self.post.bias.data.zero_()
@ -343,7 +459,15 @@ class ResidualCouplingLayer(nn.Module):
class ConvFlow(nn.Module): class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): def __init__(
self,
in_channels,
filter_channels,
kernel_size,
n_layers,
num_bins=10,
tail_bound=5.0,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
@ -354,8 +478,10 @@ class ConvFlow(nn.Module):
self.half_channels = in_channels // 2 self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj.weight.data.zero_() self.proj.weight.data.zero_()
self.proj.bias.data.zero_() self.proj.bias.data.zero_()
@ -369,16 +495,19 @@ class ConvFlow(nn.Module):
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_derivatives = h[..., 2 * self.num_bins :] unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(x1, x1, logabsdet = piecewise_rational_quadratic_transform(
x1,
unnormalized_widths, unnormalized_widths,
unnormalized_heights, unnormalized_heights,
unnormalized_derivatives, unnormalized_derivatives,
inverse=reverse, inverse=reverse,
tails='linear', tails="linear",
tail_bound=self.tail_bound tail_bound=self.tail_bound,
) )
x = torch.cat([x0, x1], 1) * x_mask x = torch.cat([x0, x1], 1) * x_mask
@ -389,9 +518,9 @@ class ConvFlow(nn.Module):
return x return x
class LinearNorm(nn.Module): class LinearNorm(nn.Module):
def __init__(self, def __init__(
self,
in_channels, in_channels,
out_channels, out_channels,
bias=True, bias=True,
@ -417,10 +546,10 @@ class Mish(nn.Module):
class Conv1dGLU(nn.Module): class Conv1dGLU(nn.Module):
''' """
Conv1d + GLU(Gated Linear Unit) with residual connection. Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper. For GLU refer to https://arxiv.org/abs/1612.08083 paper.
''' """
def __init__(self, in_channels, out_channels, kernel_size, dropout): def __init__(self, in_channels, out_channels, kernel_size, dropout):
super(Conv1dGLU, self).__init__() super(Conv1dGLU, self).__init__()
@ -438,7 +567,8 @@ class Conv1dGLU(nn.Module):
class ConvNorm(nn.Module): class ConvNorm(nn.Module):
def __init__(self, def __init__(
self,
in_channels, in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
@ -451,16 +581,18 @@ class ConvNorm(nn.Module):
super(ConvNorm, self).__init__() super(ConvNorm, self).__init__()
if padding is None: if padding is None:
assert (kernel_size % 2 == 1) assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2) padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, self.conv = torch.nn.Conv1d(
in_channels,
out_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
bias=bias) bias=bias,
)
if spectral_norm: if spectral_norm:
self.conv = nn.utils.spectral_norm(self.conv) self.conv = nn.utils.spectral_norm(self.conv)
@ -471,9 +603,9 @@ class ConvNorm(nn.Module):
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module ''' """Multi-Head Attention module"""
def __init__(self, n_head, d_model, d_k, d_v, dropout=0., spectral_norm=False): def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
super().__init__() super().__init__()
self.n_head = n_head self.n_head = n_head
@ -484,7 +616,9 @@ class MultiHeadAttention(nn.Module):
self.w_ks = nn.Linear(d_model, n_head * d_k) self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v) self.w_vs = nn.Linear(d_model, n_head * d_v)
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout) self.attention = ScaledDotProductAttention(
temperature=np.power(d_model, 0.5), dropout=dropout
)
self.fc = nn.Linear(n_head * d_v, d_model) self.fc = nn.Linear(n_head * d_v, d_model)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -504,12 +638,9 @@ class MultiHeadAttention(nn.Module):
q = self.w_qs(x).view(sz_b, len_x, n_head, d_k) q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
k = self.w_ks(x).view(sz_b, len_x, n_head, d_k) k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
v = self.w_vs(x).view(sz_b, len_x, n_head, d_v) v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
len_x, d_k) # (n*b) x lq x dk k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
len_x, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1,
len_x, d_v) # (n*b) x lv x dv
if mask is not None: if mask is not None:
slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
@ -518,8 +649,9 @@ class MultiHeadAttention(nn.Module):
output, attn = self.attention(q, k, v, mask=slf_mask) output, attn = self.attention(q, k, v, mask=slf_mask)
output = output.view(n_head, sz_b, len_x, d_v) output = output.view(n_head, sz_b, len_x, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view( output = (
sz_b, len_x, -1) # b x lq x (n*dv) output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
) # b x lq x (n*dv)
output = self.fc(output) output = self.fc(output)
@ -528,7 +660,7 @@ class MultiHeadAttention(nn.Module):
class ScaledDotProductAttention(nn.Module): class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention ''' """Scaled Dot-Product Attention"""
def __init__(self, temperature, dropout): def __init__(self, temperature, dropout):
super().__init__() super().__init__()
@ -551,14 +683,17 @@ class ScaledDotProductAttention(nn.Module):
class MelStyleEncoder(nn.Module): class MelStyleEncoder(nn.Module):
''' MelStyleEncoder ''' """MelStyleEncoder"""
def __init__(self, n_mel_channels=80, def __init__(
self,
n_mel_channels=80,
style_hidden=128, style_hidden=128,
style_vector_dim=256, style_vector_dim=256,
style_kernel_size=5, style_kernel_size=5,
style_head=2, style_head=2,
dropout=0.1): dropout=0.1,
):
super(MelStyleEncoder, self).__init__() super(MelStyleEncoder, self).__init__()
self.in_dim = n_mel_channels self.in_dim = n_mel_channels
self.hidden_dim = style_hidden self.hidden_dim = style_hidden
@ -573,7 +708,7 @@ class MelStyleEncoder(nn.Module):
nn.Dropout(self.dropout), nn.Dropout(self.dropout),
LinearNorm(self.hidden_dim, self.hidden_dim), LinearNorm(self.hidden_dim, self.hidden_dim),
Mish(), Mish(),
nn.Dropout(self.dropout) nn.Dropout(self.dropout),
) )
self.temporal = nn.Sequential( self.temporal = nn.Sequential(
@ -581,9 +716,13 @@ class MelStyleEncoder(nn.Module):
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
) )
self.slf_attn = MultiHeadAttention(self.n_head, self.hidden_dim, self.slf_attn = MultiHeadAttention(
self.hidden_dim // self.n_head, self.hidden_dim // self.n_head, self.n_head,
self.dropout) self.hidden_dim,
self.hidden_dim // self.n_head,
self.hidden_dim // self.n_head,
self.dropout,
)
self.fc = LinearNorm(self.hidden_dim, self.out_dim) self.fc = LinearNorm(self.hidden_dim, self.out_dim)
@ -602,7 +741,9 @@ class MelStyleEncoder(nn.Module):
if mask is not None: if mask is not None:
mask = (mask.int() == 0).squeeze(1) mask = (mask.int() == 0).squeeze(1)
max_len = x.shape[1] max_len = x.shape[1]
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None slf_attn_mask = (
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
)
# spectral # spectral
x = self.spectral(x) x = self.spectral(x)
@ -644,7 +785,9 @@ class MelStyleEncoderVAE(nn.Module):
mu = self.fc1(enc_out) mu = self.fc1(enc_out)
logvar = self.fc2(enc_out) logvar = self.fc2(enc_out)
posterior = D.Normal(mu, torch.exp(logvar)) posterior = D.Normal(mu, torch.exp(logvar))
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))) kl_divergence = D.kl_divergence(
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
)
loss_kl = kl_divergence.mean() loss_kl = kl_divergence.mean()
z = posterior.rsample() z = posterior.rsample()
@ -656,11 +799,12 @@ class MelStyleEncoderVAE(nn.Module):
if manual_latent is None: if manual_latent is None:
if random_sample: if random_sample:
dev = next(self.parameters()).device dev = next(self.parameters()).device
posterior = D.Normal(torch.zeros(1, self.z_latent_dim, device=dev), posterior = D.Normal(
torch.ones(1, self.z_latent_dim, device=dev)) torch.zeros(1, self.z_latent_dim, device=dev),
torch.ones(1, self.z_latent_dim, device=dev),
)
z = posterior.rsample() z = posterior.rsample()
else: else:
enc_out = self.ref_encoder(inputs.transpose(1, 2)) enc_out = self.ref_encoder(inputs.transpose(1, 2))
mu = self.fc1(enc_out) mu = self.fc1(enc_out)
z = mu z = mu
@ -681,7 +825,9 @@ class ActNorm(nn.Module):
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs): def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
if x_mask is None: if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
device=x.device, dtype=x.dtype
)
x_len = torch.sum(x_mask, [1, 2]) x_len = torch.sum(x_mask, [1, 2])
if not self.initialized: if not self.initialized:
self.initialize(x, x_mask) self.initialize(x, x_mask)
@ -710,7 +856,9 @@ class ActNorm(nn.Module):
v = m_sq - (m**2) v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) bias_init = (
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
)
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
self.bias.data.copy_(bias_init) self.bias.data.copy_(bias_init)
@ -720,19 +868,21 @@ class ActNorm(nn.Module):
class InvConvNear(nn.Module): class InvConvNear(nn.Module):
def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs): def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
super().__init__() super().__init__()
assert (n_split % 2 == 0) assert n_split % 2 == 0
self.channels = channels self.channels = channels
self.n_split = n_split self.n_split = n_split
self.no_jacobian = no_jacobian self.no_jacobian = no_jacobian
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] w_init = torch.linalg.qr(
torch.FloatTensor(self.n_split, self.n_split).normal_()
)[0]
if torch.det(w_init) < 0: if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0] w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init) self.weight = nn.Parameter(w_init)
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs): def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
b, c, t = x.size() b, c, t = x.size()
assert (c % self.n_split == 0) assert c % self.n_split == 0
if x_mask is None: if x_mask is None:
x_mask = 1 x_mask = 1
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
@ -740,7 +890,11 @@ class InvConvNear(nn.Module):
x_len = torch.sum(x_mask, [1, 2]) x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t) x = (
x.permute(0, 1, 3, 2, 4)
.contiguous()
.view(b, self.n_split, c // self.n_split, t)
)
if reverse: if reverse:
if hasattr(self, "weight_inv"): if hasattr(self, "weight_inv"):

View File

@ -5,14 +5,16 @@ from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention from module.attentions import MultiHeadAttention
class MRTE(nn.Module): class MRTE(nn.Module):
def __init__(self, def __init__(
self,
content_enc_channels=192, content_enc_channels=192,
hidden_size=512, hidden_size=512,
out_channels=192, out_channels=192,
kernel_size=5, kernel_size=5,
n_heads=4, n_heads=4,
ge_layer = 2 ge_layer=2,
): ):
super(MRTE, self).__init__() super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads) self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
@ -21,30 +23,56 @@ class MRTE(nn.Module):
self.c_post = nn.Conv1d(hidden_size, out_channels, 1) self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None): def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
if(ge==None):ge=0 if ge == None:
ge = 0
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask) ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask) text_enc = self.text_pre(text * text_mask)
if test != None: if test != None:
if test == 0: if test == 0:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
elif test == 1: elif test == 1:
x = ssl_enc + ge x = ssl_enc + ge
elif test == 2: elif test == 2:
x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge x = (
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
else: else:
raise ValueError("test should be 0,1,2") raise ValueError("test should be 0,1,2")
else: else:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask) x = self.c_post(x * ssl_mask)
return x return x
class SpeakerEncoder(torch.nn.Module): class SpeakerEncoder(torch.nn.Module):
def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256): def __init__(
self,
mel_n_channels=80,
model_num_layers=2,
model_hidden_size=256,
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__() super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) self.lstm = nn.LSTM(
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.linear = nn.Linear(model_hidden_size, model_embedding_size) self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU() self.relu = nn.ReLU()
@ -56,13 +84,15 @@ class SpeakerEncoder(torch.nn.Module):
class MELEncoder(nn.Module): class MELEncoder(nn.Module):
def __init__(self, def __init__(
self,
in_channels, in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers): n_layers,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -86,7 +116,7 @@ class MELEncoder(nn.Module):
class WN(torch.nn.Module): class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
super(WN, self).__init__() super(WN, self).__init__()
assert(kernel_size % 2 == 1) assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
@ -98,8 +128,13 @@ class WN(torch.nn.Module):
for i in range(n_layers): for i in range(n_layers):
dilation = dilation_rate**i dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2) padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, in_layer = nn.Conv1d(
dilation=dilation, padding=padding) hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer) in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
@ -110,7 +145,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name='weight') res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def forward(self, x): def forward(self, x):
@ -120,14 +155,12 @@ class WN(torch.nn.Module):
for i in range(self.n_layers): for i in range(self.n_layers):
x_in = self.in_layers[i](x) x_in = self.in_layers[i](x)
acts = fused_add_tanh_sigmoid_multiply( acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
x_in,
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts) res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1: if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :] res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) x = x + res_acts
output = output + res_skip_acts[:, self.hidden_channels :, :] output = output + res_skip_acts[:, self.hidden_channels :, :]
else: else:
output = output + res_skip_acts output = output + res_skip_acts
@ -149,8 +182,7 @@ def fused_add_tanh_sigmoid_multiply(input, n_channels):
return acts return acts
if __name__ == "__main__":
if __name__ == '__main__':
content_enc = torch.randn(3, 192, 100) content_enc = torch.randn(3, 192, 100)
content_mask = torch.ones(3, 1, 100) content_mask = torch.ones(3, 1, 100)
ref_mel = torch.randn(3, 128, 30) ref_mel = torch.randn(3, 128, 30)

View File

@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch. randomly selected vector from the current batch.
""" """
def __init__( def __init__(
self, self,
dimension: int = 256, dimension: int = 256,
@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module):
threshold_ema_dead_code=self.threshold_ema_dead_code, threshold_ema_dead_code=self.threshold_ema_dead_code,
) )
def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult: def forward(
self,
x: torch.Tensor,
n_q: tp.Optional[int] = None,
layers: tp.Optional[list] = None,
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor. """Residual vector quantization on the given input tensor.
Args: Args:
x (torch.Tensor): Input tensor. x (torch.Tensor): Input tensor.
@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module):
""" """
n_q = n_q if n_q else self.n_q n_q = n_q if n_q else self.n_q
if layers and max(layers) >= n_q: if layers and max(layers) >= n_q:
raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.') raise ValueError(
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers) f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
return quantized, codes, torch.mean(commit_loss), quantized_list return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth. """Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer. and returns indices for each quantizer.

View File

@ -9,26 +9,24 @@ DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3 DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(inputs, def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths, unnormalized_widths,
unnormalized_heights, unnormalized_heights,
unnormalized_derivatives, unnormalized_derivatives,
inverse=False, inverse=False,
tails=None, tails=None,
tail_bound=1., tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE): min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None: if tails is None:
spline_fn = rational_quadratic_spline spline_fn = rational_quadratic_spline
spline_kwargs = {} spline_kwargs = {}
else: else:
spline_fn = unconstrained_rational_quadratic_spline spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = { spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
'tails': tails,
'tail_bound': tail_bound
}
outputs, logabsdet = spline_fn( outputs, logabsdet = spline_fn(
inputs=inputs, inputs=inputs,
@ -46,29 +44,28 @@ def piecewise_rational_quadratic_transform(inputs,
def searchsorted(bin_locations, inputs, eps=1e-6): def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps bin_locations[..., -1] += eps
return torch.sum( return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
inputs[..., None] >= bin_locations,
dim=-1
) - 1
def unconstrained_rational_quadratic_spline(inputs, def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths, unnormalized_widths,
unnormalized_heights, unnormalized_heights,
unnormalized_derivatives, unnormalized_derivatives,
inverse=False, inverse=False,
tails='linear', tails="linear",
tail_bound=1., tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE): min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs) outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs) logabsdet = torch.zeros_like(inputs)
if tails == 'linear': if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1) constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., 0] = constant
@ -77,45 +74,57 @@ def unconstrained_rational_quadratic_spline(inputs,
outputs[outside_interval_mask] = inputs[outside_interval_mask] outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0 logabsdet[outside_interval_mask] = 0
else: else:
raise RuntimeError('{} tails are not implemented.'.format(tails)) raise RuntimeError("{} tails are not implemented.".format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( (
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask], inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse, inverse=inverse,
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width, min_bin_width=min_bin_width,
min_bin_height=min_bin_height, min_bin_height=min_bin_height,
min_derivative=min_derivative min_derivative=min_derivative,
) )
return outputs, logabsdet return outputs, logabsdet
def rational_quadratic_spline(inputs,
def rational_quadratic_spline(
inputs,
unnormalized_widths, unnormalized_widths,
unnormalized_heights, unnormalized_heights,
unnormalized_derivatives, unnormalized_derivatives,
inverse=False, inverse=False,
left=0., right=1., bottom=0., top=1., left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE): min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right: if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError('Input to a transform is not within its domain') raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1] num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0: if min_bin_width * num_bins > 1.0:
raise ValueError('Minimal bin width too large for the number of bins') raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0: if min_bin_height * num_bins > 1.0:
raise ValueError('Minimal bin height too large for the number of bins') raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1) widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1) cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left cumwidths[..., 0] = left
cumwidths[..., -1] = right cumwidths[..., -1] = right
@ -126,7 +135,7 @@ def rational_quadratic_spline(inputs,
heights = F.softmax(unnormalized_heights, dim=-1) heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1) cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom cumheights[..., 0] = bottom
cumheights[..., -1] = top cumheights[..., -1] = top
@ -150,14 +159,12 @@ def rational_quadratic_spline(inputs,
input_heights = heights.gather(-1, bin_idx)[..., 0] input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse: if inverse:
a = (((inputs - input_cumheights) * (input_derivatives a = (inputs - input_cumheights) * (
+ input_derivatives_plus_one input_derivatives + input_derivatives_plus_one - 2 * input_delta
- 2 * input_delta) ) + input_heights * (input_delta - input_derivatives)
+ input_heights * (input_delta - input_derivatives))) b = input_heights * input_derivatives - (inputs - input_cumheights) * (
b = (input_heights * input_derivatives input_derivatives + input_derivatives_plus_one - 2 * input_delta
- (inputs - input_cumheights) * (input_derivatives )
+ input_derivatives_plus_one
- 2 * input_delta))
c = -input_delta * (inputs - input_cumheights) c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c discriminant = b.pow(2) - 4 * a * c
@ -167,11 +174,15 @@ def rational_quadratic_spline(inputs,
outputs = root * input_bin_widths + input_cumwidths outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root) theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) denominator = input_delta + (
* theta_one_minus_theta) (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta + 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)) + input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet return outputs, -logabsdet
@ -179,15 +190,20 @@ def rational_quadratic_spline(inputs,
theta = (inputs - input_cumwidths) / input_bin_widths theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta) theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) numerator = input_heights * (
+ input_derivatives * theta_one_minus_theta) input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) )
* theta_one_minus_theta) denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta + 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)) + input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet return outputs, logabsdet

View File

@ -1,14 +1,30 @@
import os, torch, sys import os, torch, sys
from subprocess import Popen from subprocess import Popen
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from config import text_path,wav_dir,n_card,n_process_per_card,exp_name,n_parts,exp_dir from config import (
text_path,
wav_dir,
n_card,
exp_name,
n_parts,
exp_dir,
)
os.makedirs("%s/logs_s1" % exp_dir, exist_ok=True) os.makedirs("%s/logs_s1" % exp_dir, exist_ok=True)
os.makedirs("%s/logs_s2" % exp_dir, exist_ok=True) os.makedirs("%s/logs_s2" % exp_dir, exist_ok=True)
##############step1 ##############step1
ps = [] ps = []
for i_part in range(n_parts): for i_part in range(n_parts):
cmd="python prepare/1-get-text.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card) cmd = "python prepare/1-get-text.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd) print(cmd)
p = Popen(cmd, shell=True) p = Popen(cmd, shell=True)
ps.append(p) ps.append(p)
@ -21,12 +37,20 @@ for i_part in range(n_parts):
with open(txt_path, "r") as f: with open(txt_path, "r") as f:
opt += f.read().strip("\n").split("\n") opt += f.read().strip("\n").split("\n")
os.remove(txt_path) os.remove(txt_path)
with open("%s/2-name2text.txt"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n") with open("%s/2-name2text.txt" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")
############step2 ############step2
ps = [] ps = []
for i_part in range(n_parts): for i_part in range(n_parts):
cmd="python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card) cmd = "python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd) print(cmd)
p = Popen(cmd, shell=True) p = Popen(cmd, shell=True)
ps.append(p) ps.append(p)
@ -35,7 +59,13 @@ for p in ps:
#############step3 #############step3
ps = [] ps = []
for i_part in range(n_parts): for i_part in range(n_parts):
cmd="python prepare/3-get-semantic.py %s %s %s %s %s"%(text_path,exp_name,i_part,n_parts,i_part%n_card) cmd = "python prepare/3-get-semantic.py %s %s %s %s %s" % (
text_path,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd) print(cmd)
p = Popen(cmd, shell=True) p = Popen(cmd, shell=True)
ps.append(p) ps.append(p)
@ -47,4 +77,5 @@ for i_part in range(n_parts):
with open(semantic_path, "r") as f: with open(semantic_path, "r") as f:
opt += f.read().strip("\n").split("\n") opt += f.read().strip("\n").split("\n")
os.remove(semantic_path) os.remove(semantic_path)
with open("%s/6-name2semantic.tsv"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n") with open("%s/6-name2semantic.tsv" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")

View File

@ -31,6 +31,8 @@ import numpy as np
from time import time as ttime from time import time as ttime
import shutil import shutil
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path) dir = os.path.dirname(path)
name = os.path.basename(path) name = os.path.basename(path)
@ -38,18 +40,20 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
torch.save(fea, tmp_path) torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name)) shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
if(os.path.exists(txt_path)==False): if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir) bert_dir = "%s/3-bert" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True) os.makedirs(opt_dir, exist_ok=True)
os.makedirs(bert_dir, exist_ok=True) os.makedirs(bert_dir, exist_ok=True)
device = "cuda:0" device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir) tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if (is_half == True): if is_half == True:
bert_model = bert_model.half().to(device) bert_model = bert_model.half().to(device)
else: else:
bert_model = bert_model.to(device) bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt") inputs = tokenizer(text, return_tensors="pt")
@ -67,13 +71,16 @@ if(os.path.exists(txt_path)==False):
phone_level_feature = torch.cat(phone_level_feature, dim=0) phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T return phone_level_feature.T
def process(data, res): def process(data, res):
for name, text, lan in data: for name, text, lan in data:
try: try:
name = os.path.basename(name) name = os.path.basename(name)
phones, word2ph, norm_text=clean_text(text.replace("%", '-').replace('', ','),lan) phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan
)
path_bert = "%s/%s.pt" % (bert_dir, name) path_bert = "%s/%s.pt" % (bert_dir, name)
if (os.path.exists(path_bert) == False and lan == "zh"): if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph) bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones) assert bert_feature.shape[-1] == len(phones)
# torch.save(bert_feature, path_bert) # torch.save(bert_feature, path_bert)
@ -104,7 +111,9 @@ if(os.path.exists(txt_path)==False):
try: try:
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"]) # todo.append([name,text,"zh"])
todo.append([wav_name,text,language_v1_to_language_v2.get(language,language)]) todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
except: except:
print(line, traceback.format_exc()) print(line, traceback.format_exc())
@ -114,4 +123,3 @@ if(os.path.exists(txt_path)==False):
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text)) opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
with open(txt_path, "w", encoding="utf8") as f: with open(txt_path, "w", encoding="utf8") as f:
f.write("\n".join(opt) + "\n") f.write("\n".join(opt) + "\n")

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys, os import sys, os
inp_text = os.environ.get("inp_text") inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir") inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name") exp_name = os.environ.get("exp_name")
@ -8,6 +9,7 @@ i_part= os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
from feature_extractor import cnhubert from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir") cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
is_half = eval(os.environ.get("is_half", "True")) is_half = eval(os.environ.get("is_half", "True"))
@ -15,6 +17,7 @@ is_half=eval(os.environ.get("is_half","True"))
import pdb, traceback, numpy as np, logging import pdb, traceback, numpy as np, logging
from scipy.io import wavfile from scipy.io import wavfile
import librosa, torch import librosa, torch
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from my_utils import load_audio from my_utils import load_audio
@ -32,6 +35,8 @@ from my_utils import load_audio
from time import time as ttime from time import time as ttime
import shutil import shutil
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path) dir = os.path.dirname(path)
name = os.path.basename(path) name = os.path.basename(path)
@ -39,6 +44,7 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
torch.save(fea, tmp_path) torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name)) shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir = "%s/4-cnhubert" % (opt_dir) hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir) wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True) os.makedirs(opt_dir, exist_ok=True)
@ -49,30 +55,38 @@ maxx=0.95
alpha = 0.5 alpha = 0.5
device = "cuda:0" device = "cuda:0"
model = cnhubert.get_model() model = cnhubert.get_model()
if(is_half==True): if is_half == True:
model = model.half().to(device) model = model.half().to(device)
else: else:
model = model.to(device) model = model.to(device)
def name2go(wav_name): def name2go(wav_name):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name) hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if(os.path.exists(hubert_path)):return if os.path.exists(hubert_path):
return
wav_path = "%s/%s" % (inp_wav_dir, wav_name) wav_path = "%s/%s" % (inp_wav_dir, wav_name)
tmp_audio = load_audio(wav_path, 32000) tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max() tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2: if tmp_max > 2.2:
print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max)) print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
return return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + (
tmp_audio = librosa.resample( (1 - alpha) * 32768
tmp_audio32, orig_sr=32000, target_sr=16000 ) * tmp_audio
) tmp_audio = librosa.resample(tmp_audio32, orig_sr=32000, target_sr=16000)
tensor_wav16 = torch.from_numpy(tmp_audio) tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True): if is_half == True:
tensor_wav16 = tensor_wav16.half().to(device) tensor_wav16 = tensor_wav16.half().to(device)
else: else:
tensor_wav16 = tensor_wav16.to(device) tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) ssl = (
if np.isnan(ssl.detach().numpy()).sum()!= 0:return model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"]
.transpose(1, 2)
.cpu()
) # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
return
wavfile.write( wavfile.write(
"%s/%s" % (wav32dir, wav_name), "%s/%s" % (wav32dir, wav_name),
32000, 32000,
@ -81,6 +95,7 @@ def name2go(wav_name):
# torch.save(ssl,hubert_path ) # torch.save(ssl,hubert_path )
my_save(ssl, hubert_path) my_save(ssl, hubert_path)
with open(inp_text, "r", encoding="utf8") as f: with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")

View File

@ -1,4 +1,5 @@
import os import os
inp_text = os.environ.get("inp_text") inp_text = os.environ.get("inp_text")
exp_name = os.environ.get("exp_name") exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part") i_part = os.environ.get("i_part")
@ -11,6 +12,7 @@ is_half=eval(os.environ.get("is_half","True"))
import math, traceback import math, traceback
import multiprocessing import multiprocessing
import sys, pdb import sys, pdb
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from random import shuffle from random import shuffle
@ -19,6 +21,7 @@ from glob import glob
from tqdm import tqdm from tqdm import tqdm
import logging, librosa, utils, torch import logging, librosa, utils, torch
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G # from config import pretrained_s2G
@ -32,7 +35,7 @@ logging.getLogger("numba").setLevel(logging.WARNING)
hubert_dir = "%s/4-cnhubert" % (opt_dir) hubert_dir = "%s/4-cnhubert" % (opt_dir)
semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part) semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
if(os.path.exists(semantic_path)==False): if os.path.exists(semantic_path) == False:
os.makedirs(opt_dir, exist_ok=True) os.makedirs(opt_dir, exist_ok=True)
device = "cuda:0" device = "cuda:0"
@ -41,21 +44,27 @@ if(os.path.exists(semantic_path)==False):
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**hps.model) **hps.model
if(is_half==True): )
if is_half == True:
vq_model = vq_model.half().to(device) vq_model = vq_model.half().to(device)
else: else:
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
# utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True) # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True) # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
print(vq_model.load_state_dict(torch.load(pretrained_s2G,map_location="cpu")["weight"], strict=False)) print(
vq_model.load_state_dict(
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
)
)
def name2go(wav_name, lines): def name2go(wav_name, lines):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name) hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if(os.path.exists(hubert_path)==False):return if os.path.exists(hubert_path) == False:
return
ssl_content = torch.load(hubert_path, map_location="cpu") ssl_content = torch.load(hubert_path, map_location="cpu")
if(is_half==True): if is_half == True:
ssl_content = ssl_content.half().to(device) ssl_content = ssl_content.half().to(device)
else: else:
ssl_content = ssl_content.to(device) ssl_content = ssl_content.to(device)
@ -77,5 +86,5 @@ if(os.path.exists(semantic_path)==False):
name2go(wav_name, lines1) name2go(wav_name, lines1)
except: except:
print(line, traceback.format_exc()) print(line, traceback.format_exc())
with open(semantic_path,"w",encoding="utf8")as f:f.write("\n".join(lines1)) with open(semantic_path, "w", encoding="utf8") as f:
f.write("\n".join(lines1))

View File

@ -1,11 +1,12 @@
import os
import sys
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
import torch import torch
from i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
def savee(ckpt, name, epoch, steps, hps): def savee(ckpt, name, epoch, steps, hps):
try: try:
opt = OrderedDict() opt = OrderedDict()

View File

@ -2,7 +2,7 @@
import os import os
import pdb import pdb
if("_CUDA_VISIBLE_DEVICES"in os.environ): if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse import argparse
import logging import logging
@ -17,14 +17,25 @@ from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config from AR.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high') logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt from AR.utils import get_newest_ckpt
from collections import OrderedDict from collections import OrderedDict
class my_model_ckpt(ModelCheckpoint): class my_model_ckpt(ModelCheckpoint):
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs): def __init__(
self,
config,
if_save_latest,
if_save_every_weights,
half_weights_save_dir,
exp_name,
**kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.if_save_latest = if_save_latest self.if_save_latest = if_save_latest
self.if_save_every_weights = if_save_every_weights self.if_save_every_weights = if_save_every_weights
@ -33,25 +44,42 @@ class my_model_ckpt(ModelCheckpoint):
self.config = config self.config = config
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): if not self._should_skip_saving_checkpoint(
trainer
) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer) monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: if (
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt self._every_n_epochs >= 1
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if (
self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath)) to_clean = list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates) self._save_topk_checkpoint(trainer, monitor_candidates)
if (self.if_save_latest == True): if self.if_save_latest == True:
for name in to_clean: for name in to_clean:
try: try:
os.remove("%s/%s" % (self.dirpath, name)) os.remove("%s/%s" % (self.dirpath, name))
except:pass except:
if(self.if_save_every_weights==True): pass
if self.if_save_every_weights == True:
to_save_od = OrderedDict() to_save_od = OrderedDict()
to_save_od["weight"] = OrderedDict() to_save_od["weight"] = OrderedDict()
dictt = trainer.strategy._lightning_module.state_dict() dictt = trainer.strategy._lightning_module.state_dict()
for key in dictt:to_save_od["weight"][key]=dictt[key].half() for key in dictt:
to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1)) torch.save(
to_save_od,
"%s/%s-e%s.ckpt"
% (
self.half_weights_save_dir,
self.exp_name,
trainer.current_epoch + 1,
),
)
self._save_last_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates)
@ -61,41 +89,45 @@ def main(args):
output_dir = Path(config["output_dir"]) output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt' ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True) ckpt_dir.mkdir(parents=True, exist_ok=True)
seed_everything(config["train"]["seed"], workers=True) seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = my_model_ckpt( ckpt_callback: ModelCheckpoint = my_model_ckpt(
config=config, config=config,
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"], if_save_latest=config["train"]["if_save_latest"],
if_save_every_weights=config["train"]["if_save_every_weights"],
half_weights_save_dir=config["train"]["half_weights_save_dir"],
exp_name=config["train"]["exp_name"],
save_top_k=-1, save_top_k=-1,
monitor='top_3_acc', monitor="top_3_acc",
mode='max', mode="max",
save_on_train_epoch_end=True, save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"], every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir, dirpath=ckpt_dir,
) )
logger = TensorBoardLogger( logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
name=output_dir.stem,
save_dir=output_dir
)
trainer: Trainer = Trainer( trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"], max_epochs=config["train"]["epochs"],
accelerator='gpu', accelerator="gpu",
# val_check_interval=9999999999999999999999,###不要验证 # val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None, # check_val_every_n_epoch=None,
limit_val_batches=0, limit_val_batches=0,
devices=-1, devices=-1,
benchmark=False, benchmark=False,
fast_dev_run=False, fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"), strategy=DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
),
precision=config["train"]["precision"], precision=config["train"]["precision"],
logger=logger,num_sanity_val_steps=0, logger=logger,
callbacks=[ckpt_callback]) num_sanity_val_steps=0,
callbacks=[ckpt_callback],
)
model: Text2SemanticLightningModule = Text2SemanticLightningModule( model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir) config, output_dir
)
data_module: Text2SemanticDataModule = Text2SemanticDataModule( data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config, config,
@ -116,14 +148,15 @@ def main(args):
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'-c', "-c",
'--config_file', "--config_file",
type=str, type=str,
default='configs/s1longer.yaml', default="configs/s1longer.yaml",
help='path of config file') help="path of config file",
)
# args for dataset # args for dataset
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv') # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt') # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')

View File

@ -1,4 +1,5 @@
import utils, os import utils, os
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch import torch
@ -11,6 +12,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm from tqdm import tqdm
import logging, traceback import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO)
@ -20,37 +22,42 @@ from module import commons
from module.data_utils import ( from module.data_utils import (
TextAudioSpeakerLoader, TextAudioSpeakerLoader,
TextAudioSpeakerCollate, TextAudioSpeakerCollate,
DistributedBucketSampler DistributedBucketSampler,
) )
from module.models import ( from module.models import (
SynthesizerTrn, SynthesizerTrn,
MultiPeriodDiscriminator, MultiPeriodDiscriminator,
) )
from module.losses import ( from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
generator_loss,
discriminator_loss,
feature_loss,
kl_loss
)
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee from process_ckpt import savee
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
def main(): def main():
"""Assume Single Node Multi GPUs Training Only""" """Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed." assert torch.cuda.is_available(), "CPU training is not allowed."
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
os.environ['MASTER_ADDR'] = 'localhost' os.environ["MASTER_ADDR"] = "localhost"
os.environ['MASTER_PORT'] = str(randint(20000, 55555)) os.environ["MASTER_PORT"] = str(randint(20000, 55555))
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
def run(rank, n_gpus, hps): def run(rank, n_gpus, hps):
@ -62,7 +69,12 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank) dist.init_process_group(
backend="gloo" if os.name == "nt" else "nccl",
init_method="env://",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed) torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
@ -70,13 +82,41 @@ def run(rank, n_gpus, hps):
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
train_dataset, train_dataset,
hps.train.batch_size, hps.train.batch_size,
[32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900], [
32,
300,
400,
500,
600,
700,
800,
900,
1000,
1100,
1200,
1300,
1400,
1500,
1600,
1700,
1800,
1900,
],
num_replicas=n_gpus, num_replicas=n_gpus,
rank=rank, rank=rank,
shuffle=True) shuffle=True,
)
collate_fn = TextAudioSpeakerCollate() collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True, train_loader = DataLoader(
collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16) train_dataset,
num_workers=6,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=16,
)
# if rank == 0: # if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
@ -87,7 +127,8 @@ def run(rank, n_gpus, hps):
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**hps.model).cuda(rank) **hps.model,
).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
@ -97,7 +138,10 @@ def run(rank, n_gpus, hps):
te_p = list(map(id, net_g.enc_p.text_embedding.parameters())) te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
et_p = list(map(id, net_g.enc_p.encoder_text.parameters())) et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
mrte_p = list(map(id, net_g.enc_p.mrte.parameters())) mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters()) base_params = filter(
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
net_g.parameters(),
)
# te_p=net_g.enc_p.text_embedding.parameters() # te_p=net_g.enc_p.text_embedding.parameters()
# et_p=net_g.enc_p.encoder_text.parameters() # et_p=net_g.enc_p.encoder_text.parameters()
@ -107,30 +151,45 @@ def run(rank, n_gpus, hps):
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致 # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
[ [
{"params": base_params, "lr": hps.train.learning_rate}, {"params": base_params, "lr": hps.train.learning_rate},
{"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, {
{"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, "params": net_g.enc_p.text_embedding.parameters(),
{"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate}, "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.encoder_text.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.mrte.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
], ],
hps.train.learning_rate, hps.train.learning_rate,
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps) eps=hps.train.eps,
)
optim_d = torch.optim.AdamW( optim_d = torch.optim.AdamW(
net_d.parameters(), net_d.parameters(),
hps.train.learning_rate, hps.train.learning_rate,
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps) eps=hps.train.eps,
)
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
try: # 如果能加载自动resume try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
net_d,
optim_d,
) # D多半加载没事 ) # D多半加载没事
if rank == 0: if rank == 0:
logger.info("loaded D") logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint( _, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
net_g,
optim_g,
) )
global_step = (epoch_str - 1) * len(train_loader) global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1 # epoch_str = 1
@ -144,7 +203,8 @@ def run(rank, n_gpus, hps):
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print( print(
net_g.module.load_state_dict( net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
) )
) ##测试不加载优化器 ) ##测试不加载优化器
if hps.train.pretrained_s2D != "": if hps.train.pretrained_s2D != "":
@ -159,8 +219,12 @@ def run(rank, n_gpus, hps):
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1) optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
)
for _ in range(epoch_str): for _ in range(epoch_str):
scheduler_g.step() scheduler_g.step()
scheduler_d.step() scheduler_d.step()
@ -169,17 +233,39 @@ def run(rank, n_gpus, hps):
for epoch in range(epoch_str, hps.train.epochs + 1): for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0: if rank == 0:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval]) # [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None], logger, [writer, writer_eval]) [train_loader, None],
logger,
[writer, writer_eval],
)
else: else:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, train_and_evaluate(
[train_loader, None], None, None) rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
)
scheduler_g.step() scheduler_g.step()
scheduler_d.step() scheduler_d.step()
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets net_g, net_d = nets
optim_g, optim_d = optims optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers # scheduler_g, scheduler_d = schedulers
@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train() net_g.train()
net_d.train() net_d.train()
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)): for batch_idx, (
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) ssl,
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in tqdm(enumerate(train_loader)):
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
)
ssl = ssl.cuda(rank, non_blocking=True) ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True) text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
)
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
y_hat, kl_ssl, ids_slice, x_mask, z_mask, \ (
(z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths) y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch( mel = spec_to_mel_torch(
spec, spec,
@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.n_mel_channels, hps.data.n_mel_channels,
hps.data.sampling_rate, hps.data.sampling_rate,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax) hps.data.mel_fmax,
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) )
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_hat_mel = mel_spectrogram_torch( y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1), y_hat.squeeze(1),
hps.data.filter_length, hps.data.filter_length,
@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.hop_length, hps.data.hop_length,
hps.data.win_length, hps.data.win_length,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax hps.data.mel_fmax,
) )
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice y = commons.slice_segments(
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator # Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False): with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
loss_disc_all = loss_disc loss_disc_all = loss_disc
optim_d.zero_grad() optim_d.zero_grad()
scaler.scale(loss_disc_all).backward() scaler.scale(loss_disc_all).backward()
@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if rank == 0: if rank == 0:
if global_step % hps.train.log_interval == 0: if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr'] lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info('Train Epoch: {} [{:.0f}%]'.format( logger.info(
epoch, "Train Epoch: {} [{:.0f}%]".format(
100. * batch_idx / len(train_loader))) epoch, 100.0 * batch_idx / len(train_loader)
)
)
logger.info([x.item() for x in losses] + [global_step, lr]) logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, scalar_dict = {
"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} "loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update( scalar_dict.update(
{"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl}) {
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl_ssl": kl_ssl,
"loss/g/kl": loss_kl,
}
)
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = { image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_org": utils.plot_spectrogram_to_numpy(
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), y_mel[0].data.cpu().numpy()
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), ),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()), "slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy()
),
} }
utils.summarize( utils.summarize(
writer=writer, writer=writer,
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,
scalars=scalar_dict) scalars=scalar_dict,
)
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0: if hps.train.if_save_latest == 0:
@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
),
) )
utils.save_checkpoint( utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
),
) )
else: else:
utils.save_checkpoint( utils.save_checkpoint(
@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g, optim_g,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
),
) )
utils.save_checkpoint( utils.save_checkpoint(
net_d, net_d,
optim_d, optim_d,
hps.train.learning_rate, hps.train.learning_rate,
epoch, epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)), os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
),
) )
if rank == 0 and hps.train.if_save_every_weights == True: if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"): if hasattr(net_g, "module"):
@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
) )
) )
if rank == 0: if rank == 0:
logger.info('====> Epoch: {}'.format(epoch)) logger.info("====> Epoch: {}".format(epoch))
def evaluate(hps, generator, eval_loader, writer_eval): def evaluate(hps, generator, eval_loader, writer_eval):
@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
audio_dict = {} audio_dict = {}
print("Evaluating ...") print("Evaluating ...")
with torch.no_grad(): with torch.no_grad():
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader): for batch_idx, (
ssl,
ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in enumerate(eval_loader):
print(111) print(111)
spec, spec_lengths = spec.cuda(), spec_lengths.cuda() spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda() y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda() ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda() text, text_lengths = text.cuda(), text_lengths.cuda()
for test in [0, 1]: for test in [0, 1]:
y_hat, mask, *_ = generator.module.infer(
y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test) ssl, spec, spec_lengths, text, text_lengths, test=test
)
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel_torch( mel = spec_to_mel_torch(
@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.n_mel_channels, hps.data.n_mel_channels,
hps.data.sampling_rate, hps.data.sampling_rate,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax) hps.data.mel_fmax,
)
y_hat_mel = mel_spectrogram_torch( y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(), y_hat.squeeze(1).float(),
hps.data.filter_length, hps.data.filter_length,
@ -373,15 +526,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.hop_length, hps.data.hop_length,
hps.data.win_length, hps.data.win_length,
hps.data.mel_fmin, hps.data.mel_fmin,
hps.data.mel_fmax hps.data.mel_fmax,
)
image_dict.update(
{
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy()
)
}
)
audio_dict.update(
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
)
image_dict.update(
{
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy()
)
}
) )
image_dict.update({
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
})
audio_dict.update({
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]]
})
image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None) # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
global_step=global_step, global_step=global_step,
images=image_dict, images=image_dict,
audios=audio_dict, audios=audio_dict,
audio_sampling_rate=hps.data.sampling_rate audio_sampling_rate=hps.data.sampling_rate,
) )
generator.train() generator.train()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -6,49 +6,56 @@ import cn2an
from pypinyin import lazy_pinyin, Style from pypinyin import lazy_pinyin, Style
import sys import sys
sys.path.append("/data/docker/liujing04/gpt-vits/gpt-vits-master") sys.path.append("/data/docker/liujing04/gpt-vits/gpt-vits-master")
from text.symbols import punctuation from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi from text.tone_sandhi import ToneSandhi
current_file_path = os.path.dirname(__file__) current_file_path = os.path.dirname(__file__)
pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in pinyin_to_symbol_map = {
open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()} line.split("\t")[0]: line.strip().split("\t")[1]
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba.posseg as psg import jieba.posseg as psg
rep_map = { rep_map = {
'': ',', "": ",",
'': ',', "": ",",
'': ',', "": ",",
'': '.', "": ".",
'': '!', "": "!",
'': '?', "": "?",
'\n': '.', "\n": ".",
"·": ",", "·": ",",
'': ",", "": ",",
'...': '', "...": "",
'$': '.', "$": ".",
'/': ',', "/": ",",
'': "-" "": "-",
} }
tone_modifier = ToneSandhi() tone_modifier = ToneSandhi()
def replace_punctuation(text): def replace_punctuation(text):
text = text.replace("", "").replace("", "") text = text.replace("", "").replace("", "")
pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys())) pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text) replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text return replaced_text
def g2p(text): def g2p(text):
pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation)) pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip()!=''] sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
phones, word2ph = _g2p(sentences) phones, word2ph = _g2p(sentences)
return phones, word2ph return phones, word2ph
@ -56,10 +63,10 @@ def g2p(text):
def _get_initials_finals(word): def _get_initials_finals(word):
initials = [] initials = []
finals = [] finals = []
orig_initials = lazy_pinyin( orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin( orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals): for c, v in zip(orig_initials, orig_finals):
initials.append(c) initials.append(c)
finals.append(v) finals.append(v)
@ -72,17 +79,16 @@ def _g2p(segments):
for seg in segments: for seg in segments:
pinyins = [] pinyins = []
# Replace all English words in the sentence # Replace all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg) seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg) seg_cut = psg.lcut(seg)
initials = [] initials = []
finals = [] finals = []
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
for word, pos in seg_cut: for word, pos in seg_cut:
if pos == 'eng': if pos == "eng":
continue continue
sub_initials, sub_finals = _get_initials_finals(word) sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
sub_finals)
initials.append(sub_initials) initials.append(sub_initials)
finals.append(sub_finals) finals.append(sub_finals)
@ -103,39 +109,39 @@ def _g2p(segments):
tone = v[-1] tone = v[-1]
pinyin = c + v_without_tone pinyin = c + v_without_tone
assert tone in '12345' assert tone in "12345"
if c: if c:
# 多音节 # 多音节
v_rep_map = { v_rep_map = {
"uei": 'ui', "uei": "ui",
'iou': 'iu', "iou": "iu",
'uen': 'un', "uen": "un",
} }
if v_without_tone in v_rep_map.keys(): if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone] pinyin = c + v_rep_map[v_without_tone]
else: else:
# 单音节 # 单音节
pinyin_rep_map = { pinyin_rep_map = {
'ing': 'ying', "ing": "ying",
'i': 'yi', "i": "yi",
'in': 'yin', "in": "yin",
'u': 'wu', "u": "wu",
} }
if pinyin in pinyin_rep_map.keys(): if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin] pinyin = pinyin_rep_map[pinyin]
else: else:
single_rep_map = { single_rep_map = {
'v': 'yu', "v": "yu",
'e': 'e', "e": "e",
'i': 'y', "i": "y",
'u': 'w', "u": "w",
} }
if pinyin[0] in single_rep_map.keys(): if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:] pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(' ') new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone new_v = new_v + tone
phone = [new_c, new_v] phone = [new_c, new_v]
word2ph.append(len(phone)) word2ph.append(len(phone))
@ -144,9 +150,8 @@ def _g2p(segments):
return phones_list, word2ph return phones_list, word2ph
def text_normalize(text): def text_normalize(text):
numbers = re.findall(r'\d+(?:\.?\d+)?', text) numbers = re.findall(r"\d+(?:\.?\d+)?", text)
for number in numbers: for number in numbers:
text = text.replace(number, cn2an.an2cn(number), 1) text = text.replace(number, cn2an.an2cn(number), 1)
text = replace_punctuation(text) text = replace_punctuation(text)
@ -154,7 +159,7 @@ def text_normalize(text):
return text return text
if __name__ == '__main__': if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?" text = "呣呣呣~就是…大人的鼹鼠党吧?"
text = "你好" text = "你好"

View File

@ -1,23 +1,21 @@
from text import chinese, japanese, cleaned_text_to_sequence, symbols, english from text import chinese, japanese, cleaned_text_to_sequence, symbols, english
language_module_map = { language_module_map = {"zh": chinese, "ja": japanese, "en": english}
'zh': chinese,
"ja": japanese,
'en': english
}
special = [ special = [
('%', 'zh', "SP"), ("%", "zh", "SP"),
('', 'zh', "SP2"), ("", "zh", "SP2"),
('^', 'zh', "SP3"), ("^", "zh", "SP3"),
# ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧 # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
] ]
def clean_text(text, language): def clean_text(text, language):
for special_s, special_l, target_symbol in special: for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l: if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol) return clean_special(text, language, special_s, target_symbol)
language_module = language_module_map[language] language_module = language_module_map[language]
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
if(language=="zh"): if language == "zh":
phones, word2ph = language_module.g2p(norm_text) phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph) assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph) assert len(norm_text) == len(word2ph)
@ -41,17 +39,17 @@ def clean_special(text, language, special_s, target_symbol):
new_ph = [] new_ph = []
for ph in phones: for ph in phones:
assert ph in symbols assert ph in symbols
if ph == ',': if ph == ",":
new_ph.append(target_symbol) new_ph.append(target_symbol)
else: else:
new_ph.append(ph) new_ph.append(ph)
return new_ph return new_ph
def text_to_sequence(text, language): def text_to_sequence(text, language):
phones = clean_text(text) phones = clean_text(text)
return cleaned_text_to_sequence(phones) return cleaned_text_to_sequence(phones)
if __name__ == '__main__':
print(clean_text("你好%啊啊啊额、还是到付红四方。", 'zh'))
if __name__ == "__main__":
print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh"))

View File

@ -8,20 +8,87 @@ from string import punctuation
from text import symbols from text import symbols
current_file_path = os.path.dirname(__file__) current_file_path = os.path.dirname(__file__)
CMU_DICT_PATH = os.path.join(current_file_path, 'cmudict.rep') CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
CACHE_PATH = os.path.join(current_file_path, 'cmudict_cache.pickle') CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
_g2p = G2p() _g2p = G2p()
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
def replace_phs(phs): def replace_phs(phs):
rep_map = { rep_map = {";": ",", ":": ",", "'": "-", '"': "-"}
';': ',',
':': ',',
'\'': '-',
'"': '-'
}
phs_new = [] phs_new = []
for ph in phs: for ph in phs:
if ph in symbols: if ph in symbols:
@ -29,9 +96,10 @@ def replace_phs(phs):
elif ph in rep_map.keys(): elif ph in rep_map.keys():
phs_new.append(rep_map[ph]) phs_new.append(rep_map[ph])
else: else:
print('ph not in symbols: ', ph) print("ph not in symbols: ", ph)
return phs_new return phs_new
def read_dict(): def read_dict():
g2p_dict = {} g2p_dict = {}
start_line = 49 start_line = 49
@ -41,13 +109,13 @@ def read_dict():
while line: while line:
if line_index >= start_line: if line_index >= start_line:
line = line.strip() line = line.strip()
word_split = line.split(' ') word_split = line.split(" ")
word = word_split[0] word = word_split[0]
syllable_split = word_split[1].split(' - ') syllable_split = word_split[1].split(" - ")
g2p_dict[word] = [] g2p_dict[word] = []
for syllable in syllable_split: for syllable in syllable_split:
phone_split = syllable.split(' ') phone_split = syllable.split(" ")
g2p_dict[word].append(phone_split) g2p_dict[word].append(phone_split)
line_index = line_index + 1 line_index = line_index + 1
@ -57,13 +125,13 @@ def read_dict():
def cache_dict(g2p_dict, file_path): def cache_dict(g2p_dict, file_path):
with open(file_path, 'wb') as pickle_file: with open(file_path, "wb") as pickle_file:
pickle.dump(g2p_dict, pickle_file) pickle.dump(g2p_dict, pickle_file)
def get_dict(): def get_dict():
if os.path.exists(CACHE_PATH): if os.path.exists(CACHE_PATH):
with open(CACHE_PATH, 'rb') as pickle_file: with open(CACHE_PATH, "rb") as pickle_file:
g2p_dict = pickle.load(pickle_file) g2p_dict = pickle.load(pickle_file)
else: else:
g2p_dict = read_dict() g2p_dict = read_dict()
@ -71,6 +139,7 @@ def get_dict():
return g2p_dict return g2p_dict
eng_dict = get_dict() eng_dict = get_dict()
@ -78,8 +147,8 @@ def text_normalize(text):
# todo: eng text normalize # todo: eng text normalize
return text.replace(";", ",") return text.replace(";", ",")
def g2p(text):
def g2p(text):
phones = [] phones = []
words = re.split(r"([,;.\-\?\!\s+])", text) words = re.split(r"([,;.\-\?\!\s+])", text)
for w in words: for w in words:
@ -97,6 +166,7 @@ def g2p(text):
return replace_phs(phones) return replace_phs(phones)
if __name__ == "__main__": if __name__ == "__main__":
# print(get_dict()) # print(get_dict())
print(g2p("hello")) print(g2p("hello"))

View File

@ -8,57 +8,63 @@ from text import symbols
# Regular expression matching Japanese without punctuation marks: # Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile( _japanese_characters = re.compile(
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# Regular expression matching non-Japanese characters or punctuation marks: # Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile( _japanese_marks = re.compile(
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# List of (symbol, Japanese) pairs for marks: # List of (symbol, Japanese) pairs for marks:
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ _symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("", "パーセント")]]
('', 'パーセント')
]]
# List of (consonant, sokuon) pairs: # List of (consonant, sokuon) pairs:
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ _real_sokuon = [
(r'Q([↑↓]*[kg])', r'k#\1'), (re.compile("%s" % x[0]), x[1])
(r'Q([↑↓]*[tdjʧ])', r't#\1'), for x in [
(r'Q([↑↓]*[sʃ])', r's\1'), (r"Q([↑↓]*[kg])", r"k#\1"),
(r'Q([↑↓]*[pb])', r'p#\1') (r"Q([↑↓]*[tdjʧ])", r"t#\1"),
]] (r"Q([↑↓]*[sʃ])", r"s\1"),
(r"Q([↑↓]*[pb])", r"p#\1"),
]
]
# List of (consonant, hatsuon) pairs: # List of (consonant, hatsuon) pairs:
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ _real_hatsuon = [
(r'N([↑↓]*[pbm])', r'm\1'), (re.compile("%s" % x[0]), x[1])
(r'N([↑↓]*[ʧʥj])', r'n^\1'), for x in [
(r'N([↑↓]*[tdn])', r'n\1'), (r"N([↑↓]*[pbm])", r"m\1"),
(r'N([↑↓]*[kg])', r'ŋ\1') (r"N([↑↓]*[ʧʥj])", r"n^\1"),
]] (r"N([↑↓]*[tdn])", r"n\1"),
(r"N([↑↓]*[kg])", r"ŋ\1"),
]
]
def post_replace_ph(ph): def post_replace_ph(ph):
rep_map = { rep_map = {
'': ',', "": ",",
'': ',', "": ",",
'': ',', "": ",",
'': '.', "": ".",
'': '!', "": "!",
'': '?', "": "?",
'\n': '.', "\n": ".",
"·": ",", "·": ",",
'': ",", "": ",",
'...': '' "...": "",
} }
if ph in rep_map.keys(): if ph in rep_map.keys():
ph = rep_map[ph] ph = rep_map[ph]
if ph in symbols: if ph in symbols:
return ph return ph
if ph not in symbols: if ph not in symbols:
ph = 'UNK' ph = "UNK"
return ph return ph
def symbols_to_japanese(text): def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese: for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
@ -66,7 +72,7 @@ def symbols_to_japanese(text):
def preprocess_jap(text): def preprocess_jap(text):
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
text = symbols_to_japanese(text) text = symbols_to_japanese(text)
sentences = re.split(_japanese_marks, text) sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text) marks = re.findall(_japanese_marks, text)
@ -77,13 +83,15 @@ def preprocess_jap(text):
text += p.split(" ") text += p.split(" ")
if i < len(marks): if i < len(marks):
text += [marks[i].replace(' ', '')] text += [marks[i].replace(" ", "")]
return text return text
def text_normalize(text): def text_normalize(text):
# todo: jap text normalize # todo: jap text normalize
return text return text
def g2p(norm_text): def g2p(norm_text):
phones = preprocess_jap(norm_text) phones = preprocess_jap(norm_text)
phones = [post_replace_ph(i) for i in phones] phones = [post_replace_ph(i) for i in phones]
@ -91,7 +99,7 @@ def g2p(norm_text):
return phones return phones
if __name__ == '__main__': if __name__ == "__main__":
for line in open("../../../Downloads/transcript_utf8.txt").readlines(): for line in open("../../../Downloads/transcript_utf8.txt").readlines():
text = line.split(":")[1] text = line.split(":")[1]
phones = g2p(text) phones = g2p(text)

View File

@ -1,24 +1,397 @@
import os import os
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 # punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ['!', '?', '', ",", "."]#@是SP停顿 punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-") punctuation.append("-")
pu_symbols = punctuation + ["SP", 'SP2', 'SP3', "UNK"] pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"] # pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
pad = '_' pad = "_"
c = ['AA', 'EE', 'OO', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'w', 'x', 'y', 'z', 'zh'] c = [
v = ['E1', 'En1', 'a1', 'ai1', 'an1', 'ang1', 'ao1', 'e1', 'ei1', 'en1', 'eng1', 'er1', 'i1', 'i01', 'ia1', 'ian1', 'iang1', 'iao1', 'ie1', 'in1', 'ing1', 'iong1', 'ir1', 'iu1', 'o1', 'ong1', 'ou1', 'u1', 'ua1', 'uai1', 'uan1', 'uang1', 'ui1', 'un1', 'uo1', 'v1', 'van1', 've1', 'vn1', 'E2', 'En2', 'a2', 'ai2', 'an2', 'ang2', 'ao2', 'e2', 'ei2', 'en2', 'eng2', 'er2', 'i2', 'i02', 'ia2', 'ian2', 'iang2', 'iao2', 'ie2', 'in2', 'ing2', 'iong2', 'ir2', 'iu2', 'o2', 'ong2', 'ou2', 'u2', 'ua2', 'uai2', 'uan2', 'uang2', 'ui2', 'un2', 'uo2', 'v2', 'van2', 've2', 'vn2', 'E3', 'En3', 'a3', 'ai3', 'an3', 'ang3', 'ao3', 'e3', 'ei3', 'en3', 'eng3', 'er3', 'i3', 'i03', 'ia3', 'ian3', 'iang3', 'iao3', 'ie3', 'in3', 'ing3', 'iong3', 'ir3', 'iu3', 'o3', 'ong3', 'ou3', 'u3', 'ua3', 'uai3', 'uan3', 'uang3', 'ui3', 'un3', 'uo3', 'v3', 'van3', 've3', 'vn3', 'E4', 'En4', 'a4', 'ai4', 'an4', 'ang4', 'ao4', 'e4', 'ei4', 'en4', 'eng4', 'er4', 'i4', 'i04', 'ia4', 'ian4', 'iang4', 'iao4', 'ie4', 'in4', 'ing4', 'iong4', 'ir4', 'iu4', 'o4', 'ong4', 'ou4', 'u4', 'ua4', 'uai4', 'uan4', 'uang4', 'ui4', 'un4', 'uo4', 'v4', 'van4', 've4', 'vn4', 'E5', 'En5', 'a5', 'ai5', 'an5', 'ang5', 'ao5', 'e5', 'ei5', 'en5', 'eng5', 'er5', 'i5', 'i05', 'ia5', 'ian5', 'iang5', 'iao5', 'ie5', 'in5', 'ing5', 'iong5', 'ir5', 'iu5', 'o5', 'ong5', 'ou5', 'u5', 'ua5', 'uai5', 'uan5', 'uang5', 'ui5', 'un5', 'uo5', 'v5', 'van5', 've5', 'vn5'] "AA",
"EE",
"OO",
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"w",
"x",
"y",
"z",
"zh",
]
v = [
"E1",
"En1",
"a1",
"ai1",
"an1",
"ang1",
"ao1",
"e1",
"ei1",
"en1",
"eng1",
"er1",
"i1",
"i01",
"ia1",
"ian1",
"iang1",
"iao1",
"ie1",
"in1",
"ing1",
"iong1",
"ir1",
"iu1",
"o1",
"ong1",
"ou1",
"u1",
"ua1",
"uai1",
"uan1",
"uang1",
"ui1",
"un1",
"uo1",
"v1",
"van1",
"ve1",
"vn1",
"E2",
"En2",
"a2",
"ai2",
"an2",
"ang2",
"ao2",
"e2",
"ei2",
"en2",
"eng2",
"er2",
"i2",
"i02",
"ia2",
"ian2",
"iang2",
"iao2",
"ie2",
"in2",
"ing2",
"iong2",
"ir2",
"iu2",
"o2",
"ong2",
"ou2",
"u2",
"ua2",
"uai2",
"uan2",
"uang2",
"ui2",
"un2",
"uo2",
"v2",
"van2",
"ve2",
"vn2",
"E3",
"En3",
"a3",
"ai3",
"an3",
"ang3",
"ao3",
"e3",
"ei3",
"en3",
"eng3",
"er3",
"i3",
"i03",
"ia3",
"ian3",
"iang3",
"iao3",
"ie3",
"in3",
"ing3",
"iong3",
"ir3",
"iu3",
"o3",
"ong3",
"ou3",
"u3",
"ua3",
"uai3",
"uan3",
"uang3",
"ui3",
"un3",
"uo3",
"v3",
"van3",
"ve3",
"vn3",
"E4",
"En4",
"a4",
"ai4",
"an4",
"ang4",
"ao4",
"e4",
"ei4",
"en4",
"eng4",
"er4",
"i4",
"i04",
"ia4",
"ian4",
"iang4",
"iao4",
"ie4",
"in4",
"ing4",
"iong4",
"ir4",
"iu4",
"o4",
"ong4",
"ou4",
"u4",
"ua4",
"uai4",
"uan4",
"uang4",
"ui4",
"un4",
"uo4",
"v4",
"van4",
"ve4",
"vn4",
"E5",
"En5",
"a5",
"ai5",
"an5",
"ang5",
"ao5",
"e5",
"ei5",
"en5",
"eng5",
"er5",
"i5",
"i05",
"ia5",
"ian5",
"iang5",
"iao5",
"ie5",
"in5",
"ing5",
"iong5",
"ir5",
"iu5",
"o5",
"ong5",
"ou5",
"u5",
"ua5",
"uai5",
"uan5",
"uang5",
"ui5",
"un5",
"uo5",
"v5",
"van5",
"ve5",
"vn5",
]
v_without_tone = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn'] v_without_tone = [
"E",
"En",
"a",
"ai",
"an",
"ang",
"ao",
"e",
"ei",
"en",
"eng",
"er",
"i",
"i0",
"ia",
"ian",
"iang",
"iao",
"ie",
"in",
"ing",
"iong",
"ir",
"iu",
"o",
"ong",
"ou",
"u",
"ua",
"uai",
"uan",
"uang",
"ui",
"un",
"uo",
"v",
"van",
"ve",
"vn",
]
# japanese # japanese
ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky', ja_symbols = [
'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'v', 'w', 'y', 'z'] "I",
"N",
"U",
"a",
"b",
"by",
"ch",
"cl",
"d",
"dy",
"e",
"f",
"g",
"gy",
"h",
"hy",
"i",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"p",
"py",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"u",
"v",
"w",
"y",
"z",
]
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
symbols = sorted(set(symbols)) symbols = sorted(set(symbols))
if __name__ == '__main__': if __name__ == "__main__":
print(len(symbols)) print(len(symbols))

View File

@ -19,51 +19,442 @@ from pypinyin import lazy_pinyin
from pypinyin import Style from pypinyin import Style
class ToneSandhi(): class ToneSandhi:
def __init__(self): def __init__(self):
self.must_neural_tone_words = { self.must_neural_tone_words = {
'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', "麻烦",
'难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', "麻利",
'里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去', "鸳鸯",
'软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号', "高粱",
'认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当', "骨头",
'蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻', "骆驼",
'舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂', "马虎",
'胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆', "首饰",
'老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂', "馒头",
'精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿', "馄饨",
'窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台', "风筝",
'码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算', "难为",
'白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨', "队伍",
'琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快', "阔气",
'爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜', "闺女",
'溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔', "门道",
'棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事', "锄头",
'木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾', "铺盖",
'收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼', "铃铛",
'抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实', "铁匠",
'扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', "钥匙",
'念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', "里脊",
'干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数', "里头",
'屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气', "部分",
'实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈', "那么",
'姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方', "道士",
'大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴', "造化",
'嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦', "迷糊",
'咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝', "连累",
'叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹', "这么",
'功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息', "这个",
'凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤', "运气",
'佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', "过去",
'交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', "软和",
'不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨', "转悠",
'父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅', "踏实",
'幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱', "跳蚤",
'凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱', "跟头",
'扫把', '惦记' "趔趄",
"财主",
"豆腐",
"讲究",
"记性",
"记号",
"认识",
"规矩",
"见识",
"裁缝",
"补丁",
"衣裳",
"衣服",
"衙门",
"街坊",
"行李",
"行当",
"蛤蟆",
"蘑菇",
"薄荷",
"葫芦",
"葡萄",
"萝卜",
"荸荠",
"苗条",
"苗头",
"苍蝇",
"芝麻",
"舒服",
"舒坦",
"舌头",
"自在",
"膏药",
"脾气",
"脑袋",
"脊梁",
"能耐",
"胳膊",
"胭脂",
"胡萝",
"胡琴",
"胡同",
"聪明",
"耽误",
"耽搁",
"耷拉",
"耳朵",
"老爷",
"老实",
"老婆",
"老头",
"老太",
"翻腾",
"罗嗦",
"罐头",
"编辑",
"结实",
"红火",
"累赘",
"糨糊",
"糊涂",
"精神",
"粮食",
"簸箕",
"篱笆",
"算计",
"算盘",
"答应",
"笤帚",
"笑语",
"笑话",
"窟窿",
"窝囊",
"窗户",
"稳当",
"稀罕",
"称呼",
"秧歌",
"秀气",
"秀才",
"福气",
"祖宗",
"砚台",
"码头",
"石榴",
"石头",
"石匠",
"知识",
"眼睛",
"眯缝",
"眨巴",
"眉毛",
"相声",
"盘算",
"白净",
"痢疾",
"痛快",
"疟疾",
"疙瘩",
"疏忽",
"畜生",
"生意",
"甘蔗",
"琵琶",
"琢磨",
"琉璃",
"玻璃",
"玫瑰",
"玄乎",
"狐狸",
"状元",
"特务",
"牲口",
"牙碜",
"牌楼",
"爽快",
"爱人",
"热闹",
"烧饼",
"烟筒",
"烂糊",
"点心",
"炊帚",
"灯笼",
"火候",
"漂亮",
"滑溜",
"溜达",
"温和",
"清楚",
"消息",
"浪头",
"活泼",
"比方",
"正经",
"欺负",
"模糊",
"槟榔",
"棺材",
"棒槌",
"棉花",
"核桃",
"栅栏",
"柴火",
"架势",
"枕头",
"枇杷",
"机灵",
"本事",
"木头",
"木匠",
"朋友",
"月饼",
"月亮",
"暖和",
"明白",
"时候",
"新鲜",
"故事",
"收拾",
"收成",
"提防",
"挖苦",
"挑剔",
"指甲",
"指头",
"拾掇",
"拳头",
"拨弄",
"招牌",
"招呼",
"抬举",
"护士",
"折腾",
"扫帚",
"打量",
"打算",
"打点",
"打扮",
"打听",
"打发",
"扎实",
"扁担",
"戒指",
"懒得",
"意识",
"意思",
"情形",
"悟性",
"怪物",
"思量",
"怎么",
"念头",
"念叨",
"快活",
"忙活",
"志气",
"心思",
"得罪",
"张罗",
"弟兄",
"开通",
"应酬",
"庄稼",
"干事",
"帮手",
"帐篷",
"希罕",
"师父",
"师傅",
"巴结",
"巴掌",
"差事",
"工夫",
"岁数",
"屁股",
"尾巴",
"少爷",
"小气",
"小伙",
"将就",
"对头",
"对付",
"寡妇",
"家伙",
"客气",
"实在",
"官司",
"学问",
"学生",
"字号",
"嫁妆",
"媳妇",
"媒人",
"婆家",
"娘家",
"委屈",
"姑娘",
"姐夫",
"妯娌",
"妥当",
"妖精",
"奴才",
"女婿",
"头发",
"太阳",
"大爷",
"大方",
"大意",
"大夫",
"多少",
"多么",
"外甥",
"壮实",
"地道",
"地方",
"在乎",
"困难",
"嘴巴",
"嘱咐",
"嘟囔",
"嘀咕",
"喜欢",
"喇嘛",
"喇叭",
"商量",
"唾沫",
"哑巴",
"哈欠",
"哆嗦",
"咳嗽",
"和尚",
"告诉",
"告示",
"含糊",
"吓唬",
"后头",
"名字",
"名堂",
"合同",
"吆喝",
"叫唤",
"口袋",
"厚道",
"厉害",
"千斤",
"包袱",
"包涵",
"匀称",
"勤快",
"动静",
"动弹",
"功夫",
"力气",
"前头",
"刺猬",
"刺激",
"别扭",
"利落",
"利索",
"利害",
"分析",
"出息",
"凑合",
"凉快",
"冷战",
"冤枉",
"冒失",
"养活",
"关系",
"先生",
"兄弟",
"便宜",
"使唤",
"佩服",
"作坊",
"体面",
"位置",
"似的",
"伙计",
"休息",
"什么",
"人家",
"亲戚",
"亲家",
"交情",
"云彩",
"事情",
"买卖",
"主意",
"丫头",
"丧气",
"两口",
"东西",
"东家",
"世故",
"不由",
"不在",
"下水",
"下巴",
"上头",
"上司",
"丈夫",
"丈人",
"一辈",
"那个",
"菩萨",
"父亲",
"母亲",
"咕噜",
"邋遢",
"费用",
"冤家",
"甜头",
"介绍",
"荒唐",
"大人",
"泥鳅",
"幸福",
"熟悉",
"计划",
"扑腾",
"蜡烛",
"姥爷",
"照顾",
"喉咙",
"吉他",
"弄堂",
"蚂蚱",
"凤凰",
"拖沓",
"寒碜",
"糟蹋",
"倒腾",
"报复",
"逻辑",
"盘缠",
"喽啰",
"牢骚",
"咖喱",
"扫把",
"惦记",
} }
self.must_not_neural_tone_words = { self.must_not_neural_tone_words = {
"男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子", "人人", "虎虎" "男子",
"女子",
"分子",
"原子",
"量子",
"莲子",
"石子",
"瓜子",
"电子",
"人人",
"虎虎",
} }
self.punc = ":,;。?!“”‘’':,;.?!" self.punc = ":,;。?!“”‘’':,;.?!"
@ -72,14 +463,15 @@ class ToneSandhi():
# word: "家里" # word: "家里"
# pos: "s" # pos: "s"
# finals: ['ia1', 'i3'] # finals: ['ia1', 'i3']
def _neural_sandhi(self, word: str, pos: str, def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals: List[str]) -> List[str]:
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺 # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
for j, item in enumerate(word): for j, item in enumerate(word):
if j - 1 >= 0 and item == word[j - 1] and pos[0] in { if (
"n", "v", "a" j - 1 >= 0
} and word not in self.must_not_neural_tone_words: and item == word[j - 1]
and pos[0] in {"n", "v", "a"}
and word not in self.must_not_neural_tone_words
):
finals[j] = finals[j][:-1] + "5" finals[j] = finals[j][:-1] + "5"
ge_idx = word.find("") ge_idx = word.find("")
if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶": if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
@ -89,9 +481,12 @@ class ToneSandhi():
# e.g. 走了, 看着, 去过 # e.g. 走了, 看着, 去过
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}: elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
elif len(word) > 1 and word[-1] in "们子" and pos in { elif (
"r", "n" len(word) > 1
} and word not in self.must_not_neural_tone_words: and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里 # e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
@ -100,21 +495,26 @@ class ToneSandhi():
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开": elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
# 个做量词 # 个做量词
elif (ge_idx >= 1 and elif (
(word[ge_idx - 1].isnumeric() or ge_idx >= 1
word[ge_idx - 1] in "几有两半多各整每做是")) or word == '': and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
) or word == "":
finals[ge_idx] = finals[ge_idx][:-1] + "5" finals[ge_idx] = finals[ge_idx][:-1] + "5"
else: else:
if word in self.must_neural_tone_words or word[ if (
-2:] in self.must_neural_tone_words: word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word) word_list = self._split_word(word)
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list): for i, word in enumerate(word_list):
# conventional neural in Chinese # conventional neural in Chinese
if word in self.must_neural_tone_words or word[ if (
-2:] in self.must_neural_tone_words: word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals_list[i][-1] = finals_list[i][-1][:-1] + "5" finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, []) finals = sum(finals_list, [])
return finals return finals
@ -126,15 +526,15 @@ class ToneSandhi():
else: else:
for i, char in enumerate(word): for i, char in enumerate(word):
# "不" before tone4 should be bu2, e.g. 不怕 # "不" before tone4 should be bu2, e.g. 不怕
if char == "" and i + 1 < len(word) and finals[i + if char == "" and i + 1 < len(word) and finals[i + 1][-1] == "4":
1][-1] == "4":
finals[i] = finals[i][:-1] + "2" finals[i] = finals[i][:-1] + "2"
return finals return finals
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零 # "一" in number sequences, e.g. 一零零, 二一零
if word.find("") != -1 and all( if word.find("") != -1 and all(
[item.isnumeric() for item in word if item != ""]): [item.isnumeric() for item in word if item != ""]
):
return finals return finals
# "一" between reduplication words shold be yi5, e.g. 看一看 # "一" between reduplication words shold be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]: elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
@ -182,18 +582,19 @@ class ToneSandhi():
elif len(word_list[0]) == 1: elif len(word_list[0]) == 1:
finals[1] = finals[1][:-1] + "2" finals[1] = finals[1][:-1] + "2"
else: else:
finals_list = [ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
finals[:len(word_list[0])], finals[len(word_list[0]):]
]
if len(finals_list) == 2: if len(finals_list) == 2:
for i, sub in enumerate(finals_list): for i, sub in enumerate(finals_list):
# e.g. 所有/人 # e.g. 所有/人
if self._all_tone_three(sub) and len(sub) == 2: if self._all_tone_three(sub) and len(sub) == 2:
finals_list[i][0] = finals_list[i][0][:-1] + "2" finals_list[i][0] = finals_list[i][0][:-1] + "2"
# e.g. 好/喜欢 # e.g. 好/喜欢
elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \ elif (
finals_list[0][-1][-1] == "3": i == 1
and not self._all_tone_three(sub)
and finals_list[i][0][-1] == "3"
and finals_list[0][-1][-1] == "3"
):
finals_list[0][-1] = finals_list[0][-1][:-1] + "2" finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
finals = sum(finals_list, []) finals = sum(finals_list, [])
# split idiom into two words who's length is 2 # split idiom into two words who's length is 2
@ -222,7 +623,7 @@ class ToneSandhi():
new_seg.append((word, pos)) new_seg.append((word, pos))
last_word = word[:] last_word = word[:]
if last_word == "": if last_word == "":
new_seg.append((last_word, 'd')) new_seg.append((last_word, "d"))
last_word = "" last_word = ""
return new_seg return new_seg
@ -236,12 +637,21 @@ class ToneSandhi():
new_seg = [] new_seg = []
# function 1 # function 1
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and i + 1 < len(seg) and seg[i - 1][ if (
0] == seg[i + 1][0] and seg[i - 1][1] == "v": i - 1 >= 0
and word == ""
and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v"
):
new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0] new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0]
else: else:
if i - 2 >= 0 and seg[i - 1][0] == "" and seg[i - 2][ if (
0] == word and pos == "v": i - 2 >= 0
and seg[i - 1][0] == ""
and seg[i - 2][0] == word
and pos == "v"
):
continue continue
else: else:
new_seg.append([word, pos]) new_seg.append([word, pos])
@ -257,22 +667,27 @@ class ToneSandhi():
# the first and the second words are all_tone_three # the first and the second words are all_tone_three
def _merge_continuous_three_tones( def _merge_continuous_three_tones(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
sub_finals_list = [ sub_finals_list = [
lazy_pinyin( lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg for (word, pos) in seg
] ]
assert len(sub_finals_list) == len(seg) assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg) merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and self._all_tone_three( if (
sub_finals_list[i - 1]) and self._all_tone_three( i - 1 >= 0
sub_finals_list[i]) and not merge_last[i - 1]: and self._all_tone_three(sub_finals_list[i - 1])
and self._all_tone_three(sub_finals_list[i])
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len( if (
seg[i - 1][0]) + len(seg[i][0]) <= 3: not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True merge_last[i] = True
else: else:
@ -287,21 +702,27 @@ class ToneSandhi():
# the last char of first word and the first char of second word is tone_three # the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2( def _merge_continuous_three_tones_2(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
sub_finals_list = [ sub_finals_list = [
lazy_pinyin( lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg for (word, pos) in seg
] ]
assert len(sub_finals_list) == len(seg) assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg) merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \ if (
merge_last[i - 1]: i - 1 >= 0
and sub_finals_list[i - 1][-1][-1] == "3"
and sub_finals_list[i][0][-1] == "3"
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len( if (
seg[i - 1][0]) + len(seg[i][0]) <= 3: not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True merge_last[i] = True
else: else:
@ -319,8 +740,7 @@ class ToneSandhi():
new_seg.append([word, pos]) new_seg.append([word, pos])
return new_seg return new_seg
def _merge_reduplication( def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if new_seg and word == new_seg[-1][0]: if new_seg and word == new_seg[-1][0]:
@ -329,8 +749,7 @@ class ToneSandhi():
new_seg.append([word, pos]) new_seg.append([word, pos])
return new_seg return new_seg
def pre_merge_for_modify( def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
seg = self._merge_bu(seg) seg = self._merge_bu(seg)
try: try:
seg = self._merge_yi(seg) seg = self._merge_yi(seg)
@ -349,8 +768,7 @@ class ToneSandhi():
seg = self._merge_er(seg) seg = self._merge_er(seg)
return seg return seg
def modified_tone(self, word: str, pos: str, def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals: List[str]) -> List[str]:
finals = self._bu_sandhi(word, finals) finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals) finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals) finals = self._neural_sandhi(word, pos, finals)

View File

@ -12,8 +12,9 @@ import numpy as np
from scipy.io.wavfile import read from scipy.io.wavfile import read
import torch import torch
import logging import logging
logging.getLogger('numba').setLevel(logging.ERROR)
logging.getLogger('matplotlib').setLevel(logging.ERROR) logging.getLogger("numba").setLevel(logging.ERROR)
logging.getLogger("matplotlib").setLevel(logging.ERROR)
MATPLOTLIB_FLAG = False MATPLOTLIB_FLAG = False
@ -23,13 +24,17 @@ logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path) assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict['iteration'] iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict['learning_rate'] learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: if (
optimizer.load_state_dict(checkpoint_dict['optimizer']) optimizer is not None
saved_state_dict = checkpoint_dict['model'] and not skip_optimizer
if hasattr(model, 'module'): and checkpoint_dict["optimizer"] is not None
):
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
# assert "quantizer" not in k # assert "quantizer" not in k
# print("load", k) # print("load", k)
new_state_dict[k] = saved_state_dict[k] new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except: except:
traceback.print_exc() traceback.print_exc()
print("error, %s is not in the checkpoint" % k)#shape不对也会比如text_embedding当cleaner修改时 print(
"error, %s is not in the checkpoint" % k
) # shape不对也会比如text_embedding当cleaner修改时
new_state_dict[k] = v new_state_dict[k] = v
if hasattr(model, 'module'): if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict) model.module.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
print("load ") print("load ")
logger.info("Loaded checkpoint '{}' (iteration {})".format( logger.info(
checkpoint_path, iteration)) "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info("Saving model and optimizer state at iteration {} to {}".format( logger.info(
iteration, checkpoint_path)) "Saving model and optimizer state at iteration {} to {}".format(
if hasattr(model, 'module'): iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict() state_dict = model.module.state_dict()
else: else:
state_dict = model.state_dict() state_dict = model.state_dict()
torch.save({'model': state_dict, torch.save(
'iteration': iteration, {
'optimizer': optimizer.state_dict(), "model": state_dict,
'learning_rate': learning_rate}, checkpoint_path) "iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items(): for k, v in scalars.items():
writer.add_scalar(k, v, global_step) writer.add_scalar(k, v, global_step)
for k, v in histograms.items(): for k, v in histograms.items():
writer.add_histogram(k, v, global_step) writer.add_histogram(k, v, global_step)
for k, v in images.items(): for k, v in images.items():
writer.add_image(k, v, global_step, dataformats='HWC') writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items(): for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate) writer.add_audio(k, v, global_step, audio_sampling_rate)
@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG: if not MATPLOTLIB_FLAG:
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
MATPLOTLIB_FLAG = True MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib') mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt import matplotlib.pylab as plt
import numpy as np import numpy as np
fig, ax = plt.subplots(figsize=(10, 2)) fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
interpolation='none')
plt.colorbar(im, ax=ax) plt.colorbar(im, ax=ax)
plt.xlabel("Frames") plt.xlabel("Frames")
plt.ylabel("Channels") plt.ylabel("Channels")
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close() plt.close()
return data return data
@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG: if not MATPLOTLIB_FLAG:
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
MATPLOTLIB_FLAG = True MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib') mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING) mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt import matplotlib.pylab as plt
import numpy as np import numpy as np
fig, ax = plt.subplots(figsize=(6, 4)) fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', im = ax.imshow(
interpolation='none') alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax) fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep' xlabel = "Decoder timestep"
if info is not None: if info is not None:
xlabel += '\n\n' + info xlabel += "\n\n" + info
plt.xlabel(xlabel) plt.xlabel(xlabel)
plt.ylabel('Encoder timestep') plt.ylabel("Encoder timestep")
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close() plt.close()
return data return data
@ -147,16 +176,31 @@ def load_wav_to_torch(full_path):
def load_filepaths_and_text(filename, split="|"): def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f: with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f] filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text return filepaths_and_text
def get_hparams(init=True, stage=1): def get_hparams(init=True, stage=1):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration') parser.add_argument(
parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir') "-c",
parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step') "--config",
type=str,
default="./configs/s2.json",
help="JSON file for configuration",
)
parser.add_argument(
"-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
)
parser.add_argument(
"-rs",
"--resume_step",
type=int,
required=False,
default=None,
help="resume step",
)
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory') # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights') # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights') # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
@ -186,8 +230,7 @@ def get_hparams(init=True, stage=1):
return hparams return hparams
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts """Freeing up space by deleting saved ckpts
Arguments: Arguments:
@ -197,18 +240,28 @@ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_tim
False -> lexicographically delete ckpts False -> lexicographically delete ckpts
""" """
import re import re
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1))) ckpts_files = [
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], x_sorted = lambda _x: sorted(
key=sort_key) [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
to_del = [os.path.join(path_to_models, fn) for fn in key=sort_key,
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])] )
to_del = [
os.path.join(path_to_models, fn)
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)] del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del] rs = [del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir): def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json") config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f: with open(config_save_path, "r") as f:
@ -228,12 +281,15 @@ def get_hparams_from_file(config_path):
hparams = HParams(**config) hparams = HParams(**config)
return hparams return hparams
def check_git_hash(model_dir): def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__)) source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")): if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir source_dir
)) )
)
return return
cur_hash = subprocess.getoutput("git rev-parse HEAD") cur_hash = subprocess.getoutput("git rev-parse HEAD")
@ -242,8 +298,11 @@ def check_git_hash(model_dir):
if os.path.exists(path): if os.path.exists(path):
saved_hash = open(path).read() saved_hash = open(path).read()
if saved_hash != cur_hash: if saved_hash != cur_hash:
logger.warn("git hash values are different. {}(saved) != {}(current)".format( logger.warn(
saved_hash[:8], cur_hash[:8])) "git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else: else:
open(path, "w").write(cur_hash) open(path, "w").write(cur_hash)
@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"):
return logger return logger
class HParams(): class HParams:
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
if type(v) == dict: if type(v) == dict:
@ -294,5 +353,10 @@ class HParams():
def __repr__(self): def __repr__(self):
return self.__dict__.__repr__() return self.__dict__.__repr__()
if __name__ == '__main__':
print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac')) if __name__ == "__main__":
print(
load_wav_to_torch(
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
)
)

View File

@ -71,6 +71,8 @@ conda install ffmpeg
```bash ```bash
sudo apt install ffmpeg sudo apt install ffmpeg
sudo apt install libsox-dev
conda install -c conda-forge 'ffmpeg<7'
``` ```
##### MacOS Users ##### MacOS Users

View File

@ -18,3 +18,4 @@ modelscope
sentencepiece sentencepiece
transformers transformers
chardet chardet
PyYAML

View File

@ -1,7 +1,7 @@
import os import os
import traceback,gradio as gr import traceback,gradio as gr
import logging import logging
from i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto() i18n = I18nAuto()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

985
webui.py

File diff suppressed because it is too large Load Diff