more code refactor

This commit is contained in:
Blaise 2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@ -16,7 +16,7 @@ __all__ = [
"DistributedBucketSampler", "DistributedBucketSampler",
] ]
T_co = TypeVar('T_co', covariant=True) T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]): class DistributedBucketSampler(Sampler[T_co]):
@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
sort batches sort batches
""" """
def __init__(self, def __init__(
dataset: Dataset, self,
num_replicas: Optional[int]=None, dataset: Dataset,
rank: Optional[int]=None, num_replicas: Optional[int] = None,
shuffle: bool=True, rank: Optional[int] = None,
seed: int=0, shuffle: bool = True,
drop_last: bool=False, seed: int = 0,
batch_size: int=32) -> None: drop_last: bool = False,
batch_size: int = 32,
) -> None:
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError( raise RuntimeError("Requires distributed package to be available")
"Requires distributed package to be available")
num_replicas = dist.get_world_size() num_replicas = dist.get_world_size()
if rank is None: if rank is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError( raise RuntimeError("Requires distributed package to be available")
"Requires distributed package to be available")
rank = dist.get_rank() rank = dist.get_rank()
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0: if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval" raise ValueError(
" [0, {}]".format(rank, num_replicas - 1)) "Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset self.dataset = dataset
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.rank = rank self.rank = rank
@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there # If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally. # is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len( if (
self. self.drop_last and len(self.dataset) % self.num_replicas != 0
dataset) % self.num_replicas != 0: # type: ignore[arg-type] ): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible. # Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when # This is to ensure each rank receives the same amount of data when
# using this Sampler. # using this Sampler.
self.num_samples = math.ceil( self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / (len(self.dataset) - self.num_replicas)
self.num_replicas # type: ignore[arg-type] / self.num_replicas # type: ignore[arg-type]
) )
else: else:
self.num_samples = math.ceil( self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type] len(self.dataset) / self.num_replicas
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle self.shuffle = shuffle
self.seed = seed self.seed = seed
@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]):
id_with_lengths.sort(key=lambda x: x[1]) id_with_lengths.sort(key=lambda x: x[1])
return id_with_lengths return id_with_lengths
def make_buckets(self, bucket_width: float=2.0): def make_buckets(self, bucket_width: float = 2.0):
buckets = [] buckets = []
cur = [] cur = []
max_sec = bucket_width max_sec = bucket_width
@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
shuffled_bucket = list(itertools.chain(*shuffled_bucket)) shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [ batches = [
shuffled_bucket[b * grouped_batch_size:(b + 1) * shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
grouped_batch_size] for b in range(n_batch) for b in range(n_batch)
] ]
shuffle(batches) shuffle(batches)
indices = list(itertools.chain(*batches)) indices = list(itertools.chain(*batches))
@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices): if padding_size <= len(indices):
indices += indices[:padding_size] indices += indices[:padding_size]
else: else:
indices += (indices * math.ceil(padding_size / indices += (indices * math.ceil(padding_size / len(indices)))[
len(indices)))[:padding_size] :padding_size
]
else: else:
# remove tail of data to make it evenly divisible. # remove tail of data to make it evenly divisible.
indices = indices[:self.total_size] indices = indices[: self.total_size]
assert len(indices) == self.total_size assert len(indices) == self.total_size
# subsample # subsample
indices = indices[self.rank:self.total_size:self.num_replicas] indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
return iter(indices) return iter(indices)

View File

@ -6,14 +6,21 @@ from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule): class Text2SemanticDataModule(LightningDataModule):
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None): def __init__(
self,
config,
train_semantic_path,
train_phoneme_path,
dev_semantic_path=None,
dev_phoneme_path=None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.train_semantic_path = train_semantic_path self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config['data']['num_workers'] self.num_workers = self.config["data"]["num_workers"]
def prepare_data(self): def prepare_data(self):
pass pass
@ -22,8 +29,9 @@ class Text2SemanticDataModule(LightningDataModule):
self._train_dataset = Text2SemanticDataset( self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path, phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path, semantic_path=self.train_semantic_path,
max_sec=self.config['data']['max_sec'], max_sec=self.config["data"]["max_sec"],
pad_val=self.config['data']['pad_val']) pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset( # self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path, # phoneme_path=self.dev_phoneme_path,
@ -33,9 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val']) # pad_val=self.config['data']['pad_val'])
def train_dataloader(self): def train_dataloader(self):
batch_size = self.config['train']['batch_size'] batch_size = self.config["train"]["batch_size"]
sampler = DistributedBucketSampler( sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
self._train_dataset, batch_size=batch_size)
return DataLoader( return DataLoader(
self._train_dataset, self._train_dataset,
batch_size=batch_size, batch_size=batch_size,
@ -43,7 +50,7 @@ class Text2SemanticDataModule(LightningDataModule):
collate_fn=self._train_dataset.collate, collate_fn=self._train_dataset.collate,
num_workers=self.num_workers, num_workers=self.num_workers,
persistent_workers=True, persistent_workers=True,
prefetch_factor=16 prefetch_factor=16,
) )
def val_dataloader(self): def val_dataloader(self):
@ -52,9 +59,9 @@ class Text2SemanticDataModule(LightningDataModule):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
collate_fn=self._train_dataset.collate, collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers,12), num_workers=max(self.num_workers, 12),
persistent_workers=True, persistent_workers=True,
prefetch_factor=16 prefetch_factor=16,
) )
# 这个会使用到嘛? # 这个会使用到嘛?
@ -63,4 +70,5 @@ class Text2SemanticDataModule(LightningDataModule):
self._dev_dataset, self._dev_dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
collate_fn=self._train_dataset.collate) collate_fn=self._train_dataset.collate,
)

View File

@ -1,21 +1,24 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
import pdb import pdb
import sys import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert") # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback,os import traceback, os
from typing import Dict from typing import Dict
from typing import List from typing import List
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch,json import torch, json
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
# from config import exp_dir # from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0): def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0] seq = sequences[0]
ndim = seq.ndim ndim = seq.ndim
@ -28,44 +31,52 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = [] padded_sequences = []
for seq, length in zip(sequences, seq_lengths): for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * ( padding = (
ndim - axis - 1) [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
padded_seq = np.pad( )
seq, padding, mode='constant', constant_values=pad_value) padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq) padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences) batch = np.stack(padded_sequences)
return batch return batch
class Text2SemanticDataset(Dataset): class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training.""" """dataset class for text tokens to semantic model training."""
def __init__(self, def __init__(
phoneme_path: str, self,
semantic_path: str, phoneme_path: str,
max_sample: int = None, semantic_path: str,
max_sec: int = 100, max_sample: int = None,
pad_val: int = 1024, max_sec: int = 100,
# min value of phoneme/sec pad_val: int = 1024,
min_ps_ratio: int = 3, # min value of phoneme/sec
# max value of phoneme/sec min_ps_ratio: int = 3,
max_ps_ratio: int = 25) -> None: # max value of phoneme/sec
max_ps_ratio: int = 25,
) -> None:
super().__init__() super().__init__()
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8") self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
)
# get dict # get dict
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir self.path3 = "%s/3-bert" % (
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path os.path.basename(phoneme_path)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2) assert os.path.exists(self.path2)
assert os.path.exists(self.path6) assert os.path.exists(self.path6)
self.phoneme_data={} self.phoneme_data = {}
with open(self.path2,"r",encoding="utf8")as f: with open(self.path2, "r", encoding="utf8") as f:
lines=f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")
for line in lines: for line in lines:
tmp=line.split("\t") tmp = line.split("\t")
if(len(tmp)!=4):continue if len(tmp) != 4:
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]] continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
# pad for semantic tokens # pad for semantic tokens
@ -74,7 +85,7 @@ class Text2SemanticDataset(Dataset):
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read() # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
# self.hz=int(data[:-2])# # self.hz=int(data[:-2])#
self.hz=int(os.environ.get("hz","25hz")[:-2]) self.hz = int(os.environ.get("hz", "25hz")[:-2])
# max seconds of semantic token # max seconds of semantic token
self.max_sec = max_sec self.max_sec = max_sec
@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset):
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large") # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large") # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self): def init_batch(self):
semantic_data_len = len(self.semantic_data) semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys()) phoneme_data_len = len(self.phoneme_data.keys())
@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len): for i in range(semantic_data_len):
# 先依次遍历 # 先依次遍历
# get str # get str
item_name = self.semantic_data['item_name'][i] item_name = self.semantic_data["item_name"][i]
# print(self.phoneme_data) # print(self.phoneme_data)
try: try:
phoneme, word2ph, text = self.phoneme_data[item_name] phoneme, word2ph, text = self.phoneme_data[item_name]
@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1 num_not_in += 1
continue continue
semantic_str = self.semantic_data['semantic_audio'][i] semantic_str = self.semantic_data["semantic_audio"][i]
# get token list # get token list
semantic_ids = [int(idx) for idx in semantic_str.split(' ')] semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本 # 过滤掉太长的样本
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1 num_deleted_bigger += 1
continue continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理#### # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(' ') phoneme = phoneme.split(" ")
try: try:
phoneme_ids = cleaned_text_to_sequence(phoneme) phoneme_ids = cleaned_text_to_sequence(phoneme)
@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1 num_not_in += 1
continue continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行 # if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2改为恒定限制为semantic/2.5就行 if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1 num_deleted_ps += 1
continue continue
# if len(semantic_ids) > 1000:###########3 # if len(semantic_ids) > 1000:###########3
@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1 num_deleted_ps += 1
# print(item_name) # print(item_name)
continue continue
@ -160,16 +176,16 @@ class Text2SemanticDataset(Dataset):
idx += 1 idx += 1
self.item_names.append(item_name) self.item_names.append(item_name)
min_num=100#20直接不补#30补了也不存ckpt min_num = 100 # 20直接不补#30补了也不存ckpt
leng =len(self.semantic_phoneme) leng = len(self.semantic_phoneme)
if(leng<min_num): if leng < min_num:
tmp1=self.semantic_phoneme tmp1 = self.semantic_phoneme
tmp2=self.item_names tmp2 = self.item_names
self.semantic_phoneme=[] self.semantic_phoneme = []
self.item_names=[] self.item_names = []
for _ in range(max(2,int(min_num/leng))): for _ in range(max(2, int(min_num / leng))):
self.semantic_phoneme+=tmp1 self.semantic_phoneme += tmp1
self.item_names+=tmp2 self.item_names += tmp2
if num_not_in > 0: if num_not_in > 0:
print(f"there are {num_not_in} semantic datas not in phoneme datas") print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0: if num_deleted_bigger > 0:
@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset):
print( print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}" f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
) )
''' """
there are 31 semantic datas not in phoneme datas there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3 deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463 dataset.__len__(): 366463
''' """
# 345410 for LibriTTS # 345410 for LibriTTS
print("dataset.__len__():", self.__len__()) print("dataset.__len__():", self.__len__())
@ -204,22 +220,24 @@ class Text2SemanticDataset(Dataset):
# semantic tokens target # semantic tokens target
semantic_ids_len = len(semantic_ids) semantic_ids_len = len(semantic_ids)
flag=0 flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name) path_bert = "%s/%s.pt" % (self.path3, item_name)
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu") if os.path.exists(path_bert) == True:
else:flag=1 bert_feature = torch.load(path_bert, map_location="cpu")
if(flag==1): else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32) # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature=None bert_feature = None
else: else:
assert bert_feature.shape[-1] == len(phoneme_ids) assert bert_feature.shape[-1] == len(phoneme_ids)
return { return {
'idx': idx, "idx": idx,
'phoneme_ids': phoneme_ids, "phoneme_ids": phoneme_ids,
'phoneme_ids_len': phoneme_ids_len, "phoneme_ids_len": phoneme_ids_len,
'semantic_ids': semantic_ids, "semantic_ids": semantic_ids,
'semantic_ids_len': semantic_ids_len, "semantic_ids_len": semantic_ids_len,
'bert_feature': bert_feature, "bert_feature": bert_feature,
} }
def get_sample_length(self, idx: int): def get_sample_length(self, idx: int):
@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset):
semantic_ids_lens: List[int] = [] semantic_ids_lens: List[int] = []
# return # return
for item in examples: for item in examples:
sample_index.append(item["idx"]) sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
@ -256,9 +273,9 @@ class Text2SemanticDataset(Dataset):
bert_padded.zero_() bert_padded.zero_()
for idx, item in enumerate(examples): for idx, item in enumerate(examples):
bert = item['bert_feature'] bert = item["bert_feature"]
if(bert!=None): if bert != None:
bert_padded[idx, :, :bert.shape[-1]] = bert bert_padded[idx, :, : bert.shape[-1]] = bert
return { return {
# List[int] # List[int]
@ -276,27 +293,27 @@ class Text2SemanticDataset(Dataset):
} }
if __name__ == '__main__': if __name__ == "__main__":
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/' root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset( dataset = Text2SemanticDataset(
phoneme_path=root_dir + 'phoneme_train.npy', phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + 'semantic_train.tsv') semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12 batch_size = 12
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
batch_size=batch_size, )
collate_fn=dataset.collate,
shuffle=False)
for i, batch in enumerate(dataloader): for i, batch in enumerate(dataloader):
if(i%1000==0):print(i) if i % 1000 == 0:
print(i)
# if i == 0: # if i == 0:
# print('batch["ids"]:', batch["ids"]) # print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"], # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape) # batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"], # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape) # batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"], # print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape) # batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"], # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape) # batch["semantic_ids_len"].shape)

View File

@ -1,5 +1,6 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os,sys import os, sys
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from typing import Dict from typing import Dict
@ -12,29 +13,35 @@ from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule): class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir,is_train=True): def __init__(self, config, output_dir, is_train=True):
super().__init__() super().__init__()
self.config = config self.config = config
self.top_k = 3 self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1=config.get("pretrained_s1") pretrained_s1 = config.get("pretrained_s1")
if(pretrained_s1 and is_train): if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"])) print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
)
if is_train: if is_train:
self.automatic_optimization = False self.automatic_optimization = False
self.save_hyperparameters() self.save_hyperparameters()
self.eval_dir = output_dir / 'eval' self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True) self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int): def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers() opt = self.optimizers()
scheduler = self.lr_schedulers() scheduler = self.lr_schedulers()
loss, acc = self.model.forward( loss, acc = self.model.forward(
batch['phoneme_ids'], batch['phoneme_ids_len'], batch["phoneme_ids"],
batch['semantic_ids'], batch['semantic_ids_len'], batch["phoneme_ids_len"],
batch['bert_feature']) batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
)
self.manual_backward(loss) self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0: if batch_idx > 0 and batch_idx % 4 == 0:
opt.step() opt.step()
@ -47,63 +54,67 @@ class Text2SemanticLightningModule(LightningModule):
on_step=True, on_step=True,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True) sync_dist=True,
)
self.log( self.log(
"lr", "lr",
scheduler.get_last_lr()[0], scheduler.get_last_lr()[0],
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True) sync_dist=True,
)
self.log( self.log(
f"top_{self.top_k}_acc", f"top_{self.top_k}_acc",
acc, acc,
on_step=True, on_step=True,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
sync_dist=True) sync_dist=True,
)
def validation_step(self, batch: Dict, batch_idx: int):return def validation_step(self, batch: Dict, batch_idx: int):
# # get loss return
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'], # # get loss
# batch['semantic_ids'], batch['semantic_ids_len'], # loss, acc = self.model.forward(
# batch['bert_feature'] # batch['phoneme_ids'], batch['phoneme_ids_len'],
# ) # batch['semantic_ids'], batch['semantic_ids_len'],
# # batch['bert_feature']
# self.log( # )
# "val_total_loss", #
# loss, # self.log(
# on_step=True, # "val_total_loss",
# on_epoch=True, # loss,
# prog_bar=True, # on_step=True,
# sync_dist=True) # on_epoch=True,
# self.log( # prog_bar=True,
# f"val_top_{self.top_k}_acc", # sync_dist=True)
# acc, # self.log(
# on_step=True, # f"val_top_{self.top_k}_acc",
# on_epoch=True, # acc,
# prog_bar=True, # on_step=True,
# sync_dist=True) # on_epoch=True,
# # prog_bar=True,
# # get infer output # sync_dist=True)
# semantic_len = batch['semantic_ids'].size(1) #
# prompt_len = min(int(semantic_len * 0.5), 150) # # get infer output
# prompt = batch['semantic_ids'][:, :prompt_len] # semantic_len = batch['semantic_ids'].size(1)
# pred_semantic = self.model.infer(batch['phoneme_ids'], # prompt_len = min(int(semantic_len * 0.5), 150)
# batch['phoneme_ids_len'], prompt, # prompt = batch['semantic_ids'][:, :prompt_len]
# batch['bert_feature'] # pred_semantic = self.model.infer(batch['phoneme_ids'],
# ) # batch['phoneme_ids_len'], prompt,
# save_name = f'semantic_toks_{batch_idx}.pt' # batch['bert_feature']
# save_path = os.path.join(self.eval_dir, save_name) # )
# torch.save(pred_semantic.detach().cpu(), save_path) # save_name = f'semantic_toks_{batch_idx}.pt'
# save_path = os.path.join(self.eval_dir, save_name)
# torch.save(pred_semantic.detach().cpu(), save_path)
def configure_optimizers(self): def configure_optimizers(self):
model_parameters = self.model.parameters() model_parameters = self.model.parameters()
parameters_names = [] parameters_names = []
parameters_names.append([ parameters_names.append(
name_param_pair[0] [name_param_pair[0] for name_param_pair in self.model.named_parameters()]
for name_param_pair in self.model.named_parameters() )
])
lm_opt = ScaledAdam( lm_opt = ScaledAdam(
model_parameters, model_parameters,
lr=0.01, lr=0.01,
@ -111,18 +122,19 @@ class Text2SemanticLightningModule(LightningModule):
clipping_scale=2.0, clipping_scale=2.0,
parameters_names=parameters_names, parameters_names=parameters_names,
show_dominant_parameters=False, show_dominant_parameters=False,
clipping_update_period=1000, ) clipping_update_period=1000,
)
return { return {
"optimizer": lm_opt, "optimizer": lm_opt,
"lr_scheduler": { "lr_scheduler": {
"scheduler": "scheduler": WarmupCosineLRSchedule(
WarmupCosineLRSchedule(
lm_opt, lm_opt,
init_lr=self.config['optimizer']['lr_init'], init_lr=self.config["optimizer"]["lr_init"],
peak_lr=self.config['optimizer']['lr'], peak_lr=self.config["optimizer"]["lr"],
end_lr=self.config['optimizer']['lr_end'], end_lr=self.config["optimizer"]["lr_end"],
warmup_steps=self.config['optimizer']['warmup_steps'], warmup_steps=self.config["optimizer"]["warmup_steps"],
total_steps=self.config['optimizer']['decay_steps']) total_steps=self.config["optimizer"]["decay_steps"],
} )
},
} }

View File

@ -3,7 +3,12 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from AR.models.utils import make_pad_mask from AR.models.utils import make_pad_mask
from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
)
from AR.modules.embedding import SinePositionalEmbedding from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm from AR.modules.transformer import LayerNorm
@ -22,35 +27,39 @@ default_config = {
"p_dropout": 0.0, "p_dropout": 0.0,
"vocab_size": 1024 + 1, "vocab_size": 1024 + 1,
"phoneme_vocab_size": 512, "phoneme_vocab_size": 512,
"EOS": 1024 "EOS": 1024,
} }
class Text2SemanticDecoder(nn.Module): class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3): def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__() super(Text2SemanticDecoder, self).__init__()
self.model_dim = config['model']["hidden_dim"] self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config['model']["embedding_dim"] self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config['model']["head"] self.num_head = config["model"]["head"]
self.num_layers = config['model']["n_layer"] self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first self.norm_first = norm_first
self.vocab_size = config['model']["vocab_size"] self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config['model']["phoneme_vocab_size"] self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config['model']["dropout"] self.p_dropout = config["model"]["dropout"]
self.EOS = config['model']["EOS"] self.EOS = config["model"]["EOS"]
self.norm_first = norm_first self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1 assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin # should be same as num of kmeans bin
# assert self.EOS == 1024 # assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim) self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding( self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
self.ar_text_position = SinePositionalEmbedding( self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True) self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding( self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout) self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding( self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True) self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder( self.h = TransformerEncoder(
TransformerEncoderLayer( TransformerEncoderLayer(
@ -59,28 +68,30 @@ class Text2SemanticDecoder(nn.Module):
dim_feedforward=self.model_dim * 4, dim_feedforward=self.model_dim * 4,
dropout=0.1, dropout=0.1,
batch_first=True, batch_first=True,
norm_first=norm_first, ), norm_first=norm_first,
),
num_layers=self.num_layers, num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None, ) norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear( self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.model_dim, self.vocab_size, bias=False) self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
self.ar_accuracy_metric = MulticlassAccuracy( self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size, self.vocab_size,
top_k=top_k, top_k=top_k,
average="micro", average="micro",
multidim_average="global", multidim_average="global",
ignore_index=self.EOS, ) ignore_index=self.EOS,
)
def forward(self, x, x_lens, y, y_lens, bert_feature): def forward(self, x, x_lens, y, y_lens, bert_feature):
''' """
x: phoneme_ids x: phoneme_ids
y: semantic_ids y: semantic_ids
''' """
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens) x_mask = make_pad_mask(x_lens)
@ -102,18 +113,23 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = F.pad( x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len), (0, y_len),
value=True, ) value=True,
)
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1, ), diagonal=1,
),
(x_len, 0), (x_len, 0),
value=False, ) value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len) _xy_padding_mask = (
.expand(-1, self.num_head, -1, -1) ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.reshape(bsz * self.num_head, 1, src_len)) .expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
@ -122,26 +138,28 @@ class Text2SemanticDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h( xy_dec, _ = self.h(
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, ) mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss # loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction='sum') loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item() acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(self, def infer(
x, self,
x_lens, x,
prompts, x_lens,
bert_feature, prompts,
top_k: int=-100, bert_feature,
early_stop_num: int=-1, top_k: int = -100,
temperature: float=1.0): early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
# AR Decoder # AR Decoder
@ -159,35 +177,37 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad( x_attn_mask_pad = F.pad(
x_attn_mask, x_attn_mask,
(0, y_len), (0, y_len),
value=True, ) value=True,
)
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), (x_len, 0),
value=False, ) value=False,
xy_attn_mask = torch.concat( )
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.h( xy_dec, _ = self.h(
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, ) mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling( samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature) logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
) > early_stop_num:
print("use early stop num:", early_stop_num) print("use early stop num:", early_stop_num)
stop = True stop = True
if torch.argmax( if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True stop = True
if stop: if stop:
if prompts.shape[1] == y.shape[1]: if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1) y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction') print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y # 本次生成的 semantic_ids 和之前的 y 构成新的 y
@ -198,23 +218,24 @@ class Text2SemanticDecoder(nn.Module):
return y return y
def pad_y_eos(self, y, y_mask_int, eos_id): def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad( targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y, (0, 1), value=0) + eos_id * F.pad( y_mask_int, (0, 1), value=1
y_mask_int, (0, 1), value=1) )
# 错位 # 错位
return targets[:, :-1], targets[:, 1:] return targets[:, :-1], targets[:, 1:]
def infer_panel(self, def infer_panel(
x,#####全部文本token self,
x_lens, x, #####全部文本token
prompts,####参考音频token x_lens,
bert_feature, prompts, ####参考音频token
top_k: int=-100, bert_feature,
early_stop_num: int=-1, top_k: int = -100,
temperature: float=1.0): early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
# AR Decoder # AR Decoder
@ -224,75 +245,81 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False stop = False
# print(1111111,self.num_layers) # print(1111111,self.num_layers)
cache={ cache = {
"all_stage":self.num_layers, "all_stage": self.num_layers,
"k":[None]*self.num_layers,###根据配置自己手写 "k": [None] * self.num_layers, ###根据配置自己手写
"v":[None]*self.num_layers, "v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb":None,##只需要对最新的samples求emb再拼历史的就行 "y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管 # "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits # "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer":1, "first_infer": 1,
"stage":0 "stage": 0,
} }
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if(cache["first_infer"]==1): if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y)
else: else:
y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1) y_emb = torch.cat(
cache["y_emb"]=y_emb [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型 # x 和逐渐增长的 y 一起输入给模型
if(cache["first_infer"]==1): if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1) xy_pos = torch.concat([x, y_pos], dim=1)
else: else:
xy_pos=y_pos[:,-1:] xy_pos = y_pos[:, -1:]
y_len = y_pos.shape[1] y_len = y_pos.shape[1]
###以下3个不做缓存 ###以下3个不做缓存
if (cache["first_infer"] == 1): if cache["first_infer"] == 1:
x_attn_mask_pad = F.pad( x_attn_mask_pad = F.pad(
x_attn_mask, x_attn_mask,
(0, y_len),###xx的纯0扩展到xx纯0+xy纯1(x,x+y) (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True, ) value=True,
y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y) )
torch.triu( y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), (x_len, 0),
value=False, ) value=False,
xy_attn_mask = torch.concat( )
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
else: else:
###最右边一列(是错的) ###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False # xy_attn_mask[:,-1]=False
###最下面一行(是对的) ###最下面一行(是对的)
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device) xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
# pdb.set_trace() # pdb.set_trace()
###缓存重头戏 ###缓存重头戏
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len) # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
xy_dec, _ = self.h( xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
(xy_pos, None), logits = self.ar_predict_layer(
mask=xy_attn_mask,cache=cache ) xy_dec[:, -1]
logits = self.ar_predict_layer(xy_dec[:, -1])##不用改如果用了cache的默认就是只有一帧取最后一帧一样的 ) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) samples = sample(
if early_stop_num != -1 and (y.shape[1] - prefix_len logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
) > early_stop_num: )[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num) print("use early stop num:", early_stop_num)
stop = True stop = True
if torch.argmax( if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True stop = True
if stop: if stop:
if prompts.shape[1] == y.shape[1]: if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1) y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction') print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y # 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs # print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
cache["first_infer"]=0 cache["first_infer"] = 0
return y,idx return y, idx

