diff --git a/GPT_SoVITS/AR/data/bucket_sampler.py b/GPT_SoVITS/AR/data/bucket_sampler.py index ee59479..7d752db 100644 --- a/GPT_SoVITS/AR/data/bucket_sampler.py +++ b/GPT_SoVITS/AR/data/bucket_sampler.py @@ -16,7 +16,7 @@ __all__ = [ "DistributedBucketSampler", ] -T_co = TypeVar('T_co', covariant=True) +T_co = TypeVar("T_co", covariant=True) class DistributedBucketSampler(Sampler[T_co]): @@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]): sort batches """ - def __init__(self, - dataset: Dataset, - num_replicas: Optional[int]=None, - rank: Optional[int]=None, - shuffle: bool=True, - seed: int=0, - drop_last: bool=False, - batch_size: int=32) -> None: + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + batch_size: int = 32, + ) -> None: if num_replicas is None: if not dist.is_available(): - raise RuntimeError( - "Requires distributed package to be available") + raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): - raise RuntimeError( - "Requires distributed package to be available") + raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() torch.cuda.set_device(rank) if rank >= num_replicas or rank < 0: - raise ValueError("Invalid rank {}, rank should be in the interval" - " [0, {}]".format(rank, num_replicas - 1)) + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1) + ) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]): self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. - if self.drop_last and len( - self. - dataset) % self.num_replicas != 0: # type: ignore[arg-type] + if ( + self.drop_last and len(self.dataset) % self.num_replicas != 0 + ): # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / - self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil( - len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + len(self.dataset) / self.num_replicas + ) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed @@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]): id_with_lengths.sort(key=lambda x: x[1]) return id_with_lengths - def make_buckets(self, bucket_width: float=2.0): + def make_buckets(self, bucket_width: float = 2.0): buckets = [] cur = [] max_sec = bucket_width @@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]): shuffled_bucket = list(itertools.chain(*shuffled_bucket)) n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) batches = [ - shuffled_bucket[b * grouped_batch_size:(b + 1) * - grouped_batch_size] for b in range(n_batch) + shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] + for b in range(n_batch) ] shuffle(batches) indices = list(itertools.chain(*batches)) @@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]): if padding_size <= len(indices): indices += indices[:padding_size] else: - indices += (indices * math.ceil(padding_size / - len(indices)))[:padding_size] + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] else: # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) diff --git a/GPT_SoVITS/AR/data/data_module.py b/GPT_SoVITS/AR/data/data_module.py index 4c300f1..f3d895a 100644 --- a/GPT_SoVITS/AR/data/data_module.py +++ b/GPT_SoVITS/AR/data/data_module.py @@ -6,14 +6,21 @@ from torch.utils.data import DataLoader class Text2SemanticDataModule(LightningDataModule): - def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None): + def __init__( + self, + config, + train_semantic_path, + train_phoneme_path, + dev_semantic_path=None, + dev_phoneme_path=None, + ): super().__init__() self.config = config self.train_semantic_path = train_semantic_path self.train_phoneme_path = train_phoneme_path self.dev_semantic_path = dev_semantic_path self.dev_phoneme_path = dev_phoneme_path - self.num_workers = self.config['data']['num_workers'] + self.num_workers = self.config["data"]["num_workers"] def prepare_data(self): pass @@ -22,8 +29,9 @@ class Text2SemanticDataModule(LightningDataModule): self._train_dataset = Text2SemanticDataset( phoneme_path=self.train_phoneme_path, semantic_path=self.train_semantic_path, - max_sec=self.config['data']['max_sec'], - pad_val=self.config['data']['pad_val']) + max_sec=self.config["data"]["max_sec"], + pad_val=self.config["data"]["pad_val"], + ) self._dev_dataset = self._train_dataset # self._dev_dataset = Text2SemanticDataset( # phoneme_path=self.dev_phoneme_path, @@ -33,9 +41,8 @@ class Text2SemanticDataModule(LightningDataModule): # pad_val=self.config['data']['pad_val']) def train_dataloader(self): - batch_size = self.config['train']['batch_size'] - sampler = DistributedBucketSampler( - self._train_dataset, batch_size=batch_size) + batch_size = self.config["train"]["batch_size"] + sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) return DataLoader( self._train_dataset, batch_size=batch_size, @@ -43,7 +50,7 @@ class Text2SemanticDataModule(LightningDataModule): collate_fn=self._train_dataset.collate, num_workers=self.num_workers, persistent_workers=True, - prefetch_factor=16 + prefetch_factor=16, ) def val_dataloader(self): @@ -52,9 +59,9 @@ class Text2SemanticDataModule(LightningDataModule): batch_size=1, shuffle=False, collate_fn=self._train_dataset.collate, - num_workers=max(self.num_workers,12), + num_workers=max(self.num_workers, 12), persistent_workers=True, - prefetch_factor=16 + prefetch_factor=16, ) # 这个会使用到嘛? @@ -63,4 +70,5 @@ class Text2SemanticDataModule(LightningDataModule): self._dev_dataset, batch_size=1, shuffle=False, - collate_fn=self._train_dataset.collate) + collate_fn=self._train_dataset.collate, + ) diff --git a/GPT_SoVITS/AR/data/dataset.py b/GPT_SoVITS/AR/data/dataset.py index 72c9e2e..47adacc 100644 --- a/GPT_SoVITS/AR/data/dataset.py +++ b/GPT_SoVITS/AR/data/dataset.py @@ -1,21 +1,24 @@ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py import pdb import sys + # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert") -import traceback,os +import traceback, os from typing import Dict from typing import List import numpy as np import pandas as pd -import torch,json +import torch, json from torch.utils.data import DataLoader from torch.utils.data import Dataset from transformers import AutoTokenizer from text import cleaned_text_to_sequence + # from config import exp_dir + def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0): seq = sequences[0] ndim = seq.ndim @@ -28,44 +31,52 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0 padded_sequences = [] for seq, length in zip(sequences, seq_lengths): - padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * ( - ndim - axis - 1) - padded_seq = np.pad( - seq, padding, mode='constant', constant_values=pad_value) + padding = ( + [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1) + ) + padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value) padded_sequences.append(padded_seq) batch = np.stack(padded_sequences) return batch + class Text2SemanticDataset(Dataset): """dataset class for text tokens to semantic model training.""" - def __init__(self, - phoneme_path: str, - semantic_path: str, - max_sample: int = None, - max_sec: int = 100, - pad_val: int = 1024, - # min value of phoneme/sec - min_ps_ratio: int = 3, - # max value of phoneme/sec - max_ps_ratio: int = 25) -> None: + def __init__( + self, + phoneme_path: str, + semantic_path: str, + max_sample: int = None, + max_sec: int = 100, + pad_val: int = 1024, + # min value of phoneme/sec + min_ps_ratio: int = 3, + # max value of phoneme/sec + max_ps_ratio: int = 25, + ) -> None: super().__init__() - self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8") + self.semantic_data = pd.read_csv( + semantic_path, delimiter="\t", encoding="utf-8" + ) # get dict - self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path - self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir - self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path + self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path + self.path3 = "%s/3-bert" % ( + os.path.basename(phoneme_path) + ) # "%s/3-bert"%exp_dir#bert_dir + self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path assert os.path.exists(self.path2) assert os.path.exists(self.path6) - self.phoneme_data={} - with open(self.path2,"r",encoding="utf8")as f: - lines=f.read().strip("\n").split("\n") + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") for line in lines: - tmp=line.split("\t") - if(len(tmp)!=4):continue - self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]] + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]] # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() # pad for semantic tokens @@ -74,7 +85,7 @@ class Text2SemanticDataset(Dataset): # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read() # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz # self.hz=int(data[:-2])# - self.hz=int(os.environ.get("hz","25hz")[:-2]) + self.hz = int(os.environ.get("hz", "25hz")[:-2]) # max seconds of semantic token self.max_sec = max_sec @@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset): # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large") # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large") - def init_batch(self): semantic_data_len = len(self.semantic_data) phoneme_data_len = len(self.phoneme_data.keys()) @@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset): for i in range(semantic_data_len): # 先依次遍历 # get str - item_name = self.semantic_data['item_name'][i] + item_name = self.semantic_data["item_name"][i] # print(self.phoneme_data) try: phoneme, word2ph, text = self.phoneme_data[item_name] @@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset): num_not_in += 1 continue - semantic_str = self.semantic_data['semantic_audio'][i] + semantic_str = self.semantic_data["semantic_audio"][i] # get token list - semantic_ids = [int(idx) for idx in semantic_str.split(' ')] + semantic_ids = [int(idx) for idx in semantic_str.split(" ")] # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len # 过滤掉太长的样本 - if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k + if ( + len(semantic_ids) > self.max_sec * self.hz + ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k num_deleted_bigger += 1 continue # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理#### - phoneme = phoneme.split(' ') + phoneme = phoneme.split(" ") try: phoneme_ids = cleaned_text_to_sequence(phoneme) @@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset): num_not_in += 1 continue # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行 - if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2:改为恒定限制为semantic/2.5就行 + if ( + len(phoneme_ids) > self.max_sec * self.hz / 2.5 + ): ###########2:改为恒定限制为semantic/2.5就行 num_deleted_ps += 1 continue # if len(semantic_ids) > 1000:###########3 @@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset): ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) - if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone + if ( + ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio + ): ##########4#3~25#每秒多少个phone num_deleted_ps += 1 # print(item_name) continue @@ -160,16 +176,16 @@ class Text2SemanticDataset(Dataset): idx += 1 self.item_names.append(item_name) - min_num=100#20直接不补#30补了也不存ckpt - leng =len(self.semantic_phoneme) - if(leng 0: print(f"there are {num_not_in} semantic datas not in phoneme datas") if num_deleted_bigger > 0: @@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset): print( f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}" ) - ''' + """ there are 31 semantic datas not in phoneme datas deleted 34 audios who's duration are bigger than 54 seconds deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3 dataset.__len__(): 366463 - ''' + """ # 345410 for LibriTTS print("dataset.__len__():", self.__len__()) @@ -204,22 +220,24 @@ class Text2SemanticDataset(Dataset): # semantic tokens target semantic_ids_len = len(semantic_ids) - flag=0 + flag = 0 path_bert = "%s/%s.pt" % (self.path3, item_name) - if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu") - else:flag=1 - if(flag==1): + if os.path.exists(path_bert) == True: + bert_feature = torch.load(path_bert, map_location="cpu") + else: + flag = 1 + if flag == 1: # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32) - bert_feature=None + bert_feature = None else: assert bert_feature.shape[-1] == len(phoneme_ids) return { - 'idx': idx, - 'phoneme_ids': phoneme_ids, - 'phoneme_ids_len': phoneme_ids_len, - 'semantic_ids': semantic_ids, - 'semantic_ids_len': semantic_ids_len, - 'bert_feature': bert_feature, + "idx": idx, + "phoneme_ids": phoneme_ids, + "phoneme_ids_len": phoneme_ids_len, + "semantic_ids": semantic_ids, + "semantic_ids_len": semantic_ids_len, + "bert_feature": bert_feature, } def get_sample_length(self, idx: int): @@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset): semantic_ids_lens: List[int] = [] # return - for item in examples: sample_index.append(item["idx"]) phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) @@ -256,9 +273,9 @@ class Text2SemanticDataset(Dataset): bert_padded.zero_() for idx, item in enumerate(examples): - bert = item['bert_feature'] - if(bert!=None): - bert_padded[idx, :, :bert.shape[-1]] = bert + bert = item["bert_feature"] + if bert != None: + bert_padded[idx, :, : bert.shape[-1]] = bert return { # List[int] @@ -276,27 +293,27 @@ class Text2SemanticDataset(Dataset): } -if __name__ == '__main__': - root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/' +if __name__ == "__main__": + root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/" dataset = Text2SemanticDataset( - phoneme_path=root_dir + 'phoneme_train.npy', - semantic_path=root_dir + 'semantic_train.tsv') + phoneme_path=root_dir + "phoneme_train.npy", + semantic_path=root_dir + "semantic_train.tsv", + ) batch_size = 12 dataloader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=dataset.collate, - shuffle=False) + dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False + ) for i, batch in enumerate(dataloader): - if(i%1000==0):print(i) + if i % 1000 == 0: + print(i) # if i == 0: # print('batch["ids"]:', batch["ids"]) - # print('batch["phoneme_ids"]:', batch["phoneme_ids"], - # batch["phoneme_ids"].shape) - # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"], - # batch["phoneme_ids_len"].shape) - # print('batch["semantic_ids"]:', batch["semantic_ids"], - # batch["semantic_ids"].shape) - # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"], - # batch["semantic_ids_len"].shape) + # print('batch["phoneme_ids"]:', batch["phoneme_ids"], + # batch["phoneme_ids"].shape) + # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"], + # batch["phoneme_ids_len"].shape) + # print('batch["semantic_ids"]:', batch["semantic_ids"], + # batch["semantic_ids"].shape) + # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"], + # batch["semantic_ids_len"].shape) diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index 149d88e..f9dfc64 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -1,5 +1,6 @@ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py -import os,sys +import os, sys + now_dir = os.getcwd() sys.path.append(now_dir) from typing import Dict @@ -12,29 +13,35 @@ from AR.modules.optim import ScaledAdam class Text2SemanticLightningModule(LightningModule): - def __init__(self, config, output_dir,is_train=True): + def __init__(self, config, output_dir, is_train=True): super().__init__() self.config = config self.top_k = 3 self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) - pretrained_s1=config.get("pretrained_s1") - if(pretrained_s1 and is_train): + pretrained_s1 = config.get("pretrained_s1") + if pretrained_s1 and is_train: # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) - print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"])) + print( + self.load_state_dict( + torch.load(pretrained_s1, map_location="cpu")["weight"] + ) + ) if is_train: self.automatic_optimization = False self.save_hyperparameters() - self.eval_dir = output_dir / 'eval' + self.eval_dir = output_dir / "eval" self.eval_dir.mkdir(parents=True, exist_ok=True) def training_step(self, batch: Dict, batch_idx: int): - opt = self.optimizers() scheduler = self.lr_schedulers() loss, acc = self.model.forward( - batch['phoneme_ids'], batch['phoneme_ids_len'], - batch['semantic_ids'], batch['semantic_ids_len'], - batch['bert_feature']) + batch["phoneme_ids"], + batch["phoneme_ids_len"], + batch["semantic_ids"], + batch["semantic_ids_len"], + batch["bert_feature"], + ) self.manual_backward(loss) if batch_idx > 0 and batch_idx % 4 == 0: opt.step() @@ -47,63 +54,67 @@ class Text2SemanticLightningModule(LightningModule): on_step=True, on_epoch=True, prog_bar=True, - sync_dist=True) + sync_dist=True, + ) self.log( "lr", scheduler.get_last_lr()[0], on_epoch=True, prog_bar=True, - sync_dist=True) + sync_dist=True, + ) self.log( f"top_{self.top_k}_acc", acc, on_step=True, on_epoch=True, prog_bar=True, - sync_dist=True) + sync_dist=True, + ) - def validation_step(self, batch: Dict, batch_idx: int):return - # # get loss - # loss, acc = self.model.forward( - # batch['phoneme_ids'], batch['phoneme_ids_len'], - # batch['semantic_ids'], batch['semantic_ids_len'], - # batch['bert_feature'] - # ) - # - # self.log( - # "val_total_loss", - # loss, - # on_step=True, - # on_epoch=True, - # prog_bar=True, - # sync_dist=True) - # self.log( - # f"val_top_{self.top_k}_acc", - # acc, - # on_step=True, - # on_epoch=True, - # prog_bar=True, - # sync_dist=True) - # - # # get infer output - # semantic_len = batch['semantic_ids'].size(1) - # prompt_len = min(int(semantic_len * 0.5), 150) - # prompt = batch['semantic_ids'][:, :prompt_len] - # pred_semantic = self.model.infer(batch['phoneme_ids'], - # batch['phoneme_ids_len'], prompt, - # batch['bert_feature'] - # ) - # save_name = f'semantic_toks_{batch_idx}.pt' - # save_path = os.path.join(self.eval_dir, save_name) - # torch.save(pred_semantic.detach().cpu(), save_path) + def validation_step(self, batch: Dict, batch_idx: int): + return + + # # get loss + # loss, acc = self.model.forward( + # batch['phoneme_ids'], batch['phoneme_ids_len'], + # batch['semantic_ids'], batch['semantic_ids_len'], + # batch['bert_feature'] + # ) + # + # self.log( + # "val_total_loss", + # loss, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # sync_dist=True) + # self.log( + # f"val_top_{self.top_k}_acc", + # acc, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # sync_dist=True) + # + # # get infer output + # semantic_len = batch['semantic_ids'].size(1) + # prompt_len = min(int(semantic_len * 0.5), 150) + # prompt = batch['semantic_ids'][:, :prompt_len] + # pred_semantic = self.model.infer(batch['phoneme_ids'], + # batch['phoneme_ids_len'], prompt, + # batch['bert_feature'] + # ) + # save_name = f'semantic_toks_{batch_idx}.pt' + # save_path = os.path.join(self.eval_dir, save_name) + # torch.save(pred_semantic.detach().cpu(), save_path) def configure_optimizers(self): model_parameters = self.model.parameters() parameters_names = [] - parameters_names.append([ - name_param_pair[0] - for name_param_pair in self.model.named_parameters() - ]) + parameters_names.append( + [name_param_pair[0] for name_param_pair in self.model.named_parameters()] + ) lm_opt = ScaledAdam( model_parameters, lr=0.01, @@ -111,18 +122,19 @@ class Text2SemanticLightningModule(LightningModule): clipping_scale=2.0, parameters_names=parameters_names, show_dominant_parameters=False, - clipping_update_period=1000, ) + clipping_update_period=1000, + ) return { "optimizer": lm_opt, "lr_scheduler": { - "scheduler": - WarmupCosineLRSchedule( + "scheduler": WarmupCosineLRSchedule( lm_opt, - init_lr=self.config['optimizer']['lr_init'], - peak_lr=self.config['optimizer']['lr'], - end_lr=self.config['optimizer']['lr_end'], - warmup_steps=self.config['optimizer']['warmup_steps'], - total_steps=self.config['optimizer']['decay_steps']) - } + init_lr=self.config["optimizer"]["lr_init"], + peak_lr=self.config["optimizer"]["lr"], + end_lr=self.config["optimizer"]["lr_end"], + warmup_steps=self.config["optimizer"]["warmup_steps"], + total_steps=self.config["optimizer"]["decay_steps"], + ) + }, } diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 9f5337e..9f8330b 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -3,7 +3,12 @@ import torch from tqdm import tqdm from AR.models.utils import make_pad_mask -from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync +from AR.models.utils import ( + topk_sampling, + sample, + logits_to_probs, + multinomial_sample_one_no_sync, +) from AR.modules.embedding import SinePositionalEmbedding from AR.modules.embedding import TokenEmbedding from AR.modules.transformer import LayerNorm @@ -22,35 +27,39 @@ default_config = { "p_dropout": 0.0, "vocab_size": 1024 + 1, "phoneme_vocab_size": 512, - "EOS": 1024 + "EOS": 1024, } class Text2SemanticDecoder(nn.Module): def __init__(self, config, norm_first=False, top_k=3): super(Text2SemanticDecoder, self).__init__() - self.model_dim = config['model']["hidden_dim"] - self.embedding_dim = config['model']["embedding_dim"] - self.num_head = config['model']["head"] - self.num_layers = config['model']["n_layer"] + self.model_dim = config["model"]["hidden_dim"] + self.embedding_dim = config["model"]["embedding_dim"] + self.num_head = config["model"]["head"] + self.num_layers = config["model"]["n_layer"] self.norm_first = norm_first - self.vocab_size = config['model']["vocab_size"] - self.phoneme_vocab_size = config['model']["phoneme_vocab_size"] - self.p_dropout = config['model']["dropout"] - self.EOS = config['model']["EOS"] + self.vocab_size = config["model"]["vocab_size"] + self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] + self.p_dropout = config["model"]["dropout"] + self.EOS = config["model"]["EOS"] self.norm_first = norm_first assert self.EOS == self.vocab_size - 1 # should be same as num of kmeans bin # assert self.EOS == 1024 self.bert_proj = nn.Linear(1024, self.embedding_dim) self.ar_text_embedding = TokenEmbedding( - self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) + self.embedding_dim, self.phoneme_vocab_size, self.p_dropout + ) self.ar_text_position = SinePositionalEmbedding( - self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.embedding_dim, dropout=0.1, scale=False, alpha=True + ) self.ar_audio_embedding = TokenEmbedding( - self.embedding_dim, self.vocab_size, self.p_dropout) + self.embedding_dim, self.vocab_size, self.p_dropout + ) self.ar_audio_position = SinePositionalEmbedding( - self.embedding_dim, dropout=0.1, scale=False, alpha=True) + self.embedding_dim, dropout=0.1, scale=False, alpha=True + ) self.h = TransformerEncoder( TransformerEncoderLayer( @@ -59,28 +68,30 @@ class Text2SemanticDecoder(nn.Module): dim_feedforward=self.model_dim * 4, dropout=0.1, batch_first=True, - norm_first=norm_first, ), + norm_first=norm_first, + ), num_layers=self.num_layers, - norm=LayerNorm(self.model_dim) if norm_first else None, ) + norm=LayerNorm(self.model_dim) if norm_first else None, + ) - self.ar_predict_layer = nn.Linear( - self.model_dim, self.vocab_size, bias=False) - self.loss_fct = nn.CrossEntropyLoss(reduction='sum') + self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) + self.loss_fct = nn.CrossEntropyLoss(reduction="sum") self.ar_accuracy_metric = MulticlassAccuracy( self.vocab_size, top_k=top_k, average="micro", multidim_average="global", - ignore_index=self.EOS, ) + ignore_index=self.EOS, + ) def forward(self, x, x_lens, y, y_lens, bert_feature): - ''' + """ x: phoneme_ids y: semantic_ids - ''' + """ x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1,2)) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) x_mask = make_pad_mask(x_lens) @@ -102,18 +113,23 @@ class Text2SemanticDecoder(nn.Module): x_attn_mask = F.pad( torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), (0, y_len), - value=True, ) + value=True, + ) y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), - diagonal=1, ), + diagonal=1, + ), (x_len, 0), - value=False, ) + value=False, + ) xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) bsz, src_len = x.shape[0], x_len + y_len - _xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len) - .expand(-1, self.num_head, -1, -1) - .reshape(bsz * self.num_head, 1, src_len)) + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_head, -1, -1) + .reshape(bsz * self.num_head, 1, src_len) + ) xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) @@ -122,26 +138,28 @@ class Text2SemanticDecoder(nn.Module): xy_pos = torch.concat([x, y_pos], dim=1) xy_dec, _ = self.h( (xy_pos, None), - mask=xy_attn_mask, ) + mask=xy_attn_mask, + ) logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) # loss # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum - loss = F.cross_entropy(logits, targets, reduction='sum') + loss = F.cross_entropy(logits, targets, reduction="sum") acc = self.ar_accuracy_metric(logits.detach(), targets).item() return loss, acc # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 - def infer(self, - x, - x_lens, - prompts, - bert_feature, - top_k: int=-100, - early_stop_num: int=-1, - temperature: float=1.0): - + def infer( + self, + x, + x_lens, + prompts, + bert_feature, + top_k: int = -100, + early_stop_num: int = -1, + temperature: float = 1.0, + ): x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1,2)) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) # AR Decoder @@ -159,35 +177,37 @@ class Text2SemanticDecoder(nn.Module): x_attn_mask_pad = F.pad( x_attn_mask, (0, y_len), - value=True, ) + value=True, + ) y_attn_mask = F.pad( - torch.triu( - torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), - value=False, ) - xy_attn_mask = torch.concat( - [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + y.device + ) xy_dec, _ = self.h( (xy_pos, None), - mask=xy_attn_mask, ) + mask=xy_attn_mask, + ) logits = self.ar_predict_layer(xy_dec[:, -1]) samples = topk_sampling( - logits, top_k=top_k, top_p=1.0, temperature=temperature) + logits, top_k=top_k, top_p=1.0, temperature=temperature + ) - if early_stop_num != -1 and (y.shape[1] - prefix_len - ) > early_stop_num: + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True - if torch.argmax( - logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) stop = True if stop: if prompts.shape[1] == y.shape[1]: y = torch.concat([y, torch.zeros_like(samples)], dim=1) - print('bad zero prediction') + print("bad zero prediction") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break # 本次生成的 semantic_ids 和之前的 y 构成新的 y @@ -198,23 +218,24 @@ class Text2SemanticDecoder(nn.Module): return y def pad_y_eos(self, y, y_mask_int, eos_id): - targets = F.pad( - y, (0, 1), value=0) + eos_id * F.pad( - y_mask_int, (0, 1), value=1) + targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( + y_mask_int, (0, 1), value=1 + ) # 错位 return targets[:, :-1], targets[:, 1:] - def infer_panel(self, - x,#####全部文本token - x_lens, - prompts,####参考音频token - bert_feature, - top_k: int=-100, - early_stop_num: int=-1, - temperature: float=1.0): - + def infer_panel( + self, + x, #####全部文本token + x_lens, + prompts, ####参考音频token + bert_feature, + top_k: int = -100, + early_stop_num: int = -1, + temperature: float = 1.0, + ): x = self.ar_text_embedding(x) - x = x + self.bert_proj(bert_feature.transpose(1,2)) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) # AR Decoder @@ -224,75 +245,81 @@ class Text2SemanticDecoder(nn.Module): x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) stop = False # print(1111111,self.num_layers) - cache={ - "all_stage":self.num_layers, - "k":[None]*self.num_layers,###根据配置自己手写 - "v":[None]*self.num_layers, + cache = { + "all_stage": self.num_layers, + "k": [None] * self.num_layers, ###根据配置自己手写 + "v": [None] * self.num_layers, # "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存,每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了 - "y_emb":None,##只需要对最新的samples求emb,再拼历史的就行 + "y_emb": None, ##只需要对最新的samples求emb,再拼历史的就行 # "logits":None,###原版就已经只对结尾求再拼接了,不用管 # "xy_dec":None,###不需要,本来只需要最后一个做logits - "first_infer":1, - "stage":0 + "first_infer": 1, + "stage": 0, } for idx in tqdm(range(1500)): - if(cache["first_infer"]==1): + if cache["first_infer"] == 1: y_emb = self.ar_audio_embedding(y) else: - y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1) - cache["y_emb"]=y_emb + y_emb = torch.cat( + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 + ) + cache["y_emb"] = y_emb y_pos = self.ar_audio_position(y_emb) # x 和逐渐增长的 y 一起输入给模型 - if(cache["first_infer"]==1): + if cache["first_infer"] == 1: xy_pos = torch.concat([x, y_pos], dim=1) else: - xy_pos=y_pos[:,-1:] + xy_pos = y_pos[:, -1:] y_len = y_pos.shape[1] ###以下3个不做缓存 - if (cache["first_infer"] == 1): + if cache["first_infer"] == 1: x_attn_mask_pad = F.pad( - x_attn_mask, - (0, y_len),###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) - value=True, ) - y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y) - torch.triu( - torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), (x_len, 0), - value=False, ) - xy_attn_mask = torch.concat( - [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + y.device + ) else: ###最右边一列(是错的) # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device) # xy_attn_mask[:,-1]=False ###最下面一行(是对的) - xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device) + xy_attn_mask = torch.zeros( + (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device + ) # pdb.set_trace() ###缓存重头戏 # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len) - xy_dec, _ = self.h( - (xy_pos, None), - mask=xy_attn_mask,cache=cache ) - logits = self.ar_predict_layer(xy_dec[:, -1])##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 + xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache) + logits = self.ar_predict_layer( + xy_dec[:, -1] + ) ##不用改,如果用了cache的默认就是只有一帧,取最后一帧一样的 # samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) - samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) - if early_stop_num != -1 and (y.shape[1] - prefix_len - ) > early_stop_num: + samples = sample( + logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35 + )[0].unsqueeze(0) + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) stop = True - if torch.argmax( - logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) stop = True if stop: if prompts.shape[1] == y.shape[1]: y = torch.concat([y, torch.zeros_like(samples)], dim=1) - print('bad zero prediction') + print("bad zero prediction") print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") break # 本次生成的 semantic_ids 和之前的 y 构成新的 y # print(samples.shape)#[1,1]#第一个1是bs y = torch.concat([y, samples], dim=1) - cache["first_infer"]=0 - return y,idx + cache["first_infer"] = 0 + return y, idx diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index dfe1d8a..25fe446 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F + def sequence_mask(length, max_length=None): if max_length is None: max_length = length.max() @@ -9,7 +10,7 @@ def sequence_mask(length, max_length=None): return x.unsqueeze(0) < length.unsqueeze(1) -def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor: +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: """ Args: lengths: @@ -38,11 +39,9 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor: # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py -def top_k_top_p_filtering(logits, - top_k=0, - top_p=1.0, - filter_value=-float("Inf"), - min_tokens_to_keep=1): +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) @@ -53,16 +52,14 @@ def top_k_top_p_filtering(logits, From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: - top_k = min(max(top_k, min_tokens_to_keep), - logits.size(-1)) # Safety check + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum( - F.softmax(sorted_logits, dim=-1), dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p @@ -70,13 +67,13 @@ def top_k_top_p_filtering(logits, # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1].clone() + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove) + 1, sorted_indices, sorted_indices_to_remove + ) logits[indices_to_remove] = filter_value return logits @@ -100,6 +97,8 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): from typing import Optional, Tuple + + def multinomial_sample_one_no_sync( probs_sort, ): # Does multinomial sampling without a cuda synchronization @@ -115,7 +114,7 @@ def logits_to_probs( top_p: Optional[int] = None, repetition_penalty: float = 1.0, ): - previous_tokens=previous_tokens.squeeze() + previous_tokens = previous_tokens.squeeze() # print(logits.shape,previous_tokens.shape) # pdb.set_trace() if previous_tokens is not None and repetition_penalty != 1.0: @@ -159,4 +158,3 @@ def sample( ) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs - diff --git a/GPT_SoVITS/AR/modules/activation.py b/GPT_SoVITS/AR/modules/activation.py index 50631e9..5ca888b 100644 --- a/GPT_SoVITS/AR/modules/activation.py +++ b/GPT_SoVITS/AR/modules/activation.py @@ -13,7 +13,9 @@ from torch.nn.parameter import Parameter from torch.nn import functional as F from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched -F.multi_head_attention_forward=multi_head_attention_forward_patched + +F.multi_head_attention_forward = multi_head_attention_forward_patched + class MultiheadAttention(Module): r"""Allows the model to jointly attend to information @@ -76,66 +78,71 @@ class MultiheadAttention(Module): bias_v: Optional[torch.Tensor] def __init__( - self, - embed_dim, - num_heads, - dropout=0.0, - bias=True, - add_bias_kv=False, - add_zero_attn=False, - kdim=None, - vdim=None, - batch_first=False, - linear1_cls=Linear, - linear2_cls=Linear, - device=None, - dtype=None, ) -> None: + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = (self.kdim == embed_dim and - self.vdim == embed_dim) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads - assert (self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" if add_bias_kv: - self.bias_k = Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs)) - self.bias_v = Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None if linear1_cls == Linear: if not self._qkv_same_embed_dim: self.q_proj_weight = Parameter( - torch.empty((embed_dim, embed_dim), **factory_kwargs)) + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) self.k_proj_weight = Parameter( - torch.empty((embed_dim, self.kdim), **factory_kwargs)) + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) self.v_proj_weight = Parameter( - torch.empty((embed_dim, self.vdim), **factory_kwargs)) + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) self.register_parameter("in_proj_weight", None) else: self.in_proj_weight = Parameter( - torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) if bias: self.in_proj_bias = Parameter( - torch.empty(3 * embed_dim, **factory_kwargs)) + torch.empty(3 * embed_dim, **factory_kwargs) + ) else: self.register_parameter("in_proj_bias", None) self.out_proj = NonDynamicallyQuantizableLinear( - embed_dim, embed_dim, bias=bias, **factory_kwargs) + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) self._reset_parameters() else: @@ -143,7 +150,8 @@ class MultiheadAttention(Module): raise NotImplementedError else: self.in_proj_linear = linear1_cls( - embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) self.in_proj_weight = self.in_proj_linear.weight self.register_parameter("q_proj_weight", None) @@ -156,7 +164,8 @@ class MultiheadAttention(Module): self.register_parameter("in_proj_bias", None) self.out_proj = linear2_cls( - embed_dim, embed_dim, bias=bias, **factory_kwargs) + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) if self.bias_k is not None: xavier_normal_(self.bias_k) @@ -190,14 +199,15 @@ class MultiheadAttention(Module): super(MultiheadAttention, self).__setstate__(state) def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - key_padding_mask: Optional[Tensor]=None, - need_weights: bool=True, - attn_mask: Optional[Tensor]=None, - average_attn_weights: bool=True,cache=None + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + cache=None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -251,23 +261,26 @@ class MultiheadAttention(Module): if key_padding_mask is not None: _kpm_dtype = key_padding_mask.dtype if _kpm_dtype != torch.bool and not torch.is_floating_point( - key_padding_mask): + key_padding_mask + ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) why_not_fast_path = "" if not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) elif query is not key or key is not value: # When lifting this restriction, don't forget to either # enforce that the dtypes all match or test cases where # they don't! why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif (self.in_proj_bias is not None and - query.dtype != self.in_proj_bias.dtype): + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" - elif (self.in_proj_weight is not None and - query.dtype != self.in_proj_weight.dtype): + elif ( + self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype + ): # this case will fail anyway, but at least they'll get a useful error message. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" elif self.training: @@ -288,29 +301,41 @@ class MultiheadAttention(Module): why_not_fast_path = "attn_mask was not None" elif query.is_nested and key_padding_mask is not None: why_not_fast_path = ( - "key_padding_mask is not supported with NestedTensor input") + "key_padding_mask is not supported with NestedTensor input" + ) elif self.num_heads % 2 == 1: why_not_fast_path = "num_heads is odd" elif torch.is_autocast_enabled(): why_not_fast_path = "autocast is enabled" if not why_not_fast_path: - tensor_args = (query, key, value, self.in_proj_weight, - self.in_proj_bias, self.out_proj.weight, - self.out_proj.bias, ) + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) # We have to use list comprehensions below because TorchScript does not support # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) - for x in tensor_args]): - why_not_fast_path = ( - "some Tensor argument is neither CUDA nor CPU") + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any( - [x is not None and x.requires_grad for x in tensor_args]): + [x is not None and x.requires_grad for x in tensor_args] + ): why_not_fast_path = ( "grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") + "input/output projection weights or biases requires_grad" + ) if not why_not_fast_path: return torch._native_multi_head_attention( query, @@ -322,17 +347,21 @@ class MultiheadAttention(Module): self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, - key_padding_mask - if key_padding_mask is not None else attn_mask, + key_padding_mask if key_padding_mask is not None else attn_mask, need_weights, average_attn_weights, - 1 if key_padding_mask is not None else 0 - if attn_mask is not None else None, ) + 1 + if key_padding_mask is not None + else 0 + if attn_mask is not None + else None, + ) any_nested = query.is_nested or key.is_nested or value.is_nested assert not any_nested, ( "MultiheadAttention does not support NestedTensor outside of its fast path. " - + f"The fast path was not hit because {why_not_fast_path}") + + f"The fast path was not hit because {why_not_fast_path}" + ) if self.batch_first and is_batched: # make sure that the transpose op does not affect the "is" property @@ -343,9 +372,7 @@ class MultiheadAttention(Module): query, key = [x.transpose(1, 0) for x in (query, key)] value = key else: - query, key, value = [ - x.transpose(1, 0) for x in (query, key, value) - ] + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: attn_output, attn_output_weights = F.multi_head_attention_forward( @@ -370,7 +397,9 @@ class MultiheadAttention(Module): q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, - average_attn_weights=average_attn_weights,cache=cache ) + average_attn_weights=average_attn_weights, + cache=cache, + ) else: attn_output, attn_output_weights = F.multi_head_attention_forward( query, @@ -390,7 +419,9 @@ class MultiheadAttention(Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, - average_attn_weights=average_attn_weights,cache=cache ) + average_attn_weights=average_attn_weights, + cache=cache, + ) if self.batch_first and is_batched: return attn_output.transpose(1, 0), attn_output_weights else: diff --git a/GPT_SoVITS/AR/modules/embedding.py b/GPT_SoVITS/AR/modules/embedding.py index 35063c7..3a382f9 100644 --- a/GPT_SoVITS/AR/modules/embedding.py +++ b/GPT_SoVITS/AR/modules/embedding.py @@ -7,10 +7,11 @@ from torch import nn class TokenEmbedding(nn.Module): def __init__( - self, - embedding_dim: int, - vocab_size: int, - dropout: float=0.0, ): + self, + embedding_dim: int, + vocab_size: int, + dropout: float = 0.0, + ): super().__init__() self.vocab_size = vocab_size @@ -24,7 +25,7 @@ class TokenEmbedding(nn.Module): return self.word_embeddings.weight def embedding(self, index: int) -> torch.Tensor: - return self.word_embeddings.weight[index:index + 1] + return self.word_embeddings.weight[index : index + 1] def forward(self, x: torch.Tensor): x = self.word_embeddings(x) @@ -34,11 +35,12 @@ class TokenEmbedding(nn.Module): class SinePositionalEmbedding(nn.Module): def __init__( - self, - embedding_dim: int, - dropout: float=0.0, - scale: bool=False, - alpha: bool=False, ): + self, + embedding_dim: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): super().__init__() self.embedding_dim = embedding_dim self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 @@ -59,13 +61,14 @@ class SinePositionalEmbedding(nn.Module): pe = torch.zeros(x.size(1), self.embedding_dim) if self.reverse: position = torch.arange( - x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) else: - position = torch.arange( - 0, x.size(1), dtype=torch.float32).unsqueeze(1) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( - torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * - -(math.log(10000.0) / self.embedding_dim)) + torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.embedding_dim) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) @@ -74,5 +77,5 @@ class SinePositionalEmbedding(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: self.extend_pe(x) output = x.unsqueeze(-1) if x.ndim == 2 else x - output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)] + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] return self.dropout(output) diff --git a/GPT_SoVITS/AR/modules/lr_schedulers.py b/GPT_SoVITS/AR/modules/lr_schedulers.py index 955d804..7dec462 100644 --- a/GPT_SoVITS/AR/modules/lr_schedulers.py +++ b/GPT_SoVITS/AR/modules/lr_schedulers.py @@ -12,14 +12,16 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. """ - def __init__(self, - optimizer, - init_lr, - peak_lr, - end_lr, - warmup_steps=10000, - total_steps=400000, - current_step=0): + def __init__( + self, + optimizer, + init_lr, + peak_lr, + end_lr, + warmup_steps=10000, + total_steps=400000, + current_step=0, + ): self.init_lr = init_lr self.peak_lr = peak_lr self.end_lr = end_lr @@ -33,10 +35,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): self._last_lr = [self.lr] def set_lr(self, lr): - self._last_lr = [g['lr'] for g in self.optimizer.param_groups] + self._last_lr = [g["lr"] for g in self.optimizer.param_groups] for g in self.optimizer.param_groups: # g['lr'] = lr - g['lr'] = self.end_lr###锁定用线性 + g["lr"] = self.end_lr ###锁定用线性 def step(self): if self._current_step < self.warmup_steps: @@ -47,7 +49,8 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): else: decay_ratio = (self._current_step - self.warmup_steps) / ( - self.total_steps - self.warmup_steps) + self.total_steps - self.warmup_steps + ) if decay_ratio < 0.0 or decay_ratio > 1.0: raise RuntimeError( "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings." @@ -55,25 +58,19 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) - self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定! + self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定! self.set_lr(lr) self.lr = lr self._current_step += 1 return self.lr - -if __name__ == '__main__': +if __name__ == "__main__": m = nn.Linear(10, 10) opt = Adam(m.parameters(), lr=1e-4) s = WarmupCosineLRSchedule( - opt, - 1e-6, - 2e-4, - 1e-6, - warmup_steps=2000, - total_steps=20000, - current_step=0) + opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0 + ) lrs = [] for i in range(25000): s.step() diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py index bfb748e..5720670 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py @@ -1,9 +1,16 @@ from torch.nn.functional import * -from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed +from torch.nn.functional import ( + _mha_shape_check, + _canonical_mask, + _none_or_dtype, + _in_projection_packed, +) + # import torch # Tensor = torch.Tensor # from typing import Callable, List, Optional, Tuple, Union + def multi_head_attention_forward_patched( query: Tensor, key: Tensor, @@ -29,7 +36,8 @@ def multi_head_attention_forward_patched( static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None, average_attn_weights: bool = True, - is_causal: bool = False,cache=None + is_causal: bool = False, + cache=None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -105,7 +113,17 @@ def multi_head_attention_forward_patched( :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. """ - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) if has_torch_function(tens_ops): return handle_torch_function( multi_head_attention_forward, @@ -134,10 +152,13 @@ def multi_head_attention_forward_patched( v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v, - average_attn_weights=average_attn_weights,cache=cache + average_attn_weights=average_attn_weights, + cache=cache, ) - is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + is_batched = _mha_shape_check( + query, key, value, key_padding_mask, attn_mask, num_heads + ) # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input # is batched, run the computation and before returning squeeze the @@ -159,7 +180,7 @@ def multi_head_attention_forward_patched( mask_name="key_padding_mask", other_type=_none_or_dtype(attn_mask), other_name="attn_mask", - target_type=query.dtype + target_type=query.dtype, ) if is_causal and attn_mask is None: @@ -184,59 +205,84 @@ def multi_head_attention_forward_patched( check_other=False, ) - if key_padding_mask is not None: # We have the attn_mask, and use that to merge kpm into it. # Turn off use of is_causal hint, as the merged mask is no # longer causal. is_causal = False - assert embed_dim == embed_dim_to_check, \ - f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" if isinstance(embed_dim, torch.Tensor): # embed_dim can be a tensor when JIT tracing - head_dim = embed_dim.div(num_heads, rounding_mode='trunc') + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert key.shape[:2] == value.shape[:2], \ - f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" else: - assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" # # compute in-projection # if not use_separate_proj_weight: - assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) else: - assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" - assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" - assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" if in_proj_bias is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = in_proj_bias.chunk(3) - q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) - if(cache!=None): - if(cache["first_infer"]==1): - cache["k"][cache["stage"]]=k + q, k, v = _in_projection( + query, + key, + value, + q_proj_weight, + k_proj_weight, + v_proj_weight, + b_q, + b_k, + b_v, + ) + if cache != None: + if cache["first_infer"] == 1: + cache["k"][cache["stage"]] = k # print(0,cache["k"].shape) - cache["v"][cache["stage"]]=v - else:###12个layer每个都要留自己的cache_kv + cache["v"][cache["stage"]] = v + else: ###12个layer每个都要留自己的cache_kv # print(1,cache["k"].shape) - cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1,但是proj的时候可能transpose了所以时序到0维了 - cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0) + cache["k"][cache["stage"]] = torch.cat( + [cache["k"][cache["stage"]], k], 0 + ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了 + cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0) # print(2, cache["k"].shape) src_len = cache["k"][cache["stage"]].shape[0] - k=cache["k"][cache["stage"]] - v=cache["v"][cache["stage"]] + k = cache["k"][cache["stage"]] + v = cache["v"][cache["stage"]] # if attn_mask is not None: # attn_mask=attn_mask[-1:,] - # print(attn_mask.shape,attn_mask) + # print(attn_mask.shape,attn_mask) cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] # print(2333,cache) # prep attention mask @@ -255,14 +301,20 @@ def multi_head_attention_forward_patched( if attn_mask.dim() == 2: correct_2d_size = (tgt_len, src_len) if attn_mask.shape != correct_2d_size: - raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) attn_mask = attn_mask.unsqueeze(0) elif attn_mask.dim() == 3: correct_3d_size = (bsz * num_heads, tgt_len, src_len) if attn_mask.shape != correct_3d_size: - raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) else: - raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) # add bias along batch dimension (currently second) if bias_k is not None and bias_v is not None: @@ -286,26 +338,34 @@ def multi_head_attention_forward_patched( k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert static_k.size(0) == bsz * num_heads, \ - f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" - assert static_k.size(2) == head_dim, \ - f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + assert ( + static_k.size(0) == bsz * num_heads + ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert ( + static_k.size(2) == head_dim + ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" k = static_k if static_v is None: v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert static_v.size(0) == bsz * num_heads, \ - f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" - assert static_v.size(2) == head_dim, \ - f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + assert ( + static_v.size(0) == bsz * num_heads + ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert ( + static_v.size(2) == head_dim + ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" v = static_v # add zero attention along batch dimension (now first) if add_zero_attn: zero_attn_shape = (bsz * num_heads, 1, head_dim) - k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) - v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 + ) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: @@ -316,10 +376,15 @@ def multi_head_attention_forward_patched( # merge key padding and attention masks if key_padding_mask is not None: - assert key_padding_mask.shape == (bsz, src_len), \ - f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" - key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ - expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, src_len) + ) if attn_mask is None: attn_mask = key_padding_mask else: @@ -337,10 +402,14 @@ def multi_head_attention_forward_patched( B, Nt, E = q.shape q_scaled = q / math.sqrt(E) - assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + assert not ( + is_causal and attn_mask is None + ), "FIXME: is_causal not implemented for need_weights" if attn_mask is not None: - attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) + attn_output_weights = torch.baddbmm( + attn_mask, q_scaled, k.transpose(-2, -1) + ) else: attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) attn_output_weights = softmax(attn_output_weights, dim=-1) @@ -349,7 +418,9 @@ def multi_head_attention_forward_patched( attn_output = torch.bmm(attn_output_weights, v) - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + ) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) @@ -377,8 +448,12 @@ def multi_head_attention_forward_patched( k = k.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim) - attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) - attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + attn_output = scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + ) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) diff --git a/GPT_SoVITS/AR/modules/scaling.py b/GPT_SoVITS/AR/modules/scaling.py index ec31d61..9256a8c 100644 --- a/GPT_SoVITS/AR/modules/scaling.py +++ b/GPT_SoVITS/AR/modules/scaling.py @@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor) - ) + torch.rand_like(deriv) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -75,7 +76,7 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - (d, ) = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 @@ -96,11 +97,12 @@ class DoubleSwish(torch.nn.Module): class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, ) -> Tensor: + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim @@ -125,16 +127,22 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return (x_grad - neg_delta_grad, None, None, None, ) + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, ) -> Tensor: + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -145,23 +153,25 @@ def _compute_scale_factor( else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ( - (min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor) + min=0, max=max_factor + ) return below_threshold - above_threshold def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, ) -> Tensor: + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -171,18 +181,18 @@ def _compute_sign_factor( else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_( - min=0, max=max_factor) + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_( - min=0, max=max_factor) + factor2 = ( + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) @@ -230,17 +240,18 @@ class ActivationBalancer(torch.nn.Module): """ def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float=0.05, - max_positive: float=0.95, - max_factor: float=0.04, - sign_gain_factor: float=0.01, - scale_gain_factor: float=0.02, - min_abs: float=0.2, - max_abs: float=100.0, - min_prob: float=0.1, ): + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, + ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim @@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module): self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - torch.jit.is_tracing()): + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count @@ -276,7 +286,7 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5**(1 + (count / 4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) if random.random() < prob: sign_gain_factor = 0.5 @@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module): self.min_positive, self.max_positive, gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, ) + max_factor=self.max_factor, + ) else: sign_factor = None @@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module): min_abs=self.min_abs, max_abs=self.max_abs, gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, ) + max_factor=self.max_factor, + ) return ActivationBalancerFunction.apply( x, scale_factor, sign_factor, - self.channel_dim, ) + self.channel_dim, + ) else: return _no_op(x) -def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, - min_prob=0.25) -> nn.Sequential: +def BalancedDoubleSwish( + d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 +) -> nn.Sequential: """ ActivationBalancer -> DoubleSwish """ balancer = ActivationBalancer( - d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob) + d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob + ) return nn.Sequential( balancer, - DoubleSwish(), ) + DoubleSwish(), + ) diff --git a/GPT_SoVITS/AR/modules/transformer.py b/GPT_SoVITS/AR/modules/transformer.py index 04f0b1b..7921f48 100644 --- a/GPT_SoVITS/AR/modules/transformer.py +++ b/GPT_SoVITS/AR/modules/transformer.py @@ -26,26 +26,28 @@ class LayerNorm(nn.Module): elementwise_affine: bool def __init__( - self, - normalized_shape: _shape_t, - eps: float=1e-5, - elementwise_affine: bool=True, - device=None, - dtype=None, ) -> None: + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment - normalized_shape = (normalized_shape, ) # type: ignore[assignment] - self.normalized_shape = tuple( - normalized_shape) # type: ignore[arg-type] + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs)) + torch.empty(self.normalized_shape, **factory_kwargs) + ) self.bias = nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs)) + torch.empty(self.normalized_shape, **factory_kwargs) + ) else: self.register_parameter("weight", None) self.register_parameter("bias", None) @@ -57,36 +59,43 @@ class LayerNorm(nn.Module): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) - def forward(self, input: Tensor, embedding: Any=None) -> Tensor: + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): input, embedding = input - return (F.layer_norm( - input, - self.normalized_shape, - self.weight, - self.bias, - self.eps, ), embedding, ) + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) assert embedding is None - return F.layer_norm(input, self.normalized_shape, self.weight, - self.bias, self.eps) + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) def extra_repr(self) -> str: return ( "{normalized_shape}, eps={eps}, " - "elementwise_affine={elementwise_affine}".format(**self.__dict__)) + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) class IdentityNorm(nn.Module): def __init__( - self, - d_model: int, - eps: float=1e-5, - device=None, - dtype=None, ) -> None: + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: super(IdentityNorm, self).__init__() - def forward(self, input: Tensor, embedding: Any=None) -> Tensor: + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): return input @@ -121,11 +130,13 @@ class TransformerEncoder(nn.Module): self.norm = norm def forward( - self, - src: Tensor, - mask: Optional[Tensor]=None, - src_key_padding_mask: Optional[Tensor]=None, - return_layer_states: bool=False,cache=None ) -> Tensor: + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + cache=None, + ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: @@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module): output = mod( output, src_mask=mask, - src_key_padding_mask=src_key_padding_mask, cache=cache) + src_key_padding_mask=src_key_padding_mask, + cache=cache, + ) layer_states.append(output[0]) if self.norm is not None: @@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module): output = src for mod in self.layers: - output = mod(output, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, cache=cache) + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + cache=cache, + ) if self.norm is not None: output = self.norm(output) @@ -168,43 +184,47 @@ class TransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int=2048, - dropout: float=0.1, - activation: Union[str, Callable[[Tensor], Tensor]]=F.relu, - batch_first: bool=False, - norm_first: bool=False, - device=None, - dtype=None, - linear1_self_attention_cls: nn.Module=nn.Linear, - linear2_self_attention_cls: nn.Module=nn.Linear, - linear1_feedforward_cls: nn.Module=nn.Linear, - linear2_feedforward_cls: nn.Module=nn.Linear, - layer_norm_cls: nn.Module=LayerNorm, - layer_norm_eps: float=1e-5, - adaptive_layer_norm=False, ) -> None: + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() # print(233333333333,d_model,nhead) # import os # os._exit(2333333) self.self_attn = MultiheadAttention( - d_model,#512 16 + d_model, # 512 16 nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, - **factory_kwargs, ) + **factory_kwargs, + ) # Implementation of Feedforward model - self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, - **factory_kwargs) + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) self.dropout = nn.Dropout(dropout) - self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, - **factory_kwargs) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) @@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module): norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls == IdentityNorm: - norm2 = BalancedBasicNorm( - d_model, eps=layer_norm_eps, **factory_kwargs) + norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) else: - norm2 = layer_norm_cls( - d_model, eps=layer_norm_eps, **factory_kwargs) + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if adaptive_layer_norm: self.norm1 = AdaptiveLayerNorm(d_model, norm1) @@ -249,10 +267,12 @@ class TransformerEncoderLayer(nn.Module): self.activation = F.relu def forward( - self, - src: Tensor, - src_mask: Optional[Tensor]=None, - src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor: + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + cache=None, + ) -> Tensor: r"""Pass the input through the encoder layer. Args: @@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module): if src_key_padding_mask is not None: _skpm_dtype = src_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( - src_key_padding_mask): + src_key_padding_mask + ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) @@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module): x = x + self._sa_block( self.norm1(x, stage_embedding), src_mask, - src_key_padding_mask,cache=cache ) + src_key_padding_mask, + cache=cache, + ) x = x + self._ff_block(self.norm2(x, stage_embedding)) else: x = self.norm1( - x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache), - stage_embedding, ) + x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), + stage_embedding, + ) x = self.norm2(x + self._ff_block(x), stage_embedding) if is_src_tuple: @@ -295,12 +319,14 @@ class TransformerEncoderLayer(nn.Module): # self-attention block def _sa_block( - self, - x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor],cache=None ) -> Tensor: + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + cache=None, + ) -> Tensor: # print(x.shape,attn_mask.shape,key_padding_mask) - #torch.Size([1, 188, 512]) torch.Size([188, 188]) None + # torch.Size([1, 188, 512]) torch.Size([188, 188]) None # import os # os._exit(23333) x = self.self_attn( @@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module): x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, - need_weights=False,cache=cache )[0] + need_weights=False, + cache=cache, + )[0] return self.dropout1(x) # feed forward block @@ -328,20 +356,23 @@ class AdaptiveLayerNorm(nn.Module): self.d_model = d_model self.eps = self.norm.eps - def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor: + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: if isinstance(input, tuple): input, embedding = input weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, - dim=-1, ) + dim=-1, + ) return (weight * self.norm(input) + bias, embedding) weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, - dim=-1, ) + dim=-1, + ) return weight * self.norm(input) + bias + def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) diff --git a/GPT_SoVITS/AR/text_processing/phonemizer.py b/GPT_SoVITS/AR/text_processing/phonemizer.py index 83ecfb7..9fcf5c0 100644 --- a/GPT_SoVITS/AR/text_processing/phonemizer.py +++ b/GPT_SoVITS/AR/text_processing/phonemizer.py @@ -27,46 +27,44 @@ class GruutPhonemizer: "—": "—", "…": "… ", "«": "«", - "»": "»" + "»": "»", } - self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])" + self._punctuation_regexp: str = ( + rf"([{''.join(self._special_cases_dict.keys())}])" + ) def _normalize_punctuation(self, text: str) -> str: - text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text) - text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text) + text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text) + text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text) text = regex.sub(r"\pZ+", r" ", text) return text.strip() def _convert_punctuation(self, word: Word) -> str: if not word.phonemes: - return '' - if word.phonemes[0] in ['‖', '|']: + return "" + if word.phonemes[0] in ["‖", "|"]: return word.text.strip() - phonemes = ''.join(word.phonemes) + phonemes = "".join(word.phonemes) # remove modifier characters ˈˌː with regex - phonemes = re.sub(r'[ˈˌː͡]', '', phonemes) + phonemes = re.sub(r"[ˈˌː͡]", "", phonemes) return phonemes.strip() - def phonemize(self, text: str, espeak: bool=False) -> str: + def phonemize(self, text: str, espeak: bool = False) -> str: text_to_phonemize: str = self._normalize_punctuation(text) sents: List[Sentence] = [ sent - for sent in self._phonemizer( - text_to_phonemize, lang="en-us", espeak=espeak) + for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak) ] words: List[str] = [ self._convert_punctuation(word) for word in itertools.chain(*sents) ] - return ' '.join(words) + return " ".join(words) def transform(self, phonemes): # convert phonemes to ids # dictionary is in symbols.py - return [ - self.symbol_to_id[p] for p in phonemes - if p in self.symbol_to_id.keys() - ] + return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()] if __name__ == "__main__": diff --git a/GPT_SoVITS/AR/text_processing/symbols.py b/GPT_SoVITS/AR/text_processing/symbols.py index 6bc9a0c..c57e2d4 100644 --- a/GPT_SoVITS/AR/text_processing/symbols.py +++ b/GPT_SoVITS/AR/text_processing/symbols.py @@ -1,7 +1,7 @@ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py -PAD = '_' +PAD = "_" PUNCTUATION = ';:,.!?¡¿—…"«»“” ' -LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS) SPACE_ID = SYMBOLS.index(" ") diff --git a/GPT_SoVITS/AR/utils/io.py b/GPT_SoVITS/AR/utils/io.py index 24f1be6..52f1f3c 100644 --- a/GPT_SoVITS/AR/utils/io.py +++ b/GPT_SoVITS/AR/utils/io.py @@ -11,22 +11,24 @@ def load_yaml_config(path): def save_config_to_yaml(config, path): - assert path.endswith('.yaml') - with open(path, 'w') as f: + assert path.endswith(".yaml") + with open(path, "w") as f: f.write(yaml.dump(config)) f.close() def write_args(args, path): - args_dict = dict((name, getattr(args, name)) for name in dir(args) - if not name.startswith('_')) - with open(path, 'a') as args_file: - args_file.write('==> torch version: {}\n'.format(torch.__version__)) + args_dict = dict( + (name, getattr(args, name)) for name in dir(args) if not name.startswith("_") + ) + with open(path, "a") as args_file: + args_file.write("==> torch version: {}\n".format(torch.__version__)) args_file.write( - '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) - args_file.write('==> Cmd:\n') + "==> cudnn version: {}\n".format(torch.backends.cudnn.version()) + ) + args_file.write("==> Cmd:\n") args_file.write(str(sys.argv)) - args_file.write('\n==> args:\n') + args_file.write("\n==> args:\n") for k, v in sorted(args_dict.items()): - args_file.write(' %s: %s\n' % (str(k), str(v))) + args_file.write(" %s: %s\n" % (str(k), str(v))) args_file.close() diff --git a/GPT_SoVITS/configs/s1.yaml b/GPT_SoVITS/configs/s1.yaml index 5481b9b..f8ae17d 100644 --- a/GPT_SoVITS/configs/s1.yaml +++ b/GPT_SoVITS/configs/s1.yaml @@ -1,31 +1,31 @@ train: - seed: 1234 - epochs: 300 - batch_size: 8 - gradient_accumulation: 4 - save_every_n_epoch: 1 - precision: 16 - gradient_clip: 1.0 + seed: 1234 + epochs: 300 + batch_size: 8 + gradient_accumulation: 4 + save_every_n_epoch: 1 + precision: 16 + gradient_clip: 1.0 optimizer: - lr: 0.01 - lr_init: 0.00001 - lr_end: 0.0001 - warmup_steps: 2000 - decay_steps: 40000 + lr: 0.01 + lr_init: 0.00001 + lr_end: 0.0001 + warmup_steps: 2000 + decay_steps: 40000 data: - max_eval_sample: 8 - max_sec: 54 - num_workers: 1 - pad_val: 1024 # same with EOS in model + max_eval_sample: 8 + max_sec: 54 + num_workers: 1 + pad_val: 1024 # same with EOS in model model: - vocab_size: 1025 - phoneme_vocab_size: 512 - embedding_dim: 512 - hidden_dim: 512 - head: 16 - linear_units: 2048 - n_layer: 12 - dropout: 0 - EOS: 1024 + vocab_size: 1025 + phoneme_vocab_size: 512 + embedding_dim: 512 + hidden_dim: 512 + head: 16 + linear_units: 2048 + n_layer: 12 + dropout: 0 + EOS: 1024 inference: - top_k: 5 \ No newline at end of file + top_k: 5 diff --git a/GPT_SoVITS/configs/s1big.yaml b/GPT_SoVITS/configs/s1big.yaml index 3a17ae5..a811150 100644 --- a/GPT_SoVITS/configs/s1big.yaml +++ b/GPT_SoVITS/configs/s1big.yaml @@ -1,31 +1,31 @@ train: - seed: 1234 - epochs: 300 - batch_size: 8 - gradient_accumulation: 4 - save_every_n_epoch: 1 - precision: 16-mixed - gradient_clip: 1.0 + seed: 1234 + epochs: 300 + batch_size: 8 + gradient_accumulation: 4 + save_every_n_epoch: 1 + precision: 16-mixed + gradient_clip: 1.0 optimizer: - lr: 0.01 - lr_init: 0.00001 - lr_end: 0.0001 - warmup_steps: 2000 - decay_steps: 40000 + lr: 0.01 + lr_init: 0.00001 + lr_end: 0.0001 + warmup_steps: 2000 + decay_steps: 40000 data: - max_eval_sample: 8 - max_sec: 54 - num_workers: 1 - pad_val: 1024 # same with EOS in model + max_eval_sample: 8 + max_sec: 54 + num_workers: 1 + pad_val: 1024 # same with EOS in model model: - vocab_size: 1025 - phoneme_vocab_size: 512 - embedding_dim: 1024 - hidden_dim: 1024 - head: 16 - linear_units: 2048 - n_layer: 16 - dropout: 0 - EOS: 1024 + vocab_size: 1025 + phoneme_vocab_size: 512 + embedding_dim: 1024 + hidden_dim: 1024 + head: 16 + linear_units: 2048 + n_layer: 16 + dropout: 0 + EOS: 1024 inference: - top_k: 5 \ No newline at end of file + top_k: 5 diff --git a/GPT_SoVITS/configs/s1big2.yaml b/GPT_SoVITS/configs/s1big2.yaml index 1037fc7..b8b889b 100644 --- a/GPT_SoVITS/configs/s1big2.yaml +++ b/GPT_SoVITS/configs/s1big2.yaml @@ -1,31 +1,31 @@ train: - seed: 1234 - epochs: 300 - batch_size: 12 - gradient_accumulation: 4 - save_every_n_epoch: 1 - precision: 16-mixed - gradient_clip: 1.0 + seed: 1234 + epochs: 300 + batch_size: 12 + gradient_accumulation: 4 + save_every_n_epoch: 1 + precision: 16-mixed + gradient_clip: 1.0 optimizer: - lr: 0.01 - lr_init: 0.00001 - lr_end: 0.0001 - warmup_steps: 2000 - decay_steps: 40000 + lr: 0.01 + lr_init: 0.00001 + lr_end: 0.0001 + warmup_steps: 2000 + decay_steps: 40000 data: - max_eval_sample: 8 - max_sec: 54 - num_workers: 1 - pad_val: 1024 # same with EOS in model + max_eval_sample: 8 + max_sec: 54 + num_workers: 1 + pad_val: 1024 # same with EOS in model model: - vocab_size: 1025 - phoneme_vocab_size: 512 - embedding_dim: 1024 - hidden_dim: 1024 - head: 16 - linear_units: 2048 - n_layer: 6 - dropout: 0 - EOS: 1024 + vocab_size: 1025 + phoneme_vocab_size: 512 + embedding_dim: 1024 + hidden_dim: 1024 + head: 16 + linear_units: 2048 + n_layer: 6 + dropout: 0 + EOS: 1024 inference: - top_k: 5 \ No newline at end of file + top_k: 5 diff --git a/GPT_SoVITS/configs/s1longer.yaml b/GPT_SoVITS/configs/s1longer.yaml index b238abd..3f57abd 100644 --- a/GPT_SoVITS/configs/s1longer.yaml +++ b/GPT_SoVITS/configs/s1longer.yaml @@ -1,31 +1,31 @@ train: - seed: 1234 - epochs: 20 - batch_size: 8 - save_every_n_epoch: 1 - precision: 16-mixed - gradient_clip: 1.0 + seed: 1234 + epochs: 20 + batch_size: 8 + save_every_n_epoch: 1 + precision: 16-mixed + gradient_clip: 1.0 optimizer: - lr: 0.01 - lr_init: 0.00001 - lr_end: 0.0001 - warmup_steps: 2000 - decay_steps: 40000 + lr: 0.01 + lr_init: 0.00001 + lr_end: 0.0001 + warmup_steps: 2000 + decay_steps: 40000 data: - max_eval_sample: 8 - max_sec: 54 - num_workers: 4 - pad_val: 1024 # same with EOS in model + max_eval_sample: 8 + max_sec: 54 + num_workers: 4 + pad_val: 1024 # same with EOS in model model: - vocab_size: 1025 - phoneme_vocab_size: 512 - embedding_dim: 512 - hidden_dim: 512 - head: 16 - linear_units: 2048 - n_layer: 24 - dropout: 0 - EOS: 1024 - random_bert: 0 + vocab_size: 1025 + phoneme_vocab_size: 512 + embedding_dim: 512 + hidden_dim: 512 + head: 16 + linear_units: 2048 + n_layer: 24 + dropout: 0 + EOS: 1024 + random_bert: 0 inference: - top_k: 5 \ No newline at end of file + top_k: 5 diff --git a/GPT_SoVITS/configs/s1mq.yaml b/GPT_SoVITS/configs/s1mq.yaml index 19aac92..b554fd3 100644 --- a/GPT_SoVITS/configs/s1mq.yaml +++ b/GPT_SoVITS/configs/s1mq.yaml @@ -1,77 +1,77 @@ train: - seed: 1234 - epochs: 100 - batch_size: 6 - gradient_accumulation: 4 - save_every_n_epoch: 1 - precision: 32 - gradient_clip: 1.0 + seed: 1234 + epochs: 100 + batch_size: 6 + gradient_accumulation: 4 + save_every_n_epoch: 1 + precision: 32 + gradient_clip: 1.0 optimizer: - lr: 0.01 - lr_init: 0.00001 - lr_end: 0.0001 - warmup_steps: 2000 - decay_steps: 40000 + lr: 0.01 + lr_init: 0.00001 + lr_end: 0.0001 + warmup_steps: 2000 + decay_steps: 40000 data: - max_eval_sample: 8 - max_sec: 40 - num_workers: 1 - pad_val: 1024 # same with EOS in model + max_eval_sample: 8 + max_sec: 40 + num_workers: 1 + pad_val: 1024 # same with EOS in model model: - saving_path: "ckpt/" - resume_checkpoint: null - vocoder_config_path: "quantizer/new_ckpt/config.json" - vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000" - datadir: "/home/liweiche/GigaSpeech/wavs" - metapath: "/home/liweiche/GigaSpeech/train2.json" - val_metapath: "/home/liweiche/GigaSpeech/dev2.json" - sampledir: "logs/" - pretrained_path: null - lr: 0.0001 - batch_size: 200.0 - train_bucket_size: 8192 - training_step: 800000 - optim_flat_percent: 0.0 - warmup_step: 50 - adam_beta1: 0.9 - adam_beta2: 0.98 - ffd_size: 3072 - hidden_size: 768 - enc_nlayers: 6 - dec_nlayers: 6 - nheads: 12 - ar_layer: 4 - ar_ffd_size: 1024 - ar_hidden_size: 256 - ar_nheads: 4 - aligner_softmax_temp: 1.0 - layer_norm_eps: 0.00001 - speaker_embed_dropout: 0.05 - label_smoothing: 0.0 - val_check_interval: 5000 - check_val_every_n_epoch: 1 - precision: "fp16" - nworkers: 16 - distributed: true - accelerator: "ddp" - version: null - accumulate_grad_batches: 1 - use_repetition_token: true - use_repetition_gating: false - repetition_penalty: 1.0 - sampling_temperature: 1.0 - top_k: -1 - min_top_k: 3 - top_p: 0.8 - sample_num: 4 - length_penalty_max_length: 15000 - length_penalty_max_prob: 0.95 - max_input_length: 2048 - max_output_length: 2000 - sample_rate: 16000 - n_codes: 1024 - n_cluster_groups: 1 - phone_context_window: 4 - phoneset_size: 1000 + saving_path: "ckpt/" + resume_checkpoint: null + vocoder_config_path: "quantizer/new_ckpt/config.json" + vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000" + datadir: "/home/liweiche/GigaSpeech/wavs" + metapath: "/home/liweiche/GigaSpeech/train2.json" + val_metapath: "/home/liweiche/GigaSpeech/dev2.json" + sampledir: "logs/" + pretrained_path: null + lr: 0.0001 + batch_size: 200.0 + train_bucket_size: 8192 + training_step: 800000 + optim_flat_percent: 0.0 + warmup_step: 50 + adam_beta1: 0.9 + adam_beta2: 0.98 + ffd_size: 3072 + hidden_size: 768 + enc_nlayers: 6 + dec_nlayers: 6 + nheads: 12 + ar_layer: 4 + ar_ffd_size: 1024 + ar_hidden_size: 256 + ar_nheads: 4 + aligner_softmax_temp: 1.0 + layer_norm_eps: 0.00001 + speaker_embed_dropout: 0.05 + label_smoothing: 0.0 + val_check_interval: 5000 + check_val_every_n_epoch: 1 + precision: "fp16" + nworkers: 16 + distributed: true + accelerator: "ddp" + version: null + accumulate_grad_batches: 1 + use_repetition_token: true + use_repetition_gating: false + repetition_penalty: 1.0 + sampling_temperature: 1.0 + top_k: -1 + min_top_k: 3 + top_p: 0.8 + sample_num: 4 + length_penalty_max_length: 15000 + length_penalty_max_prob: 0.95 + max_input_length: 2048 + max_output_length: 2000 + sample_rate: 16000 + n_codes: 1024 + n_cluster_groups: 1 + phone_context_window: 4 + phoneset_size: 1000 inference: - top_k: 5 \ No newline at end of file + top_k: 5 diff --git a/GPT_SoVITS/configs/train.yaml b/GPT_SoVITS/configs/train.yaml index a61e90d..be53335 100644 --- a/GPT_SoVITS/configs/train.yaml +++ b/GPT_SoVITS/configs/train.yaml @@ -1,32 +1,32 @@ gpu: - n_card: 1 - n_process_per_card: 2 + n_card: 1 + n_process_per_card: 2 io: - text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS - save_every_n_epoch: 1 - precision: 16-mixed - gradient_clip: 1.0 + text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS + save_every_n_epoch: 1 + precision: 16-mixed + gradient_clip: 1.0 optimizer: - lr: 0.01 - lr_init: 0.00001 - lr_end: 0.0001 - warmup_steps: 2000 - decay_steps: 40000 + lr: 0.01 + lr_init: 0.00001 + lr_end: 0.0001 + warmup_steps: 2000 + decay_steps: 40000 data: - max_eval_sample: 8 - max_sec: 54 - num_workers: 1 - pad_val: 1024 # same with EOS in model + max_eval_sample: 8 + max_sec: 54 + num_workers: 1 + pad_val: 1024 # same with EOS in model model: - vocab_size: 1025 - phoneme_vocab_size: 512 - embedding_dim: 512 - hidden_dim: 512 - head: 16 - linear_units: 2048 - n_layer: 24 - dropout: 0 - EOS: 1024 - random_bert: 0 + vocab_size: 1025 + phoneme_vocab_size: 512 + embedding_dim: 512 + hidden_dim: 512 + head: 16 + linear_units: 2048 + n_layer: 24 + dropout: 0 + EOS: 1024 + random_bert: 0 inference: - top_k: 5 \ No newline at end of file + top_k: 5 diff --git a/GPT_SoVITS/feature_extractor/cnhubert.py b/GPT_SoVITS/feature_extractor/cnhubert.py index 048dc85..dc155bd 100644 --- a/GPT_SoVITS/feature_extractor/cnhubert.py +++ b/GPT_SoVITS/feature_extractor/cnhubert.py @@ -11,23 +11,30 @@ logging.getLogger("numba").setLevel(logging.WARNING) from transformers import ( Wav2Vec2FeatureExtractor, HubertModel, - Wav2Vec2Model, ) import utils import torch.nn as nn -cnhubert_base_path=None +cnhubert_base_path = None + + class CNHubert(nn.Module): def __init__(self): super().__init__() self.model = HubertModel.from_pretrained(cnhubert_base_path) - self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path) + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + cnhubert_base_path + ) + def forward(self, x): - input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) + input_values = self.feature_extractor( + x, return_tensors="pt", sampling_rate=16000 + ).input_values.to(x.device) feats = self.model(input_values)["last_hidden_state"] return feats + # class CNHubertLarge(nn.Module): # def __init__(self): # super().__init__() @@ -59,12 +66,12 @@ class CNHubert(nn.Module): # return feats - def get_model(): model = CNHubert() model.eval() return model + # def get_large_model(): # model = CNHubertLarge() # model.eval() @@ -80,18 +87,18 @@ def get_model(): # model.eval() # return model + def get_content(hmodel, wav_16k_tensor): with torch.no_grad(): feats = hmodel(wav_16k_tensor) - return feats.transpose(1,2) + return feats.transpose(1, 2) -if __name__ == '__main__': +if __name__ == "__main__": model = get_model() src_path = "/Users/Shared/原音频2.wav" wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000) model = model wav_16k_tensor = wav_16k_tensor - feats = get_content(model,wav_16k_tensor) + feats = get_content(model, wav_16k_tensor) print(feats.shape) - diff --git a/GPT_SoVITS/feature_extractor/whisper_enc.py b/GPT_SoVITS/feature_extractor/whisper_enc.py index 023f751..983c3e4 100644 --- a/GPT_SoVITS/feature_extractor/whisper_enc.py +++ b/GPT_SoVITS/feature_extractor/whisper_enc.py @@ -3,20 +3,23 @@ import torch def get_model(): import whisper - model = whisper.load_model("small", device='cpu') + + model = whisper.load_model("small", device="cpu") return model.encoder def get_content(model=None, wav_16k_tensor=None): from whisper import log_mel_spectrogram, pad_or_trim + dev = next(model.parameters()).device mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000] # if torch.cuda.is_available(): # mel = mel.to(torch.float16) feature_len = mel.shape[-1] // 2 - assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频" + assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频" with torch.no_grad(): - feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1,2) + feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[ + :1, :feature_len, : + ].transpose(1, 2) return feature - diff --git a/GPT_SoVITS/module/attentions.py b/GPT_SoVITS/module/attentions.py index 07672e2..a2e9e51 100644 --- a/GPT_SoVITS/module/attentions.py +++ b/GPT_SoVITS/module/attentions.py @@ -4,315 +4,432 @@ from torch import nn from torch.nn import functional as F from module import commons -from module. modules import LayerNorm - +from module.modules import LayerNorm + class Encoder(nn.Module): - def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,isflow=False, **kwargs): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.window_size = window_size + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + isflow=False, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size - self.drop = nn.Dropout(p_dropout) - self.attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): - self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - if isflow: - cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1) - self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) - self.cond_layer = weight_norm_modules(cond_layer, name='weight') - self.gin_channels = kwargs["gin_channels"] - def forward(self, x, x_mask, g=None): - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - if g is not None: - g = self.cond_layer(g) + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + if isflow: + cond_layer = torch.nn.Conv1d( + kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1 + ) + self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) + self.cond_layer = weight_norm_modules(cond_layer, name="weight") + self.gin_channels = kwargs["gin_channels"] - for i in range(self.n_layers): - if g is not None: - x = self.cond_pre(x) - cond_offset = i * 2 * self.hidden_channels - g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] - x = commons.fused_add_tanh_sigmoid_multiply( - x, - g_l, - torch.IntTensor([self.hidden_channels])) - y = self.attn_layers[i](x, x, attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) + def forward(self, x, x_mask, g=None): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + if g is not None: + g = self.cond_layer(g) - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x + for i in range(self.n_layers): + if g is not None: + x = self.cond_pre(x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + x = commons.fused_add_tanh_sigmoid_multiply( + x, g_l, torch.IntTensor([self.hidden_channels]) + ) + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x class Decoder(nn.Module): - def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init - self.drop = nn.Dropout(p_dropout) - self.self_attn_layers = nn.ModuleList() - self.norm_layers_0 = nn.ModuleList() - self.encdec_attn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_2 = nn.ModuleList() - for i in range(self.n_layers): - self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) - self.norm_layers_0.append(LayerNorm(hidden_channels)) - self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) - self.norm_layers_2.append(LayerNorm(hidden_channels)) + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask, h, h_mask): - """ - x: decoder input - h: encoder output - """ - self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) - encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - x = x * x_mask - for i in range(self.n_layers): - y = self.self_attn_layers[i](x, x, self_attn_mask) - y = self.drop(y) - x = self.norm_layers_0[i](x + y) + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) - y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x class MultiHeadAttention(nn.Module): - def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): - super().__init__() - assert channels % n_heads == 0 + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 - self.channels = channels - self.out_channels = out_channels - self.n_heads = n_heads - self.p_dropout = p_dropout - self.window_size = window_size - self.heads_share = heads_share - self.block_length = block_length - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - self.attn = None + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None - self.k_channels = channels // n_heads - self.conv_q = nn.Conv1d(channels, channels, 1) - self.conv_k = nn.Conv1d(channels, channels, 1) - self.conv_v = nn.Conv1d(channels, channels, 1) - self.conv_o = nn.Conv1d(channels, out_channels, 1) - self.drop = nn.Dropout(p_dropout) + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) - if window_size is not None: - n_heads_rel = 1 if heads_share else n_heads - rel_stddev = self.k_channels**-0.5 - self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) - self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) - nn.init.xavier_uniform_(self.conv_q.weight) - nn.init.xavier_uniform_(self.conv_k.weight) - nn.init.xavier_uniform_(self.conv_v.weight) - if proximal_init: - with torch.no_grad(): - self.conv_k.weight.copy_(self.conv_q.weight) - self.conv_k.bias.copy_(self.conv_q.bias) - - def forward(self, x, c, attn_mask=None): - q = self.conv_q(x) - k = self.conv_k(c) - v = self.conv_v(c) - - x, self.attn = self.attention(q, k, v, mask=attn_mask) + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) - x = self.conv_o(x) - return x + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) - def attention(self, query, key, value, mask=None): - # reshape [b, d, t] -> [b, n_h, t, d_k] - b, d, t_s, t_t = (*key.size(), query.size(2)) - query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) - key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + x, self.attn = self.attention(q, k, v, mask=attn_mask) - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) - if self.window_size is not None: - assert t_s == t_t, "Relative attention is only available for self-attention." - key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) - scores_local = self._relative_position_to_absolute_position(rel_logits) - scores = scores + scores_local - if self.proximal_bias: - assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e4) - if self.block_length is not None: - assert t_s == t_t, "Local attention is only available for self-attention." - block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) - scores = scores.masked_fill(block_mask == 0, -1e4) - p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] - p_attn = self.drop(p_attn) - output = torch.matmul(p_attn, value) - if self.window_size is not None: - relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) - output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) - output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] - return output, p_attn + x = self.conv_o(x) + return x - def _matmul_with_relative_values(self, x, y): - """ - x: [b, h, l, m] - y: [h or 1, m, d] - ret: [b, h, l, d] - """ - ret = torch.matmul(x, y.unsqueeze(0)) - return ret + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - def _matmul_with_relative_keys(self, x, y): - """ - x: [b, h, l, d] - y: [h or 1, m, d] - ret: [b, h, l, m] - """ - ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) - return ret + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn - def _get_relative_embeddings(self, relative_embeddings, length): - max_relative_position = 2 * self.window_size + 1 - # Pad first before slice to avoid using cond ops. - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) - slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, - commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) - else: - padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] - return used_relative_embeddings + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret - def _relative_position_to_absolute_position(self, x): - """ - x: [b, h, l, 2*l-1] - ret: [b, h, l, l] - """ - batch, heads, length, _ = x.size() - # Concat columns of pad to shift from relative to absolute indexing. - x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret - # Concat extra elements so to add up to shape (len+1, 2*len-1). - x_flat = x.view([batch, heads, length * 2 * length]) - x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings - # Reshape and slice out the padded elements. - x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] - return x_final + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) - def _absolute_position_to_relative_position(self, x): - """ - x: [b, h, l, l] - ret: [b, h, l, 2*l-1] - """ - batch, heads, length, _ = x.size() - # padd along column - x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) - x_flat = x.view([batch, heads, length**2 + length*(length -1)]) - # add 0's in the beginning that will skew the elements after reshape - x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) - x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] - return x_final + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) - def _attention_bias_proximal(self, length): - """Bias for self-attention to encourage attention to close positions. - Args: - length: an integer scalar. - Returns: - a Tensor with shape [1, 1, length, length] - """ - r = torch.arange(length, dtype=torch.float32) - diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) class FFN(nn.Module): - def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.activation = activation - self.causal = causal + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal - if causal: - self.padding = self._causal_padding - else: - self.padding = self._same_padding + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) - self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) - self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) - def forward(self, x, x_mask): - x = self.conv_1(self.padding(x * x_mask)) - if self.activation == "gelu": - x = x * torch.sigmoid(1.702 * x) - else: - x = torch.relu(x) - x = self.drop(x) - x = self.conv_2(self.padding(x * x_mask)) - return x * x_mask - - def _causal_padding(self, x): - if self.kernel_size == 1: - return x - pad_l = self.kernel_size - 1 - pad_r = 0 - padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, commons.convert_pad_shape(padding)) - return x + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask - def _same_padding(self, x): - if self.kernel_size == 1: - return x - pad_l = (self.kernel_size - 1) // 2 - pad_r = self.kernel_size // 2 - padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, commons.convert_pad_shape(padding)) - return x + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x import torch.nn as nn @@ -320,195 +437,273 @@ from torch.nn.utils import remove_weight_norm, weight_norm class Depthwise_Separable_Conv1D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - bias=True, - padding_mode='zeros', # TODO: refine this type - device=None, - dtype=None - ): - super().__init__() - self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, - groups=in_channels, stride=stride, padding=padding, dilation=dilation, bias=bias, - padding_mode=padding_mode, device=device, dtype=dtype) - self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, - device=device, dtype=dtype) + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=True, + padding_mode="zeros", # TODO: refine this type + device=None, + dtype=None, + ): + super().__init__() + self.depth_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + self.point_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=bias, + device=device, + dtype=dtype, + ) - def forward(self, input): - return self.point_conv(self.depth_conv(input)) + def forward(self, input): + return self.point_conv(self.depth_conv(input)) - def weight_norm(self): - self.depth_conv = weight_norm(self.depth_conv, name='weight') - self.point_conv = weight_norm(self.point_conv, name='weight') + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name="weight") + self.point_conv = weight_norm(self.point_conv, name="weight") - def remove_weight_norm(self): - self.depth_conv = remove_weight_norm(self.depth_conv, name='weight') - self.point_conv = remove_weight_norm(self.point_conv, name='weight') + def remove_weight_norm(self): + self.depth_conv = remove_weight_norm(self.depth_conv, name="weight") + self.point_conv = remove_weight_norm(self.point_conv, name="weight") class Depthwise_Separable_TransposeConv1D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - output_padding=0, - bias=True, - dilation=1, - padding_mode='zeros', # TODO: refine this type - device=None, - dtype=None - ): - super().__init__() - self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, - groups=in_channels, stride=stride, output_padding=output_padding, - padding=padding, dilation=dilation, bias=bias, padding_mode=padding_mode, - device=device, dtype=dtype) - self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, - device=device, dtype=dtype) + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + bias=True, + dilation=1, + padding_mode="zeros", # TODO: refine this type + device=None, + dtype=None, + ): + super().__init__() + self.depth_conv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + output_padding=output_padding, + padding=padding, + dilation=dilation, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + self.point_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=bias, + device=device, + dtype=dtype, + ) - def forward(self, input): - return self.point_conv(self.depth_conv(input)) + def forward(self, input): + return self.point_conv(self.depth_conv(input)) - def weight_norm(self): - self.depth_conv = weight_norm(self.depth_conv, name='weight') - self.point_conv = weight_norm(self.point_conv, name='weight') + def weight_norm(self): + self.depth_conv = weight_norm(self.depth_conv, name="weight") + self.point_conv = weight_norm(self.point_conv, name="weight") - def remove_weight_norm(self): - remove_weight_norm(self.depth_conv, name='weight') - remove_weight_norm(self.point_conv, name='weight') + def remove_weight_norm(self): + remove_weight_norm(self.depth_conv, name="weight") + remove_weight_norm(self.point_conv, name="weight") -def weight_norm_modules(module, name='weight', dim=0): - if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): - module.weight_norm() - return module - else: - return weight_norm(module, name, dim) +def weight_norm_modules(module, name="weight", dim=0): + if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( + module, Depthwise_Separable_TransposeConv1D + ): + module.weight_norm() + return module + else: + return weight_norm(module, name, dim) -def remove_weight_norm_modules(module, name='weight'): - if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): - module.remove_weight_norm() - else: - remove_weight_norm(module, name) +def remove_weight_norm_modules(module, name="weight"): + if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( + module, Depthwise_Separable_TransposeConv1D + ): + module.remove_weight_norm() + else: + remove_weight_norm(module, name) class FFT(nn.Module): - def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., - proximal_bias=False, proximal_init=True, isflow = False, **kwargs): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.proximal_bias = proximal_bias - self.proximal_init = proximal_init - if isflow: - cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1) - self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) - self.cond_layer = weight_norm_modules(cond_layer, name='weight') - self.gin_channels = kwargs["gin_channels"] - self.drop = nn.Dropout(p_dropout) - self.self_attn_layers = nn.ModuleList() - self.norm_layers_0 = nn.ModuleList() - self.ffn_layers = nn.ModuleList() - self.norm_layers_1 = nn.ModuleList() - for i in range(self.n_layers): - self.self_attn_layers.append( - MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, - proximal_init=proximal_init)) - self.norm_layers_0.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) - self.norm_layers_1.append(LayerNorm(hidden_channels)) + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers=1, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + isflow=False, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + if isflow: + cond_layer = torch.nn.Conv1d( + kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1 + ) + self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) + self.cond_layer = weight_norm_modules(cond_layer, name="weight") + self.gin_channels = kwargs["gin_channels"] + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask, g = None): - """ - x: decoder input - h: encoder output - """ - if g is not None: - g = self.cond_layer(g) + def forward(self, x, x_mask, g=None): + """ + x: decoder input + h: encoder output + """ + if g is not None: + g = self.cond_layer(g) - self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) - x = x * x_mask - for i in range(self.n_layers): - if g is not None: - x = self.cond_pre(x) - cond_offset = i * 2 * self.hidden_channels - g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] - x = commons.fused_add_tanh_sigmoid_multiply( - x, - g_l, - torch.IntTensor([self.hidden_channels])) - y = self.self_attn_layers[i](x, x, self_attn_mask) - y = self.drop(y) - x = self.norm_layers_0[i](x + y) - - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - x = x * x_mask - return x + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + x = x * x_mask + for i in range(self.n_layers): + if g is not None: + x = self.cond_pre(x) + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + x = commons.fused_add_tanh_sigmoid_multiply( + x, g_l, torch.IntTensor([self.hidden_channels]) + ) + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + x = x * x_mask + return x class TransformerCouplingLayer(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - n_layers, - n_heads, - p_dropout=0, - filter_channels=0, - mean_only=False, - wn_sharing_parameter=None, - gin_channels = 0 - ): - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only + def __init__( + self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels=0, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = ( + Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + isflow=True, + gin_channels=gin_channels, + ) + if wn_sharing_parameter is None + else wn_sharing_parameter + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels]*2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels]*2, 1) - else: - m = stats - logs = torch.zeros_like(m) + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1,2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x \ No newline at end of file + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x diff --git a/GPT_SoVITS/module/commons.py b/GPT_SoVITS/module/commons.py index 7c9b028..e96cf92 100644 --- a/GPT_SoVITS/module/commons.py +++ b/GPT_SoVITS/module/commons.py @@ -1,189 +1,189 @@ import math -import numpy as np import torch -from torch import nn from torch.nn import functional as F def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) def get_padding(kernel_size, dilation=1): - return int((kernel_size*dilation - dilation)/2) + return int((kernel_size * dilation - dilation) / 2) def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) - return kl + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str -def get_timing_signal_1d( - length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = ( - math.log(float(max_timescale) / float(min_timescale)) / - (num_timescales - 1)) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - device = duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2,3) * mask - return path + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1. / norm_type) - return total_norm + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm def squeeze(x, x_mask=None, n_sqz=2): - b, c, t = x.size() + b, c, t = x.size() - t = (t // n_sqz) * n_sqz - x = x[:, :, :t] - x_sqz = x.view(b, c, t // n_sqz, n_sqz) - x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) + t = (t // n_sqz) * n_sqz + x = x[:, :, :t] + x_sqz = x.view(b, c, t // n_sqz, n_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) - if x_mask is not None: - x_mask = x_mask[:, :, n_sqz - 1::n_sqz] - else: - x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) - return x_sqz * x_mask, x_mask + if x_mask is not None: + x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz] + else: + x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) + return x_sqz * x_mask, x_mask def unsqueeze(x, x_mask=None, n_sqz=2): - b, c, t = x.size() + b, c, t = x.size() - x_unsqz = x.view(b, n_sqz, c // n_sqz, t) - x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) + x_unsqz = x.view(b, n_sqz, c // n_sqz, t) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) - if x_mask is not None: - x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) - else: - x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) - return x_unsqz * x_mask, x_mask + if x_mask is not None: + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) + else: + x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) + return x_unsqz * x_mask, x_mask diff --git a/GPT_SoVITS/module/core_vq.py b/GPT_SoVITS/module/core_vq.py index 9121f3a..a5e22d6 100644 --- a/GPT_SoVITS/module/core_vq.py +++ b/GPT_SoVITS/module/core_vq.py @@ -76,10 +76,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): print("kmeans start ... ") for _ in tqdm(range(num_iters)): - diffs = rearrange(samples, "n d -> n () d") - rearrange( - means, "c d -> () c d" - ) - dists = -(diffs ** 2).sum(dim=-1) + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) @@ -110,6 +108,7 @@ class EuclideanCodebook(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dim: int, @@ -122,7 +121,9 @@ class EuclideanCodebook(nn.Module): ): super().__init__() self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size @@ -147,7 +148,7 @@ class EuclideanCodebook(nn.Module): self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) # Make sure all buffers across workers are in sync after initialization - #broadcast_tensors(self.buffers()) + # broadcast_tensors(self.buffers()) def replace_(self, samples, mask): modified_codebook = torch.where( @@ -165,7 +166,7 @@ class EuclideanCodebook(nn.Module): batch_samples = rearrange(batch_samples, "... d -> (...) d") self.replace_(batch_samples, mask=expired_codes) - #broadcast_tensors(self.buffers()) + # broadcast_tensors(self.buffers()) def preprocess(self, x): x = rearrange(x, "... d -> (...) d") @@ -246,6 +247,7 @@ class VectorQuantization(nn.Module): randomly selected vector from the current batch. commitment_weight (float): Weight for commitment loss. """ + def __init__( self, dim: int, @@ -256,22 +258,31 @@ class VectorQuantization(nn.Module): kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, - commitment_weight: float = 1., + commitment_weight: float = 1.0, ): super().__init__() _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) self.epsilon = epsilon self.commitment_weight = commitment_weight - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) self.codebook_size = codebook_size @property @@ -316,13 +327,16 @@ class ResidualVectorQuantization(nn.Module): """Residual vector quantization implementation. Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ + def __init__(self, *, num_quantizers, **kwargs): super().__init__() self.layers = nn.ModuleList( [VectorQuantization(**kwargs) for _ in range(num_quantizers)] ) - def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None): + def forward( + self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None + ): quantized_out = 0.0 residual = x @@ -345,7 +359,9 @@ class ResidualVectorQuantization(nn.Module): out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses, out_quantized - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor: + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: residual = x all_indices = [] n_q = n_q or len(self.layers) @@ -358,10 +374,10 @@ class ResidualVectorQuantization(nn.Module): out_indices = torch.stack(all_indices) return out_indices - def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor: + def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: quantized_out = torch.tensor(0.0, device=q_indices.device) for i, indices in enumerate(q_indices): layer = self.layers[st + i] quantized = layer.decode(indices) quantized_out = quantized_out + quantized - return quantized_out \ No newline at end of file + return quantized_out diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py index ea3fe77..15f401d 100644 --- a/GPT_SoVITS/module/data_utils.py +++ b/GPT_SoVITS/module/data_utils.py @@ -1,6 +1,6 @@ -import time,logging +import time, logging import os -import random,traceback +import random, traceback import numpy as np import torch import torch.utils.data @@ -16,41 +16,44 @@ import torch import requests from scipy.io import wavfile from io import BytesIO + # from config import exp_dir from my_utils import load_audio + class TextAudioSpeakerLoader(torch.utils.data.Dataset): """ - 1) loads audio, speaker_id, text pairs - 2) normalizes text and converts them to sequences of integers - 3) computes spectrograms from audio files. + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. """ def __init__(self, hparams, val=False): - exp_dir=hparams.exp_dir - self.path2="%s/2-name2text.txt"%exp_dir - self.path4="%s/4-cnhubert"%exp_dir - self.path5="%s/5-wav32k"%exp_dir + exp_dir = hparams.exp_dir + self.path2 = "%s/2-name2text.txt" % exp_dir + self.path4 = "%s/4-cnhubert" % exp_dir + self.path5 = "%s/5-wav32k" % exp_dir assert os.path.exists(self.path2) assert os.path.exists(self.path4) assert os.path.exists(self.path5) - names4=set([name[:-3]for name in list(os.listdir(self.path4))])#去除.pt后缀 - names5=set(os.listdir(self.path5)) - self.phoneme_data={} - with open(self.path2,"r",encoding="utf8")as f: - lines=f.read().strip("\n").split("\n") + names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 + names5 = set(os.listdir(self.path5)) + self.phoneme_data = {} + with open(self.path2, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") for line in lines: - tmp=line.split("\t") - if(len(tmp)!=4):continue - self.phoneme_data[tmp[0]]=[tmp[1]] + tmp = line.split("\t") + if len(tmp) != 4: + continue + self.phoneme_data[tmp[0]] = [tmp[1]] - self.audiopaths_sid_text=list(set(self.phoneme_data)&names4&names5) - tmp=self.audiopaths_sid_text - leng=len(tmp) - min_num=100 - if(leng duration > 0.6 or self.val): + if 54 > duration > 0.6 or self.val: audiopaths_sid_text_new.append([audiopath, phoneme_ids]) lengths.append(size // (2 * self.hop_length)) else: @@ -90,7 +93,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): continue print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) print("total left: ", len(audiopaths_sid_text_new)) - assert len(audiopaths_sid_text_new)>1#至少能凑够batch size,这里todo + assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo self.audiopaths_sid_text = audiopaths_sid_text_new self.lengths = lengths @@ -98,30 +101,41 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): audiopath, phoneme_ids = audiopath_sid_text text = torch.FloatTensor(phoneme_ids) try: - spec, wav = self.get_audio("%s/%s"%(self.path5,audiopath)) + spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath)) with torch.no_grad(): - ssl = torch.load("%s/%s.pt"%(self.path4,audiopath),map_location="cpu") - if(ssl.shape[-1]!=spec.shape[-1]): - typee=ssl.dtype - ssl=F.pad(ssl.float(),(0,1),mode="replicate").to(typee) - ssl.requires_grad=False + ssl = torch.load( + "%s/%s.pt" % (self.path4, audiopath), map_location="cpu" + ) + if ssl.shape[-1] != spec.shape[-1]: + typee = ssl.dtype + ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) + ssl.requires_grad = False except: traceback.print_exc() spec = torch.zeros(1025, 100) - wav = torch.zeros(1, 100*self.hop_length) - ssl=torch.zeros(1,768,100) - text=text[-1:] + wav = torch.zeros(1, 100 * self.hop_length) + ssl = torch.zeros(1, 768, 100) + text = text[-1:] print("load audio or ssl error!!!!!!", audiopath) # print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad) return (ssl, spec, wav, text) def get_audio(self, filename): - audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768 + audio_array = load_audio( + filename, self.sampling_rate + ) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768 # print(filename,audio_array.max(),audio_array.min(),audio_array.mean()) - audio=torch.FloatTensor(audio_array)#/32768 + audio = torch.FloatTensor(audio_array) # /32768 audio_norm = audio audio_norm = audio_norm.unsqueeze(0) - spec = spectrogram_torch(audio_norm, self.filter_length,self.sampling_rate, self.hop_length, self.win_length,center=False) + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) spec = torch.squeeze(spec, 0) return spec, audio_norm @@ -131,39 +145,51 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): def __getitem__(self, index): # with torch.no_grad(): - return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) + return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) def __len__(self): return len(self.audiopaths_sid_text) def random_slice(self, ssl, wav, mel): - assert abs(ssl.shape[-1]- wav.shape[-1]//self.hop_length) < 3, ("first", ssl.shape, wav.shape) + assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ( + "first", + ssl.shape, + wav.shape, + ) len_mel = mel.shape[1] if self.val: - reference_mel = mel[:, :len_mel//3] + reference_mel = mel[:, : len_mel // 3] return reference_mel, ssl, wav, mel dir = random.randint(0, 1) - sep_point = random.randint(int(len_mel//3), int(len_mel//3*2)) + sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2)) if dir == 0: reference_mel = mel[:, :sep_point] ssl = ssl[:, :, sep_point:] - wav2 = wav[:, sep_point*self.hop_length:] + wav2 = wav[:, sep_point * self.hop_length :] mel = mel[:, sep_point:] else: reference_mel = mel[:, sep_point:] ssl = ssl[:, :, :sep_point] - wav2 = wav[:, :sep_point*self.hop_length] + wav2 = wav[:, : sep_point * self.hop_length] mel = mel[:, :sep_point] - assert abs(ssl.shape[-1]- wav2.shape[-1]//self.hop_length) < 3, (ssl.shape, wav.shape,wav2.shape, mel.shape, sep_point,self.hop_length, sep_point*self.hop_length, dir) + assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, ( + ssl.shape, + wav.shape, + wav2.shape, + mel.shape, + sep_point, + self.hop_length, + sep_point * self.hop_length, + dir, + ) return reference_mel, ssl, wav2, mel -class TextAudioSpeakerCollate(): - """ Zero-pads model inputs and targets - """ +class TextAudioSpeakerCollate: + """Zero-pads model inputs and targets""" def __init__(self, return_ids=False): self.return_ids = return_ids @@ -176,8 +202,8 @@ class TextAudioSpeakerCollate(): """ # Right zero-pad all one-hot text sequences to max input length _, ids_sorted_decreasing = torch.sort( - torch.LongTensor([x[1].size(1) for x in batch]), - dim=0, descending=True) + torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True + ) max_ssl_len = max([x[0].size(2) for x in batch]) max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) @@ -194,7 +220,7 @@ class TextAudioSpeakerCollate(): spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) - text_padded = torch.LongTensor(len(batch), max_text_len) + text_padded = torch.LongTensor(len(batch), max_text_len) spec_padded.zero_() wav_padded.zero_() @@ -205,23 +231,31 @@ class TextAudioSpeakerCollate(): row = batch[ids_sorted_decreasing[i]] ssl = row[0] - ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :] + ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :] ssl_lengths[i] = ssl.size(2) spec = row[1] - spec_padded[i, :, :spec.size(1)] = spec + spec_padded[i, :, : spec.size(1)] = spec spec_lengths[i] = spec.size(1) wav = row[2] - wav_padded[i, :, :wav.size(1)] = wav + wav_padded[i, :, : wav.size(1)] = wav wav_lengths[i] = wav.size(1) text = row[3] - text_padded[i, :text.size(0)] = text + text_padded[i, : text.size(0)] = text text_lengths[i] = text.size(0) - - return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths + return ( + ssl_padded, + ssl_lengths, + spec_padded, + spec_lengths, + wav_padded, + wav_lengths, + text_padded, + text_lengths, + ) class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): @@ -234,7 +268,15 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. """ - def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): + def __init__( + self, + dataset, + batch_size, + boundaries, + num_replicas=None, + rank=None, + shuffle=True, + ): super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) self.lengths = dataset.lengths # print(233333333333333,self.lengths,dir(dataset)) @@ -254,7 +296,7 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): buckets[idx_bucket].append(i) for i in range(len(buckets) - 1, 0, -1): - # for i in range(len(buckets) - 1, -1, -1): + # for i in range(len(buckets) - 1, -1, -1): if len(buckets[i]) == 0: buckets.pop(i) self.boundaries.pop(i + 1) @@ -263,7 +305,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): for i in range(len(buckets)): len_bucket = len(buckets[i]) total_batch_size = self.num_replicas * self.batch_size - rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size + rem = ( + total_batch_size - (len_bucket % total_batch_size) + ) % total_batch_size num_samples_per_bucket.append(len_bucket + rem) return buckets, num_samples_per_bucket @@ -289,14 +333,23 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): # add extra samples to make it evenly divisible rem = num_samples_bucket - len_bucket - ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] + ids_bucket = ( + ids_bucket + + ids_bucket * (rem // len_bucket) + + ids_bucket[: (rem % len_bucket)] + ) # subsample - ids_bucket = ids_bucket[self.rank::self.num_replicas] + ids_bucket = ids_bucket[self.rank :: self.num_replicas] # batching for j in range(len(ids_bucket) // self.batch_size): - batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]] + batch = [ + bucket[idx] + for idx in ids_bucket[ + j * self.batch_size : (j + 1) * self.batch_size + ] + ] batches.append(batch) if self.shuffle: diff --git a/GPT_SoVITS/module/losses.py b/GPT_SoVITS/module/losses.py index 50fdf85..b23fc8c 100644 --- a/GPT_SoVITS/module/losses.py +++ b/GPT_SoVITS/module/losses.py @@ -5,64 +5,69 @@ from torch.nn import functional as F def feature_loss(fmap_r, fmap_g): - loss = 0 - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - rl = rl.float().detach() - gl = gl.float() - loss += torch.mean(torch.abs(rl - gl)) + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 + return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): - loss = 0 - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - dr = dr.float() - dg = dg.float() - r_loss = torch.mean((1-dr)**2) - g_loss = torch.mean(dg**2) - loss += (r_loss + g_loss) - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) - return loss, r_losses, g_losses + return loss, r_losses, g_losses def generator_loss(disc_outputs): - loss = 0 - gen_losses = [] - for dg in disc_outputs: - dg = dg.float() - l = torch.mean((1-dg)**2) - gen_losses.append(l) - loss += l + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l - return loss, gen_losses + return loss, gen_losses def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): - """ - z_p, logs_q: [b, h, t_t] - m_p, logs_p: [b, h, t_t] - """ - z_p = z_p.float() - logs_q = logs_q.float() - m_p = m_p.float() - logs_p = logs_p.float() - z_mask = z_mask.float() + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l - kl = logs_p - logs_q - 0.5 - kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) - kl = torch.sum(kl * z_mask) - l = kl / torch.sum(z_mask) - return l def mle_loss(z, m, logs, logdet, mask): - l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term - l = l - torch.sum(logdet) # log jacobian determinant - l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes - l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term - return l \ No newline at end of file + l = torch.sum(logs) + 0.5 * torch.sum( + torch.exp(-2 * logs) * ((z - m) ** 2) + ) # neg normal likelihood w/o the constant term + l = l - torch.sum(logdet) # log jacobian determinant + l = l / torch.sum( + torch.ones_like(z) * mask + ) # averaging across batch, channel and time axes + l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term + return l diff --git a/GPT_SoVITS/module/mel_processing.py b/GPT_SoVITS/module/mel_processing.py index 0ef5608..503825e 100644 --- a/GPT_SoVITS/module/mel_processing.py +++ b/GPT_SoVITS/module/mel_processing.py @@ -49,21 +49,37 @@ hann_window = {} def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) global hann_window - dtype_device = str(y.dtype) + '_' + str(y.device) - wnsize_dtype_device = str(win_size) + '_' + dtype_device + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) y = y.squeeze(1) - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -71,37 +87,63 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): global mel_basis - dtype_device = str(spec.dtype) + '_' + str(spec.device) - fmax_dtype_device = str(fmax) + '_' + dtype_device + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = spectral_normalize_torch(spec) return spec -def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) global mel_basis, hann_window - dtype_device = str(y.dtype) + '_' + str(y.device) - fmax_dtype_device = str(fmax) + '_' + dtype_device - wnsize_dtype_device = str(win_size) + '_' + dtype_device + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=y.dtype, device=y.device + ) if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) y = y.squeeze(1) - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], - center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 2361b64..c99485c 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -12,12 +12,21 @@ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from module.commons import init_weights, get_padding from module.mrte_model import MRTE -from module.quantize import ResidualVectorQuantizer +from module.quantize import ResidualVectorQuantizer from text import symbols from torch.cuda.amp import autocast + class StochasticDurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + p_dropout, + n_flows=4, + gin_channels=0, + ): super().__init__() filter_channels = in_channels # it needs to be removed from future version. self.in_channels = in_channels @@ -31,21 +40,29 @@ class StochasticDurationPredictor(nn.Module): self.flows = nn.ModuleList() self.flows.append(modules.ElementwiseAffine(2)) for i in range(n_flows): - self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) self.flows.append(modules.Flip()) self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) self.post_flows = nn.ModuleList() self.post_flows.append(modules.ElementwiseAffine(2)) for i in range(4): - self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append( + modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) self.post_flows.append(modules.Flip()) self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.convs = modules.DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1) @@ -66,7 +83,10 @@ class StochasticDurationPredictor(nn.Module): h_w = self.post_pre(w) h_w = self.post_convs(h_w, x_mask) h_w = self.post_proj(h_w) * x_mask - e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + e_q = ( + torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) + * x_mask + ) z_q = e_q for flow in self.post_flows: z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) @@ -74,8 +94,13 @@ class StochasticDurationPredictor(nn.Module): z_u, z1 = torch.split(z_q, [1, 1], 1) u = torch.sigmoid(z_u) * x_mask z0 = (w - u) * x_mask - logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) - logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) logdet_tot = 0 z0, logdet = self.log_flow(z0, x_mask) @@ -84,12 +109,18 @@ class StochasticDurationPredictor(nn.Module): for flow in flows: z, logdet = flow(z, x_mask, g=x, reverse=reverse) logdet_tot = logdet_tot + logdet - nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) return nll + logq # [b] else: flows = list(reversed(self.flows)) flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + z = ( + torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) + * noise_scale + ) for flow in flows: z = flow(z, x_mask, g=x, reverse=reverse) z0, z1 = torch.split(z, [1, 1], 1) @@ -98,7 +129,9 @@ class StochasticDurationPredictor(nn.Module): class DurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): super().__init__() self.in_channels = in_channels @@ -108,9 +141,13 @@ class DurationPredictor(nn.Module): self.gin_channels = gin_channels self.drop = nn.Dropout(p_dropout) - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) self.norm_1 = modules.LayerNorm(filter_channels) - self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) self.norm_2 = modules.LayerNorm(filter_channels) self.proj = nn.Conv1d(filter_channels, 1, 1) @@ -135,15 +172,17 @@ class DurationPredictor(nn.Module): class TextEncoder(nn.Module): - def __init__(self, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - latent_channels=192): + def __init__( + self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + latent_channels=192, + ): super().__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels @@ -160,17 +199,14 @@ class TextEncoder(nn.Module): hidden_channels, filter_channels, n_heads, - n_layers//2, + n_layers // 2, kernel_size, - p_dropout) + p_dropout, + ) self.encoder_text = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.mrte = MRTE() @@ -179,21 +215,25 @@ class TextEncoder(nn.Module): hidden_channels, filter_channels, n_heads, - n_layers//2, + n_layers // 2, kernel_size, - p_dropout) - + p_dropout, + ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, y, y_lengths, text, text_lengths, ge, test=None): - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( + y.dtype + ) y = self.ssl_proj(y * y_mask) * y_mask y = self.encoder_ssl(y * y_mask, y_mask) - text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype) - if test == 1 : + text_mask = torch.unsqueeze( + commons.sequence_mask(text_lengths, text.size(1)), 1 + ).to(y.dtype) + if test == 1: text[:, :] = 0 text = self.text_embedding(text).transpose(1, 2) text = self.encoder_text(text * text_mask, text_mask) @@ -208,9 +248,9 @@ class TextEncoder(nn.Module): def extract_latent(self, x): x = self.ssl_proj(x) quantized, codes, commit_loss, quantized_list = self.quantizer(x) - return codes.transpose(0,1) - def decode_latent(self, codes, y_mask, refer,refer_mask, ge): + return codes.transpose(0, 1) + def decode_latent(self, codes, y_mask, refer, refer_mask, ge): quantized = self.quantizer.decode(codes) y = self.vq_proj(quantized) * y_mask @@ -224,15 +264,18 @@ class TextEncoder(nn.Module): m, logs = torch.split(stats, self.out_channels, dim=1) return y, m, logs, y_mask, quantized + class ResidualCouplingBlock(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): super().__init__() self.channels = channels self.hidden_channels = hidden_channels @@ -245,8 +288,16 @@ class ResidualCouplingBlock(nn.Module): self.flows = nn.ModuleList() for i in range(n_flows): self.flows.append( - modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, - gin_channels=gin_channels, mean_only=True)) + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) self.flows.append(modules.Flip()) def forward(self, x, x_mask, g=None, reverse=False): @@ -260,14 +311,16 @@ class ResidualCouplingBlock(nn.Module): class PosteriorEncoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -278,13 +331,21 @@ class PosteriorEncoder(nn.Module): self.gin_channels = gin_channels self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, g=None): - if(g!=None): + if g != None: g = g.detach() - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask @@ -294,14 +355,16 @@ class PosteriorEncoder(nn.Module): class WNEncoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -312,11 +375,20 @@ class WNEncoder(nn.Module): self.gin_channels = gin_channels self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.norm = modules.LayerNorm(out_channels) + def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) out = self.proj(x) * x_mask @@ -325,24 +397,45 @@ class WNEncoder(nn.Module): class Generator(torch.nn.Module): - def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, - upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append(weight_norm( - ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), - k, u, padding=(k - u) // 2))) + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): self.resblocks.append(resblock(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) @@ -373,7 +466,7 @@ class Generator(torch.nn.Module): return x def remove_weight_norm(self): - print('Removing weight norm...') + print("Removing weight norm...") for l in self.ups: remove_weight_norm(l) for l in self.resblocks: @@ -386,13 +479,55 @@ class DiscriminatorP(torch.nn.Module): self.period = period self.use_spectral_norm = use_spectral_norm norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ]) + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): @@ -421,14 +556,16 @@ class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) def forward(self, x): @@ -451,7 +588,9 @@ class MultiPeriodDiscriminator(torch.nn.Module): periods = [2, 3, 5, 7, 11] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] self.discriminators = nn.ModuleList(discs) def forward(self, y, y_hat): @@ -469,31 +608,40 @@ class MultiPeriodDiscriminator(torch.nn.Module): return y_d_rs, y_d_gs, fmap_rs, fmap_gs + class ReferenceEncoder(nn.Module): - ''' + """ inputs --- [N, Ty/r, n_mels*r] mels outputs --- [N, ref_enc_gru_size] - ''' + """ def __init__(self, spec_channels, gin_channels=0): - super().__init__() self.spec_channels = spec_channels ref_enc_filters = [32, 32, 64, 64, 128, 128] K = len(ref_enc_filters) filters = [1] + ref_enc_filters - convs = [weight_norm(nn.Conv2d(in_channels=filters[i], - out_channels=filters[i + 1], - kernel_size=(3, 3), - stride=(2, 2), - padding=(1, 1))) for i in range(K)] + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] self.convs = nn.ModuleList(convs) # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) - self.gru = nn.GRU(input_size=ref_enc_filters[-1] * out_channels, - hidden_size=256 // 2, - batch_first=True) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) self.proj = nn.Linear(128, gin_channels) def forward(self, inputs): @@ -527,23 +675,31 @@ class Quantizer_module(torch.nn.Module): self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) def forward(self, x): - d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) - 2 * torch.matmul(x, self.embedding.weight.T) + d = ( + torch.sum(x**2, 1, keepdim=True) + + torch.sum(self.embedding.weight**2, 1) + - 2 * torch.matmul(x, self.embedding.weight.T) + ) min_indicies = torch.argmin(d, 1) z_q = self.embedding(min_indicies) return z_q, min_indicies + class Quantizer(torch.nn.Module): def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160): super(Quantizer, self).__init__() assert embed_dim % n_code_groups == 0 - self.quantizer_modules = nn.ModuleList([ - Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups) - ]) + self.quantizer_modules = nn.ModuleList( + [ + Quantizer_module(n_codes, embed_dim // n_code_groups) + for _ in range(n_code_groups) + ] + ) self.n_code_groups = n_code_groups self.embed_dim = embed_dim def forward(self, xin): - #B, C, T + # B, C, T B, C, T = xin.shape xin = xin.transpose(1, 2) x = xin.reshape(-1, self.embed_dim) @@ -553,38 +709,41 @@ class Quantizer(torch.nn.Module): for _x, m in zip(x, self.quantizer_modules): _z_q, _min_indicies = m(_x) z_q.append(_z_q) - min_indicies.append(_min_indicies) #B * T, + min_indicies.append(_min_indicies) # B * T, z_q = torch.cat(z_q, -1).reshape(xin.shape) - loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2) + loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean( + (z_q - xin.detach()) ** 2 + ) z_q = xin + (z_q - xin).detach() z_q = z_q.transpose(1, 2) codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups) return z_q, loss, codes.transpose(1, 2) def embed(self, x): - #idx: N, 4, T - x=x.transpose(1, 2) + # idx: N, 4, T + x = x.transpose(1, 2) x = torch.split(x, 1, 2) ret = [] for q, embed in zip(x, self.quantizer_modules): q = embed.embedding(q.squeeze(-1)) ret.append(q) ret = torch.cat(ret, -1) - return ret.transpose(1, 2) #N, C, T + return ret.transpose(1, 2) # N, C, T class CodePredictor(nn.Module): - def __init__(self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - n_q=8, - dims=1024, - ssl_dim=768 - ): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + n_q=8, + dims=1024, + ssl_dim=768, + ): super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -594,19 +753,18 @@ class CodePredictor(nn.Module): self.p_dropout = p_dropout self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1) - self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels) + self.ref_enc = modules.MelStyleEncoder( + ssl_dim, style_vector_dim=hidden_channels + ) self.encoder = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) + hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout + ) - self.out_proj = nn.Conv1d(hidden_channels, (n_q-1) * dims, 1) + self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1) self.n_q = n_q self.dims = dims + def forward(self, x, x_mask, refer, codes, infer=False): x = x.detach() x = self.vq_proj(x * x_mask) * x_mask @@ -614,7 +772,9 @@ class CodePredictor(nn.Module): x = x + g x = self.encoder(x * x_mask, x_mask) x = self.out_proj(x * x_mask) * x_mask - logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3) + logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose( + 2, 3 + ) target = codes[1:].transpose(0, 1) if not infer: logits = logits.reshape(-1, self.dims) @@ -626,44 +786,44 @@ class CodePredictor(nn.Module): correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1) top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item() - print('Top-10 Accuracy:', top3_acc, "%") + print("Top-10 Accuracy:", top3_acc, "%") pred_codes = torch.argmax(logits, dim=-1) acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item() - print('Top-1 Accuracy:', acc, "%") + print("Top-1 Accuracy:", acc, "%") return pred_codes.transpose(0, 1) - class SynthesizerTrn(nn.Module): """ - Synthesizer for Training - """ - - def __init__(self, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - n_speakers=0, - gin_channels=0, - use_sdp=True, - semantic_frame_rate=None, - freeze_quantizer=None, - **kwargs): + Synthesizer for Training + """ + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + semantic_frame_rate=None, + freeze_quantizer=None, + **kwargs + ): super().__init__() self.spec_channels = spec_channels self.inter_channels = inter_channels @@ -685,34 +845,50 @@ class SynthesizerTrn(nn.Module): self.use_sdp = use_sdp self.enc_p = TextEncoder( - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, - upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, - gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + ) + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels + ) - self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels) + self.ref_enc = modules.MelStyleEncoder( + spec_channels, style_vector_dim=gin_channels + ) ssl_dim = 768 - assert semantic_frame_rate in ['25hz', "50hz"] + assert semantic_frame_rate in ["25hz", "50hz"] self.semantic_frame_rate = semantic_frame_rate - if semantic_frame_rate == '25hz': + if semantic_frame_rate == "25hz": self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2) else: self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) - self.quantizer = ResidualVectorQuantizer( - dimension=ssl_dim, - n_q=1, - bins=1024 - ) + self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) if freeze_quantizer: self.ssl_proj.requires_grad_(False) self.quantizer.requires_grad_(False) @@ -721,56 +897,85 @@ class SynthesizerTrn(nn.Module): # self.enc_p.mrte.requires_grad_(False) def forward(self, ssl, y, y_lengths, text, text_lengths): - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( + y.dtype + ) ge = self.ref_enc(y * y_mask, y_mask) with autocast(enabled=False): ssl = self.ssl_proj(ssl) - quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0]) + quantized, codes, commit_loss, quantized_list = self.quantizer( + ssl, layers=[0] + ) - if self.semantic_frame_rate == '25hz': - quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate( + quantized, size=int(quantized.shape[-1] * 2), mode="nearest" + ) - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask = self.enc_p( + quantized, y_lengths, text, text_lengths, ge + ) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) - z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size + ) o = self.dec(z_slice, g=ge) - return o, commit_loss, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized + return ( + o, + commit_loss, + ids_slice, + y_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + quantized, + ) def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5): - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( + y.dtype + ) ge = self.ref_enc(y * y_mask, y_mask) - ssl = self.ssl_proj(ssl) + ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0]) - if self.semantic_frame_rate == '25hz': - quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate( + quantized, size=int(quantized.shape[-1] * 2), mode="nearest" + ) - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test) + x, m_p, logs_p, y_mask = self.enc_p( + quantized, y_lengths, text, text_lengths, ge, test=test + ) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) o = self.dec((z * y_mask)[:, :, :], g=ge) - return o,y_mask, (z, z_p, m_p, logs_p) - + return o, y_mask, (z, z_p, m_p, logs_p) @torch.no_grad() - def decode(self, codes,text, refer, noise_scale=0.5): + def decode(self, codes, text, refer, noise_scale=0.5): refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) - refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype) + refer_mask = torch.unsqueeze( + commons.sequence_mask(refer_lengths, refer.size(2)), 1 + ).to(refer.dtype) ge = self.ref_enc(refer * refer_mask, refer_mask) - y_lengths = torch.LongTensor([codes.size(2)*2]).to(codes.device) + y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) quantized = self.quantizer.decode(codes) - if self.semantic_frame_rate == '25hz': - quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") + if self.semantic_frame_rate == "25hz": + quantized = F.interpolate( + quantized, size=int(quantized.shape[-1] * 2), mode="nearest" + ) - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask = self.enc_p( + quantized, y_lengths, text, text_lengths, ge + ) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) @@ -779,6 +984,6 @@ class SynthesizerTrn(nn.Module): return o def extract_latent(self, x): - ssl = self.ssl_proj(x) + ssl = self.ssl_proj(x) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) - return codes.transpose(0,1) + return codes.transpose(0, 1) diff --git a/GPT_SoVITS/module/modules.py b/GPT_SoVITS/module/modules.py index 711cc5b..f444745 100644 --- a/GPT_SoVITS/module/modules.py +++ b/GPT_SoVITS/module/modules.py @@ -17,193 +17,282 @@ LRELU_SLOPE = 0.1 class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - class ConvReluNorm(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - assert n_layers > 1, "Number of layers should be larger than 0." + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." - self.conv_layers = nn.ModuleList() - self.norm_layers = nn.ModuleList() - self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = nn.Sequential( - nn.ReLU(), - nn.Dropout(p_dropout)) - for _ in range(n_layers-1): - self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout + """ + Dialted and Depth-Separable Convolution + """ - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, - groups=channels, dilation=dilation, padding=padding - )) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask class WN(torch.nn.Module): - def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): - super(WN, self).__init__() - assert(kernel_size % 2 == 1) - self.hidden_channels =hidden_channels - self.kernel_size = kernel_size, - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - self.p_dropout = p_dropout + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() - self.drop = nn.Dropout(p_dropout) + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) - if gin_channels != 0: - cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + if gin_channels != 0: + cond_layer = torch.nn.Conv1d( + gin_channels, 2 * hidden_channels * n_layers, 1 + ) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") - for i in range(n_layers): - dilation = dilation_rate ** i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, - dilation=dilation, padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') - self.in_layers.append(in_layer) + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) - # last one is not necessary - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels - res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') - self.res_skip_layers.append(res_skip_layer) + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) - def forward(self, x, x_mask, g=None, **kwargs): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) - if g is not None: - g = self.cond_layer(g) + if g is not None: + g = self.cond_layer(g) - for i in range(self.n_layers): - x_in = self.in_layers[i](x) - if g is not None: - cond_offset = i * 2 * self.hidden_channels - g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] - else: - g_l = torch.zeros_like(x_in) + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) - acts = commons.fused_add_tanh_sigmoid_multiply( - x_in, - g_l, - n_channels_tensor) - acts = self.drop(acts) + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - res_acts = res_skip_acts[:,:self.hidden_channels,:] - x = (x + res_acts) * x_mask - output = output + res_skip_acts[:,self.hidden_channels:,:] - else: - output = output + res_skip_acts - return output * x_mask + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask - def remove_weight_norm(self): - if self.gin_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) - for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) - for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() - self.convs1 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]))) - ]) + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) self.convs1.apply(init_weights) - self.convs2 = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1))) - ]) + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) self.convs2.apply(init_weights) def forward(self, x, x_mask=None): @@ -231,12 +320,30 @@ class ResBlock1(torch.nn.Module): class ResBlock2(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() - self.convs = nn.ModuleList([ - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))) - ]) + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) self.convs.apply(init_weights) def forward(self, x, x_mask=None): @@ -256,147 +363,169 @@ class ResBlock2(torch.nn.Module): class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels,1)) - self.logs = nn.Parameter(torch.zeros(channels,1)) + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1,2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x class ResidualCouplingLayer(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=0, - gin_channels=0, - mean_only=False): - assert channels % 2 == 0, "channels should be divisible by 2" - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.half_channels = channels // 2 - self.mean_only = mean_only + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only - self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) - self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) - self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) - self.post.weight.data.zero_() - self.post.bias.data.zero_() + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels]*2, 1) - h = self.pre(x0) * x_mask - h = self.enc(h, x_mask, g=g) - stats = self.post(h) * x_mask - if not self.mean_only: - m, logs = torch.split(stats, [self.half_channels]*2, 1) - else: - m = stats - logs = torch.zeros_like(m) + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) - if not reverse: - x1 = m + x1 * torch.exp(logs) * x_mask - x = torch.cat([x0, x1], 1) - logdet = torch.sum(logs, [1,2]) - return x, logdet - else: - x1 = (x1 - m) * torch.exp(-logs) * x_mask - x = torch.cat([x0, x1], 1) - return x + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x class ConvFlow(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.num_bins = num_bins - self.tail_bound = tail_bound - self.half_channels = in_channels // 2 + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 - self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) - self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) - self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d( + filter_channels, self.half_channels * (num_bins * 3 - 1), 1 + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels]*2, 1) - h = self.pre(x0) - h = self.convs(h, x_mask, g=g) - h = self.proj(h) * x_mask + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask - b, c, t = x0.shape - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] - unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_derivatives = h[..., 2 * self.num_bins:] + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( + self.filter_channels + ) + unnormalized_derivatives = h[..., 2 * self.num_bins :] - x1, logabsdet = piecewise_rational_quadratic_transform(x1, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=reverse, - tails='linear', - tail_bound=self.tail_bound - ) - - x = torch.cat([x0, x1], 1) * x_mask - logdet = torch.sum(logabsdet * x_mask, [1,2]) - if not reverse: - return x, logdet - else: - return x + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x class LinearNorm(nn.Module): - def __init__(self, - in_channels, - out_channels, - bias=True, - spectral_norm=False, - ): + def __init__( + self, + in_channels, + out_channels, + bias=True, + spectral_norm=False, + ): super(LinearNorm, self).__init__() self.fc = nn.Linear(in_channels, out_channels, bias) @@ -417,10 +546,10 @@ class Mish(nn.Module): class Conv1dGLU(nn.Module): - ''' + """ Conv1d + GLU(Gated Linear Unit) with residual connection. For GLU refer to https://arxiv.org/abs/1612.08083 paper. - ''' + """ def __init__(self, in_channels, out_channels, kernel_size, dropout): super(Conv1dGLU, self).__init__() @@ -438,29 +567,32 @@ class Conv1dGLU(nn.Module): class ConvNorm(nn.Module): - def __init__(self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=None, - dilation=1, - bias=True, - spectral_norm=False, - ): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + spectral_norm=False, + ): super(ConvNorm, self).__init__() if padding is None: - assert (kernel_size % 2 == 1) + assert kernel_size % 2 == 1 padding = int(dilation * (kernel_size - 1) / 2) - self.conv = torch.nn.Conv1d(in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias) + self.conv = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) if spectral_norm: self.conv = nn.utils.spectral_norm(self.conv) @@ -471,9 +603,9 @@ class ConvNorm(nn.Module): class MultiHeadAttention(nn.Module): - ''' Multi-Head Attention module ''' + """Multi-Head Attention module""" - def __init__(self, n_head, d_model, d_k, d_v, dropout=0., spectral_norm=False): + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False): super().__init__() self.n_head = n_head @@ -484,7 +616,9 @@ class MultiHeadAttention(nn.Module): self.w_ks = nn.Linear(d_model, n_head * d_k) self.w_vs = nn.Linear(d_model, n_head * d_v) - self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout) + self.attention = ScaledDotProductAttention( + temperature=np.power(d_model, 0.5), dropout=dropout + ) self.fc = nn.Linear(n_head * d_v, d_model) self.dropout = nn.Dropout(dropout) @@ -504,12 +638,9 @@ class MultiHeadAttention(nn.Module): q = self.w_qs(x).view(sz_b, len_x, n_head, d_k) k = self.w_ks(x).view(sz_b, len_x, n_head, d_k) v = self.w_vs(x).view(sz_b, len_x, n_head, d_v) - q = q.permute(2, 0, 1, 3).contiguous().view(-1, - len_x, d_k) # (n*b) x lq x dk - k = k.permute(2, 0, 1, 3).contiguous().view(-1, - len_x, d_k) # (n*b) x lk x dk - v = v.permute(2, 0, 1, 3).contiguous().view(-1, - len_x, d_v) # (n*b) x lv x dv + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv if mask is not None: slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. @@ -518,8 +649,9 @@ class MultiHeadAttention(nn.Module): output, attn = self.attention(q, k, v, mask=slf_mask) output = output.view(n_head, sz_b, len_x, d_v) - output = output.permute(1, 2, 0, 3).contiguous().view( - sz_b, len_x, -1) # b x lq x (n*dv) + output = ( + output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) + ) # b x lq x (n*dv) output = self.fc(output) @@ -528,7 +660,7 @@ class MultiHeadAttention(nn.Module): class ScaledDotProductAttention(nn.Module): - ''' Scaled Dot-Product Attention ''' + """Scaled Dot-Product Attention""" def __init__(self, temperature, dropout): super().__init__() @@ -551,14 +683,17 @@ class ScaledDotProductAttention(nn.Module): class MelStyleEncoder(nn.Module): - ''' MelStyleEncoder ''' + """MelStyleEncoder""" - def __init__(self, n_mel_channels=80, - style_hidden=128, - style_vector_dim=256, - style_kernel_size=5, - style_head=2, - dropout=0.1): + def __init__( + self, + n_mel_channels=80, + style_hidden=128, + style_vector_dim=256, + style_kernel_size=5, + style_head=2, + dropout=0.1, + ): super(MelStyleEncoder, self).__init__() self.in_dim = n_mel_channels self.hidden_dim = style_hidden @@ -573,7 +708,7 @@ class MelStyleEncoder(nn.Module): nn.Dropout(self.dropout), LinearNorm(self.hidden_dim, self.hidden_dim), Mish(), - nn.Dropout(self.dropout) + nn.Dropout(self.dropout), ) self.temporal = nn.Sequential( @@ -581,9 +716,13 @@ class MelStyleEncoder(nn.Module): Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), ) - self.slf_attn = MultiHeadAttention(self.n_head, self.hidden_dim, - self.hidden_dim // self.n_head, self.hidden_dim // self.n_head, - self.dropout) + self.slf_attn = MultiHeadAttention( + self.n_head, + self.hidden_dim, + self.hidden_dim // self.n_head, + self.hidden_dim // self.n_head, + self.dropout, + ) self.fc = LinearNorm(self.hidden_dim, self.out_dim) @@ -598,11 +737,13 @@ class MelStyleEncoder(nn.Module): return out def forward(self, x, mask=None): - x = x.transpose(1,2) + x = x.transpose(1, 2) if mask is not None: - mask = (mask.int()==0).squeeze(1) + mask = (mask.int() == 0).squeeze(1) max_len = x.shape[1] - slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None + slf_attn_mask = ( + mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None + ) # spectral x = self.spectral(x) @@ -644,7 +785,9 @@ class MelStyleEncoderVAE(nn.Module): mu = self.fc1(enc_out) logvar = self.fc2(enc_out) posterior = D.Normal(mu, torch.exp(logvar)) - kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))) + kl_divergence = D.kl_divergence( + posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)) + ) loss_kl = kl_divergence.mean() z = posterior.rsample() @@ -656,11 +799,12 @@ class MelStyleEncoderVAE(nn.Module): if manual_latent is None: if random_sample: dev = next(self.parameters()).device - posterior = D.Normal(torch.zeros(1, self.z_latent_dim, device=dev), - torch.ones(1, self.z_latent_dim, device=dev)) + posterior = D.Normal( + torch.zeros(1, self.z_latent_dim, device=dev), + torch.ones(1, self.z_latent_dim, device=dev), + ) z = posterior.rsample() else: - enc_out = self.ref_encoder(inputs.transpose(1, 2)) mu = self.fc1(enc_out) z = mu @@ -681,7 +825,9 @@ class ActNorm(nn.Module): def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs): if x_mask is None: - x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) + x_mask = torch.ones(x.size(0), 1, x.size(2)).to( + device=x.device, dtype=x.dtype + ) x_len = torch.sum(x_mask, [1, 2]) if not self.initialized: self.initialize(x, x_mask) @@ -707,10 +853,12 @@ class ActNorm(nn.Module): denom = torch.sum(x_mask, [0, 2]) m = torch.sum(x * x_mask, [0, 2]) / denom m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom - v = m_sq - (m ** 2) + v = m_sq - (m**2) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) - bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + bias_init = ( + (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + ) logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) self.bias.data.copy_(bias_init) @@ -720,19 +868,21 @@ class ActNorm(nn.Module): class InvConvNear(nn.Module): def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs): super().__init__() - assert (n_split % 2 == 0) + assert n_split % 2 == 0 self.channels = channels self.n_split = n_split self.no_jacobian = no_jacobian - w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] + w_init = torch.linalg.qr( + torch.FloatTensor(self.n_split, self.n_split).normal_() + )[0] if torch.det(w_init) < 0: w_init[:, 0] = -1 * w_init[:, 0] self.weight = nn.Parameter(w_init) def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs): b, c, t = x.size() - assert (c % self.n_split == 0) + assert c % self.n_split == 0 if x_mask is None: x_mask = 1 x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t @@ -740,7 +890,11 @@ class InvConvNear(nn.Module): x_len = torch.sum(x_mask, [1, 2]) x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) - x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t) + x = ( + x.permute(0, 1, 3, 2, 4) + .contiguous() + .view(b, self.n_split, c // self.n_split, t) + ) if reverse: if hasattr(self, "weight_inv"): diff --git a/GPT_SoVITS/module/mrte_model.py b/GPT_SoVITS/module/mrte_model.py index e936c76..b0cd242 100644 --- a/GPT_SoVITS/module/mrte_model.py +++ b/GPT_SoVITS/module/mrte_model.py @@ -5,46 +5,74 @@ from torch import nn from torch.nn.utils import remove_weight_norm, weight_norm from module.attentions import MultiHeadAttention + class MRTE(nn.Module): - def __init__(self, - content_enc_channels=192, - hidden_size=512, - out_channels=192, - kernel_size=5, - n_heads=4, - ge_layer = 2 - ): + def __init__( + self, + content_enc_channels=192, + hidden_size=512, + out_channels=192, + kernel_size=5, + n_heads=4, + ge_layer=2, + ): super(MRTE, self).__init__() - self.cross_attention = MultiHeadAttention(hidden_size,hidden_size,n_heads) - self.c_pre = nn.Conv1d(content_enc_channels,hidden_size, 1) - self.text_pre = nn.Conv1d(content_enc_channels,hidden_size, 1) - self.c_post = nn.Conv1d(hidden_size,out_channels, 1) + self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads) + self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.c_post = nn.Conv1d(hidden_size, out_channels, 1) def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None): - if(ge==None):ge=0 + if ge == None: + ge = 0 attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) ssl_enc = self.c_pre(ssl_enc * ssl_mask) text_enc = self.text_pre(text * text_mask) if test != None: if test == 0: - x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge + x = ( + self.cross_attention( + ssl_enc * ssl_mask, text_enc * text_mask, attn_mask + ) + + ssl_enc + + ge + ) elif test == 1: x = ssl_enc + ge - elif test ==2: - x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge + elif test == 2: + x = ( + self.cross_attention( + ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask + ) + + ge + ) else: raise ValueError("test should be 0,1,2") else: - x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge + x = ( + self.cross_attention( + ssl_enc * ssl_mask, text_enc * text_mask, attn_mask + ) + + ssl_enc + + ge + ) x = self.c_post(x * ssl_mask) return x - + class SpeakerEncoder(torch.nn.Module): - def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256): + def __init__( + self, + mel_n_channels=80, + model_num_layers=2, + model_hidden_size=256, + model_embedding_size=256, + ): super(SpeakerEncoder, self).__init__() - self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.lstm = nn.LSTM( + mel_n_channels, model_hidden_size, model_num_layers, batch_first=True + ) self.linear = nn.Linear(model_hidden_size, model_embedding_size) self.relu = nn.ReLU() @@ -56,13 +84,15 @@ class SpeakerEncoder(torch.nn.Module): class MELEncoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -81,80 +111,82 @@ class MELEncoder(nn.Module): x = self.enc(x) x = self.proj(x) return x - + class WN(torch.nn.Module): - def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers): - super(WN, self).__init__() - assert(kernel_size % 2 == 1) - self.hidden_channels =hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() - for i in range(n_layers): - dilation = dilation_rate ** i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, - dilation=dilation, padding=padding) - in_layer = weight_norm(in_layer) - self.in_layers.append(in_layer) + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = weight_norm(in_layer) + self.in_layers.append(in_layer) - # last one is not necessary - if i < n_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels - res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) - res_skip_layer = weight_norm(res_skip_layer, name='weight') - self.res_skip_layers.append(res_skip_layer) + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) - def forward(self, x): - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) + def forward(self, x): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) - for i in range(self.n_layers): - x_in = self.in_layers[i](x) + for i in range(self.n_layers): + x_in = self.in_layers[i](x) - acts = fused_add_tanh_sigmoid_multiply( - x_in, - n_channels_tensor) + acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor) - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.n_layers - 1: - res_acts = res_skip_acts[:,:self.hidden_channels,:] - x = (x + res_acts) - output = output + res_skip_acts[:,self.hidden_channels:,:] - else: - output = output + res_skip_acts - return output + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = x + res_acts + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output - def remove_weight_norm(self): - for l in self.in_layers: - remove_weight_norm(l) - for l in self.res_skip_layers: - remove_weight_norm(l) + def remove_weight_norm(self): + for l in self.in_layers: + remove_weight_norm(l) + for l in self.res_skip_layers: + remove_weight_norm(l) @torch.jit.script def fused_add_tanh_sigmoid_multiply(input, n_channels): - n_channels_int = n_channels[0] - t_act = torch.tanh(input[:, :n_channels_int, :]) - s_act = torch.sigmoid(input[:, n_channels_int:, :]) - acts = t_act * s_act - return acts + n_channels_int = n_channels[0] + t_act = torch.tanh(input[:, :n_channels_int, :]) + s_act = torch.sigmoid(input[:, n_channels_int:, :]) + acts = t_act * s_act + return acts - -if __name__ == '__main__': - content_enc = torch.randn(3,192,100) - content_mask = torch.ones(3,1,100) - ref_mel = torch.randn(3,128,30) - ref_mask = torch.ones(3,1,30) +if __name__ == "__main__": + content_enc = torch.randn(3, 192, 100) + content_mask = torch.ones(3, 1, 100) + ref_mel = torch.randn(3, 128, 30) + ref_mask = torch.ones(3, 1, 30) model = MRTE() - out = model(content_enc,content_mask,ref_mel,ref_mask) - print(out.shape) \ No newline at end of file + out = model(content_enc, content_mask, ref_mel, ref_mask) + print(out.shape) diff --git a/GPT_SoVITS/module/quantize.py b/GPT_SoVITS/module/quantize.py index cdbdeea..f9a5c63 100644 --- a/GPT_SoVITS/module/quantize.py +++ b/GPT_SoVITS/module/quantize.py @@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dimension: int = 256, @@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module): threshold_ema_dead_code=self.threshold_ema_dead_code, ) - def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult: + def forward( + self, + x: torch.Tensor, + n_q: tp.Optional[int] = None, + layers: tp.Optional[list] = None, + ) -> QuantizedResult: """Residual vector quantization on the given input tensor. Args: x (torch.Tensor): Input tensor. @@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module): """ n_q = n_q if n_q else self.n_q if layers and max(layers) >= n_q: - raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.') - quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers) + raise ValueError( + f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." + ) + quantized, codes, commit_loss, quantized_list = self.vq( + x, n_q=n_q, layers=layers + ) return quantized, codes, torch.mean(commit_loss), quantized_list - - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth. The RVQ encode method sets the appropriate number of quantizer to use and returns indices for each quantizer. @@ -105,4 +116,4 @@ class ResidualVectorQuantizer(nn.Module): st (int): Start to decode input codes from which layers. Default: 0. """ quantized = self.vq.decode(codes, st=st) - return quantized \ No newline at end of file + return quantized diff --git a/GPT_SoVITS/module/transforms.py b/GPT_SoVITS/module/transforms.py index 4793d67..a11f799 100644 --- a/GPT_SoVITS/module/transforms.py +++ b/GPT_SoVITS/module/transforms.py @@ -9,66 +9,63 @@ DEFAULT_MIN_BIN_HEIGHT = 1e-3 DEFAULT_MIN_DERIVATIVE = 1e-3 -def piecewise_rational_quadratic_transform(inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1., - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE): - +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): if tails is None: spline_fn = rational_quadratic_spline spline_kwargs = {} else: spline_fn = unconstrained_rational_quadratic_spline - spline_kwargs = { - 'tails': tails, - 'tail_bound': tail_bound - } + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} outputs, logabsdet = spline_fn( - inputs=inputs, - unnormalized_widths=unnormalized_widths, - unnormalized_heights=unnormalized_heights, - unnormalized_derivatives=unnormalized_derivatives, - inverse=inverse, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - **spline_kwargs + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs ) return outputs, logabsdet def searchsorted(bin_locations, inputs, eps=1e-6): bin_locations[..., -1] += eps - return torch.sum( - inputs[..., None] >= bin_locations, - dim=-1 - ) - 1 + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 -def unconstrained_rational_quadratic_spline(inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails='linear', - tail_bound=1., - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE): +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask outputs = torch.zeros_like(inputs) logabsdet = torch.zeros_like(inputs) - if tails == 'linear': + if tails == "linear": unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives[..., 0] = constant @@ -77,45 +74,57 @@ def unconstrained_rational_quadratic_spline(inputs, outputs[outside_interval_mask] = inputs[outside_interval_mask] logabsdet[outside_interval_mask] = 0 else: - raise RuntimeError('{} tails are not implemented.'.format(tails)) + raise RuntimeError("{} tails are not implemented.".format(tails)) - outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( inputs=inputs[inside_interval_mask], unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :], unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], inverse=inverse, - left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, min_bin_width=min_bin_width, min_bin_height=min_bin_height, - min_derivative=min_derivative + min_derivative=min_derivative, ) return outputs, logabsdet -def rational_quadratic_spline(inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - left=0., right=1., bottom=0., top=1., - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE): + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): if torch.min(inputs) < left or torch.max(inputs) > right: - raise ValueError('Input to a transform is not within its domain') + raise ValueError("Input to a transform is not within its domain") num_bins = unnormalized_widths.shape[-1] if min_bin_width * num_bins > 1.0: - raise ValueError('Minimal bin width too large for the number of bins') + raise ValueError("Minimal bin width too large for the number of bins") if min_bin_height * num_bins > 1.0: - raise ValueError('Minimal bin height too large for the number of bins') + raise ValueError("Minimal bin height too large for the number of bins") widths = F.softmax(unnormalized_widths, dim=-1) widths = min_bin_width + (1 - min_bin_width * num_bins) * widths cumwidths = torch.cumsum(widths, dim=-1) - cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) cumwidths = (right - left) * cumwidths + left cumwidths[..., 0] = left cumwidths[..., -1] = right @@ -126,7 +135,7 @@ def rational_quadratic_spline(inputs, heights = F.softmax(unnormalized_heights, dim=-1) heights = min_bin_height + (1 - min_bin_height * num_bins) * heights cumheights = torch.cumsum(heights, dim=-1) - cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) cumheights = (top - bottom) * cumheights + bottom cumheights[..., 0] = bottom cumheights[..., -1] = top @@ -150,15 +159,13 @@ def rational_quadratic_spline(inputs, input_heights = heights.gather(-1, bin_idx)[..., 0] if inverse: - a = (((inputs - input_cumheights) * (input_derivatives - + input_derivatives_plus_one - - 2 * input_delta) - + input_heights * (input_delta - input_derivatives))) - b = (input_heights * input_derivatives - - (inputs - input_cumheights) * (input_derivatives - + input_derivatives_plus_one - - 2 * input_delta)) - c = - input_delta * (inputs - input_cumheights) + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c assert (discriminant >= 0).all() @@ -167,11 +174,15 @@ def rational_quadratic_spline(inputs, outputs = root * input_bin_widths + input_cumwidths theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta) - derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2)) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) return outputs, -logabsdet @@ -179,15 +190,20 @@ def rational_quadratic_spline(inputs, theta = (inputs - input_cumwidths) / input_bin_widths theta_one_minus_theta = theta * (1 - theta) - numerator = input_heights * (input_delta * theta.pow(2) - + input_derivatives * theta_one_minus_theta) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta) + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) outputs = input_cumheights + numerator / denominator - derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2)) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) return outputs, logabsdet diff --git a/GPT_SoVITS/prepare_datasets/0-pipeline.py b/GPT_SoVITS/prepare_datasets/0-pipeline.py index 4b90a68..4979ed2 100644 --- a/GPT_SoVITS/prepare_datasets/0-pipeline.py +++ b/GPT_SoVITS/prepare_datasets/0-pipeline.py @@ -1,50 +1,81 @@ -import os,torch,sys +import os, torch, sys from subprocess import Popen + now_dir = os.getcwd() sys.path.append(now_dir) -from config import text_path,wav_dir,n_card,n_process_per_card,exp_name,n_parts,exp_dir -os.makedirs("%s/logs_s1"%exp_dir,exist_ok=True) -os.makedirs("%s/logs_s2"%exp_dir,exist_ok=True) +from config import ( + text_path, + wav_dir, + n_card, + exp_name, + n_parts, + exp_dir, +) + +os.makedirs("%s/logs_s1" % exp_dir, exist_ok=True) +os.makedirs("%s/logs_s2" % exp_dir, exist_ok=True) ##############step1 -ps=[] +ps = [] for i_part in range(n_parts): - cmd="python prepare/1-get-text.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card) + cmd = "python prepare/1-get-text.py %s %s %s %s %s %s" % ( + text_path, + wav_dir, + exp_name, + i_part, + n_parts, + i_part % n_card, + ) print(cmd) p = Popen(cmd, shell=True) ps.append(p) for p in ps: p.wait() -opt=[] +opt = [] for i_part in range(n_parts): txt_path = "%s/2-name2text-%s.txt" % (exp_dir, i_part) - with open(txt_path,"r")as f: - opt+=f.read().strip("\n").split("\n") + with open(txt_path, "r") as f: + opt += f.read().strip("\n").split("\n") os.remove(txt_path) -with open("%s/2-name2text.txt"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n") +with open("%s/2-name2text.txt" % exp_dir, "w") as f: + f.write("\n".join(opt) + "\n") ############step2 -ps=[] +ps = [] for i_part in range(n_parts): - cmd="python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card) + cmd = "python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s" % ( + text_path, + wav_dir, + exp_name, + i_part, + n_parts, + i_part % n_card, + ) print(cmd) p = Popen(cmd, shell=True) ps.append(p) for p in ps: p.wait() #############step3 -ps=[] +ps = [] for i_part in range(n_parts): - cmd="python prepare/3-get-semantic.py %s %s %s %s %s"%(text_path,exp_name,i_part,n_parts,i_part%n_card) + cmd = "python prepare/3-get-semantic.py %s %s %s %s %s" % ( + text_path, + exp_name, + i_part, + n_parts, + i_part % n_card, + ) print(cmd) p = Popen(cmd, shell=True) ps.append(p) for p in ps: p.wait() -opt=["item_name semantic_audio"] +opt = ["item_name semantic_audio"] for i_part in range(n_parts): semantic_path = "%s/6-name2semantic-%s.tsv" % (exp_dir, i_part) - with open(semantic_path,"r")as f: - opt+=f.read().strip("\n").split("\n") + with open(semantic_path, "r") as f: + opt += f.read().strip("\n").split("\n") os.remove(semantic_path) -with open("%s/6-name2semantic.tsv"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n") +with open("%s/6-name2semantic.tsv" % exp_dir, "w") as f: + f.write("\n".join(opt) + "\n") diff --git a/GPT_SoVITS/prepare_datasets/1-get-text.py b/GPT_SoVITS/prepare_datasets/1-get-text.py index 5abd353..8579693 100644 --- a/GPT_SoVITS/prepare_datasets/1-get-text.py +++ b/GPT_SoVITS/prepare_datasets/1-get-text.py @@ -2,16 +2,16 @@ import os -inp_text= os.environ.get("inp_text") -inp_wav_dir= os.environ.get("inp_wav_dir") -exp_name= os.environ.get("exp_name") -i_part= os.environ.get("i_part") -all_parts= os.environ.get("all_parts") -os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES") -opt_dir= os.environ.get("opt_dir") -bert_pretrained_dir= os.environ.get("bert_pretrained_dir") -is_half=eval(os.environ.get("is_half","True")) -import sys,numpy as np,traceback,pdb +inp_text = os.environ.get("inp_text") +inp_wav_dir = os.environ.get("inp_wav_dir") +exp_name = os.environ.get("exp_name") +i_part = os.environ.get("i_part") +all_parts = os.environ.get("all_parts") +os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES") +opt_dir = os.environ.get("opt_dir") +bert_pretrained_dir = os.environ.get("bert_pretrained_dir") +is_half = eval(os.environ.get("is_half", "True")) +import sys, numpy as np, traceback, pdb import os.path from glob import glob from tqdm import tqdm @@ -31,25 +31,29 @@ import numpy as np from time import time as ttime import shutil -def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path - dir=os.path.dirname(path) - name=os.path.basename(path) - tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) - torch.save(fea,tmp_path) - shutil.move(tmp_path,"%s/%s"%(dir,name)) -txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part) -if(os.path.exists(txt_path)==False): - bert_dir="%s/3-bert"%(opt_dir) - os.makedirs(opt_dir,exist_ok=True) - os.makedirs(bert_dir,exist_ok=True) - device="cuda:0" + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) +if os.path.exists(txt_path) == False: + bert_dir = "%s/3-bert" % (opt_dir) + os.makedirs(opt_dir, exist_ok=True) + os.makedirs(bert_dir, exist_ok=True) + device = "cuda:0" tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir) - bert_model=AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) - if (is_half == True): + bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir) + if is_half == True: bert_model = bert_model.half().to(device) else: bert_model = bert_model.to(device) + def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") @@ -67,51 +71,55 @@ if(os.path.exists(txt_path)==False): phone_level_feature = torch.cat(phone_level_feature, dim=0) return phone_level_feature.T - def process(data,res): - for name,text,lan in data: + + def process(data, res): + for name, text, lan in data: try: - name=os.path.basename(name) - phones, word2ph, norm_text=clean_text(text.replace("%", '-').replace('¥', ','),lan) - path_bert="%s/%s.pt"%(bert_dir,name) - if (os.path.exists(path_bert) == False and lan == "zh"): + name = os.path.basename(name) + phones, word2ph, norm_text = clean_text( + text.replace("%", "-").replace("¥", ","), lan + ) + path_bert = "%s/%s.pt" % (bert_dir, name) + if os.path.exists(path_bert) == False and lan == "zh": bert_feature = get_bert_feature(norm_text, word2ph) assert bert_feature.shape[-1] == len(phones) # torch.save(bert_feature, path_bert) my_save(bert_feature, path_bert) phones = " ".join(phones) # res.append([name,phones]) - res.append([name,phones, word2ph, norm_text]) + res.append([name, phones, word2ph, norm_text]) except: print(name, text, traceback.format_exc()) - todo=[] - res=[] - with open(inp_text,"r",encoding="utf8")as f: - lines=f.read().strip("\n").split("\n") + todo = [] + res = [] + with open(inp_text, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") - language_v1_to_language_v2={ - "ZH":"zh", - "zh":"zh", - "JP":"ja", - "jp":"ja", - "JA":"ja", - "ja":"ja", - "EN":"en", - "en":"en", - "En":"en", + language_v1_to_language_v2 = { + "ZH": "zh", + "zh": "zh", + "JP": "ja", + "jp": "ja", + "JA": "ja", + "ja": "ja", + "EN": "en", + "en": "en", + "En": "en", } - for line in lines[int(i_part)::int(all_parts)]: + for line in lines[int(i_part) :: int(all_parts)]: try: - wav_name,spk_name,language,text=line.split("|") + wav_name, spk_name, language, text = line.split("|") # todo.append([name,text,"zh"]) - todo.append([wav_name,text,language_v1_to_language_v2.get(language,language)]) + todo.append( + [wav_name, text, language_v1_to_language_v2.get(language, language)] + ) except: - print(line,traceback.format_exc()) - - process(todo,res) - opt=[] - for name,phones, word2ph, norm_text in res: - opt.append("%s\t%s\t%s\t%s"%(name,phones, word2ph, norm_text)) - with open(txt_path,"w",encoding="utf8")as f: - f.write("\n".join(opt)+"\n") + print(line, traceback.format_exc()) + process(todo, res) + opt = [] + for name, phones, word2ph, norm_text in res: + opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text)) + with open(txt_path, "w", encoding="utf8") as f: + f.write("\n".join(opt) + "\n") diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py index a5075ff..25cb4a8 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py @@ -1,20 +1,23 @@ # -*- coding: utf-8 -*- -import sys,os -inp_text= os.environ.get("inp_text") -inp_wav_dir= os.environ.get("inp_wav_dir") -exp_name= os.environ.get("exp_name") -i_part= os.environ.get("i_part") -all_parts= os.environ.get("all_parts") -os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES") -from feature_extractor import cnhubert -opt_dir= os.environ.get("opt_dir") -cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir") -is_half=eval(os.environ.get("is_half","True")) +import sys, os -import pdb,traceback,numpy as np,logging +inp_text = os.environ.get("inp_text") +inp_wav_dir = os.environ.get("inp_wav_dir") +exp_name = os.environ.get("exp_name") +i_part = os.environ.get("i_part") +all_parts = os.environ.get("all_parts") +os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES") +from feature_extractor import cnhubert + +opt_dir = os.environ.get("opt_dir") +cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir") +is_half = eval(os.environ.get("is_half", "True")) + +import pdb, traceback, numpy as np, logging from scipy.io import wavfile -import librosa,torch +import librosa, torch + now_dir = os.getcwd() sys.path.append(now_dir) from my_utils import load_audio @@ -32,63 +35,75 @@ from my_utils import load_audio from time import time as ttime import shutil -def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path - dir=os.path.dirname(path) - name=os.path.basename(path) - tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part) - torch.save(fea,tmp_path) - shutil.move(tmp_path,"%s/%s"%(dir,name)) -hubert_dir="%s/4-cnhubert"%(opt_dir) -wav32dir="%s/5-wav32k"%(opt_dir) -os.makedirs(opt_dir,exist_ok=True) -os.makedirs(hubert_dir,exist_ok=True) -os.makedirs(wav32dir,exist_ok=True) -maxx=0.95 -alpha=0.5 -device="cuda:0" -model=cnhubert.get_model() -if(is_half==True): - model=model.half().to(device) +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +hubert_dir = "%s/4-cnhubert" % (opt_dir) +wav32dir = "%s/5-wav32k" % (opt_dir) +os.makedirs(opt_dir, exist_ok=True) +os.makedirs(hubert_dir, exist_ok=True) +os.makedirs(wav32dir, exist_ok=True) + +maxx = 0.95 +alpha = 0.5 +device = "cuda:0" +model = cnhubert.get_model() +if is_half == True: + model = model.half().to(device) else: model = model.to(device) + + def name2go(wav_name): - hubert_path="%s/%s.pt"%(hubert_dir,wav_name) - if(os.path.exists(hubert_path)):return - wav_path="%s/%s"%(inp_wav_dir,wav_name) + hubert_path = "%s/%s.pt" % (hubert_dir, wav_name) + if os.path.exists(hubert_path): + return + wav_path = "%s/%s" % (inp_wav_dir, wav_name) tmp_audio = load_audio(wav_path, 32000) tmp_max = np.abs(tmp_audio).max() if tmp_max > 2.2: print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max)) return - tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio - tmp_audio = librosa.resample( - tmp_audio32, orig_sr=32000, target_sr=16000 - ) + tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ( + (1 - alpha) * 32768 + ) * tmp_audio + tmp_audio = librosa.resample(tmp_audio32, orig_sr=32000, target_sr=16000) tensor_wav16 = torch.from_numpy(tmp_audio) - if (is_half == True): - tensor_wav16=tensor_wav16.half().to(device) + if is_half == True: + tensor_wav16 = tensor_wav16.half().to(device) else: tensor_wav16 = tensor_wav16.to(device) - ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) - if np.isnan(ssl.detach().numpy()).sum()!= 0:return + ssl = ( + model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"] + .transpose(1, 2) + .cpu() + ) # torch.Size([1, 768, 215]) + if np.isnan(ssl.detach().numpy()).sum() != 0: + return wavfile.write( - "%s/%s"%(wav32dir,wav_name), + "%s/%s" % (wav32dir, wav_name), 32000, tmp_audio32.astype("int16"), ) # torch.save(ssl,hubert_path ) - my_save(ssl,hubert_path ) + my_save(ssl, hubert_path) -with open(inp_text,"r",encoding="utf8")as f: - lines=f.read().strip("\n").split("\n") -for line in lines[int(i_part)::int(all_parts)]: +with open(inp_text, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") + +for line in lines[int(i_part) :: int(all_parts)]: try: # wav_name,text=line.split("\t") wav_name, spk_name, language, text = line.split("|") - wav_name=os.path.basename(wav_name) + wav_name = os.path.basename(wav_name) name2go(wav_name) except: - print(line,traceback.format_exc()) + print(line, traceback.format_exc()) diff --git a/GPT_SoVITS/prepare_datasets/3-get-semantic.py b/GPT_SoVITS/prepare_datasets/3-get-semantic.py index 69f8e3e..7cee6e4 100644 --- a/GPT_SoVITS/prepare_datasets/3-get-semantic.py +++ b/GPT_SoVITS/prepare_datasets/3-get-semantic.py @@ -1,24 +1,27 @@ import os -inp_text= os.environ.get("inp_text") -exp_name= os.environ.get("exp_name") -i_part= os.environ.get("i_part") -all_parts= os.environ.get("all_parts") -os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES") -opt_dir= os.environ.get("opt_dir") -pretrained_s2G= os.environ.get("pretrained_s2G") -s2config_path= os.environ.get("s2config_path") -is_half=eval(os.environ.get("is_half","True")) -import math,traceback + +inp_text = os.environ.get("inp_text") +exp_name = os.environ.get("exp_name") +i_part = os.environ.get("i_part") +all_parts = os.environ.get("all_parts") +os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES") +opt_dir = os.environ.get("opt_dir") +pretrained_s2G = os.environ.get("pretrained_s2G") +s2config_path = os.environ.get("s2config_path") +is_half = eval(os.environ.get("is_half", "True")) +import math, traceback import multiprocessing -import sys,pdb +import sys, pdb + now_dir = os.getcwd() sys.path.append(now_dir) from random import shuffle import torch.multiprocessing as mp from glob import glob from tqdm import tqdm -import logging,librosa,utils,torch +import logging, librosa, utils, torch from module.models import SynthesizerTrn + logging.getLogger("numba").setLevel(logging.WARNING) # from config import pretrained_s2G @@ -30,52 +33,58 @@ logging.getLogger("numba").setLevel(logging.WARNING) # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name -hubert_dir="%s/4-cnhubert"%(opt_dir) -semantic_path="%s/6-name2semantic-%s.tsv"%(opt_dir,i_part) -if(os.path.exists(semantic_path)==False): - os.makedirs(opt_dir,exist_ok=True) +hubert_dir = "%s/4-cnhubert" % (opt_dir) +semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part) +if os.path.exists(semantic_path) == False: + os.makedirs(opt_dir, exist_ok=True) - device="cuda:0" + device = "cuda:0" hps = utils.get_hparams_from_file(s2config_path) vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, - **hps.model) - if(is_half==True): - vq_model=vq_model.half().to(device) + **hps.model + ) + if is_half == True: + vq_model = vq_model.half().to(device) else: vq_model = vq_model.to(device) vq_model.eval() # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True) # utils.load_checkpoint(pretrained_s2G, vq_model, None, True) - print(vq_model.load_state_dict(torch.load(pretrained_s2G,map_location="cpu")["weight"], strict=False)) + print( + vq_model.load_state_dict( + torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False + ) + ) - def name2go(wav_name,lines): + def name2go(wav_name, lines): hubert_path = "%s/%s.pt" % (hubert_dir, wav_name) - if(os.path.exists(hubert_path)==False):return + if os.path.exists(hubert_path) == False: + return ssl_content = torch.load(hubert_path, map_location="cpu") - if(is_half==True): - ssl_content=ssl_content.half().to(device) + if is_half == True: + ssl_content = ssl_content.half().to(device) else: ssl_content = ssl_content.to(device) codes = vq_model.extract_latent(ssl_content) semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()]) - lines.append("%s\t%s"%(wav_name,semantic)) + lines.append("%s\t%s" % (wav_name, semantic)) - with open(inp_text,"r",encoding="utf8")as f: - lines=f.read().strip("\n").split("\n") + with open(inp_text, "r", encoding="utf8") as f: + lines = f.read().strip("\n").split("\n") - lines1=[] - for line in lines[int(i_part)::int(all_parts)]: + lines1 = [] + for line in lines[int(i_part) :: int(all_parts)]: # print(line) try: # wav_name,text=line.split("\t") wav_name, spk_name, language, text = line.split("|") - wav_name=os.path.basename(wav_name) + wav_name = os.path.basename(wav_name) # name2go(name,lines1) - name2go(wav_name,lines1) + name2go(wav_name, lines1) except: - print(line,traceback.format_exc()) - with open(semantic_path,"w",encoding="utf8")as f:f.write("\n".join(lines1)) - + print(line, traceback.format_exc()) + with open(semantic_path, "w", encoding="utf8") as f: + f.write("\n".join(lines1)) diff --git a/GPT_SoVITS/text/chinese.py b/GPT_SoVITS/text/chinese.py index 03bdefb..64c8818 100644 --- a/GPT_SoVITS/text/chinese.py +++ b/GPT_SoVITS/text/chinese.py @@ -6,49 +6,56 @@ import cn2an from pypinyin import lazy_pinyin, Style import sys + sys.path.append("/data/docker/liujing04/gpt-vits/gpt-vits-master") from text.symbols import punctuation from text.tone_sandhi import ToneSandhi current_file_path = os.path.dirname(__file__) -pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in - open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()} +pinyin_to_symbol_map = { + line.split("\t")[0]: line.strip().split("\t")[1] + for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() +} import jieba.posseg as psg rep_map = { - ':': ',', - ';': ',', - ',': ',', - '。': '.', - '!': '!', - '?': '?', - '\n': '.', + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", "·": ",", - '、': ",", - '...': '…', - '$': '.', - '/': ',', - '—': "-" + "、": ",", + "...": "…", + "$": ".", + "/": ",", + "—": "-", } tone_modifier = ToneSandhi() + def replace_punctuation(text): - text = text.replace("嗯", "恩").replace("呣","母") - pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys())) + text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) - replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text) + replaced_text = re.sub( + r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text + ) return replaced_text + def g2p(text): - pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation)) - sentences = [i for i in re.split(pattern, text) if i.strip()!=''] + pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) + sentences = [i for i in re.split(pattern, text) if i.strip() != ""] phones, word2ph = _g2p(sentences) return phones, word2ph @@ -56,10 +63,10 @@ def g2p(text): def _get_initials_finals(word): initials = [] finals = [] - orig_initials = lazy_pinyin( - word, neutral_tone_with_five=True, style=Style.INITIALS) + orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) orig_finals = lazy_pinyin( - word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + word, neutral_tone_with_five=True, style=Style.FINALS_TONE3 + ) for c, v in zip(orig_initials, orig_finals): initials.append(c) finals.append(v) @@ -72,17 +79,16 @@ def _g2p(segments): for seg in segments: pinyins = [] # Replace all English words in the sentence - seg = re.sub('[a-zA-Z]+', '', seg) + seg = re.sub("[a-zA-Z]+", "", seg) seg_cut = psg.lcut(seg) initials = [] finals = [] seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) for word, pos in seg_cut: - if pos == 'eng': + if pos == "eng": continue sub_initials, sub_finals = _get_initials_finals(word) - sub_finals = tone_modifier.modified_tone(word, pos, - sub_finals) + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) initials.append(sub_initials) finals.append(sub_finals) @@ -91,7 +97,7 @@ def _g2p(segments): finals = sum(finals, []) # for c, v in zip(initials, finals): - raw_pinyin = c+v + raw_pinyin = c + v # NOTE: post process for pypinyin outputs # we discriminate i, ii and iii if c == v: @@ -102,40 +108,40 @@ def _g2p(segments): v_without_tone = v[:-1] tone = v[-1] - pinyin = c+v_without_tone - assert tone in '12345' + pinyin = c + v_without_tone + assert tone in "12345" if c: # 多音节 v_rep_map = { - "uei": 'ui', - 'iou': 'iu', - 'uen': 'un', + "uei": "ui", + "iou": "iu", + "uen": "un", } if v_without_tone in v_rep_map.keys(): - pinyin = c+v_rep_map[v_without_tone] + pinyin = c + v_rep_map[v_without_tone] else: # 单音节 pinyin_rep_map = { - 'ing': 'ying', - 'i': 'yi', - 'in': 'yin', - 'u': 'wu', + "ing": "ying", + "i": "yi", + "in": "yin", + "u": "wu", } if pinyin in pinyin_rep_map.keys(): pinyin = pinyin_rep_map[pinyin] else: single_rep_map = { - 'v': 'yu', - 'e': 'e', - 'i': 'y', - 'u': 'w', + "v": "yu", + "e": "e", + "i": "y", + "u": "w", } if pinyin[0] in single_rep_map.keys(): - pinyin = single_rep_map[pinyin[0]]+pinyin[1:] + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) - new_c, new_v = pinyin_to_symbol_map[pinyin].split(' ') + new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") new_v = new_v + tone phone = [new_c, new_v] word2ph.append(len(phone)) @@ -144,9 +150,8 @@ def _g2p(segments): return phones_list, word2ph - def text_normalize(text): - numbers = re.findall(r'\d+(?:\.?\d+)?', text) + numbers = re.findall(r"\d+(?:\.?\d+)?", text) for number in numbers: text = text.replace(number, cn2an.an2cn(number), 1) text = replace_punctuation(text) @@ -154,7 +159,7 @@ def text_normalize(text): return text -if __name__ == '__main__': +if __name__ == "__main__": text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "呣呣呣~就是…大人的鼹鼠党吧?" text = "你好" diff --git a/GPT_SoVITS/text/cleaner.py b/GPT_SoVITS/text/cleaner.py index dc4bd73..e5a9b1b 100644 --- a/GPT_SoVITS/text/cleaner.py +++ b/GPT_SoVITS/text/cleaner.py @@ -1,29 +1,27 @@ from text import chinese, japanese, cleaned_text_to_sequence, symbols, english -language_module_map = { - 'zh': chinese, - "ja": japanese, - 'en': english -} +language_module_map = {"zh": chinese, "ja": japanese, "en": english} special = [ - ('%', 'zh', "SP"), - ('¥', 'zh', "SP2"), - ('^', 'zh', "SP3"), + ("%", "zh", "SP"), + ("¥", "zh", "SP2"), + ("^", "zh", "SP3"), # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧 ] + + def clean_text(text, language): for special_s, special_l, target_symbol in special: if special_s in text and language == special_l: return clean_special(text, language, special_s, target_symbol) language_module = language_module_map[language] norm_text = language_module.text_normalize(text) - if(language=="zh"): + if language == "zh": phones, word2ph = language_module.g2p(norm_text) assert len(phones) == sum(word2ph) assert len(norm_text) == len(word2ph) else: phones = language_module.g2p(norm_text) - word2ph=None + word2ph = None for ph in phones: assert ph in symbols @@ -41,17 +39,17 @@ def clean_special(text, language, special_s, target_symbol): new_ph = [] for ph in phones: assert ph in symbols - if ph == ',': + if ph == ",": new_ph.append(target_symbol) else: new_ph.append(ph) return new_ph + def text_to_sequence(text, language): phones = clean_text(text) return cleaned_text_to_sequence(phones) -if __name__ == '__main__': - print(clean_text("你好%啊啊啊额、还是到付红四方。", 'zh')) - +if __name__ == "__main__": + print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh")) diff --git a/GPT_SoVITS/text/english.py b/GPT_SoVITS/text/english.py index bf48db1..bd68ddf 100644 --- a/GPT_SoVITS/text/english.py +++ b/GPT_SoVITS/text/english.py @@ -8,20 +8,87 @@ from string import punctuation from text import symbols current_file_path = os.path.dirname(__file__) -CMU_DICT_PATH = os.path.join(current_file_path, 'cmudict.rep') -CACHE_PATH = os.path.join(current_file_path, 'cmudict_cache.pickle') +CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep") +CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle") _g2p = G2p() -arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} +arpa = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", +} def replace_phs(phs): - rep_map = { - ';': ',', - ':': ',', - '\'': '-', - '"': '-' - } + rep_map = {";": ",", ":": ",", "'": "-", '"': "-"} phs_new = [] for ph in phs: if ph in symbols: @@ -29,9 +96,10 @@ def replace_phs(phs): elif ph in rep_map.keys(): phs_new.append(rep_map[ph]) else: - print('ph not in symbols: ', ph) + print("ph not in symbols: ", ph) return phs_new + def read_dict(): g2p_dict = {} start_line = 49 @@ -41,13 +109,13 @@ def read_dict(): while line: if line_index >= start_line: line = line.strip() - word_split = line.split(' ') + word_split = line.split(" ") word = word_split[0] - syllable_split = word_split[1].split(' - ') + syllable_split = word_split[1].split(" - ") g2p_dict[word] = [] for syllable in syllable_split: - phone_split = syllable.split(' ') + phone_split = syllable.split(" ") g2p_dict[word].append(phone_split) line_index = line_index + 1 @@ -57,13 +125,13 @@ def read_dict(): def cache_dict(g2p_dict, file_path): - with open(file_path, 'wb') as pickle_file: + with open(file_path, "wb") as pickle_file: pickle.dump(g2p_dict, pickle_file) def get_dict(): if os.path.exists(CACHE_PATH): - with open(CACHE_PATH, 'rb') as pickle_file: + with open(CACHE_PATH, "rb") as pickle_file: g2p_dict = pickle.load(pickle_file) else: g2p_dict = read_dict() @@ -71,6 +139,7 @@ def get_dict(): return g2p_dict + eng_dict = get_dict() @@ -78,8 +147,8 @@ def text_normalize(text): # todo: eng text normalize return text.replace(";", ",") -def g2p(text): +def g2p(text): phones = [] words = re.split(r"([,;.\-\?\!\s+])", text) for w in words: @@ -97,6 +166,7 @@ def g2p(text): return replace_phs(phones) + if __name__ == "__main__": # print(get_dict()) print(g2p("hello")) @@ -106,4 +176,4 @@ if __name__ == "__main__": # for group in syllables: # for ph in group: # all_phones.add(ph) - # print(all_phones) \ No newline at end of file + # print(all_phones) diff --git a/GPT_SoVITS/text/japanese.py b/GPT_SoVITS/text/japanese.py index f1263e4..1cef2db 100644 --- a/GPT_SoVITS/text/japanese.py +++ b/GPT_SoVITS/text/japanese.py @@ -8,57 +8,63 @@ from text import symbols # Regular expression matching Japanese without punctuation marks: _japanese_characters = re.compile( - r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) # Regular expression matching non-Japanese characters or punctuation marks: _japanese_marks = re.compile( - r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) # List of (symbol, Japanese) pairs for marks: -_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('%', 'パーセント') -]] +_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]] # List of (consonant, sokuon) pairs: -_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ - (r'Q([↑↓]*[kg])', r'k#\1'), - (r'Q([↑↓]*[tdjʧ])', r't#\1'), - (r'Q([↑↓]*[sʃ])', r's\1'), - (r'Q([↑↓]*[pb])', r'p#\1') -]] +_real_sokuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"Q([↑↓]*[kg])", r"k#\1"), + (r"Q([↑↓]*[tdjʧ])", r"t#\1"), + (r"Q([↑↓]*[sʃ])", r"s\1"), + (r"Q([↑↓]*[pb])", r"p#\1"), + ] +] # List of (consonant, hatsuon) pairs: -_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ - (r'N([↑↓]*[pbm])', r'm\1'), - (r'N([↑↓]*[ʧʥj])', r'n^\1'), - (r'N([↑↓]*[tdn])', r'n\1'), - (r'N([↑↓]*[kg])', r'ŋ\1') -]] - +_real_hatsuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"N([↑↓]*[pbm])", r"m\1"), + (r"N([↑↓]*[ʧʥj])", r"n^\1"), + (r"N([↑↓]*[tdn])", r"n\1"), + (r"N([↑↓]*[kg])", r"ŋ\1"), + ] +] def post_replace_ph(ph): rep_map = { - ':': ',', - ';': ',', - ',': ',', - '。': '.', - '!': '!', - '?': '?', - '\n': '.', + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", "·": ",", - '、': ",", - '...': '…' + "、": ",", + "...": "…", } if ph in rep_map.keys(): ph = rep_map[ph] if ph in symbols: return ph if ph not in symbols: - ph = 'UNK' + ph = "UNK" return ph + def symbols_to_japanese(text): for regex, replacement in _symbols_to_japanese: text = re.sub(regex, replacement, text) @@ -66,7 +72,7 @@ def symbols_to_japanese(text): def preprocess_jap(text): - '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' + """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html""" text = symbols_to_japanese(text) sentences = re.split(_japanese_marks, text) marks = re.findall(_japanese_marks, text) @@ -77,13 +83,15 @@ def preprocess_jap(text): text += p.split(" ") if i < len(marks): - text += [marks[i].replace(' ', '')] + text += [marks[i].replace(" ", "")] return text + def text_normalize(text): # todo: jap text normalize return text + def g2p(norm_text): phones = preprocess_jap(norm_text) phones = [post_replace_ph(i) for i in phones] @@ -91,7 +99,7 @@ def g2p(norm_text): return phones -if __name__ == '__main__': +if __name__ == "__main__": for line in open("../../../Downloads/transcript_utf8.txt").readlines(): text = line.split(":")[1] phones = g2p(text) diff --git a/GPT_SoVITS/text/symbols.py b/GPT_SoVITS/text/symbols.py index 5322a92..97e3938 100644 --- a/GPT_SoVITS/text/symbols.py +++ b/GPT_SoVITS/text/symbols.py @@ -1,24 +1,397 @@ import os # punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 -punctuation = ['!', '?', '…', ",", "."]#@是SP停顿 +punctuation = ["!", "?", "…", ",", "."] # @是SP停顿 punctuation.append("-") -pu_symbols = punctuation + ["SP", 'SP2', 'SP3', "UNK"] +pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"] # pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"] -pad = '_' +pad = "_" -c = ['AA', 'EE', 'OO', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'w', 'x', 'y', 'z', 'zh'] -v = ['E1', 'En1', 'a1', 'ai1', 'an1', 'ang1', 'ao1', 'e1', 'ei1', 'en1', 'eng1', 'er1', 'i1', 'i01', 'ia1', 'ian1', 'iang1', 'iao1', 'ie1', 'in1', 'ing1', 'iong1', 'ir1', 'iu1', 'o1', 'ong1', 'ou1', 'u1', 'ua1', 'uai1', 'uan1', 'uang1', 'ui1', 'un1', 'uo1', 'v1', 'van1', 've1', 'vn1', 'E2', 'En2', 'a2', 'ai2', 'an2', 'ang2', 'ao2', 'e2', 'ei2', 'en2', 'eng2', 'er2', 'i2', 'i02', 'ia2', 'ian2', 'iang2', 'iao2', 'ie2', 'in2', 'ing2', 'iong2', 'ir2', 'iu2', 'o2', 'ong2', 'ou2', 'u2', 'ua2', 'uai2', 'uan2', 'uang2', 'ui2', 'un2', 'uo2', 'v2', 'van2', 've2', 'vn2', 'E3', 'En3', 'a3', 'ai3', 'an3', 'ang3', 'ao3', 'e3', 'ei3', 'en3', 'eng3', 'er3', 'i3', 'i03', 'ia3', 'ian3', 'iang3', 'iao3', 'ie3', 'in3', 'ing3', 'iong3', 'ir3', 'iu3', 'o3', 'ong3', 'ou3', 'u3', 'ua3', 'uai3', 'uan3', 'uang3', 'ui3', 'un3', 'uo3', 'v3', 'van3', 've3', 'vn3', 'E4', 'En4', 'a4', 'ai4', 'an4', 'ang4', 'ao4', 'e4', 'ei4', 'en4', 'eng4', 'er4', 'i4', 'i04', 'ia4', 'ian4', 'iang4', 'iao4', 'ie4', 'in4', 'ing4', 'iong4', 'ir4', 'iu4', 'o4', 'ong4', 'ou4', 'u4', 'ua4', 'uai4', 'uan4', 'uang4', 'ui4', 'un4', 'uo4', 'v4', 'van4', 've4', 'vn4', 'E5', 'En5', 'a5', 'ai5', 'an5', 'ang5', 'ao5', 'e5', 'ei5', 'en5', 'eng5', 'er5', 'i5', 'i05', 'ia5', 'ian5', 'iang5', 'iao5', 'ie5', 'in5', 'ing5', 'iong5', 'ir5', 'iu5', 'o5', 'ong5', 'ou5', 'u5', 'ua5', 'uai5', 'uan5', 'uang5', 'ui5', 'un5', 'uo5', 'v5', 'van5', 've5', 'vn5'] +c = [ + "AA", + "EE", + "OO", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "w", + "x", + "y", + "z", + "zh", +] +v = [ + "E1", + "En1", + "a1", + "ai1", + "an1", + "ang1", + "ao1", + "e1", + "ei1", + "en1", + "eng1", + "er1", + "i1", + "i01", + "ia1", + "ian1", + "iang1", + "iao1", + "ie1", + "in1", + "ing1", + "iong1", + "ir1", + "iu1", + "o1", + "ong1", + "ou1", + "u1", + "ua1", + "uai1", + "uan1", + "uang1", + "ui1", + "un1", + "uo1", + "v1", + "van1", + "ve1", + "vn1", + "E2", + "En2", + "a2", + "ai2", + "an2", + "ang2", + "ao2", + "e2", + "ei2", + "en2", + "eng2", + "er2", + "i2", + "i02", + "ia2", + "ian2", + "iang2", + "iao2", + "ie2", + "in2", + "ing2", + "iong2", + "ir2", + "iu2", + "o2", + "ong2", + "ou2", + "u2", + "ua2", + "uai2", + "uan2", + "uang2", + "ui2", + "un2", + "uo2", + "v2", + "van2", + "ve2", + "vn2", + "E3", + "En3", + "a3", + "ai3", + "an3", + "ang3", + "ao3", + "e3", + "ei3", + "en3", + "eng3", + "er3", + "i3", + "i03", + "ia3", + "ian3", + "iang3", + "iao3", + "ie3", + "in3", + "ing3", + "iong3", + "ir3", + "iu3", + "o3", + "ong3", + "ou3", + "u3", + "ua3", + "uai3", + "uan3", + "uang3", + "ui3", + "un3", + "uo3", + "v3", + "van3", + "ve3", + "vn3", + "E4", + "En4", + "a4", + "ai4", + "an4", + "ang4", + "ao4", + "e4", + "ei4", + "en4", + "eng4", + "er4", + "i4", + "i04", + "ia4", + "ian4", + "iang4", + "iao4", + "ie4", + "in4", + "ing4", + "iong4", + "ir4", + "iu4", + "o4", + "ong4", + "ou4", + "u4", + "ua4", + "uai4", + "uan4", + "uang4", + "ui4", + "un4", + "uo4", + "v4", + "van4", + "ve4", + "vn4", + "E5", + "En5", + "a5", + "ai5", + "an5", + "ang5", + "ao5", + "e5", + "ei5", + "en5", + "eng5", + "er5", + "i5", + "i05", + "ia5", + "ian5", + "iang5", + "iao5", + "ie5", + "in5", + "ing5", + "iong5", + "ir5", + "iu5", + "o5", + "ong5", + "ou5", + "u5", + "ua5", + "uai5", + "uan5", + "uang5", + "ui5", + "un5", + "uo5", + "v5", + "van5", + "ve5", + "vn5", +] -v_without_tone = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn'] +v_without_tone = [ + "E", + "En", + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "i0", + "ia", + "ian", + "iang", + "iao", + "ie", + "in", + "ing", + "iong", + "ir", + "iu", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "ui", + "un", + "uo", + "v", + "van", + "ve", + "vn", +] # japanese -ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky', - 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'v', 'w', 'y', 'z'] +ja_symbols = [ + "I", + "N", + "U", + "a", + "b", + "by", + "ch", + "cl", + "d", + "dy", + "e", + "f", + "g", + "gy", + "h", + "hy", + "i", + "j", + "k", + "ky", + "m", + "my", + "n", + "ny", + "o", + "p", + "py", + "r", + "ry", + "s", + "sh", + "t", + "ts", + "u", + "v", + "w", + "y", + "z", +] -arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} +arpa = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", +} symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) symbols = sorted(set(symbols)) -if __name__ == '__main__': - print(len(symbols)) \ No newline at end of file +if __name__ == "__main__": + print(len(symbols)) diff --git a/GPT_SoVITS/text/tone_sandhi.py b/GPT_SoVITS/text/tone_sandhi.py index bf3893f..f987a3f 100644 --- a/GPT_SoVITS/text/tone_sandhi.py +++ b/GPT_SoVITS/text/tone_sandhi.py @@ -19,51 +19,442 @@ from pypinyin import lazy_pinyin from pypinyin import Style -class ToneSandhi(): +class ToneSandhi: def __init__(self): self.must_neural_tone_words = { - '麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', - '难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', - '里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去', - '软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号', - '认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当', - '蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻', - '舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂', - '胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆', - '老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂', - '精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿', - '窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台', - '码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算', - '白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨', - '琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快', - '爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜', - '溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔', - '棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事', - '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾', - '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼', - '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实', - '扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', - '念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', - '干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数', - '屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气', - '实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈', - '姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方', - '大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴', - '嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦', - '咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝', - '叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹', - '功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息', - '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤', - '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', - '交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', - '不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨', - '父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅', - '幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱', - '凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱', - '扫把', '惦记' + "麻烦", + "麻利", + "鸳鸯", + "高粱", + "骨头", + "骆驼", + "马虎", + "首饰", + "馒头", + "馄饨", + "风筝", + "难为", + "队伍", + "阔气", + "闺女", + "门道", + "锄头", + "铺盖", + "铃铛", + "铁匠", + "钥匙", + "里脊", + "里头", + "部分", + "那么", + "道士", + "造化", + "迷糊", + "连累", + "这么", + "这个", + "运气", + "过去", + "软和", + "转悠", + "踏实", + "跳蚤", + "跟头", + "趔趄", + "财主", + "豆腐", + "讲究", + "记性", + "记号", + "认识", + "规矩", + "见识", + "裁缝", + "补丁", + "衣裳", + "衣服", + "衙门", + "街坊", + "行李", + "行当", + "蛤蟆", + "蘑菇", + "薄荷", + "葫芦", + "葡萄", + "萝卜", + "荸荠", + "苗条", + "苗头", + "苍蝇", + "芝麻", + "舒服", + "舒坦", + "舌头", + "自在", + "膏药", + "脾气", + "脑袋", + "脊梁", + "能耐", + "胳膊", + "胭脂", + "胡萝", + "胡琴", + "胡同", + "聪明", + "耽误", + "耽搁", + "耷拉", + "耳朵", + "老爷", + "老实", + "老婆", + "老头", + "老太", + "翻腾", + "罗嗦", + "罐头", + "编辑", + "结实", + "红火", + "累赘", + "糨糊", + "糊涂", + "精神", + "粮食", + "簸箕", + "篱笆", + "算计", + "算盘", + "答应", + "笤帚", + "笑语", + "笑话", + "窟窿", + "窝囊", + "窗户", + "稳当", + "稀罕", + "称呼", + "秧歌", + "秀气", + "秀才", + "福气", + "祖宗", + "砚台", + "码头", + "石榴", + "石头", + "石匠", + "知识", + "眼睛", + "眯缝", + "眨巴", + "眉毛", + "相声", + "盘算", + "白净", + "痢疾", + "痛快", + "疟疾", + "疙瘩", + "疏忽", + "畜生", + "生意", + "甘蔗", + "琵琶", + "琢磨", + "琉璃", + "玻璃", + "玫瑰", + "玄乎", + "狐狸", + "状元", + "特务", + "牲口", + "牙碜", + "牌楼", + "爽快", + "爱人", + "热闹", + "烧饼", + "烟筒", + "烂糊", + "点心", + "炊帚", + "灯笼", + "火候", + "漂亮", + "滑溜", + "溜达", + "温和", + "清楚", + "消息", + "浪头", + "活泼", + "比方", + "正经", + "欺负", + "模糊", + "槟榔", + "棺材", + "棒槌", + "棉花", + "核桃", + "栅栏", + "柴火", + "架势", + "枕头", + "枇杷", + "机灵", + "本事", + "木头", + "木匠", + "朋友", + "月饼", + "月亮", + "暖和", + "明白", + "时候", + "新鲜", + "故事", + "收拾", + "收成", + "提防", + "挖苦", + "挑剔", + "指甲", + "指头", + "拾掇", + "拳头", + "拨弄", + "招牌", + "招呼", + "抬举", + "护士", + "折腾", + "扫帚", + "打量", + "打算", + "打点", + "打扮", + "打听", + "打发", + "扎实", + "扁担", + "戒指", + "懒得", + "意识", + "意思", + "情形", + "悟性", + "怪物", + "思量", + "怎么", + "念头", + "念叨", + "快活", + "忙活", + "志气", + "心思", + "得罪", + "张罗", + "弟兄", + "开通", + "应酬", + "庄稼", + "干事", + "帮手", + "帐篷", + "希罕", + "师父", + "师傅", + "巴结", + "巴掌", + "差事", + "工夫", + "岁数", + "屁股", + "尾巴", + "少爷", + "小气", + "小伙", + "将就", + "对头", + "对付", + "寡妇", + "家伙", + "客气", + "实在", + "官司", + "学问", + "学生", + "字号", + "嫁妆", + "媳妇", + "媒人", + "婆家", + "娘家", + "委屈", + "姑娘", + "姐夫", + "妯娌", + "妥当", + "妖精", + "奴才", + "女婿", + "头发", + "太阳", + "大爷", + "大方", + "大意", + "大夫", + "多少", + "多么", + "外甥", + "壮实", + "地道", + "地方", + "在乎", + "困难", + "嘴巴", + "嘱咐", + "嘟囔", + "嘀咕", + "喜欢", + "喇嘛", + "喇叭", + "商量", + "唾沫", + "哑巴", + "哈欠", + "哆嗦", + "咳嗽", + "和尚", + "告诉", + "告示", + "含糊", + "吓唬", + "后头", + "名字", + "名堂", + "合同", + "吆喝", + "叫唤", + "口袋", + "厚道", + "厉害", + "千斤", + "包袱", + "包涵", + "匀称", + "勤快", + "动静", + "动弹", + "功夫", + "力气", + "前头", + "刺猬", + "刺激", + "别扭", + "利落", + "利索", + "利害", + "分析", + "出息", + "凑合", + "凉快", + "冷战", + "冤枉", + "冒失", + "养活", + "关系", + "先生", + "兄弟", + "便宜", + "使唤", + "佩服", + "作坊", + "体面", + "位置", + "似的", + "伙计", + "休息", + "什么", + "人家", + "亲戚", + "亲家", + "交情", + "云彩", + "事情", + "买卖", + "主意", + "丫头", + "丧气", + "两口", + "东西", + "东家", + "世故", + "不由", + "不在", + "下水", + "下巴", + "上头", + "上司", + "丈夫", + "丈人", + "一辈", + "那个", + "菩萨", + "父亲", + "母亲", + "咕噜", + "邋遢", + "费用", + "冤家", + "甜头", + "介绍", + "荒唐", + "大人", + "泥鳅", + "幸福", + "熟悉", + "计划", + "扑腾", + "蜡烛", + "姥爷", + "照顾", + "喉咙", + "吉他", + "弄堂", + "蚂蚱", + "凤凰", + "拖沓", + "寒碜", + "糟蹋", + "倒腾", + "报复", + "逻辑", + "盘缠", + "喽啰", + "牢骚", + "咖喱", + "扫把", + "惦记", } self.must_not_neural_tone_words = { - "男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子", "人人", "虎虎" + "男子", + "女子", + "分子", + "原子", + "量子", + "莲子", + "石子", + "瓜子", + "电子", + "人人", + "虎虎", } self.punc = ":,;。?!“”‘’':,;.?!" @@ -72,14 +463,15 @@ class ToneSandhi(): # word: "家里" # pos: "s" # finals: ['ia1', 'i3'] - def _neural_sandhi(self, word: str, pos: str, - finals: List[str]) -> List[str]: - + def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]: # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺 for j, item in enumerate(word): - if j - 1 >= 0 and item == word[j - 1] and pos[0] in { - "n", "v", "a" - } and word not in self.must_not_neural_tone_words: + if ( + j - 1 >= 0 + and item == word[j - 1] + and pos[0] in {"n", "v", "a"} + and word not in self.must_not_neural_tone_words + ): finals[j] = finals[j][:-1] + "5" ge_idx = word.find("个") if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶": @@ -89,9 +481,12 @@ class ToneSandhi(): # e.g. 走了, 看着, 去过 elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}: finals[-1] = finals[-1][:-1] + "5" - elif len(word) > 1 and word[-1] in "们子" and pos in { - "r", "n" - } and word not in self.must_not_neural_tone_words: + elif ( + len(word) > 1 + and word[-1] in "们子" + and pos in {"r", "n"} + and word not in self.must_not_neural_tone_words + ): finals[-1] = finals[-1][:-1] + "5" # e.g. 桌上, 地下, 家里 elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: @@ -100,21 +495,26 @@ class ToneSandhi(): elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开": finals[-1] = finals[-1][:-1] + "5" # 个做量词 - elif (ge_idx >= 1 and - (word[ge_idx - 1].isnumeric() or - word[ge_idx - 1] in "几有两半多各整每做是")) or word == '个': + elif ( + ge_idx >= 1 + and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是") + ) or word == "个": finals[ge_idx] = finals[ge_idx][:-1] + "5" else: - if word in self.must_neural_tone_words or word[ - -2:] in self.must_neural_tone_words: + if ( + word in self.must_neural_tone_words + or word[-2:] in self.must_neural_tone_words + ): finals[-1] = finals[-1][:-1] + "5" word_list = self._split_word(word) - finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]] + finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] for i, word in enumerate(word_list): # conventional neural in Chinese - if word in self.must_neural_tone_words or word[ - -2:] in self.must_neural_tone_words: + if ( + word in self.must_neural_tone_words + or word[-2:] in self.must_neural_tone_words + ): finals_list[i][-1] = finals_list[i][-1][:-1] + "5" finals = sum(finals_list, []) return finals @@ -126,15 +526,15 @@ class ToneSandhi(): else: for i, char in enumerate(word): # "不" before tone4 should be bu2, e.g. 不怕 - if char == "不" and i + 1 < len(word) and finals[i + - 1][-1] == "4": + if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4": finals[i] = finals[i][:-1] + "2" return finals def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: # "一" in number sequences, e.g. 一零零, 二一零 if word.find("一") != -1 and all( - [item.isnumeric() for item in word if item != "一"]): + [item.isnumeric() for item in word if item != "一"] + ): return finals # "一" between reduplication words shold be yi5, e.g. 看一看 elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]: @@ -161,10 +561,10 @@ class ToneSandhi(): first_subword = word_list[0] first_begin_idx = word.find(first_subword) if first_begin_idx == 0: - second_subword = word[len(first_subword):] + second_subword = word[len(first_subword) :] new_word_list = [first_subword, second_subword] else: - second_subword = word[:-len(first_subword)] + second_subword = word[: -len(first_subword)] new_word_list = [second_subword, first_subword] return new_word_list @@ -182,18 +582,19 @@ class ToneSandhi(): elif len(word_list[0]) == 1: finals[1] = finals[1][:-1] + "2" else: - finals_list = [ - finals[:len(word_list[0])], finals[len(word_list[0]):] - ] + finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] if len(finals_list) == 2: for i, sub in enumerate(finals_list): # e.g. 所有/人 if self._all_tone_three(sub) and len(sub) == 2: finals_list[i][0] = finals_list[i][0][:-1] + "2" # e.g. 好/喜欢 - elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \ - finals_list[0][-1][-1] == "3": - + elif ( + i == 1 + and not self._all_tone_three(sub) + and finals_list[i][0][-1] == "3" + and finals_list[0][-1][-1] == "3" + ): finals_list[0][-1] = finals_list[0][-1][:-1] + "2" finals = sum(finals_list, []) # split idiom into two words who's length is 2 @@ -222,7 +623,7 @@ class ToneSandhi(): new_seg.append((word, pos)) last_word = word[:] if last_word == "不": - new_seg.append((last_word, 'd')) + new_seg.append((last_word, "d")) last_word = "" return new_seg @@ -236,12 +637,21 @@ class ToneSandhi(): new_seg = [] # function 1 for i, (word, pos) in enumerate(seg): - if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][ - 0] == seg[i + 1][0] and seg[i - 1][1] == "v": + if ( + i - 1 >= 0 + and word == "一" + and i + 1 < len(seg) + and seg[i - 1][0] == seg[i + 1][0] + and seg[i - 1][1] == "v" + ): new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0] else: - if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][ - 0] == word and pos == "v": + if ( + i - 2 >= 0 + and seg[i - 1][0] == "一" + and seg[i - 2][0] == word + and pos == "v" + ): continue else: new_seg.append([word, pos]) @@ -257,22 +667,27 @@ class ToneSandhi(): # the first and the second words are all_tone_three def _merge_continuous_three_tones( - self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + self, seg: List[Tuple[str, str]] + ) -> List[Tuple[str, str]]: new_seg = [] sub_finals_list = [ - lazy_pinyin( - word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg ] assert len(sub_finals_list) == len(seg) merge_last = [False] * len(seg) for i, (word, pos) in enumerate(seg): - if i - 1 >= 0 and self._all_tone_three( - sub_finals_list[i - 1]) and self._all_tone_three( - sub_finals_list[i]) and not merge_last[i - 1]: + if ( + i - 1 >= 0 + and self._all_tone_three(sub_finals_list[i - 1]) + and self._all_tone_three(sub_finals_list[i]) + and not merge_last[i - 1] + ): # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi - if not self._is_reduplication(seg[i - 1][0]) and len( - seg[i - 1][0]) + len(seg[i][0]) <= 3: + if ( + not self._is_reduplication(seg[i - 1][0]) + and len(seg[i - 1][0]) + len(seg[i][0]) <= 3 + ): new_seg[-1][0] = new_seg[-1][0] + seg[i][0] merge_last[i] = True else: @@ -287,21 +702,27 @@ class ToneSandhi(): # the last char of first word and the first char of second word is tone_three def _merge_continuous_three_tones_2( - self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + self, seg: List[Tuple[str, str]] + ) -> List[Tuple[str, str]]: new_seg = [] sub_finals_list = [ - lazy_pinyin( - word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg ] assert len(sub_finals_list) == len(seg) merge_last = [False] * len(seg) for i, (word, pos) in enumerate(seg): - if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \ - merge_last[i - 1]: + if ( + i - 1 >= 0 + and sub_finals_list[i - 1][-1][-1] == "3" + and sub_finals_list[i][0][-1] == "3" + and not merge_last[i - 1] + ): # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi - if not self._is_reduplication(seg[i - 1][0]) and len( - seg[i - 1][0]) + len(seg[i][0]) <= 3: + if ( + not self._is_reduplication(seg[i - 1][0]) + and len(seg[i - 1][0]) + len(seg[i][0]) <= 3 + ): new_seg[-1][0] = new_seg[-1][0] + seg[i][0] merge_last[i] = True else: @@ -313,14 +734,13 @@ class ToneSandhi(): def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: new_seg = [] for i, (word, pos) in enumerate(seg): - if i - 1 >= 0 and word == "儿" and seg[i-1][0] != "#": + if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#": new_seg[-1][0] = new_seg[-1][0] + seg[i][0] else: new_seg.append([word, pos]) return new_seg - def _merge_reduplication( - self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: new_seg = [] for i, (word, pos) in enumerate(seg): if new_seg and word == new_seg[-1][0]: @@ -329,8 +749,7 @@ class ToneSandhi(): new_seg.append([word, pos]) return new_seg - def pre_merge_for_modify( - self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: seg = self._merge_bu(seg) try: seg = self._merge_yi(seg) @@ -349,8 +768,7 @@ class ToneSandhi(): seg = self._merge_er(seg) return seg - def modified_tone(self, word: str, pos: str, - finals: List[str]) -> List[str]: + def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]: finals = self._bu_sandhi(word, finals) finals = self._yi_sandhi(word, finals) finals = self._neural_sandhi(word, pos, finals)