View File

@ -2,6 +2,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
if max_length is None: if max_length is None:
max_length = length.max() max_length = length.max()
@ -9,7 +10,7 @@ def sequence_mask(length, max_length=None):
return x.unsqueeze(0) < length.unsqueeze(1) return x.unsqueeze(0) < length.unsqueeze(1)
def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor: def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
""" """
Args: Args:
lengths: lengths:
@ -38,11 +39,9 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(logits, def top_k_top_p_filtering(
top_k=0, logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
top_p=1.0, ):
filter_value=-float("Inf"),
min_tokens_to_keep=1):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size, vocabulary size) logits: logits distribution shape (batch size, vocabulary size)
@ -53,16 +52,14 @@ def top_k_top_p_filtering(logits,
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
if top_k > 0: if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k # Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
if top_p < 1.0: if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum( cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove = cumulative_probs > top_p
@ -70,13 +67,13 @@ def top_k_top_p_filtering(logits,
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove) 1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
return logits return logits
@ -100,6 +97,8 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
from typing import Optional, Tuple from typing import Optional, Tuple
def multinomial_sample_one_no_sync( def multinomial_sample_one_no_sync(
probs_sort, probs_sort,
): # Does multinomial sampling without a cuda synchronization ): # Does multinomial sampling without a cuda synchronization
@ -115,7 +114,7 @@ def logits_to_probs(
top_p: Optional[int] = None, top_p: Optional[int] = None,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
): ):
previous_tokens=previous_tokens.squeeze() previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape) # print(logits.shape,previous_tokens.shape)
# pdb.set_trace() # pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0: if previous_tokens is not None and repetition_penalty != 1.0:
@ -159,4 +158,3 @@ def sample(
) )
idx_next = multinomial_sample_one_no_sync(probs) idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs return idx_next, probs

View File

@ -13,7 +13,9 @@ from torch.nn.parameter import Parameter
from torch.nn import functional as F from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward=multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
class MultiheadAttention(Module): class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information r"""Allows the model to jointly attend to information
@ -76,66 +78,71 @@ class MultiheadAttention(Module):
bias_v: Optional[torch.Tensor] bias_v: Optional[torch.Tensor]
def __init__( def __init__(
self, self,
embed_dim, embed_dim,
num_heads, num_heads,
dropout=0.0, dropout=0.0,
bias=True, bias=True,
add_bias_kv=False, add_bias_kv=False,
add_zero_attn=False, add_zero_attn=False,
kdim=None, kdim=None,
vdim=None, vdim=None,
batch_first=False, batch_first=False,
linear1_cls=Linear, linear1_cls=Linear,
linear2_cls=Linear, linear2_cls=Linear,
device=None, device=None,
dtype=None, ) -> None: dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (self.kdim == embed_dim and self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.vdim == embed_dim)
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.batch_first = batch_first self.batch_first = batch_first
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim assert (
), "embed_dim must be divisible by num_heads" self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv: if add_bias_kv:
self.bias_k = Parameter( self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
else: else:
self.bias_k = self.bias_v = None self.bias_k = self.bias_v = None
if linear1_cls == Linear: if linear1_cls == Linear:
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter( self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)) torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter( self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)) torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter( self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)) torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None) self.register_parameter("in_proj_weight", None)
else: else:
self.in_proj_weight = Parameter( self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None) self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None) self.register_parameter("v_proj_weight", None)
if bias: if bias:
self.in_proj_bias = Parameter( self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)) torch.empty(3 * embed_dim, **factory_kwargs)
)
else: else:
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear( self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs) embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters() self._reset_parameters()
else: else:
@ -143,7 +150,8 @@ class MultiheadAttention(Module):
raise NotImplementedError raise NotImplementedError
else: else:
self.in_proj_linear = linear1_cls( self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None) self.register_parameter("q_proj_weight", None)
@ -156,7 +164,8 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None) self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls( self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs) embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None: if self.bias_k is not None:
xavier_normal_(self.bias_k) xavier_normal_(self.bias_k)
@ -190,14 +199,15 @@ class MultiheadAttention(Module):
super(MultiheadAttention, self).__setstate__(state) super(MultiheadAttention, self).__setstate__(state)
def forward( def forward(
self, self,
query: Tensor, query: Tensor,
key: Tensor, key: Tensor,
value: Tensor, value: Tensor,
key_padding_mask: Optional[Tensor]=None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool=True, need_weights: bool = True,
attn_mask: Optional[Tensor]=None, attn_mask: Optional[Tensor] = None,
average_attn_weights: bool=True,cache=None average_attn_weights: bool = True,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -251,23 +261,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None: if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype _kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point( if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask): key_padding_mask
):
raise AssertionError( raise AssertionError(
"only bool and floating types of key_padding_mask are supported" "only bool and floating types of key_padding_mask are supported"
) )
why_not_fast_path = "" why_not_fast_path = ""
if not is_batched: if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value: elif query is not key or key is not value:
# When lifting this restriction, don't forget to either # When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where # enforce that the dtypes all match or test cases where
# they don't! # they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (self.in_proj_bias is not None and elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
query.dtype != self.in_proj_bias.dtype):
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (self.in_proj_weight is not None and elif (
query.dtype != self.in_proj_weight.dtype): self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message. # this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training: elif self.training:
@ -288,29 +301,41 @@ class MultiheadAttention(Module):
why_not_fast_path = "attn_mask was not None" why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None: elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = ( why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input") "key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1: elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd" why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled(): elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled" why_not_fast_path = "autocast is enabled"
if not why_not_fast_path: if not why_not_fast_path:
tensor_args = (query, key, value, self.in_proj_weight, tensor_args = (
self.in_proj_bias, self.out_proj.weight, query,
self.out_proj.bias, ) key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support # We have to use list comprehensions below because TorchScript does not support
# generator expressions. # generator expressions.
if torch.overrides.has_torch_function(tensor_args): if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function" why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) elif not all(
for x in tensor_args]): [
why_not_fast_path = ( (x is None or x.is_cuda or "cpu" in str(x.device))
"some Tensor argument is neither CUDA nor CPU") for x in tensor_args
]
):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any( elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]): [x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = ( why_not_fast_path = (
"grad is enabled and at least one of query or the " "grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad") "input/output projection weights or biases requires_grad"
)
if not why_not_fast_path: if not why_not_fast_path:
return torch._native_multi_head_attention( return torch._native_multi_head_attention(
query, query,
@ -322,17 +347,21 @@ class MultiheadAttention(Module):
self.in_proj_bias, self.in_proj_bias,
self.out_proj.weight, self.out_proj.weight,
self.out_proj.bias, self.out_proj.bias,
key_padding_mask key_padding_mask if key_padding_mask is not None else attn_mask,
if key_padding_mask is not None else attn_mask,
need_weights, need_weights,
average_attn_weights, average_attn_weights,
1 if key_padding_mask is not None else 0 1
if attn_mask is not None else None, ) if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, ( assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. " "MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}") + f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched: if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property # make sure that the transpose op does not affect the "is" property
@ -343,9 +372,7 @@ class MultiheadAttention(Module):
query, key = [x.transpose(1, 0) for x in (query, key)] query, key = [x.transpose(1, 0) for x in (query, key)]
value = key value = key
else: else:
query, key, value = [ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
x.transpose(1, 0) for x in (query, key, value)
]
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward( attn_output, attn_output_weights = F.multi_head_attention_forward(
@ -370,7 +397,9 @@ class MultiheadAttention(Module):
q_proj_weight=self.q_proj_weight, q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,cache=cache ) average_attn_weights=average_attn_weights,
cache=cache,
)
else: else:
attn_output, attn_output_weights = F.multi_head_attention_forward( attn_output, attn_output_weights = F.multi_head_attention_forward(
query, query,
@ -390,7 +419,9 @@ class MultiheadAttention(Module):
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=need_weights, need_weights=need_weights,
attn_mask=attn_mask, attn_mask=attn_mask,
average_attn_weights=average_attn_weights,cache=cache ) average_attn_weights=average_attn_weights,
cache=cache,
)
if self.batch_first and is_batched: if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights return attn_output.transpose(1, 0), attn_output_weights
else: else:

View File

@ -7,10 +7,11 @@ from torch import nn
class TokenEmbedding(nn.Module): class TokenEmbedding(nn.Module):
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
vocab_size: int, vocab_size: int,
dropout: float=0.0, ): dropout: float = 0.0,
):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -24,7 +25,7 @@ class TokenEmbedding(nn.Module):
return self.word_embeddings.weight return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor: def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index:index + 1] return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = self.word_embeddings(x) x = self.word_embeddings(x)
@ -34,11 +35,12 @@ class TokenEmbedding(nn.Module):
class SinePositionalEmbedding(nn.Module): class SinePositionalEmbedding(nn.Module):
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
dropout: float=0.0, dropout: float = 0.0,
scale: bool=False, scale: bool = False,
alpha: bool=False, ): alpha: bool = False,
):
super().__init__() super().__init__()
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
@ -59,13 +61,14 @@ class SinePositionalEmbedding(nn.Module):
pe = torch.zeros(x.size(1), self.embedding_dim) pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse: if self.reverse:
position = torch.arange( position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else: else:
position = torch.arange( position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
-(math.log(10000.0) / self.embedding_dim)) * -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
@ -74,5 +77,5 @@ class SinePositionalEmbedding(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x) self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)] output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output) return self.dropout(output)

View File

@ -12,14 +12,16 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
""" """
def __init__(self, def __init__(
optimizer, self,
init_lr, optimizer,
peak_lr, init_lr,
end_lr, peak_lr,
warmup_steps=10000, end_lr,
total_steps=400000, warmup_steps=10000,
current_step=0): total_steps=400000,
current_step=0,
):
self.init_lr = init_lr self.init_lr = init_lr
self.peak_lr = peak_lr self.peak_lr = peak_lr
self.end_lr = end_lr self.end_lr = end_lr
@ -33,10 +35,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
self._last_lr = [self.lr] self._last_lr = [self.lr]
def set_lr(self, lr): def set_lr(self, lr):
self._last_lr = [g['lr'] for g in self.optimizer.param_groups] self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
for g in self.optimizer.param_groups: for g in self.optimizer.param_groups:
# g['lr'] = lr # g['lr'] = lr
g['lr'] = self.end_lr###锁定用线性 g["lr"] = self.end_lr ###锁定用线性
def step(self): def step(self):
if self._current_step < self.warmup_steps: if self._current_step < self.warmup_steps:
@ -47,7 +49,8 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
else: else:
decay_ratio = (self._current_step - self.warmup_steps) / ( decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps) self.total_steps - self.warmup_steps
)
if decay_ratio < 0.0 or decay_ratio > 1.0: if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError( raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings." "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
@ -55,25 +58,19 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定! self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
self.set_lr(lr) self.set_lr(lr)
self.lr = lr self.lr = lr
self._current_step += 1 self._current_step += 1
return self.lr return self.lr
if __name__ == "__main__":
if __name__ == '__main__':
m = nn.Linear(10, 10) m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4) opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule( s = WarmupCosineLRSchedule(
opt, opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
1e-6, )
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0)
lrs = [] lrs = []
for i in range(25000): for i in range(25000):
s.step() s.step()

View File

@ -1,9 +1,16 @@
from torch.nn.functional import * from torch.nn.functional import *
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
# import torch # import torch
# Tensor = torch.Tensor # Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union # from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched( def multi_head_attention_forward_patched(
query: Tensor, query: Tensor,
key: Tensor, key: Tensor,
@ -29,7 +36,8 @@ def multi_head_attention_forward_patched(
static_k: Optional[Tensor] = None, static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None, static_v: Optional[Tensor] = None,
average_attn_weights: bool = True, average_attn_weights: bool = True,
is_causal: bool = False,cache=None is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -105,7 +113,17 @@ def multi_head_attention_forward_patched(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
""" """
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if has_torch_function(tens_ops): if has_torch_function(tens_ops):
return handle_torch_function( return handle_torch_function(
multi_head_attention_forward, multi_head_attention_forward,
@ -134,10 +152,13 @@ def multi_head_attention_forward_patched(
v_proj_weight=v_proj_weight, v_proj_weight=v_proj_weight,
static_k=static_k, static_k=static_k,
static_v=static_v, static_v=static_v,
average_attn_weights=average_attn_weights,cache=cache average_attn_weights=average_attn_weights,
cache=cache,
) )
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the # is batched, run the computation and before returning squeeze the
@ -159,7 +180,7 @@ def multi_head_attention_forward_patched(
mask_name="key_padding_mask", mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask), other_type=_none_or_dtype(attn_mask),
other_name="attn_mask", other_name="attn_mask",
target_type=query.dtype target_type=query.dtype,
) )
if is_causal and attn_mask is None: if is_causal and attn_mask is None:
@ -184,59 +205,84 @@ def multi_head_attention_forward_patched(
check_other=False, check_other=False,
) )
if key_padding_mask is not None: if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it. # We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no # Turn off use of is_causal hint, as the merged mask is no
# longer causal. # longer causal.
is_causal = False is_causal = False
assert embed_dim == embed_dim_to_check, \ assert (
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor): if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing # embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc') head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else: else:
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight: if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used # allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \ assert (
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else: else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
# #
# compute in-projection # compute in-projection
# #
if not use_separate_proj_weight: if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else: else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" assert (
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" q_proj_weight is not None
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" ), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None: if in_proj_bias is None:
b_q = b_k = b_v = None b_q = b_k = b_v = None
else: else:
b_q, b_k, b_v = in_proj_bias.chunk(3) b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) q, k, v = _in_projection(
if(cache!=None): query,
if(cache["first_infer"]==1): key,
cache["k"][cache["stage"]]=k value,
q_proj_weight,
k_proj_weight,
v_proj_weight,
b_q,
b_k,
b_v,
)
if cache != None:
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
# print(0,cache["k"].shape) # print(0,cache["k"].shape)
cache["v"][cache["stage"]]=v cache["v"][cache["stage"]] = v
else:###12个layer每个都要留自己的cache_kv else: ###12个layer每个都要留自己的cache_kv
# print(1,cache["k"].shape) # print(1,cache["k"].shape)
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1但是proj的时候可能transpose了所以时序到0维了 cache["k"][cache["stage"]] = torch.cat(
cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0) [cache["k"][cache["stage"]], k], 0
) ##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
# print(2, cache["k"].shape) # print(2, cache["k"].shape)
src_len = cache["k"][cache["stage"]].shape[0] src_len = cache["k"][cache["stage"]].shape[0]
k=cache["k"][cache["stage"]] k = cache["k"][cache["stage"]]
v=cache["v"][cache["stage"]] v = cache["v"][cache["stage"]]
# if attn_mask is not None: # if attn_mask is not None:
# attn_mask=attn_mask[-1:,] # attn_mask=attn_mask[-1:,]
# print(attn_mask.shape,attn_mask) # print(attn_mask.shape,attn_mask)
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
# print(2333,cache) # print(2333,cache)
# prep attention mask # prep attention mask
@ -255,14 +301,20 @@ def multi_head_attention_forward_patched(
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len) correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size: if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len) correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size: if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else: else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
# add bias along batch dimension (currently second) # add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None: if bias_k is not None and bias_v is not None:
@ -286,26 +338,34 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \ assert (
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim, \ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k k = static_k
if static_v is None: if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else: else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed # TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \ assert (
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim, \ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v v = static_v
# add zero attention along batch dimension (now first) # add zero attention along batch dimension (now first)
if add_zero_attn: if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim) zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) k = torch.cat(
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
if attn_mask is not None: if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1)) attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None: if key_padding_mask is not None:
@ -316,10 +376,15 @@ def multi_head_attention_forward_patched(
# merge key padding and attention masks # merge key padding and attention masks
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \ assert key_padding_mask.shape == (
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" bsz,
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ src_len,
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None: if attn_mask is None:
attn_mask = key_padding_mask attn_mask = key_padding_mask
else: else:
@ -337,10 +402,14 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape B, Nt, E = q.shape
q_scaled = q / math.sqrt(E) q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None: if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
else: else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1) attn_output_weights = softmax(attn_output_weights, dim=-1)
@ -349,7 +418,9 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@ -377,8 +448,12 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim) k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) attn_output = scaled_dot_product_attention(
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving. # floors), should be expectation-preserving.
floor = -0.043637 floor = -0.043637
ceil = 1.2 ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor) d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
) + torch.rand_like(deriv) deriv
)
if __name__ == "__main__": if __name__ == "__main__":
# for self-testing only. # for self-testing only.
assert d_scaled.min() >= 0.0 assert d_scaled.min() >= 0.0
@ -75,7 +76,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor: def backward(ctx, y_grad: Tensor) -> Tensor:
(d, ) = ctx.saved_tensors (d,) = ctx.saved_tensors
# the same constants as used in forward pass. # the same constants as used in forward pass.
floor = -0.043637 floor = -0.043637
ceil = 1.2 ceil = 1.2
@ -96,11 +97,12 @@ class DoubleSwish(torch.nn.Module):
class ActivationBalancerFunction(torch.autograd.Function): class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
x: Tensor, x: Tensor,
scale_factor: Tensor, scale_factor: Tensor,
sign_factor: Optional[Tensor], sign_factor: Optional[Tensor],
channel_dim: int, ) -> Tensor: channel_dim: int,
) -> Tensor:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
ctx.channel_dim = channel_dim ctx.channel_dim = channel_dim
@ -125,16 +127,22 @@ class ActivationBalancerFunction(torch.autograd.Function):
scale_factor = scale_factor.unsqueeze(-1) scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor neg_delta_grad = x_grad.abs() * factor
return (x_grad - neg_delta_grad, None, None, None, ) return (
x_grad - neg_delta_grad,
None,
None,
None,
)
def _compute_scale_factor( def _compute_scale_factor(
x: Tensor, x: Tensor,
channel_dim: int, channel_dim: int,
min_abs: float, min_abs: float,
max_abs: float, max_abs: float,
gain_factor: float, gain_factor: float,
max_factor: float, ) -> Tensor: max_factor: float,
) -> Tensor:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -145,23 +153,25 @@ def _compute_scale_factor(
else: else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs. # x_abs)_mean , min_abs.
below_threshold = ( below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( min=0, max=max_factor
min=0, max=max_factor) )
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor) min=0, max=max_factor
)
return below_threshold - above_threshold return below_threshold - above_threshold
def _compute_sign_factor( def _compute_sign_factor(
x: Tensor, x: Tensor,
channel_dim: int, channel_dim: int,
min_positive: float, min_positive: float,
max_positive: float, max_positive: float,
gain_factor: float, gain_factor: float,
max_factor: float, ) -> Tensor: max_factor: float,
) -> Tensor:
if channel_dim < 0: if channel_dim < 0:
channel_dim += x.ndim channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim] sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -171,18 +181,18 @@ def _compute_sign_factor(
else: else:
# 0 if proportion_positive >= min_positive, else can be # 0 if proportion_positive >= min_positive, else can be
# as large as max_factor. # as large as max_factor.
factor1 = ((min_positive - proportion_positive) * factor1 = (
(gain_factor / min_positive)).clamp_( (min_positive - proportion_positive) * (gain_factor / min_positive)
min=0, max=max_factor) ).clamp_(min=0, max=max_factor)
if max_positive == 1.0: if max_positive == 1.0:
factor2 = 0.0 factor2 = 0.0
else: else:
# 0 if self.proportion_positive <= max_positive, else can be # 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor. # as large as -max_factor.
factor2 = ((proportion_positive - max_positive) * factor2 = (
(gain_factor / (1.0 - max_positive))).clamp_( (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
min=0, max=max_factor) ).clamp_(min=0, max=max_factor)
sign_factor = factor1 - factor2 sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1: # require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float) assert not isinstance(sign_factor, float)
@ -230,17 +240,18 @@ class ActivationBalancer(torch.nn.Module):
""" """
def __init__( def __init__(
self, self,
num_channels: int, num_channels: int,
channel_dim: int, channel_dim: int,
min_positive: float=0.05, min_positive: float = 0.05,
max_positive: float=0.95, max_positive: float = 0.95,
max_factor: float=0.04, max_factor: float = 0.04,
sign_gain_factor: float=0.01, sign_gain_factor: float = 0.01,
scale_gain_factor: float=0.02, scale_gain_factor: float = 0.02,
min_abs: float=0.2, min_abs: float = 0.2,
max_abs: float=100.0, max_abs: float = 100.0,
min_prob: float=0.1, ): min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module):
self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
torch.jit.is_tracing()):
return _no_op(x) return _no_op(x)
count = self.cpu_count count = self.cpu_count
@ -276,7 +286,7 @@ class ActivationBalancer(torch.nn.Module):
# the prob of doing some work exponentially decreases from 0.5 till it hits # the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default) # a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5**(1 + (count / 4000.0))) prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
if random.random() < prob: if random.random() < prob:
sign_gain_factor = 0.5 sign_gain_factor = 0.5
@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module):
self.min_positive, self.min_positive,
self.max_positive, self.max_positive,
gain_factor=self.sign_gain_factor / prob, gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor, ) max_factor=self.max_factor,
)
else: else:
sign_factor = None sign_factor = None
@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module):
min_abs=self.min_abs, min_abs=self.min_abs,
max_abs=self.max_abs, max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob, gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor, ) max_factor=self.max_factor,
)
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x,
scale_factor, scale_factor,
sign_factor, sign_factor,
self.channel_dim, ) self.channel_dim,
)
else: else:
return _no_op(x) return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, def BalancedDoubleSwish(
min_prob=0.25) -> nn.Sequential: d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
""" """
ActivationBalancer -> DoubleSwish ActivationBalancer -> DoubleSwish
""" """
balancer = ActivationBalancer( balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob) d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
return nn.Sequential( return nn.Sequential(
balancer, balancer,
DoubleSwish(), ) DoubleSwish(),
)

View File

@ -26,26 +26,28 @@ class LayerNorm(nn.Module):
elementwise_affine: bool elementwise_affine: bool
def __init__( def __init__(
self, self,
normalized_shape: _shape_t, normalized_shape: _shape_t,
eps: float=1e-5, eps: float = 1e-5,
elementwise_affine: bool=True, elementwise_affine: bool = True,
device=None, device=None,
dtype=None, ) -> None: dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__() super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment # mypy error: incompatible types in assignment
normalized_shape = (normalized_shape, ) # type: ignore[assignment] normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple( self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
normalized_shape) # type: ignore[arg-type]
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine self.elementwise_affine = elementwise_affine
if self.elementwise_affine: if self.elementwise_affine:
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)) torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter( self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)) torch.empty(self.normalized_shape, **factory_kwargs)
)
else: else:
self.register_parameter("weight", None) self.register_parameter("weight", None)
self.register_parameter("bias", None) self.register_parameter("bias", None)
@ -57,36 +59,43 @@ class LayerNorm(nn.Module):
nn.init.ones_(self.weight) nn.init.ones_(self.weight)
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any=None) -> Tensor: def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple): if isinstance(input, tuple):
input, embedding = input input, embedding = input
return (F.layer_norm( return (
input, F.layer_norm(
self.normalized_shape, input,
self.weight, self.normalized_shape,
self.bias, self.weight,
self.eps, ), embedding, ) self.bias,
self.eps,
),
embedding,
)
assert embedding is None assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight, return F.layer_norm(
self.bias, self.eps) input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return ( return (
"{normalized_shape}, eps={eps}, " "{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)) "elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module): class IdentityNorm(nn.Module):
def __init__( def __init__(
self, self,
d_model: int, d_model: int,
eps: float=1e-5, eps: float = 1e-5,
device=None, device=None,
dtype=None, ) -> None: dtype=None,
) -> None:
super(IdentityNorm, self).__init__() super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any=None) -> Tensor: def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple): if isinstance(input, tuple):
return input return input
@ -121,11 +130,13 @@ class TransformerEncoder(nn.Module):
self.norm = norm self.norm = norm
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
mask: Optional[Tensor]=None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor]=None, src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool=False,cache=None ) -> Tensor: return_layer_states: bool = False,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module):
output = mod( output = mod(
output, output,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache) src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
layer_states.append(output[0]) layer_states.append(output[0])
if self.norm is not None: if self.norm is not None:
@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module):
output = src output = src
for mod in self.layers: for mod in self.layers:
output = mod(output, output = mod(
src_mask=mask, output,
src_key_padding_mask=src_key_padding_mask, cache=cache) src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None: if self.norm is not None:
output = self.norm(output) output = self.norm(output)
@ -168,43 +184,47 @@ class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"] __constants__ = ["batch_first", "norm_first"]
def __init__( def __init__(
self, self,
d_model: int, d_model: int,
nhead: int, nhead: int,
dim_feedforward: int=2048, dim_feedforward: int = 2048,
dropout: float=0.1, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]]=F.relu, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool=False, batch_first: bool = False,
norm_first: bool=False, norm_first: bool = False,
device=None, device=None,
dtype=None, dtype=None,
linear1_self_attention_cls: nn.Module=nn.Linear, linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module=nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module=nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module=nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module=LayerNorm, layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float=1e-5, layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False, ) -> None: adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead) # print(233333333333,d_model,nhead)
# import os # import os
# os._exit(2333333) # os._exit(2333333)
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
d_model,#512 16 d_model, # 512 16
nhead, nhead,
dropout=dropout, dropout=dropout,
batch_first=batch_first, batch_first=batch_first,
linear1_cls=linear1_self_attention_cls, linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls, linear2_cls=linear2_self_attention_cls,
**factory_kwargs, ) **factory_kwargs,
)
# Implementation of Feedforward model # Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, self.linear1 = linear1_feedforward_cls(
**factory_kwargs) d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, self.linear2 = linear2_feedforward_cls(
**factory_kwargs) dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout)
@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module):
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm: if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm( norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
d_model, eps=layer_norm_eps, **factory_kwargs)
else: else:
norm2 = layer_norm_cls( norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm: if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@ -249,10 +267,12 @@ class TransformerEncoderLayer(nn.Module):
self.activation = F.relu self.activation = F.relu
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
src_mask: Optional[Tensor]=None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor: src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layer. r"""Pass the input through the encoder layer.
Args: Args:
@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None: if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype _skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point( if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask): src_key_padding_mask
):
raise AssertionError( raise AssertionError(
"only bool and floating types of key_padding_mask are supported" "only bool and floating types of key_padding_mask are supported"
) )
@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module):
x = x + self._sa_block( x = x + self._sa_block(
self.norm1(x, stage_embedding), self.norm1(x, stage_embedding),
src_mask, src_mask,
src_key_padding_mask,cache=cache ) src_key_padding_mask,
cache=cache,
)
x = x + self._ff_block(self.norm2(x, stage_embedding)) x = x + self._ff_block(self.norm2(x, stage_embedding))
else: else:
x = self.norm1( x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache), x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding, ) stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding) x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple: if is_src_tuple:
@ -295,12 +319,14 @@ class TransformerEncoderLayer(nn.Module):
# self-attention block # self-attention block
def _sa_block( def _sa_block(
self, self,
x: Tensor, x: Tensor,
attn_mask: Optional[Tensor], attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor: key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask) # print(x.shape,attn_mask.shape,key_padding_mask)
#torch.Size([1, 188, 512]) torch.Size([188, 188]) None # torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os # import os
# os._exit(23333) # os._exit(23333)
x = self.self_attn( x = self.self_attn(
@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module):
x, x,
attn_mask=attn_mask, attn_mask=attn_mask,
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=False,cache=cache )[0] need_weights=False,
cache=cache,
)[0]
return self.dropout1(x) return self.dropout1(x)
# feed forward block # feed forward block
@ -328,20 +356,23 @@ class AdaptiveLayerNorm(nn.Module):
self.d_model = d_model self.d_model = d_model
self.eps = self.norm.eps self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor: def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple): if isinstance(input, tuple):
input, embedding = input input, embedding = input
weight, bias = torch.split( weight, bias = torch.split(
self.project_layer(embedding), self.project_layer(embedding),
split_size_or_sections=self.d_model, split_size_or_sections=self.d_model,
dim=-1, ) dim=-1,
)
return (weight * self.norm(input) + bias, embedding) return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split( weight, bias = torch.split(
self.project_layer(embedding), self.project_layer(embedding),
split_size_or_sections=self.d_model, split_size_or_sections=self.d_model,
dim=-1, ) dim=-1,
)
return weight * self.norm(input) + bias return weight * self.norm(input) + bias
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

View File

@ -27,46 +27,44 @@ class GruutPhonemizer:
"": "", "": "",
"": "", "": "",
"«": "«", "«": "«",
"»": "»" "»": "»",
} }
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])" self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
def _normalize_punctuation(self, text: str) -> str: def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text) text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text) text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text) text = regex.sub(r"\pZ+", r" ", text)
return text.strip() return text.strip()
def _convert_punctuation(self, word: Word) -> str: def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes: if not word.phonemes:
return '' return ""
if word.phonemes[0] in ['', '|']: if word.phonemes[0] in ["", "|"]:
return word.text.strip() return word.text.strip()
phonemes = ''.join(word.phonemes) phonemes = "".join(word.phonemes)
# remove modifier characters ˈˌː with regex # remove modifier characters ˈˌː with regex
phonemes = re.sub(r'[ˈˌː͡]', '', phonemes) phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
return phonemes.strip() return phonemes.strip()
def phonemize(self, text: str, espeak: bool=False) -> str: def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text) text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [ sents: List[Sentence] = [
sent sent
for sent in self._phonemizer( for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
text_to_phonemize, lang="en-us", espeak=espeak)
] ]
words: List[str] = [ words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents) self._convert_punctuation(word) for word in itertools.chain(*sents)
] ]
return ' '.join(words) return " ".join(words)
def transform(self, phonemes): def transform(self, phonemes):
# convert phonemes to ids # convert phonemes to ids
# dictionary is in symbols.py # dictionary is in symbols.py
return [ return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
self.symbol_to_id[p] for p in phonemes
if p in self.symbol_to_id.keys()
]
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,7 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
PAD = '_' PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” ' PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'" IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS) SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ") SPACE_ID = SYMBOLS.index(" ")

View File

@ -11,22 +11,24 @@ def load_yaml_config(path):
def save_config_to_yaml(config, path): def save_config_to_yaml(config, path):
assert path.endswith('.yaml') assert path.endswith(".yaml")
with open(path, 'w') as f: with open(path, "w") as f:
f.write(yaml.dump(config)) f.write(yaml.dump(config))
f.close() f.close()
def write_args(args, path): def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args) args_dict = dict(
if not name.startswith('_')) (name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
with open(path, 'a') as args_file: )
args_file.write('==> torch version: {}\n'.format(torch.__version__)) with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write( args_file.write(
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) "==> cudnn version: {}\n".format(torch.backends.cudnn.version())
args_file.write('==> Cmd:\n') )
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv)) args_file.write(str(sys.argv))
args_file.write('\n==> args:\n') args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()): for k, v in sorted(args_dict.items()):
args_file.write(' %s: %s\n' % (str(k), str(v))) args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close() args_file.close()

View File

@ -1,31 +1,31 @@
train: train:
seed: 1234 seed: 1234
epochs: 300 epochs: 300
batch_size: 8 batch_size: 8
gradient_accumulation: 4 gradient_accumulation: 4
save_every_n_epoch: 1 save_every_n_epoch: 1
precision: 16 precision: 16
gradient_clip: 1.0 gradient_clip: 1.0
optimizer: optimizer:
lr: 0.01 lr: 0.01
lr_init: 0.00001 lr_init: 0.00001
lr_end: 0.0001 lr_end: 0.0001
warmup_steps: 2000 warmup_steps: 2000
decay_steps: 40000 decay_steps: 40000
data: data:
max_eval_sample: 8 max_eval_sample: 8
max_sec: 54 max_sec: 54
num_workers: 1 num_workers: 1
pad_val: 1024 # same with EOS in model pad_val: 1024 # same with EOS in model
model: model:
vocab_size: 1025 vocab_size: 1025
phoneme_vocab_size: 512 phoneme_vocab_size: 512
embedding_dim: 512 embedding_dim: 512
hidden_dim: 512 hidden_dim: 512
head: 16 head: 16
linear_units: 2048 linear_units: 2048
n_layer: 12 n_layer: 12
dropout: 0 dropout: 0
EOS: 1024 EOS: 1024
inference: inference:
top_k: 5 top_k: 5

View File

@ -1,31 +1,31 @@
train: train:
seed: 1234 seed: 1234
epochs: 300 epochs: 300
batch_size: 8 batch_size: 8
gradient_accumulation: 4 gradient_accumulation: 4
save_every_n_epoch: 1 save_every_n_epoch: 1
precision: 16-mixed precision: 16-mixed
gradient_clip: 1.0 gradient_clip: 1.0
optimizer: optimizer:
lr: 0.01 lr: 0.01
lr_init: 0.00001 lr_init: 0.00001
lr_end: 0.0001 lr_end: 0.0001
warmup_steps: 2000 warmup_steps: 2000
decay_steps: 40000 decay_steps: 40000
data: data:
max_eval_sample: 8 max_eval_sample: 8
max_sec: 54 max_sec: 54
num_workers: 1 num_workers: 1
pad_val: 1024 # same with EOS in model pad_val: 1024 # same with EOS in model
model: model:
vocab_size: 1025 vocab_size: 1025
phoneme_vocab_size: 512 phoneme_vocab_size: 512
embedding_dim: 1024 embedding_dim: 1024
hidden_dim: 1024 hidden_dim: 1024
head: 16 head: 16
linear_units: 2048 linear_units: 2048
n_layer: 16 n_layer: 16
dropout: 0 dropout: 0
EOS: 1024 EOS: 1024
inference: inference:
top_k: 5 top_k: 5

View File

@ -1,31 +1,31 @@
train: train:
seed: 1234 seed: 1234
epochs: 300 epochs: 300
batch_size: 12 batch_size: 12
gradient_accumulation: 4 gradient_accumulation: 4
save_every_n_epoch: 1 save_every_n_epoch: 1
precision: 16-mixed precision: 16-mixed
gradient_clip: 1.0 gradient_clip: 1.0
optimizer: optimizer:
lr: 0.01 lr: 0.01
lr_init: 0.00001 lr_init: 0.00001
lr_end: 0.0001 lr_end: 0.0001
warmup_steps: 2000 warmup_steps: 2000
decay_steps: 40000 decay_steps: 40000
data: data:
max_eval_sample: 8 max_eval_sample: 8
max_sec: 54 max_sec: 54
num_workers: 1 num_workers: 1
pad_val: 1024 # same with EOS in model pad_val: 1024 # same with EOS in model
model: model:
vocab_size: 1025 vocab_size: 1025
phoneme_vocab_size: 512 phoneme_vocab_size: 512
embedding_dim: 1024 embedding_dim: 1024
hidden_dim: 1024 hidden_dim: 1024
head: 16 head: 16
linear_units: 2048 linear_units: 2048
n_layer: 6 n_layer: 6
dropout: 0 dropout: 0
EOS: 1024 EOS: 1024
inference: inference:
top_k: 5 top_k: 5

View File

@ -1,31 +1,31 @@
train: train:
seed: 1234 seed: 1234
epochs: 20 epochs: 20
batch_size: 8 batch_size: 8
save_every_n_epoch: 1 save_every_n_epoch: 1
precision: 16-mixed precision: 16-mixed
gradient_clip: 1.0 gradient_clip: 1.0
optimizer: optimizer:
lr: 0.01 lr: 0.01
lr_init: 0.00001 lr_init: 0.00001
lr_end: 0.0001 lr_end: 0.0001
warmup_steps: 2000 warmup_steps: 2000
decay_steps: 40000 decay_steps: 40000
data: data:
max_eval_sample: 8 max_eval_sample: 8
max_sec: 54 max_sec: 54
num_workers: 4 num_workers: 4
pad_val: 1024 # same with EOS in model pad_val: 1024 # same with EOS in model
model: model:
vocab_size: 1025 vocab_size: 1025
phoneme_vocab_size: 512 phoneme_vocab_size: 512
embedding_dim: 512 embedding_dim: 512
hidden_dim: 512 hidden_dim: 512
head: 16 head: 16
linear_units: 2048 linear_units: 2048
n_layer: 24 n_layer: 24
dropout: 0 dropout: 0
EOS: 1024 EOS: 1024
random_bert: 0 random_bert: 0
inference: inference:
top_k: 5 top_k: 5

View File

@ -1,77 +1,77 @@
train: train:
seed: 1234 seed: 1234
epochs: 100 epochs: 100
batch_size: 6 batch_size: 6
gradient_accumulation: 4 gradient_accumulation: 4
save_every_n_epoch: 1 save_every_n_epoch: 1
precision: 32 precision: 32
gradient_clip: 1.0 gradient_clip: 1.0
optimizer: optimizer:
lr: 0.01 lr: 0.01
lr_init: 0.00001 lr_init: 0.00001
lr_end: 0.0001 lr_end: 0.0001
warmup_steps: 2000 warmup_steps: 2000
decay_steps: 40000 decay_steps: 40000
data: data:
max_eval_sample: 8 max_eval_sample: 8
max_sec: 40 max_sec: 40
num_workers: 1 num_workers: 1
pad_val: 1024 # same with EOS in model pad_val: 1024 # same with EOS in model
model: model:
saving_path: "ckpt/" saving_path: "ckpt/"
resume_checkpoint: null resume_checkpoint: null
vocoder_config_path: "quantizer/new_ckpt/config.json" vocoder_config_path: "quantizer/new_ckpt/config.json"
vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000" vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
datadir: "/home/liweiche/GigaSpeech/wavs" datadir: "/home/liweiche/GigaSpeech/wavs"
metapath: "/home/liweiche/GigaSpeech/train2.json" metapath: "/home/liweiche/GigaSpeech/train2.json"
val_metapath: "/home/liweiche/GigaSpeech/dev2.json" val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
sampledir: "logs/" sampledir: "logs/"
pretrained_path: null pretrained_path: null
lr: 0.0001 lr: 0.0001
batch_size: 200.0 batch_size: 200.0
train_bucket_size: 8192 train_bucket_size: 8192
training_step: 800000 training_step: 800000
optim_flat_percent: 0.0 optim_flat_percent: 0.0
warmup_step: 50 warmup_step: 50
adam_beta1: 0.9 adam_beta1: 0.9
adam_beta2: 0.98 adam_beta2: 0.98
ffd_size: 3072 ffd_size: 3072
hidden_size: 768 hidden_size: 768
enc_nlayers: 6 enc_nlayers: 6
dec_nlayers: 6 dec_nlayers: 6
nheads: 12 nheads: 12
ar_layer: 4 ar_layer: 4
ar_ffd_size: 1024 ar_ffd_size: 1024
ar_hidden_size: 256 ar_hidden_size: 256
ar_nheads: 4 ar_nheads: 4
aligner_softmax_temp: 1.0 aligner_softmax_temp: 1.0
layer_norm_eps: 0.00001 layer_norm_eps: 0.00001
speaker_embed_dropout: 0.05 speaker_embed_dropout: 0.05
label_smoothing: 0.0 label_smoothing: 0.0
val_check_interval: 5000 val_check_interval: 5000
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
precision: "fp16" precision: "fp16"
nworkers: 16 nworkers: 16
distributed: true distributed: true
accelerator: "ddp" accelerator: "ddp"
version: null version: null
accumulate_grad_batches: 1 accumulate_grad_batches: 1
use_repetition_token: true use_repetition_token: true
use_repetition_gating: false use_repetition_gating: false
repetition_penalty: 1.0 repetition_penalty: 1.0
sampling_temperature: 1.0 sampling_temperature: 1.0
top_k: -1 top_k: -1
min_top_k: 3 min_top_k: 3
top_p: 0.8 top_p: 0.8
sample_num: 4 sample_num: 4
length_penalty_max_length: 15000 length_penalty_max_length: 15000
length_penalty_max_prob: 0.95 length_penalty_max_prob: 0.95
max_input_length: 2048 max_input_length: 2048
max_output_length: 2000 max_output_length: 2000
sample_rate: 16000 sample_rate: 16000
n_codes: 1024 n_codes: 1024
n_cluster_groups: 1 n_cluster_groups: 1
phone_context_window: 4 phone_context_window: 4
phoneset_size: 1000 phoneset_size: 1000
inference: inference:
top_k: 5 top_k: 5

View File

@ -1,32 +1,32 @@
gpu: gpu:
n_card: 1 n_card: 1
n_process_per_card: 2 n_process_per_card: 2
io: io:
text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
save_every_n_epoch: 1 save_every_n_epoch: 1
precision: 16-mixed precision: 16-mixed
gradient_clip: 1.0 gradient_clip: 1.0
optimizer: optimizer:
lr: 0.01 lr: 0.01
lr_init: 0.00001 lr_init: 0.00001
lr_end: 0.0001 lr_end: 0.0001
warmup_steps: 2000 warmup_steps: 2000
decay_steps: 40000 decay_steps: 40000
data: data:
max_eval_sample: 8 max_eval_sample: 8
max_sec: 54 max_sec: 54
num_workers: 1 num_workers: 1
pad_val: 1024 # same with EOS in model pad_val: 1024 # same with EOS in model
model: model:
vocab_size: 1025 vocab_size: 1025
phoneme_vocab_size: 512 phoneme_vocab_size: 512
embedding_dim: 512 embedding_dim: 512
hidden_dim: 512 hidden_dim: 512
head: 16 head: 16
linear_units: 2048 linear_units: 2048
n_layer: 24 n_layer: 24
dropout: 0 dropout: 0
EOS: 1024 EOS: 1024
random_bert: 0 random_bert: 0
inference: inference:
top_k: 5 top_k: 5

View File

@ -11,23 +11,30 @@ logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import ( from transformers import (
Wav2Vec2FeatureExtractor, Wav2Vec2FeatureExtractor,
HubertModel, HubertModel,
Wav2Vec2Model,
) )
import utils import utils
import torch.nn as nn import torch.nn as nn
cnhubert_base_path=None cnhubert_base_path = None
class CNHubert(nn.Module): class CNHubert(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path) self.model = HubertModel.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path) self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path
)
def forward(self, x): def forward(self, x):
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"] feats = self.model(input_values)["last_hidden_state"]
return feats return feats
# class CNHubertLarge(nn.Module): # class CNHubertLarge(nn.Module):
# def __init__(self): # def __init__(self):
# super().__init__() # super().__init__()
@ -59,12 +66,12 @@ class CNHubert(nn.Module):
# return feats # return feats
def get_model(): def get_model():
model = CNHubert() model = CNHubert()
model.eval() model.eval()
return model return model
# def get_large_model(): # def get_large_model():
# model = CNHubertLarge() # model = CNHubertLarge()
# model.eval() # model.eval()
@ -80,18 +87,18 @@ def get_model():
# model.eval() # model.eval()
# return model # return model
def get_content(hmodel, wav_16k_tensor): def get_content(hmodel, wav_16k_tensor):
with torch.no_grad(): with torch.no_grad():
feats = hmodel(wav_16k_tensor) feats = hmodel(wav_16k_tensor)
return feats.transpose(1,2) return feats.transpose(1, 2)
if __name__ == '__main__': if __name__ == "__main__":
model = get_model() model = get_model()
src_path = "/Users/Shared/原音频2.wav" src_path = "/Users/Shared/原音频2.wav"
wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000) wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
model = model model = model
wav_16k_tensor = wav_16k_tensor wav_16k_tensor = wav_16k_tensor
feats = get_content(model,wav_16k_tensor) feats = get_content(model, wav_16k_tensor)
print(feats.shape) print(feats.shape)

View File

@ -3,20 +3,23 @@ import torch
def get_model(): def get_model():
import whisper import whisper
model = whisper.load_model("small", device='cpu')
model = whisper.load_model("small", device="cpu")
return model.encoder return model.encoder
def get_content(model=None, wav_16k_tensor=None): def get_content(model=None, wav_16k_tensor=None):
from whisper import log_mel_spectrogram, pad_or_trim from whisper import log_mel_spectrogram, pad_or_trim
dev = next(model.parameters()).device dev = next(model.parameters()).device
mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000] mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
# if torch.cuda.is_available(): # if torch.cuda.is_available():
# mel = mel.to(torch.float16) # mel = mel.to(torch.float16)
feature_len = mel.shape[-1] // 2 feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频" assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad(): with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1,2) feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
:1, :feature_len, :
].transpose(1, 2)
return feature return feature

File diff suppressed because it is too large Load Diff

View File

@ -1,189 +1,189 @@
import math import math
import numpy as np
import torch import torch
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01): def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find("Conv") != -1: if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std) m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1): def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2) return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape):
l = pad_shape[::-1] l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist] pad_shape = [item for sublist in l for item in sublist]
return pad_shape return pad_shape
def intersperse(lst, item): def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1) result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst result[1::2] = lst
return result return result
def kl_divergence(m_p, logs_p, m_q, logs_q): def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)""" """KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5 kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) kl += (
return kl 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape): def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows.""" """Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples)) return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x): def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g return g
def slice_segments(x, ids_str, segment_size=4): def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size]) ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)): for i in range(x.size(0)):
idx_str = ids_str[i] idx_str = ids_str[i]
idx_end = idx_str + segment_size idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end] ret[i] = x[i, :, idx_str:idx_end]
return ret return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4): def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size() b, d, t = x.size()
if x_lengths is None: if x_lengths is None:
x_lengths = t x_lengths = t
ids_str_max = x_lengths - segment_size + 1 ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size) ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str return ret, ids_str
def get_timing_signal_1d( def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
length, channels, min_timescale=1.0, max_timescale=1.0e4): position = torch.arange(length, dtype=torch.float)
position = torch.arange(length, dtype=torch.float) num_timescales = channels // 2
num_timescales = channels // 2 log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
log_timescale_increment = ( num_timescales - 1
math.log(float(max_timescale) / float(min_timescale)) / )
(num_timescales - 1)) inv_timescales = min_timescale * torch.exp(
inv_timescales = min_timescale * torch.exp( torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) )
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2]) signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length) signal = signal.view(1, channels, length)
return signal return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size() b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device) return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size() b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length): def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask return mask
@torch.jit.script @torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0] n_channels_int = n_channels[0]
in_act = input_a + input_b in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :]) t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act acts = t_act * s_act
return acts return acts
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape):
l = pad_shape[::-1] l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist] pad_shape = [item for sublist in l for item in sublist]
return pad_shape return pad_shape
def shift_1d(x): def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x return x
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
if max_length is None: if max_length is None:
max_length = length.max() max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device) x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1) return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask): def generate_path(duration, mask):
""" """
duration: [b, 1, t_x] duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x] mask: [b, 1, t_y, t_x]
""" """
device = duration.device device = duration.device
b, _, t_y, t_x = mask.shape b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1) cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x) cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y) path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask path = path.unsqueeze(1).transpose(2, 3) * mask
return path return path
def clip_grad_value_(parameters, clip_value, norm_type=2): def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters)) parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type) norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None: if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value) clip_value = float(clip_value)
total_norm = total_norm ** (1. / norm_type)
return total_norm total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
def squeeze(x, x_mask=None, n_sqz=2): def squeeze(x, x_mask=None, n_sqz=2):
b, c, t = x.size() b, c, t = x.size()
t = (t // n_sqz) * n_sqz t = (t // n_sqz) * n_sqz
x = x[:, :, :t] x = x[:, :, :t]
x_sqz = x.view(b, c, t // n_sqz, n_sqz) x_sqz = x.view(b, c, t // n_sqz, n_sqz)
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
if x_mask is not None: if x_mask is not None:
x_mask = x_mask[:, :, n_sqz - 1::n_sqz] x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
else: else:
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
return x_sqz * x_mask, x_mask return x_sqz * x_mask, x_mask
def unsqueeze(x, x_mask=None, n_sqz=2): def unsqueeze(x, x_mask=None, n_sqz=2):
b, c, t = x.size() b, c, t = x.size()
x_unsqz = x.view(b, n_sqz, c // n_sqz, t) x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
if x_mask is not None: if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
else: else:
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
return x_unsqz * x_mask, x_mask return x_unsqz * x_mask, x_mask

View File

@ -76,10 +76,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
print("kmeans start ... ") print("kmeans start ... ")
for _ in tqdm(range(num_iters)): for _ in tqdm(range(num_iters)):
diffs = rearrange(samples, "n d -> n () d") - rearrange( diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
means, "c d -> () c d" dists = -(diffs**2).sum(dim=-1)
)
dists = -(diffs ** 2).sum(dim=-1)
buckets = dists.max(dim=-1).indices buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters) bins = torch.bincount(buckets, minlength=num_clusters)
@ -110,6 +108,7 @@ class EuclideanCodebook(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch. randomly selected vector from the current batch.
""" """
def __init__( def __init__(
self, self,
dim: int, dim: int,
@ -122,7 +121,9 @@ class EuclideanCodebook(nn.Module):
): ):
super().__init__() super().__init__()
self.decay = decay self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
uniform_init if not kmeans_init else torch.zeros
)
embed = init_fn(codebook_size, dim) embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size self.codebook_size = codebook_size
@ -147,7 +148,7 @@ class EuclideanCodebook(nn.Module):
self.cluster_size.data.copy_(cluster_size) self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True])) self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization # Make sure all buffers across workers are in sync after initialization
#broadcast_tensors(self.buffers()) # broadcast_tensors(self.buffers())
def replace_(self, samples, mask): def replace_(self, samples, mask):
modified_codebook = torch.where( modified_codebook = torch.where(
@ -165,7 +166,7 @@ class EuclideanCodebook(nn.Module):
batch_samples = rearrange(batch_samples, "... d -> (...) d") batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes) self.replace_(batch_samples, mask=expired_codes)
#broadcast_tensors(self.buffers()) # broadcast_tensors(self.buffers())
def preprocess(self, x): def preprocess(self, x):
x = rearrange(x, "... d -> (...) d") x = rearrange(x, "... d -> (...) d")
@ -246,6 +247,7 @@ class VectorQuantization(nn.Module):
randomly selected vector from the current batch. randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss. commitment_weight (float): Weight for commitment loss.
""" """
def __init__( def __init__(
self, self,
dim: int, dim: int,
@ -256,22 +258,31 @@ class VectorQuantization(nn.Module):
kmeans_init: bool = True, kmeans_init: bool = True,
kmeans_iters: int = 50, kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2, threshold_ema_dead_code: int = 2,
commitment_weight: float = 1., commitment_weight: float = 1.0,
): ):
super().__init__() super().__init__()
_codebook_dim: int = default(codebook_dim, dim) _codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) self.project_in = (
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.epsilon = epsilon self.epsilon = epsilon
self.commitment_weight = commitment_weight self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, self._codebook = EuclideanCodebook(
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, dim=_codebook_dim,
decay=decay, epsilon=epsilon, codebook_size=codebook_size,
threshold_ema_dead_code=threshold_ema_dead_code) kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size self.codebook_size = codebook_size
@property @property
@ -316,13 +327,16 @@ class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation. """Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
""" """
def __init__(self, *, num_quantizers, **kwargs): def __init__(self, *, num_quantizers, **kwargs):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)] [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
) )
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None): def forward(
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
quantized_out = 0.0 quantized_out = 0.0
residual = x residual = x
@ -345,7 +359,9 @@ class ResidualVectorQuantization(nn.Module):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized return quantized_out, out_indices, out_losses, out_quantized
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor: def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
residual = x residual = x
all_indices = [] all_indices = []
n_q = n_q or len(self.layers) n_q = n_q or len(self.layers)
@ -358,10 +374,10 @@ class ResidualVectorQuantization(nn.Module):
out_indices = torch.stack(all_indices) out_indices = torch.stack(all_indices)
return out_indices return out_indices
def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor: def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device) quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices): for i, indices in enumerate(q_indices):
layer = self.layers[st + i] layer = self.layers[st + i]
quantized = layer.decode(indices) quantized = layer.decode(indices)
quantized_out = quantized_out + quantized quantized_out = quantized_out + quantized
return quantized_out return quantized_out

View File

@ -1,6 +1,6 @@
import time,logging import time, logging
import os import os
import random,traceback import random, traceback
import numpy as np import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
@ -16,41 +16,44 @@ import torch
import requests import requests
from scipy.io import wavfile from scipy.io import wavfile
from io import BytesIO from io import BytesIO
# from config import exp_dir # from config import exp_dir
from my_utils import load_audio from my_utils import load_audio
class TextAudioSpeakerLoader(torch.utils.data.Dataset): class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
1) loads audio, speaker_id, text pairs 1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers 2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files. 3) computes spectrograms from audio files.
""" """
def __init__(self, hparams, val=False): def __init__(self, hparams, val=False):
exp_dir=hparams.exp_dir exp_dir = hparams.exp_dir
self.path2="%s/2-name2text.txt"%exp_dir self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4="%s/4-cnhubert"%exp_dir self.path4 = "%s/4-cnhubert" % exp_dir
self.path5="%s/5-wav32k"%exp_dir self.path5 = "%s/5-wav32k" % exp_dir
assert os.path.exists(self.path2) assert os.path.exists(self.path2)
assert os.path.exists(self.path4) assert os.path.exists(self.path4)
assert os.path.exists(self.path5) assert os.path.exists(self.path5)
names4=set([name[:-3]for name in list(os.listdir(self.path4))])#去除.pt后缀 names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5=set(os.listdir(self.path5)) names5 = set(os.listdir(self.path5))
self.phoneme_data={} self.phoneme_data = {}
with open(self.path2,"r",encoding="utf8")as f: with open(self.path2, "r", encoding="utf8") as f:
lines=f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")
for line in lines: for line in lines:
tmp=line.split("\t") tmp = line.split("\t")
if(len(tmp)!=4):continue if len(tmp) != 4:
self.phoneme_data[tmp[0]]=[tmp[1]] continue
self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text=list(set(self.phoneme_data)&names4&names5) self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp=self.audiopaths_sid_text tmp = self.audiopaths_sid_text
leng=len(tmp) leng = len(tmp)
min_num=100 min_num = 100
if(leng<min_num): if leng < min_num:
self.audiopaths_sid_text=[] self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))): for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp self.audiopaths_sid_text += tmp
self.max_wav_value = hparams.max_wav_value self.max_wav_value = hparams.max_wav_value
@ -69,20 +72,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audiopaths_sid_text_new = [] audiopaths_sid_text_new = []
lengths = [] lengths = []
skipped_phone = 0 skipped_phone = 0
skipped_dur = 0 skipped_dur = 0
for audiopath in tqdm(self.audiopaths_sid_text): for audiopath in tqdm(self.audiopaths_sid_text):
try: try:
phoneme = self.phoneme_data[audiopath][0] phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ') phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme) phoneme_ids = cleaned_text_to_sequence(phoneme)
except Exception: except Exception:
print(f"{audiopath} not in self.phoneme_data !") print(f"{audiopath} not in self.phoneme_data !")
skipped_phone += 1 skipped_phone += 1
continue continue
size=os.path.getsize("%s/%s"%(self.path5,audiopath)) size = os.path.getsize("%s/%s" % (self.path5, audiopath))
duration = size / self.sampling_rate / 2 duration = size / self.sampling_rate / 2
if (54 > duration > 0.6 or self.val): if 54 > duration > 0.6 or self.val:
audiopaths_sid_text_new.append([audiopath, phoneme_ids]) audiopaths_sid_text_new.append([audiopath, phoneme_ids])
lengths.append(size // (2 * self.hop_length)) lengths.append(size // (2 * self.hop_length))
else: else:
@ -90,7 +93,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
continue continue
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
print("total left: ", len(audiopaths_sid_text_new)) print("total left: ", len(audiopaths_sid_text_new))
assert len(audiopaths_sid_text_new)>1#至少能凑够batch size这里todo assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths self.lengths = lengths
@ -98,30 +101,41 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audiopath, phoneme_ids = audiopath_sid_text audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids) text = torch.FloatTensor(phoneme_ids)
try: try:
spec, wav = self.get_audio("%s/%s"%(self.path5,audiopath)) spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad(): with torch.no_grad():
ssl = torch.load("%s/%s.pt"%(self.path4,audiopath),map_location="cpu") ssl = torch.load(
if(ssl.shape[-1]!=spec.shape[-1]): "%s/%s.pt" % (self.path4, audiopath), map_location="cpu"
typee=ssl.dtype )
ssl=F.pad(ssl.float(),(0,1),mode="replicate").to(typee) if ssl.shape[-1] != spec.shape[-1]:
ssl.requires_grad=False typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
except: except:
traceback.print_exc() traceback.print_exc()
spec = torch.zeros(1025, 100) spec = torch.zeros(1025, 100)
wav = torch.zeros(1, 100*self.hop_length) wav = torch.zeros(1, 100 * self.hop_length)
ssl=torch.zeros(1,768,100) ssl = torch.zeros(1, 768, 100)
text=text[-1:] text = text[-1:]
print("load audio or ssl error!!!!!!", audiopath) print("load audio or ssl error!!!!!!", audiopath)
# print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad) # print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
return (ssl, spec, wav, text) return (ssl, spec, wav, text)
def get_audio(self, filename): def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768 audio_array = load_audio(
filename, self.sampling_rate
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
# print(filename,audio_array.max(),audio_array.min(),audio_array.mean()) # print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
audio=torch.FloatTensor(audio_array)#/32768 audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,self.sampling_rate, self.hop_length, self.win_length,center=False) spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
return spec, audio_norm return spec, audio_norm
@ -131,39 +145,51 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
# with torch.no_grad(): # with torch.no_grad():
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
def __len__(self): def __len__(self):
return len(self.audiopaths_sid_text) return len(self.audiopaths_sid_text)
def random_slice(self, ssl, wav, mel): def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1]- wav.shape[-1]//self.hop_length) < 3, ("first", ssl.shape, wav.shape) assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first",
ssl.shape,
wav.shape,
)
len_mel = mel.shape[1] len_mel = mel.shape[1]
if self.val: if self.val:
reference_mel = mel[:, :len_mel//3] reference_mel = mel[:, : len_mel // 3]
return reference_mel, ssl, wav, mel return reference_mel, ssl, wav, mel
dir = random.randint(0, 1) dir = random.randint(0, 1)
sep_point = random.randint(int(len_mel//3), int(len_mel//3*2)) sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
if dir == 0: if dir == 0:
reference_mel = mel[:, :sep_point] reference_mel = mel[:, :sep_point]
ssl = ssl[:, :, sep_point:] ssl = ssl[:, :, sep_point:]
wav2 = wav[:, sep_point*self.hop_length:] wav2 = wav[:, sep_point * self.hop_length :]
mel = mel[:, sep_point:] mel = mel[:, sep_point:]
else: else:
reference_mel = mel[:, sep_point:] reference_mel = mel[:, sep_point:]
ssl = ssl[:, :, :sep_point] ssl = ssl[:, :, :sep_point]
wav2 = wav[:, :sep_point*self.hop_length] wav2 = wav[:, : sep_point * self.hop_length]
mel = mel[:, :sep_point] mel = mel[:, :sep_point]
assert abs(ssl.shape[-1]- wav2.shape[-1]//self.hop_length) < 3, (ssl.shape, wav.shape,wav2.shape, mel.shape, sep_point,self.hop_length, sep_point*self.hop_length, dir) assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape,
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate(): class TextAudioSpeakerCollate:
""" Zero-pads model inputs and targets """Zero-pads model inputs and targets"""
"""
def __init__(self, return_ids=False): def __init__(self, return_ids=False):
self.return_ids = return_ids self.return_ids = return_ids
@ -176,8 +202,8 @@ class TextAudioSpeakerCollate():
""" """
# Right zero-pad all one-hot text sequences to max input length # Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort( _, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]), torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
dim=0, descending=True) )
max_ssl_len = max([x[0].size(2) for x in batch]) max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@ -194,7 +220,7 @@ class TextAudioSpeakerCollate():
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len) text_padded = torch.LongTensor(len(batch), max_text_len)
spec_padded.zero_() spec_padded.zero_()
wav_padded.zero_() wav_padded.zero_()
@ -205,23 +231,31 @@ class TextAudioSpeakerCollate():
row = batch[ids_sorted_decreasing[i]] row = batch[ids_sorted_decreasing[i]]
ssl = row[0] ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :] ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2) ssl_lengths[i] = ssl.size(2)
spec = row[1] spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1) spec_lengths[i] = spec.size(1)
wav = row[2] wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1) wav_lengths[i] = wav.size(1)
text = row[3] text = row[3]
text_padded[i, :text.size(0)] = text text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
return (
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
)
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
@ -234,7 +268,15 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
""" """
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): def __init__(
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths self.lengths = dataset.lengths
# print(233333333333333,self.lengths,dir(dataset)) # print(233333333333333,self.lengths,dir(dataset))
@ -254,7 +296,7 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
buckets[idx_bucket].append(i) buckets[idx_bucket].append(i)
for i in range(len(buckets) - 1, 0, -1): for i in range(len(buckets) - 1, 0, -1):
# for i in range(len(buckets) - 1, -1, -1): # for i in range(len(buckets) - 1, -1, -1):
if len(buckets[i]) == 0: if len(buckets[i]) == 0:
buckets.pop(i) buckets.pop(i)
self.boundaries.pop(i + 1) self.boundaries.pop(i + 1)
@ -263,7 +305,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
for i in range(len(buckets)): for i in range(len(buckets)):
len_bucket = len(buckets[i]) len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size total_batch_size = self.num_replicas * self.batch_size
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem) num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket return buckets, num_samples_per_bucket
@ -289,14 +333,23 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample # subsample
ids_bucket = ids_bucket[self.rank::self.num_replicas] ids_bucket = ids_bucket[self.rank :: self.num_replicas]
# batching # batching
for j in range(len(ids_bucket) // self.batch_size): for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]] batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch) batches.append(batch)
if self.shuffle: if self.shuffle:

View File

@ -5,64 +5,69 @@ from torch.nn import functional as F
def feature_loss(fmap_r, fmap_g): def feature_loss(fmap_r, fmap_g):
loss = 0 loss = 0
for dr, dg in zip(fmap_r, fmap_g): for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg): for rl, gl in zip(dr, dg):
rl = rl.float().detach() rl = rl.float().detach()
gl = gl.float() gl = gl.float()
loss += torch.mean(torch.abs(rl - gl)) loss += torch.mean(torch.abs(rl - gl))
return loss * 2 return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs): def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0 loss = 0
r_losses = [] r_losses = []
g_losses = [] g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float() dr = dr.float()
dg = dg.float() dg = dg.float()
r_loss = torch.mean((1-dr)**2) r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2) g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss) loss += r_loss + g_loss
r_losses.append(r_loss.item()) r_losses.append(r_loss.item())
g_losses.append(g_loss.item()) g_losses.append(g_loss.item())
return loss, r_losses, g_losses return loss, r_losses, g_losses
def generator_loss(disc_outputs): def generator_loss(disc_outputs):
loss = 0 loss = 0
gen_losses = [] gen_losses = []
for dg in disc_outputs: for dg in disc_outputs:
dg = dg.float() dg = dg.float()
l = torch.mean((1-dg)**2) l = torch.mean((1 - dg) ** 2)
gen_losses.append(l) gen_losses.append(l)
loss += l loss += l
return loss, gen_losses return loss, gen_losses
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
""" """
z_p, logs_q: [b, h, t_t] z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t] m_p, logs_p: [b, h, t_t]
""" """
z_p = z_p.float() z_p = z_p.float()
logs_q = logs_q.float() logs_q = logs_q.float()
m_p = m_p.float() m_p = m_p.float()
logs_p = logs_p.float() logs_p = logs_p.float()
z_mask = z_mask.float() z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def mle_loss(z, m, logs, logdet, mask): def mle_loss(z, m, logs, logdet, mask):
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term l = torch.sum(logs) + 0.5 * torch.sum(
l = l - torch.sum(logdet) # log jacobian determinant torch.exp(-2 * logs) * ((z - m) ** 2)
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes ) # neg normal likelihood w/o the constant term
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term l = l - torch.sum(logdet) # log jacobian determinant
return l l = l / torch.sum(
torch.ones_like(z) * mask
) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l

View File

@ -49,21 +49,37 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.: if torch.min(y) < -1.0:
print('min value is ', torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.: if torch.max(y) > 1.0:
print('max value is ', torch.max(y)) print("max value is ", torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], spec = torch.stft(
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec return spec
@ -71,37 +87,63 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis global mel_basis
dtype_device = str(spec.dtype) + '_' + str(spec.device) dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec) spec = spectral_normalize_torch(spec)
return spec return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): def mel_spectrogram_torch(
if torch.min(y) < -1.: y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
print('min value is ', torch.min(y)) ):
if torch.max(y) > 1.: if torch.min(y) < -1.0:
print('max value is ', torch.max(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window global mel_basis, hann_window
dtype_device = str(y.dtype) + '_' + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + '_' + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], spec = torch.stft(
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

View File

@ -12,12 +12,21 @@ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from module.commons import init_weights, get_padding
from module.mrte_model import MRTE from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
from text import symbols from text import symbols
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): def __init__(
self,
in_channels,
filter_channels,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
):
super().__init__() super().__init__()
filter_channels = in_channels # it needs to be removed from future version. filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels self.in_channels = in_channels
@ -31,21 +40,29 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2)) self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows): for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList() self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2)) self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4): for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip()) self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1) self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -66,7 +83,10 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w) h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask) h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q z_q = e_q
for flow in self.post_flows: for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -74,8 +94,13 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1) z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) logdet_tot_q += torch.sum(
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0 logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask) z0, logdet = self.log_flow(z0, x_mask)
@ -84,12 +109,18 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows: for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse) z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # [b] return nll + logq # [b]
else: else:
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows: for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse) z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1) z0, z1 = torch.split(z, [1, 1], 1)
@ -98,7 +129,9 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -108,9 +141,13 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels) self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels) self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1) self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -135,15 +172,17 @@ class DurationPredictor(nn.Module):
class TextEncoder(nn.Module): class TextEncoder(nn.Module):
def __init__(self, def __init__(
out_channels, self,
hidden_channels, out_channels,
filter_channels, hidden_channels,
n_heads, filter_channels,
n_layers, n_heads,
kernel_size, n_layers,
p_dropout, kernel_size,
latent_channels=192): p_dropout,
latent_channels=192,
):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -160,17 +199,14 @@ class TextEncoder(nn.Module):
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers//2, n_layers // 2,
kernel_size, kernel_size,
p_dropout) p_dropout,
)
self.encoder_text = attentions.Encoder( self.encoder_text = attentions.Encoder(
hidden_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
filter_channels, )
n_heads,
n_layers,
kernel_size,
p_dropout)
self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE() self.mrte = MRTE()
@ -179,21 +215,25 @@ class TextEncoder(nn.Module):
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers//2, n_layers // 2,
kernel_size, kernel_size,
p_dropout) p_dropout,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, test=None): def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
y = self.ssl_proj(y * y_mask) * y_mask y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask) y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype) text_mask = torch.unsqueeze(
if test == 1 : commons.sequence_mask(text_lengths, text.size(1)), 1
).to(y.dtype)
if test == 1:
text[:, :] = 0 text[:, :] = 0
text = self.text_embedding(text).transpose(1, 2) text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask) text = self.encoder_text(text * text_mask, text_mask)
@ -208,9 +248,9 @@ class TextEncoder(nn.Module):
def extract_latent(self, x): def extract_latent(self, x):
x = self.ssl_proj(x) x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x) quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0,1) return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer,refer_mask, ge):
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask y = self.vq_proj(quantized) * y_mask
@ -224,15 +264,18 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__(self, def __init__(
channels, self,
hidden_channels, channels,
kernel_size, hidden_channels,
dilation_rate, kernel_size,
n_layers, dilation_rate,
n_flows=4, n_layers,
gin_channels=0): n_flows=4,
gin_channels=0,
):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -245,8 +288,16 @@ class ResidualCouplingBlock(nn.Module):
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
for i in range(n_flows): for i in range(n_flows):
self.flows.append( self.flows.append(
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, modules.ResidualCouplingLayer(
gin_channels=gin_channels, mean_only=True)) channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False): def forward(self, x, x_mask, g=None, reverse=False):
@ -260,14 +311,16 @@ class ResidualCouplingBlock(nn.Module):
class PosteriorEncoder(nn.Module): class PosteriorEncoder(nn.Module):
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels,
hidden_channels, out_channels,
kernel_size, hidden_channels,
dilation_rate, kernel_size,
n_layers, dilation_rate,
gin_channels=0): n_layers,
gin_channels=0,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -278,13 +331,21 @@ class PosteriorEncoder(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
if(g!=None): if g != None:
g = g.detach() g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
@ -294,14 +355,16 @@ class PosteriorEncoder(nn.Module):
class WNEncoder(nn.Module): class WNEncoder(nn.Module):
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels,
hidden_channels, out_channels,
kernel_size, hidden_channels,
dilation_rate, kernel_size,
n_layers, dilation_rate,
gin_channels=0): n_layers,
gin_channels=0,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -312,11 +375,20 @@ class WNEncoder(nn.Module):
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.norm = modules.LayerNorm(out_channels) self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask out = self.proj(x) * x_mask
@ -325,24 +397,45 @@ class WNEncoder(nn.Module):
class Generator(torch.nn.Module): class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, def __init__(
upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) self.conv_pre = Conv1d(
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm( self.ups.append(
ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), weight_norm(
k, u, padding=(k - u) // 2))) ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -373,7 +466,7 @@ class Generator(torch.nn.Module):
return x return x
def remove_weight_norm(self): def remove_weight_norm(self):
print('Removing weight norm...') print("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_weight_norm(l)
for l in self.resblocks: for l in self.resblocks:
@ -386,13 +479,55 @@ class DiscriminatorP(torch.nn.Module):
self.period = period self.period = period
self.use_spectral_norm = use_spectral_norm self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([ self.convs = nn.ModuleList(
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), [
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), norm_f(
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), Conv2d(
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 1,
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 32,
]) (kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x): def forward(self, x):
@ -421,14 +556,16 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__() super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([ self.convs = nn.ModuleList(
norm_f(Conv1d(1, 16, 15, 1, padding=7)), [
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
]) norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x): def forward(self, x):
@ -451,7 +588,9 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11] periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs) self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat): def forward(self, y, y_hat):
@ -469,31 +608,40 @@ class MultiPeriodDiscriminator(torch.nn.Module):
return y_d_rs, y_d_gs, fmap_rs, fmap_gs return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class ReferenceEncoder(nn.Module): class ReferenceEncoder(nn.Module):
''' """
inputs --- [N, Ty/r, n_mels*r] mels inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size] outputs --- [N, ref_enc_gru_size]
''' """
def __init__(self, spec_channels, gin_channels=0): def __init__(self, spec_channels, gin_channels=0):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128] ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters) K = len(ref_enc_filters)
filters = [1] + ref_enc_filters filters = [1] + ref_enc_filters
convs = [weight_norm(nn.Conv2d(in_channels=filters[i], convs = [
out_channels=filters[i + 1], weight_norm(
kernel_size=(3, 3), nn.Conv2d(
stride=(2, 2), in_channels=filters[i],
padding=(1, 1))) for i in range(K)] out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs) self.convs = nn.ModuleList(convs)
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(input_size=ref_enc_filters[-1] * out_channels, self.gru = nn.GRU(
hidden_size=256 // 2, input_size=ref_enc_filters[-1] * out_channels,
batch_first=True) hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, gin_channels) self.proj = nn.Linear(128, gin_channels)
def forward(self, inputs): def forward(self, inputs):
@ -527,23 +675,31 @@ class Quantizer_module(torch.nn.Module):
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
def forward(self, x): def forward(self, x):
d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) - 2 * torch.matmul(x, self.embedding.weight.T) d = (
torch.sum(x**2, 1, keepdim=True)
+ torch.sum(self.embedding.weight**2, 1)
- 2 * torch.matmul(x, self.embedding.weight.T)
)
min_indicies = torch.argmin(d, 1) min_indicies = torch.argmin(d, 1)
z_q = self.embedding(min_indicies) z_q = self.embedding(min_indicies)
return z_q, min_indicies return z_q, min_indicies
class Quantizer(torch.nn.Module): class Quantizer(torch.nn.Module):
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160): def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
super(Quantizer, self).__init__() super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0 assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList([ self.quantizer_modules = nn.ModuleList(
Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups) [
]) Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
)
self.n_code_groups = n_code_groups self.n_code_groups = n_code_groups
self.embed_dim = embed_dim self.embed_dim = embed_dim
def forward(self, xin): def forward(self, xin):
#B, C, T # B, C, T
B, C, T = xin.shape B, C, T = xin.shape
xin = xin.transpose(1, 2) xin = xin.transpose(1, 2)
x = xin.reshape(-1, self.embed_dim) x = xin.reshape(-1, self.embed_dim)
@ -553,38 +709,41 @@ class Quantizer(torch.nn.Module):
for _x, m in zip(x, self.quantizer_modules): for _x, m in zip(x, self.quantizer_modules):
_z_q, _min_indicies = m(_x) _z_q, _min_indicies = m(_x)
z_q.append(_z_q) z_q.append(_z_q)
min_indicies.append(_min_indicies) #B * T, min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape) z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
z_q = xin + (z_q - xin).detach() z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2) z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
return z_q, loss, codes.transpose(1, 2) return z_q, loss, codes.transpose(1, 2)
def embed(self, x): def embed(self, x):
#idx: N, 4, T # idx: N, 4, T
x=x.transpose(1, 2) x = x.transpose(1, 2)
x = torch.split(x, 1, 2) x = torch.split(x, 1, 2)
ret = [] ret = []
for q, embed in zip(x, self.quantizer_modules): for q, embed in zip(x, self.quantizer_modules):
q = embed.embedding(q.squeeze(-1)) q = embed.embedding(q.squeeze(-1))
ret.append(q) ret.append(q)
ret = torch.cat(ret, -1) ret = torch.cat(ret, -1)
return ret.transpose(1, 2) #N, C, T return ret.transpose(1, 2) # N, C, T
class CodePredictor(nn.Module): class CodePredictor(nn.Module):
def __init__(self, def __init__(
hidden_channels, self,
filter_channels, hidden_channels,
n_heads, filter_channels,
n_layers, n_heads,
kernel_size, n_layers,
p_dropout, kernel_size,
n_q=8, p_dropout,
dims=1024, n_q=8,
ssl_dim=768 dims=1024,
): ssl_dim=768,
):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
@ -594,19 +753,18 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels) self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.encoder = attentions.Encoder( self.encoder = attentions.Encoder(
hidden_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
filter_channels, )
n_heads,
n_layers,
kernel_size,
p_dropout)
self.out_proj = nn.Conv1d(hidden_channels, (n_q-1) * dims, 1) self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q self.n_q = n_q
self.dims = dims self.dims = dims
def forward(self, x, x_mask, refer, codes, infer=False): def forward(self, x, x_mask, refer, codes, infer=False):
x = x.detach() x = x.detach()
x = self.vq_proj(x * x_mask) * x_mask x = self.vq_proj(x * x_mask) * x_mask
@ -614,7 +772,9 @@ class CodePredictor(nn.Module):
x = x + g x = x + g
x = self.encoder(x * x_mask, x_mask) x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3) logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
target = codes[1:].transpose(0, 1) target = codes[1:].transpose(0, 1)
if not infer: if not infer:
logits = logits.reshape(-1, self.dims) logits = logits.reshape(-1, self.dims)
@ -626,44 +786,44 @@ class CodePredictor(nn.Module):
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1) correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item() top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
print('Top-10 Accuracy:', top3_acc, "%") print("Top-10 Accuracy:", top3_acc, "%")
pred_codes = torch.argmax(logits, dim=-1) pred_codes = torch.argmax(logits, dim=-1)
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item() acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
print('Top-1 Accuracy:', acc, "%") print("Top-1 Accuracy:", acc, "%")
return pred_codes.transpose(0, 1) return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module): class SynthesizerTrn(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
""" """
def __init__(self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs):
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs
):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
self.inter_channels = inter_channels self.inter_channels = inter_channels
@ -685,34 +845,50 @@ class SynthesizerTrn(nn.Module):
self.use_sdp = use_sdp self.use_sdp = use_sdp
self.enc_p = TextEncoder( self.enc_p = TextEncoder(
inter_channels, inter_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout) p_dropout,
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, )
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) self.dec = Generator(
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, inter_channels,
gin_channels=gin_channels) resblock,
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) self.ref_enc = modules.MelStyleEncoder(
spec_channels, style_vector_dim=gin_channels
)
ssl_dim = 768 ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"] assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz': if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else: else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer( self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
dimension=ssl_dim,
n_q=1,
bins=1024
)
if freeze_quantizer: if freeze_quantizer:
self.ssl_proj.requires_grad_(False) self.ssl_proj.requires_grad_(False)
self.quantizer.requires_grad_(False) self.quantizer.requires_grad_(False)
@ -721,56 +897,85 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.mrte.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)
def forward(self, ssl, y, y_lengths, text, text_lengths): def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False): with autocast(enabled=False):
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge) z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) z_slice, ids_slice = commons.rand_slice_segments(
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=ge) o = self.dec(z_slice, g=ge)
return o, commit_loss, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized return (
o,
commit_loss,
ids_slice,
y_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
quantized,
)
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5): def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0]) quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge, test=test
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge) o = self.dec((z * y_mask)[:, :, :], g=ge)
return o,y_mask, (z, z_p, m_p, logs_p) return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad() @torch.no_grad()
def decode(self, codes,text, refer, noise_scale=0.5): def decode(self, codes, text, refer, noise_scale=0.5):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
y_lengths = torch.LongTensor([codes.size(2)*2]).to(codes.device) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz': if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -779,6 +984,6 @@ class SynthesizerTrn(nn.Module):
return o return o
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1) return codes.transpose(0, 1)

File diff suppressed because it is too large Load Diff

View File

@ -5,46 +5,74 @@ from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention from module.attentions import MultiHeadAttention
class MRTE(nn.Module): class MRTE(nn.Module):
def __init__(self, def __init__(
content_enc_channels=192, self,
hidden_size=512, content_enc_channels=192,
out_channels=192, hidden_size=512,
kernel_size=5, out_channels=192,
n_heads=4, kernel_size=5,
ge_layer = 2 n_heads=4,
): ge_layer=2,
):
super(MRTE, self).__init__() super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size,hidden_size,n_heads) self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels,hidden_size, 1) self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels,hidden_size, 1) self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size,out_channels, 1) self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None): def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
if(ge==None):ge=0 if ge == None:
ge = 0
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask) ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask) text_enc = self.text_pre(text * text_mask)
if test != None: if test != None:
if test == 0: if test == 0:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
elif test == 1: elif test == 1:
x = ssl_enc + ge x = ssl_enc + ge
elif test ==2: elif test == 2:
x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge x = (
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
else: else:
raise ValueError("test should be 0,1,2") raise ValueError("test should be 0,1,2")
else: else:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask) x = self.c_post(x * ssl_mask)
return x return x
class SpeakerEncoder(torch.nn.Module): class SpeakerEncoder(torch.nn.Module):
def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256): def __init__(
self,
mel_n_channels=80,
model_num_layers=2,
model_hidden_size=256,
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__() super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) self.lstm = nn.LSTM(
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.linear = nn.Linear(model_hidden_size, model_embedding_size) self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU() self.relu = nn.ReLU()
@ -56,13 +84,15 @@ class SpeakerEncoder(torch.nn.Module):
class MELEncoder(nn.Module): class MELEncoder(nn.Module):
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels,
hidden_channels, out_channels,
kernel_size, hidden_channels,
dilation_rate, kernel_size,
n_layers): dilation_rate,
n_layers,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -81,80 +111,82 @@ class MELEncoder(nn.Module):
x = self.enc(x) x = self.enc(x)
x = self.proj(x) x = self.proj(x)
return x return x
class WN(torch.nn.Module): class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
super(WN, self).__init__() super(WN, self).__init__()
assert(kernel_size % 2 == 1) assert kernel_size % 2 == 1
self.hidden_channels =hidden_channels self.hidden_channels = hidden_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
self.n_layers = n_layers self.n_layers = n_layers
self.in_layers = torch.nn.ModuleList() self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList()
for i in range(n_layers): for i in range(n_layers):
dilation = dilation_rate ** i dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2) padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, in_layer = nn.Conv1d(
dilation=dilation, padding=padding) hidden_channels,
in_layer = weight_norm(in_layer) 2 * hidden_channels,
self.in_layers.append(in_layer) kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer)
# last one is not necessary # last one is not necessary
if i < n_layers - 1: if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels res_skip_channels = 2 * hidden_channels
else: else:
res_skip_channels = hidden_channels res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name='weight') res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def forward(self, x): def forward(self, x):
output = torch.zeros_like(x) output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels]) n_channels_tensor = torch.IntTensor([self.hidden_channels])
for i in range(self.n_layers): for i in range(self.n_layers):
x_in = self.in_layers[i](x) x_in = self.in_layers[i](x)
acts = fused_add_tanh_sigmoid_multiply( acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
x_in,
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts) res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1: if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:] res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = (x + res_acts) x = x + res_acts
output = output + res_skip_acts[:,self.hidden_channels:,:] output = output + res_skip_acts[:, self.hidden_channels :, :]
else: else:
output = output + res_skip_acts output = output + res_skip_acts
return output return output
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.in_layers: for l in self.in_layers:
remove_weight_norm(l) remove_weight_norm(l)
for l in self.res_skip_layers: for l in self.res_skip_layers:
remove_weight_norm(l) remove_weight_norm(l)
@torch.jit.script @torch.jit.script
def fused_add_tanh_sigmoid_multiply(input, n_channels): def fused_add_tanh_sigmoid_multiply(input, n_channels):
n_channels_int = n_channels[0] n_channels_int = n_channels[0]
t_act = torch.tanh(input[:, :n_channels_int, :]) t_act = torch.tanh(input[:, :n_channels_int, :])
s_act = torch.sigmoid(input[:, n_channels_int:, :]) s_act = torch.sigmoid(input[:, n_channels_int:, :])
acts = t_act * s_act acts = t_act * s_act
return acts return acts
if __name__ == "__main__":
if __name__ == '__main__': content_enc = torch.randn(3, 192, 100)
content_enc = torch.randn(3,192,100) content_mask = torch.ones(3, 1, 100)
content_mask = torch.ones(3,1,100) ref_mel = torch.randn(3, 128, 30)
ref_mel = torch.randn(3,128,30) ref_mask = torch.ones(3, 1, 30)
ref_mask = torch.ones(3,1,30)
model = MRTE() model = MRTE()
out = model(content_enc,content_mask,ref_mel,ref_mask) out = model(content_enc, content_mask, ref_mel, ref_mask)
print(out.shape) print(out.shape)

View File

@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch. randomly selected vector from the current batch.
""" """
def __init__( def __init__(
self, self,
dimension: int = 256, dimension: int = 256,
@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module):
threshold_ema_dead_code=self.threshold_ema_dead_code, threshold_ema_dead_code=self.threshold_ema_dead_code,
) )
def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult: def forward(
self,
x: torch.Tensor,
n_q: tp.Optional[int] = None,
layers: tp.Optional[list] = None,
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor. """Residual vector quantization on the given input tensor.
Args: Args:
x (torch.Tensor): Input tensor. x (torch.Tensor): Input tensor.
@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module):
""" """
n_q = n_q if n_q else self.n_q n_q = n_q if n_q else self.n_q
if layers and max(layers) >= n_q: if layers and max(layers) >= n_q:
raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.') raise ValueError(
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers) f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
return quantized, codes, torch.mean(commit_loss), quantized_list return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth. """Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer. and returns indices for each quantizer.
@ -105,4 +116,4 @@ class ResidualVectorQuantizer(nn.Module):
st (int): Start to decode input codes from which layers. Default: 0. st (int): Start to decode input codes from which layers. Default: 0.
""" """
quantized = self.vq.decode(codes, st=st) quantized = self.vq.decode(codes, st=st)
return quantized return quantized

View File

@ -9,66 +9,63 @@ DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3 DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(inputs, def piecewise_rational_quadratic_transform(
unnormalized_widths, inputs,
unnormalized_heights, unnormalized_widths,
unnormalized_derivatives, unnormalized_heights,
inverse=False, unnormalized_derivatives,
tails=None, inverse=False,
tail_bound=1., tails=None,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, tail_bound=1.0,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_derivative=DEFAULT_MIN_DERIVATIVE): min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None: if tails is None:
spline_fn = rational_quadratic_spline spline_fn = rational_quadratic_spline
spline_kwargs = {} spline_kwargs = {}
else: else:
spline_fn = unconstrained_rational_quadratic_spline spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = { spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
'tails': tails,
'tail_bound': tail_bound
}
outputs, logabsdet = spline_fn( outputs, logabsdet = spline_fn(
inputs=inputs, inputs=inputs,
unnormalized_widths=unnormalized_widths, unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights, unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives, unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse, inverse=inverse,
min_bin_width=min_bin_width, min_bin_width=min_bin_width,
min_bin_height=min_bin_height, min_bin_height=min_bin_height,
min_derivative=min_derivative, min_derivative=min_derivative,
**spline_kwargs **spline_kwargs
) )
return outputs, logabsdet return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6): def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps bin_locations[..., -1] += eps
return torch.sum( return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
inputs[..., None] >= bin_locations,
dim=-1
) - 1
def unconstrained_rational_quadratic_spline(inputs, def unconstrained_rational_quadratic_spline(
unnormalized_widths, inputs,
unnormalized_heights, unnormalized_widths,
unnormalized_derivatives, unnormalized_heights,
inverse=False, unnormalized_derivatives,
tails='linear', inverse=False,
tail_bound=1., tails="linear",
min_bin_width=DEFAULT_MIN_BIN_WIDTH, tail_bound=1.0,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_derivative=DEFAULT_MIN_DERIVATIVE): min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs) outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs) logabsdet = torch.zeros_like(inputs)
if tails == 'linear': if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1) constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., 0] = constant
@ -77,45 +74,57 @@ def unconstrained_rational_quadratic_spline(inputs,
outputs[outside_interval_mask] = inputs[outside_interval_mask] outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0 logabsdet[outside_interval_mask] = 0
else: else:
raise RuntimeError('{} tails are not implemented.'.format(tails)) raise RuntimeError("{} tails are not implemented.".format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( (
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask], inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse, inverse=inverse,
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width, min_bin_width=min_bin_width,
min_bin_height=min_bin_height, min_bin_height=min_bin_height,
min_derivative=min_derivative min_derivative=min_derivative,
) )
return outputs, logabsdet return outputs, logabsdet
def rational_quadratic_spline(inputs,
unnormalized_widths, def rational_quadratic_spline(
unnormalized_heights, inputs,
unnormalized_derivatives, unnormalized_widths,
inverse=False, unnormalized_heights,
left=0., right=1., bottom=0., top=1., unnormalized_derivatives,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, inverse=False,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, left=0.0,
min_derivative=DEFAULT_MIN_DERIVATIVE): right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right: if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError('Input to a transform is not within its domain') raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1] num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0: if min_bin_width * num_bins > 1.0:
raise ValueError('Minimal bin width too large for the number of bins') raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0: if min_bin_height * num_bins > 1.0:
raise ValueError('Minimal bin height too large for the number of bins') raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1) widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1) cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left cumwidths[..., 0] = left
cumwidths[..., -1] = right cumwidths[..., -1] = right
@ -126,7 +135,7 @@ def rational_quadratic_spline(inputs,
heights = F.softmax(unnormalized_heights, dim=-1) heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1) cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom cumheights[..., 0] = bottom
cumheights[..., -1] = top cumheights[..., -1] = top
@ -150,15 +159,13 @@ def rational_quadratic_spline(inputs,
input_heights = heights.gather(-1, bin_idx)[..., 0] input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse: if inverse:
a = (((inputs - input_cumheights) * (input_derivatives a = (inputs - input_cumheights) * (
+ input_derivatives_plus_one input_derivatives + input_derivatives_plus_one - 2 * input_delta
- 2 * input_delta) ) + input_heights * (input_delta - input_derivatives)
+ input_heights * (input_delta - input_derivatives))) b = input_heights * input_derivatives - (inputs - input_cumheights) * (
b = (input_heights * input_derivatives input_derivatives + input_derivatives_plus_one - 2 * input_delta
- (inputs - input_cumheights) * (input_derivatives )
+ input_derivatives_plus_one c = -input_delta * (inputs - input_cumheights)
- 2 * input_delta))
c = - input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all() assert (discriminant >= 0).all()
@ -167,11 +174,15 @@ def rational_quadratic_spline(inputs,
outputs = root * input_bin_widths + input_cumwidths outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root) theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) denominator = input_delta + (
* theta_one_minus_theta) (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) * theta_one_minus_theta
+ 2 * input_delta * theta_one_minus_theta )
+ input_derivatives * (1 - root).pow(2)) derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet return outputs, -logabsdet
@ -179,15 +190,20 @@ def rational_quadratic_spline(inputs,
theta = (inputs - input_cumwidths) / input_bin_widths theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta) theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) numerator = input_heights * (
+ input_derivatives * theta_one_minus_theta) input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) )
* theta_one_minus_theta) denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) derivative_numerator = input_delta.pow(2) * (
+ 2 * input_delta * theta_one_minus_theta input_derivatives_plus_one * theta.pow(2)
+ input_derivatives * (1 - theta).pow(2)) + 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet return outputs, logabsdet

View File

@ -1,50 +1,81 @@
import os,torch,sys import os, torch, sys
from subprocess import Popen from subprocess import Popen
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from config import text_path,wav_dir,n_card,n_process_per_card,exp_name,n_parts,exp_dir from config import (
os.makedirs("%s/logs_s1"%exp_dir,exist_ok=True) text_path,
os.makedirs("%s/logs_s2"%exp_dir,exist_ok=True) wav_dir,
n_card,
exp_name,
n_parts,
exp_dir,
)
os.makedirs("%s/logs_s1" % exp_dir, exist_ok=True)
os.makedirs("%s/logs_s2" % exp_dir, exist_ok=True)
##############step1 ##############step1
ps=[] ps = []
for i_part in range(n_parts): for i_part in range(n_parts):
cmd="python prepare/1-get-text.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card) cmd = "python prepare/1-get-text.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd) print(cmd)
p = Popen(cmd, shell=True) p = Popen(cmd, shell=True)
ps.append(p) ps.append(p)
for p in ps: for p in ps:
p.wait() p.wait()
opt=[] opt = []
for i_part in range(n_parts): for i_part in range(n_parts):
txt_path = "%s/2-name2text-%s.txt" % (exp_dir, i_part) txt_path = "%s/2-name2text-%s.txt" % (exp_dir, i_part)
with open(txt_path,"r")as f: with open(txt_path, "r") as f:
opt+=f.read().strip("\n").split("\n") opt += f.read().strip("\n").split("\n")
os.remove(txt_path) os.remove(txt_path)
with open("%s/2-name2text.txt"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n") with open("%s/2-name2text.txt" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")
############step2 ############step2
ps=[] ps = []
for i_part in range(n_parts): for i_part in range(n_parts):
cmd="python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card) cmd = "python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd) print(cmd)
p = Popen(cmd, shell=True) p = Popen(cmd, shell=True)
ps.append(p) ps.append(p)
for p in ps: for p in ps:
p.wait() p.wait()
#############step3 #############step3
ps=[] ps = []
for i_part in range(n_parts): for i_part in range(n_parts):
cmd="python prepare/3-get-semantic.py %s %s %s %s %s"%(text_path,exp_name,i_part,n_parts,i_part%n_card) cmd = "python prepare/3-get-semantic.py %s %s %s %s %s" % (
text_path,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd) print(cmd)
p = Popen(cmd, shell=True) p = Popen(cmd, shell=True)
ps.append(p) ps.append(p)
for p in ps: for p in ps:
p.wait() p.wait()
opt=["item_name semantic_audio"] opt = ["item_name semantic_audio"]
for i_part in range(n_parts): for i_part in range(n_parts):
semantic_path = "%s/6-name2semantic-%s.tsv" % (exp_dir, i_part) semantic_path = "%s/6-name2semantic-%s.tsv" % (exp_dir, i_part)
with open(semantic_path,"r")as f: with open(semantic_path, "r") as f:
opt+=f.read().strip("\n").split("\n") opt += f.read().strip("\n").split("\n")
os.remove(semantic_path) os.remove(semantic_path)
with open("%s/6-name2semantic.tsv"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n") with open("%s/6-name2semantic.tsv" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")

View File

@ -2,16 +2,16 @@
import os import os
inp_text= os.environ.get("inp_text") inp_text = os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir") inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name") exp_name = os.environ.get("exp_name")
i_part= os.environ.get("i_part") i_part = os.environ.get("i_part")
all_parts= os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir= os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir= os.environ.get("bert_pretrained_dir") bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
is_half=eval(os.environ.get("is_half","True")) is_half = eval(os.environ.get("is_half", "True"))
import sys,numpy as np,traceback,pdb import sys, numpy as np, traceback, pdb
import os.path import os.path
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
@ -31,25 +31,29 @@ import numpy as np
from time import time as ttime from time import time as ttime
import shutil import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
if(os.path.exists(txt_path)==False): def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
bert_dir="%s/3-bert"%(opt_dir) dir = os.path.dirname(path)
os.makedirs(opt_dir,exist_ok=True) name = os.path.basename(path)
os.makedirs(bert_dir,exist_ok=True) tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
device="cuda:0" torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(bert_dir, exist_ok=True)
device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir) tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model=AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if (is_half == True): if is_half == True:
bert_model = bert_model.half().to(device) bert_model = bert_model.half().to(device)
else: else:
bert_model = bert_model.to(device) bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt") inputs = tokenizer(text, return_tensors="pt")
@ -67,51 +71,55 @@ if(os.path.exists(txt_path)==False):
phone_level_feature = torch.cat(phone_level_feature, dim=0) phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T return phone_level_feature.T
def process(data,res):
for name,text,lan in data: def process(data, res):
for name, text, lan in data:
try: try:
name=os.path.basename(name) name = os.path.basename(name)
phones, word2ph, norm_text=clean_text(text.replace("%", '-').replace('', ','),lan) phones, word2ph, norm_text = clean_text(
path_bert="%s/%s.pt"%(bert_dir,name) text.replace("%", "-").replace("", ","), lan
if (os.path.exists(path_bert) == False and lan == "zh"): )
path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph) bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones) assert bert_feature.shape[-1] == len(phones)
# torch.save(bert_feature, path_bert) # torch.save(bert_feature, path_bert)
my_save(bert_feature, path_bert) my_save(bert_feature, path_bert)
phones = " ".join(phones) phones = " ".join(phones)
# res.append([name,phones]) # res.append([name,phones])
res.append([name,phones, word2ph, norm_text]) res.append([name, phones, word2ph, norm_text])
except: except:
print(name, text, traceback.format_exc()) print(name, text, traceback.format_exc())
todo=[] todo = []
res=[] res = []
with open(inp_text,"r",encoding="utf8")as f: with open(inp_text, "r", encoding="utf8") as f:
lines=f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")
language_v1_to_language_v2={ language_v1_to_language_v2 = {
"ZH":"zh", "ZH": "zh",
"zh":"zh", "zh": "zh",
"JP":"ja", "JP": "ja",
"jp":"ja", "jp": "ja",
"JA":"ja", "JA": "ja",
"ja":"ja", "ja": "ja",
"EN":"en", "EN": "en",
"en":"en", "en": "en",
"En":"en", "En": "en",
} }
for line in lines[int(i_part)::int(all_parts)]: for line in lines[int(i_part) :: int(all_parts)]:
try: try:
wav_name,spk_name,language,text=line.split("|") wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"]) # todo.append([name,text,"zh"])
todo.append([wav_name,text,language_v1_to_language_v2.get(language,language)]) todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
except: except:
print(line,traceback.format_exc()) print(line, traceback.format_exc())
process(todo,res)
opt=[]
for name,phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s"%(name,phones, word2ph, norm_text))
with open(txt_path,"w",encoding="utf8")as f:
f.write("\n".join(opt)+"\n")
process(todo, res)
opt = []
for name, phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
with open(txt_path, "w", encoding="utf8") as f:
f.write("\n".join(opt) + "\n")

View File

@ -1,20 +1,23 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys,os import sys, os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
from feature_extractor import cnhubert
opt_dir= os.environ.get("opt_dir")
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
is_half=eval(os.environ.get("is_half","True"))
import pdb,traceback,numpy as np,logging inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
is_half = eval(os.environ.get("is_half", "True"))
import pdb, traceback, numpy as np, logging
from scipy.io import wavfile from scipy.io import wavfile
import librosa,torch import librosa, torch
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from my_utils import load_audio from my_utils import load_audio
@ -32,63 +35,75 @@ from my_utils import load_audio
from time import time as ttime from time import time as ttime
import shutil import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
hubert_dir="%s/4-cnhubert"%(opt_dir)
wav32dir="%s/5-wav32k"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(hubert_dir,exist_ok=True)
os.makedirs(wav32dir,exist_ok=True)
maxx=0.95 def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
alpha=0.5 dir = os.path.dirname(path)
device="cuda:0" name = os.path.basename(path)
model=cnhubert.get_model() tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
if(is_half==True): torch.save(fea, tmp_path)
model=model.half().to(device) shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
device = "cuda:0"
model = cnhubert.get_model()
if is_half == True:
model = model.half().to(device)
else: else:
model = model.to(device) model = model.to(device)
def name2go(wav_name): def name2go(wav_name):
hubert_path="%s/%s.pt"%(hubert_dir,wav_name) hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if(os.path.exists(hubert_path)):return if os.path.exists(hubert_path):
wav_path="%s/%s"%(inp_wav_dir,wav_name) return
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
tmp_audio = load_audio(wav_path, 32000) tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max() tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2: if tmp_max > 2.2:
print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max)) print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
return return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + (
tmp_audio = librosa.resample( (1 - alpha) * 32768
tmp_audio32, orig_sr=32000, target_sr=16000 ) * tmp_audio
) tmp_audio = librosa.resample(tmp_audio32, orig_sr=32000, target_sr=16000)
tensor_wav16 = torch.from_numpy(tmp_audio) tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True): if is_half == True:
tensor_wav16=tensor_wav16.half().to(device) tensor_wav16 = tensor_wav16.half().to(device)
else: else:
tensor_wav16 = tensor_wav16.to(device) tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) ssl = (
if np.isnan(ssl.detach().numpy()).sum()!= 0:return model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"]
.transpose(1, 2)
.cpu()
) # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
return
wavfile.write( wavfile.write(
"%s/%s"%(wav32dir,wav_name), "%s/%s" % (wav32dir, wav_name),
32000, 32000,
tmp_audio32.astype("int16"), tmp_audio32.astype("int16"),
) )
# torch.save(ssl,hubert_path ) # torch.save(ssl,hubert_path )
my_save(ssl,hubert_path ) my_save(ssl, hubert_path)
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]: with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try: try:
# wav_name,text=line.split("\t") # wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
wav_name=os.path.basename(wav_name) wav_name = os.path.basename(wav_name)
name2go(wav_name) name2go(wav_name)
except: except:
print(line,traceback.format_exc()) print(line, traceback.format_exc())

View File

@ -1,24 +1,27 @@
import os import os
inp_text= os.environ.get("inp_text")
exp_name= os.environ.get("exp_name") inp_text = os.environ.get("inp_text")
i_part= os.environ.get("i_part") exp_name = os.environ.get("exp_name")
all_parts= os.environ.get("all_parts") i_part = os.environ.get("i_part")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES") all_parts = os.environ.get("all_parts")
opt_dir= os.environ.get("opt_dir") os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
pretrained_s2G= os.environ.get("pretrained_s2G") opt_dir = os.environ.get("opt_dir")
s2config_path= os.environ.get("s2config_path") pretrained_s2G = os.environ.get("pretrained_s2G")
is_half=eval(os.environ.get("is_half","True")) s2config_path = os.environ.get("s2config_path")
import math,traceback is_half = eval(os.environ.get("is_half", "True"))
import math, traceback
import multiprocessing import multiprocessing
import sys,pdb import sys, pdb
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
from random import shuffle from random import shuffle
import torch.multiprocessing as mp import torch.multiprocessing as mp
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
import logging,librosa,utils,torch import logging, librosa, utils, torch
from module.models import SynthesizerTrn from module.models import SynthesizerTrn
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G # from config import pretrained_s2G
@ -30,52 +33,58 @@ logging.getLogger("numba").setLevel(logging.WARNING)
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
hubert_dir="%s/4-cnhubert"%(opt_dir) hubert_dir = "%s/4-cnhubert" % (opt_dir)
semantic_path="%s/6-name2semantic-%s.tsv"%(opt_dir,i_part) semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
if(os.path.exists(semantic_path)==False): if os.path.exists(semantic_path) == False:
os.makedirs(opt_dir,exist_ok=True) os.makedirs(opt_dir, exist_ok=True)
device="cuda:0" device = "cuda:0"
hps = utils.get_hparams_from_file(s2config_path) hps = utils.get_hparams_from_file(s2config_path)
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**hps.model) **hps.model
if(is_half==True): )
vq_model=vq_model.half().to(device) if is_half == True:
vq_model = vq_model.half().to(device)
else: else:
vq_model = vq_model.to(device) vq_model = vq_model.to(device)
vq_model.eval() vq_model.eval()
# utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True) # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True) # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
print(vq_model.load_state_dict(torch.load(pretrained_s2G,map_location="cpu")["weight"], strict=False)) print(
vq_model.load_state_dict(
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
)
)
def name2go(wav_name,lines): def name2go(wav_name, lines):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name) hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if(os.path.exists(hubert_path)==False):return if os.path.exists(hubert_path) == False:
return
ssl_content = torch.load(hubert_path, map_location="cpu") ssl_content = torch.load(hubert_path, map_location="cpu")
if(is_half==True): if is_half == True:
ssl_content=ssl_content.half().to(device) ssl_content = ssl_content.half().to(device)
else: else:
ssl_content = ssl_content.to(device) ssl_content = ssl_content.to(device)
codes = vq_model.extract_latent(ssl_content) codes = vq_model.extract_latent(ssl_content)
semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()]) semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
lines.append("%s\t%s"%(wav_name,semantic)) lines.append("%s\t%s" % (wav_name, semantic))
with open(inp_text,"r",encoding="utf8")as f: with open(inp_text, "r", encoding="utf8") as f:
lines=f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")
lines1=[] lines1 = []
for line in lines[int(i_part)::int(all_parts)]: for line in lines[int(i_part) :: int(all_parts)]:
# print(line) # print(line)
try: try:
# wav_name,text=line.split("\t") # wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|") wav_name, spk_name, language, text = line.split("|")
wav_name=os.path.basename(wav_name) wav_name = os.path.basename(wav_name)
# name2go(name,lines1) # name2go(name,lines1)
name2go(wav_name,lines1) name2go(wav_name, lines1)
except: except:
print(line,traceback.format_exc()) print(line, traceback.format_exc())
with open(semantic_path,"w",encoding="utf8")as f:f.write("\n".join(lines1)) with open(semantic_path, "w", encoding="utf8") as f:
f.write("\n".join(lines1))

View File

@ -6,49 +6,56 @@ import cn2an
from pypinyin import lazy_pinyin, Style from pypinyin import lazy_pinyin, Style
import sys import sys
sys.path.append("/data/docker/liujing04/gpt-vits/gpt-vits-master") sys.path.append("/data/docker/liujing04/gpt-vits/gpt-vits-master")
from text.symbols import punctuation from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi from text.tone_sandhi import ToneSandhi
current_file_path = os.path.dirname(__file__) current_file_path = os.path.dirname(__file__)
pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in pinyin_to_symbol_map = {
open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()} line.split("\t")[0]: line.strip().split("\t")[1]
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba.posseg as psg import jieba.posseg as psg
rep_map = { rep_map = {
'': ',', "": ",",
'': ',', "": ",",
'': ',', "": ",",
'': '.', "": ".",
'': '!', "": "!",
'': '?', "": "?",
'\n': '.', "\n": ".",
"·": ",", "·": ",",
'': ",", "": ",",
'...': '', "...": "",
'$': '.', "$": ".",
'/': ',', "/": ",",
'': "-" "": "-",
} }
tone_modifier = ToneSandhi() tone_modifier = ToneSandhi()
def replace_punctuation(text): def replace_punctuation(text):
text = text.replace("", "").replace("","") text = text.replace("", "").replace("", "")
pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys())) pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text) replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text return replaced_text
def g2p(text): def g2p(text):
pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation)) pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip()!=''] sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
phones, word2ph = _g2p(sentences) phones, word2ph = _g2p(sentences)
return phones, word2ph return phones, word2ph
@ -56,10 +63,10 @@ def g2p(text):
def _get_initials_finals(word): def _get_initials_finals(word):
initials = [] initials = []
finals = [] finals = []
orig_initials = lazy_pinyin( orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin( orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals): for c, v in zip(orig_initials, orig_finals):
initials.append(c) initials.append(c)
finals.append(v) finals.append(v)
@ -72,17 +79,16 @@ def _g2p(segments):
for seg in segments: for seg in segments:
pinyins = [] pinyins = []
# Replace all English words in the sentence # Replace all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg) seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg) seg_cut = psg.lcut(seg)
initials = [] initials = []
finals = [] finals = []
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
for word, pos in seg_cut: for word, pos in seg_cut:
if pos == 'eng': if pos == "eng":
continue continue
sub_initials, sub_finals = _get_initials_finals(word) sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
sub_finals)
initials.append(sub_initials) initials.append(sub_initials)
finals.append(sub_finals) finals.append(sub_finals)
@ -91,7 +97,7 @@ def _g2p(segments):
finals = sum(finals, []) finals = sum(finals, [])
# #
for c, v in zip(initials, finals): for c, v in zip(initials, finals):
raw_pinyin = c+v raw_pinyin = c + v
# NOTE: post process for pypinyin outputs # NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii # we discriminate i, ii and iii
if c == v: if c == v:
@ -102,40 +108,40 @@ def _g2p(segments):
v_without_tone = v[:-1] v_without_tone = v[:-1]
tone = v[-1] tone = v[-1]
pinyin = c+v_without_tone pinyin = c + v_without_tone
assert tone in '12345' assert tone in "12345"
if c: if c:
# 多音节 # 多音节
v_rep_map = { v_rep_map = {
"uei": 'ui', "uei": "ui",
'iou': 'iu', "iou": "iu",
'uen': 'un', "uen": "un",
} }
if v_without_tone in v_rep_map.keys(): if v_without_tone in v_rep_map.keys():
pinyin = c+v_rep_map[v_without_tone] pinyin = c + v_rep_map[v_without_tone]
else: else:
# 单音节 # 单音节
pinyin_rep_map = { pinyin_rep_map = {
'ing': 'ying', "ing": "ying",
'i': 'yi', "i": "yi",
'in': 'yin', "in": "yin",
'u': 'wu', "u": "wu",
} }
if pinyin in pinyin_rep_map.keys(): if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin] pinyin = pinyin_rep_map[pinyin]
else: else:
single_rep_map = { single_rep_map = {
'v': 'yu', "v": "yu",
'e': 'e', "e": "e",
'i': 'y', "i": "y",
'u': 'w', "u": "w",
} }
if pinyin[0] in single_rep_map.keys(): if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]]+pinyin[1:] pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(' ') new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone new_v = new_v + tone
phone = [new_c, new_v] phone = [new_c, new_v]
word2ph.append(len(phone)) word2ph.append(len(phone))
@ -144,9 +150,8 @@ def _g2p(segments):
return phones_list, word2ph return phones_list, word2ph
def text_normalize(text): def text_normalize(text):
numbers = re.findall(r'\d+(?:\.?\d+)?', text) numbers = re.findall(r"\d+(?:\.?\d+)?", text)
for number in numbers: for number in numbers:
text = text.replace(number, cn2an.an2cn(number), 1) text = text.replace(number, cn2an.an2cn(number), 1)
text = replace_punctuation(text) text = replace_punctuation(text)
@ -154,7 +159,7 @@ def text_normalize(text):
return text return text
if __name__ == '__main__': if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?" text = "呣呣呣~就是…大人的鼹鼠党吧?"
text = "你好" text = "你好"

View File

@ -1,29 +1,27 @@
from text import chinese, japanese, cleaned_text_to_sequence, symbols, english from text import chinese, japanese, cleaned_text_to_sequence, symbols, english
language_module_map = { language_module_map = {"zh": chinese, "ja": japanese, "en": english}
'zh': chinese,
"ja": japanese,
'en': english
}
special = [ special = [
('%', 'zh', "SP"), ("%", "zh", "SP"),
('', 'zh', "SP2"), ("", "zh", "SP2"),
('^', 'zh', "SP3"), ("^", "zh", "SP3"),
# ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧 # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
] ]
def clean_text(text, language): def clean_text(text, language):
for special_s, special_l, target_symbol in special: for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l: if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol) return clean_special(text, language, special_s, target_symbol)
language_module = language_module_map[language] language_module = language_module_map[language]
norm_text = language_module.text_normalize(text) norm_text = language_module.text_normalize(text)
if(language=="zh"): if language == "zh":
phones, word2ph = language_module.g2p(norm_text) phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph) assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph) assert len(norm_text) == len(word2ph)
else: else:
phones = language_module.g2p(norm_text) phones = language_module.g2p(norm_text)
word2ph=None word2ph = None
for ph in phones: for ph in phones:
assert ph in symbols assert ph in symbols
@ -41,17 +39,17 @@ def clean_special(text, language, special_s, target_symbol):
new_ph = [] new_ph = []
for ph in phones: for ph in phones:
assert ph in symbols assert ph in symbols
if ph == ',': if ph == ",":
new_ph.append(target_symbol) new_ph.append(target_symbol)
else: else:
new_ph.append(ph) new_ph.append(ph)
return new_ph return new_ph
def text_to_sequence(text, language): def text_to_sequence(text, language):
phones = clean_text(text) phones = clean_text(text)
return cleaned_text_to_sequence(phones) return cleaned_text_to_sequence(phones)
if __name__ == '__main__':
print(clean_text("你好%啊啊啊额、还是到付红四方。", 'zh'))
if __name__ == "__main__":
print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh"))

View File

@ -8,20 +8,87 @@ from string import punctuation
from text import symbols from text import symbols
current_file_path = os.path.dirname(__file__) current_file_path = os.path.dirname(__file__)
CMU_DICT_PATH = os.path.join(current_file_path, 'cmudict.rep') CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
CACHE_PATH = os.path.join(current_file_path, 'cmudict_cache.pickle') CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
_g2p = G2p() _g2p = G2p()
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
def replace_phs(phs): def replace_phs(phs):
rep_map = { rep_map = {";": ",", ":": ",", "'": "-", '"': "-"}
';': ',',
':': ',',
'\'': '-',
'"': '-'
}
phs_new = [] phs_new = []
for ph in phs: for ph in phs:
if ph in symbols: if ph in symbols:
@ -29,9 +96,10 @@ def replace_phs(phs):
elif ph in rep_map.keys(): elif ph in rep_map.keys():
phs_new.append(rep_map[ph]) phs_new.append(rep_map[ph])
else: else:
print('ph not in symbols: ', ph) print("ph not in symbols: ", ph)
return phs_new return phs_new
def read_dict(): def read_dict():
g2p_dict = {} g2p_dict = {}
start_line = 49 start_line = 49
@ -41,13 +109,13 @@ def read_dict():
while line: while line:
if line_index >= start_line: if line_index >= start_line:
line = line.strip() line = line.strip()
word_split = line.split(' ') word_split = line.split(" ")
word = word_split[0] word = word_split[0]
syllable_split = word_split[1].split(' - ') syllable_split = word_split[1].split(" - ")
g2p_dict[word] = [] g2p_dict[word] = []
for syllable in syllable_split: for syllable in syllable_split:
phone_split = syllable.split(' ') phone_split = syllable.split(" ")
g2p_dict[word].append(phone_split) g2p_dict[word].append(phone_split)
line_index = line_index + 1 line_index = line_index + 1
@ -57,13 +125,13 @@ def read_dict():
def cache_dict(g2p_dict, file_path): def cache_dict(g2p_dict, file_path):
with open(file_path, 'wb') as pickle_file: with open(file_path, "wb") as pickle_file:
pickle.dump(g2p_dict, pickle_file) pickle.dump(g2p_dict, pickle_file)
def get_dict(): def get_dict():
if os.path.exists(CACHE_PATH): if os.path.exists(CACHE_PATH):
with open(CACHE_PATH, 'rb') as pickle_file: with open(CACHE_PATH, "rb") as pickle_file:
g2p_dict = pickle.load(pickle_file) g2p_dict = pickle.load(pickle_file)
else: else:
g2p_dict = read_dict() g2p_dict = read_dict()
@ -71,6 +139,7 @@ def get_dict():
return g2p_dict return g2p_dict
eng_dict = get_dict() eng_dict = get_dict()
@ -78,8 +147,8 @@ def text_normalize(text):
# todo: eng text normalize # todo: eng text normalize
return text.replace(";", ",") return text.replace(";", ",")
def g2p(text):
def g2p(text):
phones = [] phones = []
words = re.split(r"([,;.\-\?\!\s+])", text) words = re.split(r"([,;.\-\?\!\s+])", text)
for w in words: for w in words:
@ -97,6 +166,7 @@ def g2p(text):
return replace_phs(phones) return replace_phs(phones)
if __name__ == "__main__": if __name__ == "__main__":
# print(get_dict()) # print(get_dict())
print(g2p("hello")) print(g2p("hello"))
@ -106,4 +176,4 @@ if __name__ == "__main__":
# for group in syllables: # for group in syllables:
# for ph in group: # for ph in group:
# all_phones.add(ph) # all_phones.add(ph)
# print(all_phones) # print(all_phones)

View File

@ -8,57 +8,63 @@ from text import symbols
# Regular expression matching Japanese without punctuation marks: # Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile( _japanese_characters = re.compile(
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# Regular expression matching non-Japanese characters or punctuation marks: # Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile( _japanese_marks = re.compile(
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# List of (symbol, Japanese) pairs for marks: # List of (symbol, Japanese) pairs for marks:
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ _symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("", "パーセント")]]
('', 'パーセント')
]]
# List of (consonant, sokuon) pairs: # List of (consonant, sokuon) pairs:
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ _real_sokuon = [
(r'Q([↑↓]*[kg])', r'k#\1'), (re.compile("%s" % x[0]), x[1])
(r'Q([↑↓]*[tdjʧ])', r't#\1'), for x in [
(r'Q([↑↓]*[sʃ])', r's\1'), (r"Q([↑↓]*[kg])", r"k#\1"),
(r'Q([↑↓]*[pb])', r'p#\1') (r"Q([↑↓]*[tdjʧ])", r"t#\1"),
]] (r"Q([↑↓]*[sʃ])", r"s\1"),
(r"Q([↑↓]*[pb])", r"p#\1"),
]
]
# List of (consonant, hatsuon) pairs: # List of (consonant, hatsuon) pairs:
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ _real_hatsuon = [
(r'N([↑↓]*[pbm])', r'm\1'), (re.compile("%s" % x[0]), x[1])
(r'N([↑↓]*[ʧʥj])', r'n^\1'), for x in [
(r'N([↑↓]*[tdn])', r'n\1'), (r"N([↑↓]*[pbm])", r"m\1"),
(r'N([↑↓]*[kg])', r'ŋ\1') (r"N([↑↓]*[ʧʥj])", r"n^\1"),
]] (r"N([↑↓]*[tdn])", r"n\1"),
(r"N([↑↓]*[kg])", r"ŋ\1"),
]
]
def post_replace_ph(ph): def post_replace_ph(ph):
rep_map = { rep_map = {
'': ',', "": ",",
'': ',', "": ",",
'': ',', "": ",",
'': '.', "": ".",
'': '!', "": "!",
'': '?', "": "?",
'\n': '.', "\n": ".",
"·": ",", "·": ",",
'': ",", "": ",",
'...': '' "...": "",
} }
if ph in rep_map.keys(): if ph in rep_map.keys():
ph = rep_map[ph] ph = rep_map[ph]
if ph in symbols: if ph in symbols:
return ph return ph
if ph not in symbols: if ph not in symbols:
ph = 'UNK' ph = "UNK"
return ph return ph
def symbols_to_japanese(text): def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese: for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
@ -66,7 +72,7 @@ def symbols_to_japanese(text):
def preprocess_jap(text): def preprocess_jap(text):
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
text = symbols_to_japanese(text) text = symbols_to_japanese(text)
sentences = re.split(_japanese_marks, text) sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text) marks = re.findall(_japanese_marks, text)
@ -77,13 +83,15 @@ def preprocess_jap(text):
text += p.split(" ") text += p.split(" ")
if i < len(marks): if i < len(marks):
text += [marks[i].replace(' ', '')] text += [marks[i].replace(" ", "")]
return text return text
def text_normalize(text): def text_normalize(text):
# todo: jap text normalize # todo: jap text normalize
return text return text
def g2p(norm_text): def g2p(norm_text):
phones = preprocess_jap(norm_text) phones = preprocess_jap(norm_text)
phones = [post_replace_ph(i) for i in phones] phones = [post_replace_ph(i) for i in phones]
@ -91,7 +99,7 @@ def g2p(norm_text):
return phones return phones
if __name__ == '__main__': if __name__ == "__main__":
for line in open("../../../Downloads/transcript_utf8.txt").readlines(): for line in open("../../../Downloads/transcript_utf8.txt").readlines():
text = line.split(":")[1] text = line.split(":")[1]
phones = g2p(text) phones = g2p(text)

View File

@ -1,24 +1,397 @@
import os import os
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 # punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ['!', '?', '', ",", "."]#@是SP停顿 punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-") punctuation.append("-")
pu_symbols = punctuation + ["SP", 'SP2', 'SP3', "UNK"] pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"] # pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
pad = '_' pad = "_"
c = ['AA', 'EE', 'OO', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'w', 'x', 'y', 'z', 'zh'] c = [
v = ['E1', 'En1', 'a1', 'ai1', 'an1', 'ang1', 'ao1', 'e1', 'ei1', 'en1', 'eng1', 'er1', 'i1', 'i01', 'ia1', 'ian1', 'iang1', 'iao1', 'ie1', 'in1', 'ing1', 'iong1', 'ir1', 'iu1', 'o1', 'ong1', 'ou1', 'u1', 'ua1', 'uai1', 'uan1', 'uang1', 'ui1', 'un1', 'uo1', 'v1', 'van1', 've1', 'vn1', 'E2', 'En2', 'a2', 'ai2', 'an2', 'ang2', 'ao2', 'e2', 'ei2', 'en2', 'eng2', 'er2', 'i2', 'i02', 'ia2', 'ian2', 'iang2', 'iao2', 'ie2', 'in2', 'ing2', 'iong2', 'ir2', 'iu2', 'o2', 'ong2', 'ou2', 'u2', 'ua2', 'uai2', 'uan2', 'uang2', 'ui2', 'un2', 'uo2', 'v2', 'van2', 've2', 'vn2', 'E3', 'En3', 'a3', 'ai3', 'an3', 'ang3', 'ao3', 'e3', 'ei3', 'en3', 'eng3', 'er3', 'i3', 'i03', 'ia3', 'ian3', 'iang3', 'iao3', 'ie3', 'in3', 'ing3', 'iong3', 'ir3', 'iu3', 'o3', 'ong3', 'ou3', 'u3', 'ua3', 'uai3', 'uan3', 'uang3', 'ui3', 'un3', 'uo3', 'v3', 'van3', 've3', 'vn3', 'E4', 'En4', 'a4', 'ai4', 'an4', 'ang4', 'ao4', 'e4', 'ei4', 'en4', 'eng4', 'er4', 'i4', 'i04', 'ia4', 'ian4', 'iang4', 'iao4', 'ie4', 'in4', 'ing4', 'iong4', 'ir4', 'iu4', 'o4', 'ong4', 'ou4', 'u4', 'ua4', 'uai4', 'uan4', 'uang4', 'ui4', 'un4', 'uo4', 'v4', 'van4', 've4', 'vn4', 'E5', 'En5', 'a5', 'ai5', 'an5', 'ang5', 'ao5', 'e5', 'ei5', 'en5', 'eng5', 'er5', 'i5', 'i05', 'ia5', 'ian5', 'iang5', 'iao5', 'ie5', 'in5', 'ing5', 'iong5', 'ir5', 'iu5', 'o5', 'ong5', 'ou5', 'u5', 'ua5', 'uai5', 'uan5', 'uang5', 'ui5', 'un5', 'uo5', 'v5', 'van5', 've5', 'vn5'] "AA",
"EE",
"OO",
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"w",
"x",
"y",
"z",
"zh",
]
v = [
"E1",
"En1",
"a1",
"ai1",
"an1",
"ang1",
"ao1",
"e1",
"ei1",
"en1",
"eng1",
"er1",
"i1",
"i01",
"ia1",
"ian1",
"iang1",
"iao1",
"ie1",
"in1",
"ing1",
"iong1",
"ir1",
"iu1",
"o1",
"ong1",
"ou1",
"u1",
"ua1",
"uai1",
"uan1",
"uang1",
"ui1",
"un1",
"uo1",
"v1",
"van1",
"ve1",
"vn1",
"E2",
"En2",
"a2",
"ai2",
"an2",
"ang2",
"ao2",
"e2",
"ei2",
"en2",
"eng2",
"er2",
"i2",
"i02",
"ia2",
"ian2",
"iang2",
"iao2",
"ie2",
"in2",
"ing2",
"iong2",
"ir2",
"iu2",
"o2",
"ong2",
"ou2",
"u2",
"ua2",
"uai2",
"uan2",
"uang2",
"ui2",
"un2",
"uo2",
"v2",
"van2",
"ve2",
"vn2",
"E3",
"En3",
"a3",
"ai3",
"an3",
"ang3",
"ao3",
"e3",
"ei3",
"en3",
"eng3",
"er3",
"i3",
"i03",
"ia3",
"ian3",
"iang3",
"iao3",
"ie3",
"in3",
"ing3",
"iong3",
"ir3",
"iu3",
"o3",
"ong3",
"ou3",
"u3",
"ua3",
"uai3",
"uan3",
"uang3",
"ui3",
"un3",
"uo3",
"v3",
"van3",
"ve3",
"vn3",
"E4",
"En4",
"a4",
"ai4",
"an4",
"ang4",
"ao4",
"e4",
"ei4",
"en4",
"eng4",
"er4",
"i4",
"i04",
"ia4",
"ian4",
"iang4",
"iao4",
"ie4",
"in4",
"ing4",
"iong4",
"ir4",
"iu4",
"o4",
"ong4",
"ou4",
"u4",
"ua4",
"uai4",
"uan4",
"uang4",
"ui4",
"un4",
"uo4",
"v4",
"van4",
"ve4",
"vn4",
"E5",
"En5",
"a5",
"ai5",
"an5",
"ang5",
"ao5",
"e5",
"ei5",
"en5",
"eng5",
"er5",
"i5",
"i05",
"ia5",
"ian5",
"iang5",
"iao5",
"ie5",
"in5",
"ing5",
"iong5",
"ir5",
"iu5",
"o5",
"ong5",
"ou5",
"u5",
"ua5",
"uai5",
"uan5",
"uang5",
"ui5",
"un5",
"uo5",
"v5",
"van5",
"ve5",
"vn5",
]
v_without_tone = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn'] v_without_tone = [
"E",
"En",
"a",
"ai",
"an",
"ang",
"ao",
"e",
"ei",
"en",
"eng",
"er",
"i",
"i0",
"ia",
"ian",
"iang",
"iao",
"ie",
"in",
"ing",
"iong",
"ir",
"iu",
"o",
"ong",
"ou",
"u",
"ua",
"uai",
"uan",
"uang",
"ui",
"un",
"uo",
"v",
"van",
"ve",
"vn",
]
# japanese # japanese
ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky', ja_symbols = [
'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'v', 'w', 'y', 'z'] "I",
"N",
"U",
"a",
"b",
"by",
"ch",
"cl",
"d",
"dy",
"e",
"f",
"g",
"gy",
"h",
"hy",
"i",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"p",
"py",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"u",
"v",
"w",
"y",
"z",
]
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
symbols = sorted(set(symbols)) symbols = sorted(set(symbols))
if __name__ == '__main__': if __name__ == "__main__":
print(len(symbols)) print(len(symbols))

View File

@ -19,51 +19,442 @@ from pypinyin import lazy_pinyin
from pypinyin import Style from pypinyin import Style
class ToneSandhi(): class ToneSandhi:
def __init__(self): def __init__(self):
self.must_neural_tone_words = { self.must_neural_tone_words = {
'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', "麻烦",
'难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', "麻利",
'里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去', "鸳鸯",
'软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号', "高粱",
'认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当', "骨头",
'蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻', "骆驼",
'舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂', "马虎",
'胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆', "首饰",
'老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂', "馒头",
'精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿', "馄饨",
'窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台', "风筝",
'码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算', "难为",
'白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨', "队伍",
'琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快', "阔气",
'爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜', "闺女",
'溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔', "门道",
'棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事', "锄头",
'木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾', "铺盖",
'收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼', "铃铛",
'抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实', "铁匠",
'扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', "钥匙",
'念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', "里脊",
'干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数', "里头",
'屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气', "部分",
'实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈', "那么",
'姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方', "道士",
'大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴', "造化",
'嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦', "迷糊",
'咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝', "连累",
'叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹', "这么",
'功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息', "这个",
'凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤', "运气",
'佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', "过去",
'交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', "软和",
'不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨', "转悠",
'父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅', "踏实",
'幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱', "跳蚤",
'凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱', "跟头",
'扫把', '惦记' "趔趄",
"财主",
"豆腐",
"讲究",
"记性",
"记号",
"认识",
"规矩",
"见识",
"裁缝",
"补丁",
"衣裳",
"衣服",
"衙门",
"街坊",
"行李",
"行当",
"蛤蟆",
"蘑菇",
"薄荷",
"葫芦",
"葡萄",
"萝卜",
"荸荠",
"苗条",
"苗头",
"苍蝇",
"芝麻",
"舒服",
"舒坦",
"舌头",
"自在",
"膏药",
"脾气",
"脑袋",
"脊梁",
"能耐",
"胳膊",
"胭脂",
"胡萝",
"胡琴",
"胡同",
"聪明",
"耽误",
"耽搁",
"耷拉",
"耳朵",
"老爷",
"老实",
"老婆",
"老头",
"老太",
"翻腾",
"罗嗦",
"罐头",
"编辑",
"结实",
"红火",
"累赘",
"糨糊",
"糊涂",
"精神",
"粮食",
"簸箕",
"篱笆",
"算计",
"算盘",
"答应",
"笤帚",
"笑语",
"笑话",
"窟窿",
"窝囊",
"窗户",
"稳当",
"稀罕",
"称呼",
"秧歌",
"秀气",
"秀才",
"福气",
"祖宗",
"砚台",
"码头",
"石榴",
"石头",
"石匠",
"知识",
"眼睛",
"眯缝",
"眨巴",
"眉毛",
"相声",
"盘算",
"白净",
"痢疾",
"痛快",
"疟疾",
"疙瘩",
"疏忽",
"畜生",
"生意",
"甘蔗",
"琵琶",
"琢磨",
"琉璃",
"玻璃",
"玫瑰",
"玄乎",
"狐狸",
"状元",
"特务",
"牲口",
"牙碜",
"牌楼",
"爽快",
"爱人",
"热闹",
"烧饼",
"烟筒",
"烂糊",
"点心",
"炊帚",
"灯笼",
"火候",
"漂亮",
"滑溜",
"溜达",
"温和",
"清楚",
"消息",
"浪头",
"活泼",
"比方",
"正经",
"欺负",
"模糊",
"槟榔",
"棺材",
"棒槌",
"棉花",
"核桃",
"栅栏",
"柴火",
"架势",
"枕头",
"枇杷",
"机灵",
"本事",
"木头",
"木匠",
"朋友",
"月饼",
"月亮",
"暖和",
"明白",
"时候",
"新鲜",
"故事",
"收拾",
"收成",
"提防",
"挖苦",
"挑剔",
"指甲",
"指头",
"拾掇",
"拳头",
"拨弄",
"招牌",
"招呼",
"抬举",
"护士",
"折腾",
"扫帚",
"打量",
"打算",
"打点",
"打扮",
"打听",
"打发",
"扎实",
"扁担",
"戒指",
"懒得",
"意识",
"意思",
"情形",
"悟性",
"怪物",
"思量",
"怎么",
"念头",
"念叨",
"快活",
"忙活",
"志气",
"心思",
"得罪",
"张罗",
"弟兄",
"开通",
"应酬",
"庄稼",
"干事",
"帮手",
"帐篷",
"希罕",
"师父",
"师傅",
"巴结",
"巴掌",
"差事",
"工夫",
"岁数",
"屁股",
"尾巴",
"少爷",
"小气",
"小伙",
"将就",
"对头",
"对付",
"寡妇",
"家伙",
"客气",
"实在",
"官司",
"学问",
"学生",
"字号",
"嫁妆",
"媳妇",
"媒人",
"婆家",
"娘家",
"委屈",
"姑娘",
"姐夫",
"妯娌",
"妥当",
"妖精",
"奴才",
"女婿",
"头发",
"太阳",
"大爷",
"大方",
"大意",
"大夫",
"多少",
"多么",
"外甥",
"壮实",
"地道",
"地方",
"在乎",
"困难",
"嘴巴",
"嘱咐",
"嘟囔",
"嘀咕",
"喜欢",
"喇嘛",
"喇叭",
"商量",
"唾沫",
"哑巴",
"哈欠",
"哆嗦",
"咳嗽",
"和尚",
"告诉",
"告示",
"含糊",
"吓唬",
"后头",
"名字",
"名堂",
"合同",
"吆喝",
"叫唤",
"口袋",
"厚道",
"厉害",
"千斤",
"包袱",
"包涵",
"匀称",
"勤快",
"动静",
"动弹",
"功夫",
"力气",
"前头",
"刺猬",
"刺激",
"别扭",
"利落",
"利索",
"利害",
"分析",
"出息",
"凑合",
"凉快",
"冷战",
"冤枉",
"冒失",
"养活",
"关系",
"先生",
"兄弟",
"便宜",
"使唤",
"佩服",
"作坊",
"体面",
"位置",
"似的",
"伙计",
"休息",
"什么",
"人家",
"亲戚",
"亲家",
"交情",
"云彩",
"事情",
"买卖",
"主意",
"丫头",
"丧气",
"两口",
"东西",
"东家",
"世故",
"不由",
"不在",
"下水",
"下巴",
"上头",
"上司",
"丈夫",
"丈人",
"一辈",
"那个",
"菩萨",
"父亲",
"母亲",
"咕噜",
"邋遢",
"费用",
"冤家",
"甜头",
"介绍",
"荒唐",
"大人",
"泥鳅",
"幸福",
"熟悉",
"计划",
"扑腾",
"蜡烛",
"姥爷",
"照顾",
"喉咙",
"吉他",
"弄堂",
"蚂蚱",
"凤凰",
"拖沓",
"寒碜",
"糟蹋",
"倒腾",
"报复",
"逻辑",
"盘缠",
"喽啰",
"牢骚",
"咖喱",
"扫把",
"惦记",
} }
self.must_not_neural_tone_words = { self.must_not_neural_tone_words = {
"男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子", "人人", "虎虎" "男子",
"女子",
"分子",
"原子",
"量子",
"莲子",
"石子",
"瓜子",
"电子",
"人人",
"虎虎",
} }
self.punc = ":,;。?!“”‘’':,;.?!" self.punc = ":,;。?!“”‘’':,;.?!"
@ -72,14 +463,15 @@ class ToneSandhi():
# word: "家里" # word: "家里"
# pos: "s" # pos: "s"
# finals: ['ia1', 'i3'] # finals: ['ia1', 'i3']
def _neural_sandhi(self, word: str, pos: str, def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals: List[str]) -> List[str]:
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺 # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
for j, item in enumerate(word): for j, item in enumerate(word):
if j - 1 >= 0 and item == word[j - 1] and pos[0] in { if (
"n", "v", "a" j - 1 >= 0
} and word not in self.must_not_neural_tone_words: and item == word[j - 1]
and pos[0] in {"n", "v", "a"}
and word not in self.must_not_neural_tone_words
):
finals[j] = finals[j][:-1] + "5" finals[j] = finals[j][:-1] + "5"
ge_idx = word.find("") ge_idx = word.find("")
if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶": if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
@ -89,9 +481,12 @@ class ToneSandhi():
# e.g. 走了, 看着, 去过 # e.g. 走了, 看着, 去过
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}: elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
elif len(word) > 1 and word[-1] in "们子" and pos in { elif (
"r", "n" len(word) > 1
} and word not in self.must_not_neural_tone_words: and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里 # e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
@ -100,21 +495,26 @@ class ToneSandhi():
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开": elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
# 个做量词 # 个做量词
elif (ge_idx >= 1 and elif (
(word[ge_idx - 1].isnumeric() or ge_idx >= 1
word[ge_idx - 1] in "几有两半多各整每做是")) or word == '': and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
) or word == "":
finals[ge_idx] = finals[ge_idx][:-1] + "5" finals[ge_idx] = finals[ge_idx][:-1] + "5"
else: else:
if word in self.must_neural_tone_words or word[ if (
-2:] in self.must_neural_tone_words: word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word) word_list = self._split_word(word)
finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]] finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list): for i, word in enumerate(word_list):
# conventional neural in Chinese # conventional neural in Chinese
if word in self.must_neural_tone_words or word[ if (
-2:] in self.must_neural_tone_words: word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals_list[i][-1] = finals_list[i][-1][:-1] + "5" finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, []) finals = sum(finals_list, [])
return finals return finals
@ -126,15 +526,15 @@ class ToneSandhi():
else: else:
for i, char in enumerate(word): for i, char in enumerate(word):
# "不" before tone4 should be bu2, e.g. 不怕 # "不" before tone4 should be bu2, e.g. 不怕
if char == "" and i + 1 < len(word) and finals[i + if char == "" and i + 1 < len(word) and finals[i + 1][-1] == "4":
1][-1] == "4":
finals[i] = finals[i][:-1] + "2" finals[i] = finals[i][:-1] + "2"
return finals return finals
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零 # "一" in number sequences, e.g. 一零零, 二一零
if word.find("") != -1 and all( if word.find("") != -1 and all(
[item.isnumeric() for item in word if item != ""]): [item.isnumeric() for item in word if item != ""]
):
return finals return finals
# "一" between reduplication words shold be yi5, e.g. 看一看 # "一" between reduplication words shold be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]: elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
@ -161,10 +561,10 @@ class ToneSandhi():
first_subword = word_list[0] first_subword = word_list[0]
first_begin_idx = word.find(first_subword) first_begin_idx = word.find(first_subword)
if first_begin_idx == 0: if first_begin_idx == 0:
second_subword = word[len(first_subword):] second_subword = word[len(first_subword) :]
new_word_list = [first_subword, second_subword] new_word_list = [first_subword, second_subword]
else: else:
second_subword = word[:-len(first_subword)] second_subword = word[: -len(first_subword)]
new_word_list = [second_subword, first_subword] new_word_list = [second_subword, first_subword]
return new_word_list return new_word_list
@ -182,18 +582,19 @@ class ToneSandhi():
elif len(word_list[0]) == 1: elif len(word_list[0]) == 1:
finals[1] = finals[1][:-1] + "2" finals[1] = finals[1][:-1] + "2"
else: else:
finals_list = [ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
finals[:len(word_list[0])], finals[len(word_list[0]):]
]
if len(finals_list) == 2: if len(finals_list) == 2:
for i, sub in enumerate(finals_list): for i, sub in enumerate(finals_list):
# e.g. 所有/人 # e.g. 所有/人
if self._all_tone_three(sub) and len(sub) == 2: if self._all_tone_three(sub) and len(sub) == 2:
finals_list[i][0] = finals_list[i][0][:-1] + "2" finals_list[i][0] = finals_list[i][0][:-1] + "2"
# e.g. 好/喜欢 # e.g. 好/喜欢
elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \ elif (
finals_list[0][-1][-1] == "3": i == 1
and not self._all_tone_three(sub)
and finals_list[i][0][-1] == "3"
and finals_list[0][-1][-1] == "3"
):
finals_list[0][-1] = finals_list[0][-1][:-1] + "2" finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
finals = sum(finals_list, []) finals = sum(finals_list, [])
# split idiom into two words who's length is 2 # split idiom into two words who's length is 2
@ -222,7 +623,7 @@ class ToneSandhi():
new_seg.append((word, pos)) new_seg.append((word, pos))
last_word = word[:] last_word = word[:]
if last_word == "": if last_word == "":
new_seg.append((last_word, 'd')) new_seg.append((last_word, "d"))
last_word = "" last_word = ""
return new_seg return new_seg
@ -236,12 +637,21 @@ class ToneSandhi():
new_seg = [] new_seg = []
# function 1 # function 1
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and i + 1 < len(seg) and seg[i - 1][ if (
0] == seg[i + 1][0] and seg[i - 1][1] == "v": i - 1 >= 0
and word == ""
and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v"
):
new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0] new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0]
else: else:
if i - 2 >= 0 and seg[i - 1][0] == "" and seg[i - 2][ if (
0] == word and pos == "v": i - 2 >= 0
and seg[i - 1][0] == ""
and seg[i - 2][0] == word
and pos == "v"
):
continue continue
else: else:
new_seg.append([word, pos]) new_seg.append([word, pos])
@ -257,22 +667,27 @@ class ToneSandhi():
# the first and the second words are all_tone_three # the first and the second words are all_tone_three
def _merge_continuous_three_tones( def _merge_continuous_three_tones(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
sub_finals_list = [ sub_finals_list = [
lazy_pinyin( lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg for (word, pos) in seg
] ]
assert len(sub_finals_list) == len(seg) assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg) merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and self._all_tone_three( if (
sub_finals_list[i - 1]) and self._all_tone_three( i - 1 >= 0
sub_finals_list[i]) and not merge_last[i - 1]: and self._all_tone_three(sub_finals_list[i - 1])
and self._all_tone_three(sub_finals_list[i])
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len( if (
seg[i - 1][0]) + len(seg[i][0]) <= 3: not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True merge_last[i] = True
else: else:
@ -287,21 +702,27 @@ class ToneSandhi():
# the last char of first word and the first char of second word is tone_three # the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2( def _merge_continuous_three_tones_2(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
sub_finals_list = [ sub_finals_list = [
lazy_pinyin( lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg for (word, pos) in seg
] ]
assert len(sub_finals_list) == len(seg) assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg) merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \ if (
merge_last[i - 1]: i - 1 >= 0
and sub_finals_list[i - 1][-1][-1] == "3"
and sub_finals_list[i][0][-1] == "3"
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len( if (
seg[i - 1][0]) + len(seg[i][0]) <= 3: not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True merge_last[i] = True
else: else:
@ -313,14 +734,13 @@ class ToneSandhi():
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and seg[i-1][0] != "#": if i - 1 >= 0 and word == "" and seg[i - 1][0] != "#":
new_seg[-1][0] = new_seg[-1][0] + seg[i][0] new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
else: else:
new_seg.append([word, pos]) new_seg.append([word, pos])
return new_seg return new_seg
def _merge_reduplication( def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = [] new_seg = []
for i, (word, pos) in enumerate(seg): for i, (word, pos) in enumerate(seg):
if new_seg and word == new_seg[-1][0]: if new_seg and word == new_seg[-1][0]:
@ -329,8 +749,7 @@ class ToneSandhi():
new_seg.append([word, pos]) new_seg.append([word, pos])
return new_seg return new_seg
def pre_merge_for_modify( def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
seg = self._merge_bu(seg) seg = self._merge_bu(seg)
try: try:
seg = self._merge_yi(seg) seg = self._merge_yi(seg)
@ -349,8 +768,7 @@ class ToneSandhi():
seg = self._merge_er(seg) seg = self._merge_er(seg)
return seg return seg
def modified_tone(self, word: str, pos: str, def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals: List[str]) -> List[str]:
finals = self._bu_sandhi(word, finals) finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals) finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals) finals = self._neural_sandhi(word, pos, finals)