Merge branch 'main' into main

This commit is contained in:
RVC-Boss 2024-01-17 15:51:17 +08:00 committed by GitHub
commit e205abb87d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 6179 additions and 3322 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
env
runtime

View File

@ -16,7 +16,7 @@ __all__ = [
"DistributedBucketSampler",
]
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]):
@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
sort batches
"""
def __init__(self,
dataset: Dataset,
num_replicas: Optional[int]=None,
rank: Optional[int]=None,
shuffle: bool=True,
seed: int=0,
drop_last: bool=False,
batch_size: int=32) -> None:
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 32,
) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(
self.
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) /
self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
len(self.dataset) / self.num_replicas
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]):
id_with_lengths.sort(key=lambda x: x[1])
return id_with_lengths
def make_buckets(self, bucket_width: float=2.0):
def make_buckets(self, bucket_width: float = 2.0):
buckets = []
cur = []
max_sec = bucket_width
@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size:(b + 1) *
grouped_batch_size] for b in range(n_batch)
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
shuffle(batches)
indices = list(itertools.chain(*batches))
@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size /
len(indices)))[:padding_size]
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)

View File

@ -6,14 +6,21 @@ from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
def __init__(
self,
config,
train_semantic_path,
train_phoneme_path,
dev_semantic_path=None,
dev_phoneme_path=None,
):
super().__init__()
self.config = config
self.train_semantic_path = train_semantic_path
self.train_phoneme_path = train_phoneme_path
self.dev_semantic_path = dev_semantic_path
self.dev_phoneme_path = dev_phoneme_path
self.num_workers = self.config['data']['num_workers']
self.num_workers = self.config["data"]["num_workers"]
def prepare_data(self):
pass
@ -22,8 +29,9 @@ class Text2SemanticDataModule(LightningDataModule):
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config['data']['max_sec'],
pad_val=self.config['data']['pad_val'])
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path,
@ -33,9 +41,8 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size = self.config['train']['batch_size']
sampler = DistributedBucketSampler(
self._train_dataset, batch_size=batch_size)
batch_size = self.config["train"]["batch_size"]
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,
batch_size=batch_size,
@ -43,7 +50,7 @@ class Text2SemanticDataModule(LightningDataModule):
collate_fn=self._train_dataset.collate,
num_workers=self.num_workers,
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
def val_dataloader(self):
@ -52,9 +59,9 @@ class Text2SemanticDataModule(LightningDataModule):
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate,
num_workers=max(self.num_workers,12),
num_workers=max(self.num_workers, 12),
persistent_workers=True,
prefetch_factor=16
prefetch_factor=16,
)
# 这个会使用到嘛?
@ -63,4 +70,5 @@ class Text2SemanticDataModule(LightningDataModule):
self._dev_dataset,
batch_size=1,
shuffle=False,
collate_fn=self._train_dataset.collate)
collate_fn=self._train_dataset.collate,
)

View File

@ -1,21 +1,24 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback,os
import traceback, os
from typing import Dict
from typing import List
import numpy as np
import pandas as pd
import torch,json
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from text import cleaned_text_to_sequence
# from config import exp_dir
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
seq = sequences[0]
ndim = seq.ndim
@ -28,44 +31,52 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
ndim - axis - 1)
padded_seq = np.pad(
seq, padding, mode='constant', constant_values=pad_value)
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
return batch
class Text2SemanticDataset(Dataset):
"""dataset class for text tokens to semantic model training."""
def __init__(self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25) -> None:
def __init__(
self,
phoneme_path: str,
semantic_path: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25,
) -> None:
super().__init__()
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
)
# get dict
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.basename(phoneme_path)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
assert os.path.exists(self.path6)
self.phoneme_data={}
with open(self.path2,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp=line.split("\t")
if(len(tmp)!=4):continue
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
# pad for semantic tokens
@ -74,7 +85,7 @@ class Text2SemanticDataset(Dataset):
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
# self.hz=int(data[:-2])#
self.hz=int(os.environ.get("hz","25hz")[:-2])
self.hz = int(os.environ.get("hz", "25hz")[:-2])
# max seconds of semantic token
self.max_sec = max_sec
@ -100,7 +111,6 @@ class Text2SemanticDataset(Dataset):
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
def init_batch(self):
semantic_data_len = len(self.semantic_data)
phoneme_data_len = len(self.phoneme_data.keys())
@ -113,7 +123,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data['item_name'][i]
item_name = self.semantic_data["item_name"][i]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@ -123,16 +133,18 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data['semantic_audio'][i]
semantic_str = self.semantic_data["semantic_audio"][i]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
# 过滤掉太长的样本
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
if (
len(semantic_ids) > self.max_sec * self.hz
): #########1###根据token个数推测总时长过滤时长60sconfig里#40*25=1k
num_deleted_bigger += 1
continue
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
try:
phoneme_ids = cleaned_text_to_sequence(phoneme)
@ -142,7 +154,9 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@ -151,7 +165,9 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@ -160,16 +176,16 @@ class Text2SemanticDataset(Dataset):
idx += 1
self.item_names.append(item_name)
min_num=100#20直接不补#30补了也不存ckpt
leng =len(self.semantic_phoneme)
if(leng<min_num):
tmp1=self.semantic_phoneme
tmp2=self.item_names
self.semantic_phoneme=[]
self.item_names=[]
for _ in range(max(2,int(min_num/leng))):
self.semantic_phoneme+=tmp1
self.item_names+=tmp2
min_num = 100 # 20直接不补#30补了也不存ckpt
leng = len(self.semantic_phoneme)
if leng < min_num:
tmp1 = self.semantic_phoneme
tmp2 = self.item_names
self.semantic_phoneme = []
self.item_names = []
for _ in range(max(2, int(min_num / leng))):
self.semantic_phoneme += tmp1
self.item_names += tmp2
if num_not_in > 0:
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
@ -181,13 +197,13 @@ class Text2SemanticDataset(Dataset):
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
)
'''
"""
there are 31 semantic datas not in phoneme datas
deleted 34 audios who's duration are bigger than 54 seconds
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
dataset.__len__(): 366463
'''
"""
# 345410 for LibriTTS
print("dataset.__len__():", self.__len__())
@ -204,22 +220,24 @@ class Text2SemanticDataset(Dataset):
# semantic tokens target
semantic_ids_len = len(semantic_ids)
flag=0
flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name)
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
else:flag=1
if(flag==1):
if os.path.exists(path_bert) == True:
bert_feature = torch.load(path_bert, map_location="cpu")
else:
flag = 1
if flag == 1:
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature=None
bert_feature = None
else:
assert bert_feature.shape[-1] == len(phoneme_ids)
return {
'idx': idx,
'phoneme_ids': phoneme_ids,
'phoneme_ids_len': phoneme_ids_len,
'semantic_ids': semantic_ids,
'semantic_ids_len': semantic_ids_len,
'bert_feature': bert_feature,
"idx": idx,
"phoneme_ids": phoneme_ids,
"phoneme_ids_len": phoneme_ids_len,
"semantic_ids": semantic_ids,
"semantic_ids_len": semantic_ids_len,
"bert_feature": bert_feature,
}
def get_sample_length(self, idx: int):
@ -235,7 +253,6 @@ class Text2SemanticDataset(Dataset):
semantic_ids_lens: List[int] = []
# return
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
@ -256,9 +273,9 @@ class Text2SemanticDataset(Dataset):
bert_padded.zero_()
for idx, item in enumerate(examples):
bert = item['bert_feature']
if(bert!=None):
bert_padded[idx, :, :bert.shape[-1]] = bert
bert = item["bert_feature"]
if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert
return {
# List[int]
@ -276,27 +293,27 @@ class Text2SemanticDataset(Dataset):
}
if __name__ == '__main__':
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
if __name__ == "__main__":
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
dataset = Text2SemanticDataset(
phoneme_path=root_dir + 'phoneme_train.npy',
semantic_path=root_dir + 'semantic_train.tsv')
phoneme_path=root_dir + "phoneme_train.npy",
semantic_path=root_dir + "semantic_train.tsv",
)
batch_size = 12
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False)
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
)
for i, batch in enumerate(dataloader):
if(i%1000==0):print(i)
if i % 1000 == 0:
print(i)
# if i == 0:
# print('batch["ids"]:', batch["ids"])
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
# batch["phoneme_ids"].shape)
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
# batch["phoneme_ids_len"].shape)
# print('batch["semantic_ids"]:', batch["semantic_ids"],
# batch["semantic_ids"].shape)
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
# batch["semantic_ids_len"].shape)

View File

@ -1,5 +1,6 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os,sys
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict
@ -12,29 +13,35 @@ from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir,is_train=True):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
pretrained_s1=config.get("pretrained_s1")
if(pretrained_s1 and is_train):
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
)
if is_train:
self.automatic_optimization = False
self.save_hyperparameters()
self.eval_dir = output_dir / 'eval'
self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True)
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
loss, acc = self.model.forward(
batch['phoneme_ids'], batch['phoneme_ids_len'],
batch['semantic_ids'], batch['semantic_ids_len'],
batch['bert_feature'])
batch["phoneme_ids"],
batch["phoneme_ids_len"],
batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
)
self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0:
opt.step()
@ -47,63 +54,67 @@ class Text2SemanticLightningModule(LightningModule):
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
self.log(
"lr",
scheduler.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
self.log(
f"top_{self.top_k}_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True)
sync_dist=True,
)
def validation_step(self, batch: Dict, batch_idx: int):return
# # get loss
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'],
# batch['semantic_ids'], batch['semantic_ids_len'],
# batch['bert_feature']
# )
#
# self.log(
# "val_total_loss",
# loss,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
# self.log(
# f"val_top_{self.top_k}_acc",
# acc,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
#
# # get infer output
# semantic_len = batch['semantic_ids'].size(1)
# prompt_len = min(int(semantic_len * 0.5), 150)
# prompt = batch['semantic_ids'][:, :prompt_len]
# pred_semantic = self.model.infer(batch['phoneme_ids'],
# batch['phoneme_ids_len'], prompt,
# batch['bert_feature']
# )
# save_name = f'semantic_toks_{batch_idx}.pt'
# save_path = os.path.join(self.eval_dir, save_name)
# torch.save(pred_semantic.detach().cpu(), save_path)
def validation_step(self, batch: Dict, batch_idx: int):
return
# # get loss
# loss, acc = self.model.forward(
# batch['phoneme_ids'], batch['phoneme_ids_len'],
# batch['semantic_ids'], batch['semantic_ids_len'],
# batch['bert_feature']
# )
#
# self.log(
# "val_total_loss",
# loss,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
# self.log(
# f"val_top_{self.top_k}_acc",
# acc,
# on_step=True,
# on_epoch=True,
# prog_bar=True,
# sync_dist=True)
#
# # get infer output
# semantic_len = batch['semantic_ids'].size(1)
# prompt_len = min(int(semantic_len * 0.5), 150)
# prompt = batch['semantic_ids'][:, :prompt_len]
# pred_semantic = self.model.infer(batch['phoneme_ids'],
# batch['phoneme_ids_len'], prompt,
# batch['bert_feature']
# )
# save_name = f'semantic_toks_{batch_idx}.pt'
# save_path = os.path.join(self.eval_dir, save_name)
# torch.save(pred_semantic.detach().cpu(), save_path)
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append([
name_param_pair[0]
for name_param_pair in self.model.named_parameters()
])
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,
@ -111,18 +122,19 @@ class Text2SemanticLightningModule(LightningModule):
clipping_scale=2.0,
parameters_names=parameters_names,
show_dominant_parameters=False,
clipping_update_period=1000, )
clipping_update_period=1000,
)
return {
"optimizer": lm_opt,
"lr_scheduler": {
"scheduler":
WarmupCosineLRSchedule(
"scheduler": WarmupCosineLRSchedule(
lm_opt,
init_lr=self.config['optimizer']['lr_init'],
peak_lr=self.config['optimizer']['lr'],
end_lr=self.config['optimizer']['lr_end'],
warmup_steps=self.config['optimizer']['warmup_steps'],
total_steps=self.config['optimizer']['decay_steps'])
}
init_lr=self.config["optimizer"]["lr_init"],
peak_lr=self.config["optimizer"]["lr"],
end_lr=self.config["optimizer"]["lr_end"],
warmup_steps=self.config["optimizer"]["warmup_steps"],
total_steps=self.config["optimizer"]["decay_steps"],
)
},
}

View File

@ -3,7 +3,12 @@ import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask
from AR.models.utils import topk_sampling,sample,logits_to_probs,multinomial_sample_one_no_sync
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
@ -22,35 +27,39 @@ default_config = {
"p_dropout": 0.0,
"vocab_size": 1024 + 1,
"phoneme_vocab_size": 512,
"EOS": 1024
"EOS": 1024,
}
class Text2SemanticDecoder(nn.Module):
def __init__(self, config, norm_first=False, top_k=3):
super(Text2SemanticDecoder, self).__init__()
self.model_dim = config['model']["hidden_dim"]
self.embedding_dim = config['model']["embedding_dim"]
self.num_head = config['model']["head"]
self.num_layers = config['model']["n_layer"]
self.model_dim = config["model"]["hidden_dim"]
self.embedding_dim = config["model"]["embedding_dim"]
self.num_head = config["model"]["head"]
self.num_layers = config["model"]["n_layer"]
self.norm_first = norm_first
self.vocab_size = config['model']["vocab_size"]
self.phoneme_vocab_size = config['model']["phoneme_vocab_size"]
self.p_dropout = config['model']["dropout"]
self.EOS = config['model']["EOS"]
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout)
self.embedding_dim, self.vocab_size, self.p_dropout
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True)
self.embedding_dim, dropout=0.1, scale=False, alpha=True
)
self.h = TransformerEncoder(
TransformerEncoderLayer(
@ -59,28 +68,30 @@ class Text2SemanticDecoder(nn.Module):
dim_feedforward=self.model_dim * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first, ),
norm_first=norm_first,
),
num_layers=self.num_layers,
norm=LayerNorm(self.model_dim) if norm_first else None, )
norm=LayerNorm(self.model_dim) if norm_first else None,
)
self.ar_predict_layer = nn.Linear(
self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction='sum')
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.ar_accuracy_metric = MulticlassAccuracy(
self.vocab_size,
top_k=top_k,
average="micro",
multidim_average="global",
ignore_index=self.EOS, )
ignore_index=self.EOS,
)
def forward(self, x, x_lens, y, y_lens, bert_feature):
'''
"""
x: phoneme_ids
y: semantic_ids
'''
"""
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens)
@ -102,18 +113,23 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = F.pad(
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
(0, y_len),
value=True, )
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
diagonal=1, ),
diagonal=1,
),
(x_len, 0),
value=False, )
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
bsz, src_len = x.shape[0], x_len + y_len
_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len))
_xy_padding_mask = (
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, self.num_head, -1, -1)
.reshape(bsz * self.num_head, 1, src_len)
)
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
@ -122,26 +138,28 @@ class Text2SemanticDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask, )
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction='sum')
loss = F.cross_entropy(logits, targets, reduction="sum")
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
return loss, acc
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(self,
x,
x_lens,
prompts,
bert_feature,
top_k: int=-100,
early_stop_num: int=-1,
temperature: float=1.0):
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
@ -159,35 +177,37 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),
value=True, )
value=True,
)
y_attn_mask = F.pad(
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False, )
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask, )
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature)
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
if early_stop_num != -1 and (y.shape[1] - prefix_len
) > early_stop_num:
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction')
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
@ -198,23 +218,24 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(
y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel(self,
x,#####全部文本token
x_lens,
prompts,####参考音频token
bert_feature,
top_k: int=-100,
early_stop_num: int=-1,
temperature: float=1.0):
def infer_panel(
self,
x, #####全部文本token
x_lens,
prompts, ####参考音频token
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1,2))
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
# AR Decoder
@ -224,75 +245,81 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
# print(1111111,self.num_layers)
cache={
"all_stage":self.num_layers,
"k":[None]*self.num_layers,###根据配置自己手写
"v":[None]*self.num_layers,
cache = {
"all_stage": self.num_layers,
"k": [None] * self.num_layers, ###根据配置自己手写
"v": [None] * self.num_layers,
# "xy_pos":None,##y_pos位置编码每次都不一样的没法缓存每次都要重新拼xy_pos.主要还是写法原因,其实是可以历史统一一样的,但也没啥计算量就不管了
"y_emb":None,##只需要对最新的samples求emb再拼历史的就行
"y_emb": None, ##只需要对最新的samples求emb再拼历史的就行
# "logits":None,###原版就已经只对结尾求再拼接了,不用管
# "xy_dec":None,###不需要本来只需要最后一个做logits
"first_infer":1,
"stage":0
"first_infer": 1,
"stage": 0,
}
for idx in tqdm(range(1500)):
if(cache["first_infer"]==1):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat([cache["y_emb"],self.ar_audio_embedding(y[:,-1:])],1)
cache["y_emb"]=y_emb
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
# x 和逐渐增长的 y 一起输入给模型
if(cache["first_infer"]==1):
if cache["first_infer"] == 1:
xy_pos = torch.concat([x, y_pos], dim=1)
else:
xy_pos=y_pos[:,-1:]
xy_pos = y_pos[:, -1:]
y_len = y_pos.shape[1]
###以下3个不做缓存
if (cache["first_infer"] == 1):
if cache["first_infer"] == 1:
x_attn_mask_pad = F.pad(
x_attn_mask,
(0, y_len),###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True, )
y_attn_mask = F.pad(###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True,
)
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0),
value=False, )
xy_attn_mask = torch.concat(
[x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
else:
###最右边一列(是错的)
# xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
# xy_attn_mask[:,-1]=False
###最下面一行(是对的)
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool, device=xy_pos.device)
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)
# pdb.set_trace()
###缓存重头戏
# print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,cache=cache )
logits = self.ar_predict_layer(xy_dec[:, -1])##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
) ##不用改如果用了cache的默认就是只有一帧取最后一帧一样的
# samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len
) > early_stop_num:
samples = sample(
logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
)[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True
if torch.argmax(
logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print('bad zero prediction')
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
break
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)
cache["first_infer"]=0
return y,idx
cache["first_infer"] = 0
return y, idx

View File

@ -2,6 +2,7 @@
import torch
import torch.nn.functional as F
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
@ -9,7 +10,7 @@ def sequence_mask(length, max_length=None):
return x.unsqueeze(0) < length.unsqueeze(1)
def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
@ -38,11 +39,9 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor:
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1):
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
@ -53,16 +52,14 @@ def top_k_top_p_filtering(logits,
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep),
logits.size(-1)) # Safety check
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
@ -70,13 +67,13 @@ def top_k_top_p_filtering(logits,
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
@ -100,6 +97,8 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
from typing import Optional, Tuple
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
@ -115,7 +114,7 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
previous_tokens=previous_tokens.squeeze()
previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
@ -159,4 +158,3 @@ def sample(
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

View File

@ -13,7 +13,9 @@ from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward=multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
@ -76,66 +78,71 @@ class MultiheadAttention(Module):
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None, ) -> None:
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
linear1_cls=Linear,
linear2_cls=Linear,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = (self.kdim == embed_dim and
self.vdim == embed_dim)
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(
torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs))
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs))
torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs))
torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs))
torch.empty(3 * embed_dim, **factory_kwargs)
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self._reset_parameters()
else:
@ -143,7 +150,8 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
)
self.in_proj_weight = self.in_proj_linear.weight
self.register_parameter("q_proj_weight", None)
@ -156,7 +164,8 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs)
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
@ -190,14 +199,15 @@ class MultiheadAttention(Module):
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor]=None,
need_weights: bool=True,
attn_mask: Optional[Tensor]=None,
average_attn_weights: bool=True,cache=None
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -251,23 +261,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask):
key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif (self.in_proj_bias is not None and
query.dtype != self.in_proj_bias.dtype):
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (self.in_proj_weight is not None and
query.dtype != self.in_proj_weight.dtype):
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
@ -288,29 +301,41 @@ class MultiheadAttention(Module):
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input")
"key_padding_mask is not supported with NestedTensor input"
)
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (query, key, value, self.in_proj_weight,
self.in_proj_bias, self.out_proj.weight,
self.out_proj.bias, )
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args]):
why_not_fast_path = (
"some Tensor argument is neither CUDA nor CPU")
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]):
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
@ -322,17 +347,21 @@ class MultiheadAttention(Module):
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
key_padding_mask
if key_padding_mask is not None else attn_mask,
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1 if key_padding_mask is not None else 0
if attn_mask is not None else None, )
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}")
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
@ -343,9 +372,7 @@ class MultiheadAttention(Module):
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [
x.transpose(1, 0) for x in (query, key, value)
]
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
@ -370,7 +397,9 @@ class MultiheadAttention(Module):
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,cache=cache )
average_attn_weights=average_attn_weights,
cache=cache,
)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
@ -390,7 +419,9 @@ class MultiheadAttention(Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,cache=cache )
average_attn_weights=average_attn_weights,
cache=cache,
)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:

View File

@ -7,10 +7,11 @@ from torch import nn
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float=0.0, ):
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
@ -24,7 +25,7 @@ class TokenEmbedding(nn.Module):
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index:index + 1]
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
x = self.word_embeddings(x)
@ -34,11 +35,12 @@ class TokenEmbedding(nn.Module):
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float=0.0,
scale: bool=False,
alpha: bool=False, ):
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
@ -59,13 +61,14 @@ class SinePositionalEmbedding(nn.Module):
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(
0, x.size(1), dtype=torch.float32).unsqueeze(1)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.embedding_dim))
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
@ -74,5 +77,5 @@ class SinePositionalEmbedding(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.extend_pe(x)
output = x.unsqueeze(-1) if x.ndim == 2 else x
output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)]
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(output)

View File

@ -12,14 +12,16 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
"""
def __init__(self,
optimizer,
init_lr,
peak_lr,
end_lr,
warmup_steps=10000,
total_steps=400000,
current_step=0):
def __init__(
self,
optimizer,
init_lr,
peak_lr,
end_lr,
warmup_steps=10000,
total_steps=400000,
current_step=0,
):
self.init_lr = init_lr
self.peak_lr = peak_lr
self.end_lr = end_lr
@ -33,10 +35,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
self._last_lr = [self.lr]
def set_lr(self, lr):
self._last_lr = [g['lr'] for g in self.optimizer.param_groups]
self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
for g in self.optimizer.param_groups:
# g['lr'] = lr
g['lr'] = self.end_lr###锁定用线性
g["lr"] = self.end_lr ###锁定用线性
def step(self):
if self._current_step < self.warmup_steps:
@ -47,7 +49,8 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
else:
decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps)
self.total_steps - self.warmup_steps
)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
@ -55,25 +58,19 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
self.lr=lr=self.end_lr=0.002###锁定用线性###不听话,直接锁定!
self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
self.set_lr(lr)
self.lr = lr
self._current_step += 1
return self.lr
if __name__ == '__main__':
if __name__ == "__main__":
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0)
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
)
lrs = []
for i in range(25000):
s.step()

View File

@ -1,9 +1,16 @@
from torch.nn.functional import *
from torch.nn.functional import _mha_shape_check,_canonical_mask,_none_or_dtype,_in_projection_packed
from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
# import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched(
query: Tensor,
key: Tensor,
@ -29,7 +36,8 @@ def multi_head_attention_forward_patched(
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,cache=None
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -105,7 +113,17 @@ def multi_head_attention_forward_patched(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if has_torch_function(tens_ops):
return handle_torch_function(
multi_head_attention_forward,
@ -134,10 +152,13 @@ def multi_head_attention_forward_patched(
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
average_attn_weights=average_attn_weights,cache=cache
average_attn_weights=average_attn_weights,
cache=cache,
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
@ -159,7 +180,7 @@ def multi_head_attention_forward_patched(
mask_name="key_padding_mask",
other_type=_none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype
target_type=query.dtype,
)
if is_causal and attn_mask is None:
@ -184,59 +205,84 @@ def multi_head_attention_forward_patched(
check_other=False,
)
if key_padding_mask is not None:
# We have the attn_mask, and use that to merge kpm into it.
# Turn off use of is_causal hint, as the merged mask is no
# longer causal.
is_causal = False
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
assert (
q_proj_weight is not None
), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.chunk(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
if(cache!=None):
if(cache["first_infer"]==1):
cache["k"][cache["stage"]]=k
q, k, v = _in_projection(
query,
key,
value,
q_proj_weight,
k_proj_weight,
v_proj_weight,
b_q,
b_k,
b_v,
)
if cache != None:
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
# print(0,cache["k"].shape)
cache["v"][cache["stage"]]=v
else:###12个layer每个都要留自己的cache_kv
cache["v"][cache["stage"]] = v
else: ###12个layer每个都要留自己的cache_kv
# print(1,cache["k"].shape)
cache["k"][cache["stage"]]=torch.cat([cache["k"][cache["stage"]],k],0)##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]]=torch.cat([cache["v"][cache["stage"]],v],0)
cache["k"][cache["stage"]] = torch.cat(
[cache["k"][cache["stage"]], k], 0
) ##本来时序是1但是proj的时候可能transpose了所以时序到0维了
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
# print(2, cache["k"].shape)
src_len = cache["k"][cache["stage"]].shape[0]
k=cache["k"][cache["stage"]]
v=cache["v"][cache["stage"]]
k = cache["k"][cache["stage"]]
v = cache["v"][cache["stage"]]
# if attn_mask is not None:
# attn_mask=attn_mask[-1:,]
# print(attn_mask.shape,attn_mask)
# print(attn_mask.shape,attn_mask)
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
# print(2333,cache)
# prep attention mask
@ -255,14 +301,20 @@ def multi_head_attention_forward_patched(
if attn_mask.dim() == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
@ -286,26 +338,34 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.size(0) == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.size(2) == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
@ -316,10 +376,15 @@ def multi_head_attention_forward_patched(
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
assert key_padding_mask.shape == (
bsz,
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
else:
@ -337,10 +402,14 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
@ -349,7 +418,9 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@ -377,8 +448,12 @@ def multi_head_attention_forward_patched(
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)
) + torch.rand_like(deriv)
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@ -75,7 +76,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
(d, ) = ctx.saved_tensors
(d,) = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.043637
ceil = 1.2
@ -96,11 +97,12 @@ class DoubleSwish(torch.nn.Module):
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int, ) -> Tensor:
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
@ -125,16 +127,22 @@ class ActivationBalancerFunction(torch.autograd.Function):
scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return (x_grad - neg_delta_grad, None, None, None, )
return (
x_grad - neg_delta_grad,
None,
None,
None,
)
def _compute_scale_factor(
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -145,23 +153,25 @@ def _compute_scale_factor(
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = (
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor)
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor)
min=0, max=max_factor
)
return below_threshold - above_threshold
def _compute_sign_factor(
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@ -171,18 +181,18 @@ def _compute_sign_factor(
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - proportion_positive) *
(gain_factor / min_positive)).clamp_(
min=0, max=max_factor)
factor1 = (
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((proportion_positive - max_positive) *
(gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor)
factor2 = (
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
).clamp_(min=0, max=max_factor)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
@ -230,17 +240,18 @@ class ActivationBalancer(torch.nn.Module):
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
min_positive: float=0.05,
max_positive: float=0.95,
max_factor: float=0.04,
sign_gain_factor: float=0.01,
scale_gain_factor: float=0.02,
min_abs: float=0.2,
max_abs: float=100.0,
min_prob: float=0.1, ):
self,
num_channels: int,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.04,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2,
max_abs: float = 100.0,
min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module):
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
torch.jit.is_tracing()):
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
return _no_op(x)
count = self.cpu_count
@ -276,7 +286,7 @@ class ActivationBalancer(torch.nn.Module):
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5**(1 + (count / 4000.0)))
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
if random.random() < prob:
sign_gain_factor = 0.5
@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module):
self.min_positive,
self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
else:
sign_factor = None
@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module):
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
return ActivationBalancerFunction.apply(
x,
scale_factor,
sign_factor,
self.channel_dim, )
self.channel_dim,
)
else:
return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0,
min_prob=0.25) -> nn.Sequential:
def BalancedDoubleSwish(
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
return nn.Sequential(
balancer,
DoubleSwish(), )
DoubleSwish(),
)

View File

@ -26,26 +26,28 @@ class LayerNorm(nn.Module):
elementwise_affine: bool
def __init__(
self,
normalized_shape: _shape_t,
eps: float=1e-5,
elementwise_affine: bool=True,
device=None,
dtype=None, ) -> None:
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
elementwise_affine: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape, ) # type: ignore[assignment]
self.normalized_shape = tuple(
normalized_shape) # type: ignore[arg-type]
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs))
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@ -57,36 +59,43 @@ class LayerNorm(nn.Module):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
return (F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps, ), embedding, )
return (
F.layer_norm(
input,
self.normalized_shape,
self.weight,
self.bias,
self.eps,
),
embedding,
)
assert embedding is None
return F.layer_norm(input, self.normalized_shape, self.weight,
self.bias, self.eps)
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__))
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
class IdentityNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float=1e-5,
device=None,
dtype=None, ) -> None:
self,
d_model: int,
eps: float = 1e-5,
device=None,
dtype=None,
) -> None:
super(IdentityNorm, self).__init__()
def forward(self, input: Tensor, embedding: Any=None) -> Tensor:
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
if isinstance(input, tuple):
return input
@ -121,11 +130,13 @@ class TransformerEncoder(nn.Module):
self.norm = norm
def forward(
self,
src: Tensor,
mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,
return_layer_states: bool=False,cache=None ) -> Tensor:
self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
return_layer_states: bool = False,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
@ -144,7 +155,9 @@ class TransformerEncoder(nn.Module):
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
layer_states.append(output[0])
if self.norm is not None:
@ -154,9 +167,12 @@ class TransformerEncoder(nn.Module):
output = src
for mod in self.layers:
output = mod(output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask, cache=cache)
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
cache=cache,
)
if self.norm is not None:
output = self.norm(output)
@ -168,43 +184,47 @@ class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int=2048,
dropout: float=0.1,
activation: Union[str, Callable[[Tensor], Tensor]]=F.relu,
batch_first: bool=False,
norm_first: bool=False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module=nn.Linear,
linear2_self_attention_cls: nn.Module=nn.Linear,
linear1_feedforward_cls: nn.Module=nn.Linear,
linear2_feedforward_cls: nn.Module=nn.Linear,
layer_norm_cls: nn.Module=LayerNorm,
layer_norm_eps: float=1e-5,
adaptive_layer_norm=False, ) -> None:
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
batch_first: bool = False,
norm_first: bool = False,
device=None,
dtype=None,
linear1_self_attention_cls: nn.Module = nn.Linear,
linear2_self_attention_cls: nn.Module = nn.Linear,
linear1_feedforward_cls: nn.Module = nn.Linear,
linear2_feedforward_cls: nn.Module = nn.Linear,
layer_norm_cls: nn.Module = LayerNorm,
layer_norm_eps: float = 1e-5,
adaptive_layer_norm=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(TransformerEncoderLayer, self).__init__()
# print(233333333333,d_model,nhead)
# import os
# os._exit(2333333)
self.self_attn = MultiheadAttention(
d_model,#512 16
d_model, # 512 16
nhead,
dropout=dropout,
batch_first=batch_first,
linear1_cls=linear1_self_attention_cls,
linear2_cls=linear2_self_attention_cls,
**factory_kwargs, )
**factory_kwargs,
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward,
**factory_kwargs)
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model,
**factory_kwargs)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
@ -230,11 +250,9 @@ class TransformerEncoderLayer(nn.Module):
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if layer_norm_cls == IdentityNorm:
norm2 = BalancedBasicNorm(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
else:
norm2 = layer_norm_cls(
d_model, eps=layer_norm_eps, **factory_kwargs)
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
if adaptive_layer_norm:
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
@ -249,10 +267,12 @@ class TransformerEncoderLayer(nn.Module):
self.activation = F.relu
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor]=None,
src_key_padding_mask: Optional[Tensor]=None,cache=None ) -> Tensor:
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
cache=None,
) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
@ -272,7 +292,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask):
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
@ -281,12 +302,15 @@ class TransformerEncoderLayer(nn.Module):
x = x + self._sa_block(
self.norm1(x, stage_embedding),
src_mask,
src_key_padding_mask,cache=cache )
src_key_padding_mask,
cache=cache,
)
x = x + self._ff_block(self.norm2(x, stage_embedding))
else:
x = self.norm1(
x + self._sa_block(x, src_mask, src_key_padding_mask,cache=cache),
stage_embedding, )
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
stage_embedding,
)
x = self.norm2(x + self._ff_block(x), stage_embedding)
if is_src_tuple:
@ -295,12 +319,14 @@ class TransformerEncoderLayer(nn.Module):
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],cache=None ) -> Tensor:
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
cache=None,
) -> Tensor:
# print(x.shape,attn_mask.shape,key_padding_mask)
#torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None
# import os
# os._exit(23333)
x = self.self_attn(
@ -309,7 +335,9 @@ class TransformerEncoderLayer(nn.Module):
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,cache=cache )[0]
need_weights=False,
cache=cache,
)[0]
return self.dropout1(x)
# feed forward block
@ -328,20 +356,23 @@ class AdaptiveLayerNorm(nn.Module):
self.d_model = d_model
self.eps = self.norm.eps
def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor:
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
if isinstance(input, tuple):
input, embedding = input
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return (weight * self.norm(input) + bias, embedding)
weight, bias = torch.split(
self.project_layer(embedding),
split_size_or_sections=self.d_model,
dim=-1, )
dim=-1,
)
return weight * self.norm(input) + bias
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

View File

@ -27,46 +27,44 @@ class GruutPhonemizer:
"": "",
"": "",
"«": "«",
"»": "»"
"»": "»",
}
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text)
return text.strip()
def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes:
return ''
if word.phonemes[0] in ['', '|']:
return ""
if word.phonemes[0] in ["", "|"]:
return word.text.strip()
phonemes = ''.join(word.phonemes)
phonemes = "".join(word.phonemes)
# remove modifier characters ˈˌː with regex
phonemes = re.sub(r'[ˈˌː͡]', '', phonemes)
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
return phonemes.strip()
def phonemize(self, text: str, espeak: bool=False) -> str:
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [
sent
for sent in self._phonemizer(
text_to_phonemize, lang="en-us", espeak=espeak)
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
return ' '.join(words)
return " ".join(words)
def transform(self, phonemes):
# convert phonemes to ids
# dictionary is in symbols.py
return [
self.symbol_to_id[p] for p in phonemes
if p in self.symbol_to_id.keys()
]
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
if __name__ == "__main__":

View File

@ -1,7 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
PAD = '_'
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")

View File

@ -11,22 +11,24 @@ def load_yaml_config(path):
def save_config_to_yaml(config, path):
assert path.endswith('.yaml')
with open(path, 'w') as f:
assert path.endswith(".yaml")
with open(path, "w") as f:
f.write(yaml.dump(config))
f.close()
def write_args(args, path):
args_dict = dict((name, getattr(args, name)) for name in dir(args)
if not name.startswith('_'))
with open(path, 'a') as args_file:
args_file.write('==> torch version: {}\n'.format(torch.__version__))
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
args_file.write('==> Cmd:\n')
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write('\n==> args:\n')
args_file.write("\n==> args:\n")
for k, v in sorted(args_dict.items()):
args_file.write(' %s: %s\n' % (str(k), str(v)))
args_file.write(" %s: %s\n" % (str(k), str(v)))
args_file.close()

View File

@ -1,31 +1,31 @@
train:
seed: 1234
epochs: 300
batch_size: 8
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 16
gradient_clip: 1.0
seed: 1234
epochs: 300
batch_size: 8
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 16
gradient_clip: 1.0
optimizer:
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
data:
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
model:
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 512
hidden_dim: 512
head: 16
linear_units: 2048
n_layer: 12
dropout: 0
EOS: 1024
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 512
hidden_dim: 512
head: 16
linear_units: 2048
n_layer: 12
dropout: 0
EOS: 1024
inference:
top_k: 5
top_k: 5

View File

@ -1,31 +1,31 @@
train:
seed: 1234
epochs: 300
batch_size: 8
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
seed: 1234
epochs: 300
batch_size: 8
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
optimizer:
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
data:
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
model:
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 1024
hidden_dim: 1024
head: 16
linear_units: 2048
n_layer: 16
dropout: 0
EOS: 1024
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 1024
hidden_dim: 1024
head: 16
linear_units: 2048
n_layer: 16
dropout: 0
EOS: 1024
inference:
top_k: 5
top_k: 5

View File

@ -1,31 +1,31 @@
train:
seed: 1234
epochs: 300
batch_size: 12
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
seed: 1234
epochs: 300
batch_size: 12
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
optimizer:
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
data:
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
model:
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 1024
hidden_dim: 1024
head: 16
linear_units: 2048
n_layer: 6
dropout: 0
EOS: 1024
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 1024
hidden_dim: 1024
head: 16
linear_units: 2048
n_layer: 6
dropout: 0
EOS: 1024
inference:
top_k: 5
top_k: 5

View File

@ -1,31 +1,31 @@
train:
seed: 1234
epochs: 20
batch_size: 8
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
seed: 1234
epochs: 20
batch_size: 8
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
optimizer:
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
data:
max_eval_sample: 8
max_sec: 54
num_workers: 4
pad_val: 1024 # same with EOS in model
max_eval_sample: 8
max_sec: 54
num_workers: 4
pad_val: 1024 # same with EOS in model
model:
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 512
hidden_dim: 512
head: 16
linear_units: 2048
n_layer: 24
dropout: 0
EOS: 1024
random_bert: 0
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 512
hidden_dim: 512
head: 16
linear_units: 2048
n_layer: 24
dropout: 0
EOS: 1024
random_bert: 0
inference:
top_k: 5
top_k: 5

View File

@ -1,77 +1,77 @@
train:
seed: 1234
epochs: 100
batch_size: 6
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 32
gradient_clip: 1.0
seed: 1234
epochs: 100
batch_size: 6
gradient_accumulation: 4
save_every_n_epoch: 1
precision: 32
gradient_clip: 1.0
optimizer:
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
data:
max_eval_sample: 8
max_sec: 40
num_workers: 1
pad_val: 1024 # same with EOS in model
max_eval_sample: 8
max_sec: 40
num_workers: 1
pad_val: 1024 # same with EOS in model
model:
saving_path: "ckpt/"
resume_checkpoint: null
vocoder_config_path: "quantizer/new_ckpt/config.json"
vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
datadir: "/home/liweiche/GigaSpeech/wavs"
metapath: "/home/liweiche/GigaSpeech/train2.json"
val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
sampledir: "logs/"
pretrained_path: null
lr: 0.0001
batch_size: 200.0
train_bucket_size: 8192
training_step: 800000
optim_flat_percent: 0.0
warmup_step: 50
adam_beta1: 0.9
adam_beta2: 0.98
ffd_size: 3072
hidden_size: 768
enc_nlayers: 6
dec_nlayers: 6
nheads: 12
ar_layer: 4
ar_ffd_size: 1024
ar_hidden_size: 256
ar_nheads: 4
aligner_softmax_temp: 1.0
layer_norm_eps: 0.00001
speaker_embed_dropout: 0.05
label_smoothing: 0.0
val_check_interval: 5000
check_val_every_n_epoch: 1
precision: "fp16"
nworkers: 16
distributed: true
accelerator: "ddp"
version: null
accumulate_grad_batches: 1
use_repetition_token: true
use_repetition_gating: false
repetition_penalty: 1.0
sampling_temperature: 1.0
top_k: -1
min_top_k: 3
top_p: 0.8
sample_num: 4
length_penalty_max_length: 15000
length_penalty_max_prob: 0.95
max_input_length: 2048
max_output_length: 2000
sample_rate: 16000
n_codes: 1024
n_cluster_groups: 1
phone_context_window: 4
phoneset_size: 1000
saving_path: "ckpt/"
resume_checkpoint: null
vocoder_config_path: "quantizer/new_ckpt/config.json"
vocoder_ckpt_path: "quantizer/new_ckpt/g_00600000"
datadir: "/home/liweiche/GigaSpeech/wavs"
metapath: "/home/liweiche/GigaSpeech/train2.json"
val_metapath: "/home/liweiche/GigaSpeech/dev2.json"
sampledir: "logs/"
pretrained_path: null
lr: 0.0001
batch_size: 200.0
train_bucket_size: 8192
training_step: 800000
optim_flat_percent: 0.0
warmup_step: 50
adam_beta1: 0.9
adam_beta2: 0.98
ffd_size: 3072
hidden_size: 768
enc_nlayers: 6
dec_nlayers: 6
nheads: 12
ar_layer: 4
ar_ffd_size: 1024
ar_hidden_size: 256
ar_nheads: 4
aligner_softmax_temp: 1.0
layer_norm_eps: 0.00001
speaker_embed_dropout: 0.05
label_smoothing: 0.0
val_check_interval: 5000
check_val_every_n_epoch: 1
precision: "fp16"
nworkers: 16
distributed: true
accelerator: "ddp"
version: null
accumulate_grad_batches: 1
use_repetition_token: true
use_repetition_gating: false
repetition_penalty: 1.0
sampling_temperature: 1.0
top_k: -1
min_top_k: 3
top_p: 0.8
sample_num: 4
length_penalty_max_length: 15000
length_penalty_max_prob: 0.95
max_input_length: 2048
max_output_length: 2000
sample_rate: 16000
n_codes: 1024
n_cluster_groups: 1
phone_context_window: 4
phoneset_size: 1000
inference:
top_k: 5
top_k: 5

View File

@ -1,32 +1,32 @@
gpu:
n_card: 1
n_process_per_card: 2
n_card: 1
n_process_per_card: 2
io:
text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
text_path: D:\RVC1006\GPT-SoVITS\GPT_SoVITS
save_every_n_epoch: 1
precision: 16-mixed
gradient_clip: 1.0
optimizer:
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
lr: 0.01
lr_init: 0.00001
lr_end: 0.0001
warmup_steps: 2000
decay_steps: 40000
data:
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
max_eval_sample: 8
max_sec: 54
num_workers: 1
pad_val: 1024 # same with EOS in model
model:
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 512
hidden_dim: 512
head: 16
linear_units: 2048
n_layer: 24
dropout: 0
EOS: 1024
random_bert: 0
vocab_size: 1025
phoneme_vocab_size: 512
embedding_dim: 512
hidden_dim: 512
head: 16
linear_units: 2048
n_layer: 24
dropout: 0
EOS: 1024
random_bert: 0
inference:
top_k: 5
top_k: 5

View File

@ -11,23 +11,30 @@ logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import (
Wav2Vec2FeatureExtractor,
HubertModel,
Wav2Vec2Model,
)
import utils
import torch.nn as nn
cnhubert_base_path=None
cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self):
super().__init__()
self.model = HubertModel.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(cnhubert_base_path)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
cnhubert_base_path
)
def forward(self, x):
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats
# class CNHubertLarge(nn.Module):
# def __init__(self):
# super().__init__()
@ -59,12 +66,12 @@ class CNHubert(nn.Module):
# return feats
def get_model():
model = CNHubert()
model.eval()
return model
# def get_large_model():
# model = CNHubertLarge()
# model.eval()
@ -80,18 +87,18 @@ def get_model():
# model.eval()
# return model
def get_content(hmodel, wav_16k_tensor):
with torch.no_grad():
feats = hmodel(wav_16k_tensor)
return feats.transpose(1,2)
return feats.transpose(1, 2)
if __name__ == '__main__':
if __name__ == "__main__":
model = get_model()
src_path = "/Users/Shared/原音频2.wav"
wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
model = model
wav_16k_tensor = wav_16k_tensor
feats = get_content(model,wav_16k_tensor)
feats = get_content(model, wav_16k_tensor)
print(feats.shape)

View File

@ -3,20 +3,23 @@ import torch
def get_model():
import whisper
model = whisper.load_model("small", device='cpu')
model = whisper.load_model("small", device="cpu")
return model.encoder
def get_content(model=None, wav_16k_tensor=None):
from whisper import log_mel_spectrogram, pad_or_trim
dev = next(model.parameters()).device
mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
# if torch.cuda.is_available():
# mel = mel.to(torch.float16)
feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1,2)
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
:1, :feature_len, :
].transpose(1, 2)
return feature

View File

@ -1,49 +1,50 @@
import os
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
infer_ttswebui=os.environ.get("infer_ttswebui",9872)
infer_ttswebui=int(infer_ttswebui)
if("_CUDA_VISIBLE_DEVICES"in os.environ):
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
is_half=eval(os.environ.get("is_half","True"))
gpt_path = os.environ.get(
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
)
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
)
bert_path = os.environ.get(
"bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True"))
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import sys,torch,numpy as np
from pathlib import Path
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
# torch.backends.cuda.sdp_kernel("flash")
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
# torch.backends.cuda.enable_math_sdp(True)
from random import shuffle
from AR.utils import get_newest_ckpt
from glob import glob
from tqdm import tqdm
import numpy as np
import librosa,torch
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path=cnhubert_base_path
from io import BytesIO
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from text import cleaned_text_to_sequence
from text.cleaner import text_to_sequence, clean_text
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
device="cuda"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model=AutoModelForMaskedLM.from_pretrained(bert_path)
if(is_half==True):bert_model=bert_model.half().to(device)
else:bert_model=bert_model.to(device)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
# bert_model=bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)#####输入是long不用管精度问题精度随bert_model
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
@ -55,218 +56,305 @@ def get_bert_feature(text, word2ph):
# if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
n_semantic = 1024
dict_s2=torch.load(sovits_path,map_location="cpu")
hps=dict_s2["config"]
class DictToAttrRecursive:
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
# 如果值是字典,递归调用构造函数
setattr(self, key, DictToAttrRecursive(value))
else:
setattr(self, key, value)
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate="25hz"
dict_s1=torch.load(gpt_path,map_location="cpu")
config=dict_s1["config"]
ssl_model=cnhubert.get_model()
if(is_half==True):ssl_model=ssl_model.half().to(device)
else:ssl_model=ssl_model.to(device)
hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
if(is_half==True):vq_model=vq_model.half().to(device)
else:vq_model=vq_model.to(device)
**hps.model
)
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"],strict=False))
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50
max_sec = config['data']['max_sec']
max_sec = config["data"]["max_sec"]
# t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
t2s_model = Text2SemanticLightningModule(config,"ojbk",is_train=False)
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if(is_half==True):t2s_model=t2s_model.half()
t2s_model=t2s_model.to(device)
if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
def get_spepc(hps, filename):
audio=load_audio(filename,int(hps.data.sampling_rate))
audio=torch.FloatTensor(audio)
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
spec = spectrogram_torch(
audio_norm,
hps.data.filter_length,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
center=False,
)
return spec
dict_language={
"中文":"zh",
"英文":"en",
"日文":"ja"
}
def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text=prompt_text.strip("\n")
prompt_language,text=prompt_language,text.strip("\n")
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
wav16k = torch.from_numpy(wav16k)
if(is_half==True):wav16k=wav16k.half().to(device)
else:wav16k=wav16k.to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
if is_half == True:
wav16k = wav16k.half().to(device)
else:
wav16k = wav16k.to(device)
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language=dict_language[prompt_language]
text_language=dict_language[text_language]
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1=cleaned_text_to_sequence(phones1)
texts=text.split("\n")
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
)
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
if prompt_language == "zh":
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros(
(1024, len(phones1)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
if text_language == "zh":
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic,idx = t2s_model.model.infer_panel(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
top_k=config["inference"]["top_k"],
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path)#.to(device)
if(is_half==True):refer=refer.half().to(device)
else:refer=refer.to(device)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if is_half == True:
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
} # 不考虑省略号
splits={"","","","",",",".","?","!","~",":","","","",}#不考虑省略号
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if (todo_text[-1] not in splits): todo_text += ""
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while (1):
if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if (todo_text[i_split_head] in splits):
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp=inp.strip("\n")
inps=split(inp)
split_idx=list(range(0,len(inps),5))
split_idx[-1]=None
if(len(split_idx)>1):
opts=[]
for idx in range(len(split_idx)-1):
opts.append("".join(inps[split_idx[idx]:split_idx[idx+1]]))
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 5))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts=[inp]
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp=inp.strip("\n")
inps=split(inp)
if(len(inps)<2):return [inp]
opts=[]
summ=0
tmp_str=""
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return [inp]
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ+=len(inps[i])
tmp_str+=inps[i]
if(summ>50):
summ=0
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str=""
if(tmp_str!=""):opts.append(tmp_str)
if(len(opts[-1])<50):##如果最后一个太短了,和前一个合一起
opts[-2]=opts[-2]+opts[-1]
opts=opts[:-1]
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp=inp.strip("\n")
return "\n".join(["%s"%item for item in inp.strip("").split("")])
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
)
# with gr.Tabs():
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
with gr.Group():
gr.Markdown(
value=
"*请上传并填写参考信息"
)
gr.Markdown(value="*请上传并填写参考信息")
with gr.Row():
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
prompt_text= gr.Textbox(label="参考音频的文本",value="")
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文")
gr.Markdown(
value=
"*请填写需要合成的目标文本"
)
prompt_text = gr.Textbox(label="参考音频的文本", value="")
prompt_language = gr.Dropdown(
label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
)
gr.Markdown(value="*请填写需要合成的目标文本")
with gr.Row():
text=gr.Textbox(label="需要合成的文本",value="")
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文")
inference_button=gr.Button("合成语音", variant="primary")
text = gr.Textbox(label="需要合成的文本", value="")
text_language = gr.Dropdown(
label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
)
inference_button = gr.Button("合成语音", variant="primary")
output = gr.Audio(label="输出的语音")
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
gr.Markdown(
value=
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language],
[output],
)
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
with gr.Row():
text_inp=gr.Textbox(label="需要合成的切分前文本",value="")
text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
button1 = gr.Button("凑五句一切", variant="primary")
button2 = gr.Button("凑50字一切", variant="primary")
button3 = gr.Button("按中文句号。切", variant="primary")
text_opt = gr.Textbox(label="切分后文本", value="")
button1.click(cut1,[text_inp],[text_opt])
button2.click(cut2,[text_inp],[text_opt])
button3.click(cut3,[text_inp],[text_opt])
gr.Markdown(
value=
"后续将支持混合语种编码文本输入。"
)
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
gr.Markdown(value="后续将支持混合语种编码文本输入。")
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
server_port=infer_ttswebui,
quiet=True,
)
)

File diff suppressed because it is too large Load Diff

View File

@ -1,189 +1,189 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
return kl
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(
length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(num_timescales - 1))
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask
return path
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type)
return total_norm
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
def squeeze(x, x_mask=None, n_sqz=2):
b, c, t = x.size()
b, c, t = x.size()
t = (t // n_sqz) * n_sqz
x = x[:, :, :t]
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
t = (t // n_sqz) * n_sqz
x = x[:, :, :t]
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
if x_mask is not None:
x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
else:
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
return x_sqz * x_mask, x_mask
if x_mask is not None:
x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
else:
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
return x_sqz * x_mask, x_mask
def unsqueeze(x, x_mask=None, n_sqz=2):
b, c, t = x.size()
b, c, t = x.size()
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
else:
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
return x_unsqz * x_mask, x_mask
if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
else:
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
return x_unsqz * x_mask, x_mask

View File

@ -76,10 +76,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
print("kmeans start ... ")
for _ in tqdm(range(num_iters)):
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs ** 2).sum(dim=-1)
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
@ -110,6 +108,7 @@ class EuclideanCodebook(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
@ -122,7 +121,9 @@ class EuclideanCodebook(nn.Module):
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
uniform_init if not kmeans_init else torch.zeros
)
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
@ -147,7 +148,7 @@ class EuclideanCodebook(nn.Module):
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
#broadcast_tensors(self.buffers())
# broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
@ -165,7 +166,7 @@ class EuclideanCodebook(nn.Module):
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
#broadcast_tensors(self.buffers())
# broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
@ -246,6 +247,7 @@ class VectorQuantization(nn.Module):
randomly selected vector from the current batch.
commitment_weight (float): Weight for commitment loss.
"""
def __init__(
self,
dim: int,
@ -256,22 +258,31 @@ class VectorQuantization(nn.Module):
kmeans_init: bool = True,
kmeans_iters: int = 50,
threshold_ema_dead_code: int = 2,
commitment_weight: float = 1.,
commitment_weight: float = 1.0,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
self.project_in = (
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self._codebook = EuclideanCodebook(
dim=_codebook_dim,
codebook_size=codebook_size,
kmeans_init=kmeans_init,
kmeans_iters=kmeans_iters,
decay=decay,
epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code,
)
self.codebook_size = codebook_size
@property
@ -316,13 +327,16 @@ class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
def forward(
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
quantized_out = 0.0
residual = x
@ -345,7 +359,9 @@ class ResidualVectorQuantization(nn.Module):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor:
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
@ -358,10 +374,10 @@ class ResidualVectorQuantization(nn.Module):
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor:
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[st + i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
return quantized_out

View File

@ -1,6 +1,6 @@
import time,logging
import time, logging
import os
import random,traceback
import random, traceback
import numpy as np
import torch
import torch.utils.data
@ -16,41 +16,44 @@ import torch
import requests
from scipy.io import wavfile
from io import BytesIO
# from config import exp_dir
from my_utils import load_audio
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, hparams, val=False):
exp_dir=hparams.exp_dir
self.path2="%s/2-name2text.txt"%exp_dir
self.path4="%s/4-cnhubert"%exp_dir
self.path5="%s/5-wav32k"%exp_dir
exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
self.path5 = "%s/5-wav32k" % exp_dir
assert os.path.exists(self.path2)
assert os.path.exists(self.path4)
assert os.path.exists(self.path5)
names4=set([name[:-3]for name in list(os.listdir(self.path4))])#去除.pt后缀
names5=set(os.listdir(self.path5))
self.phoneme_data={}
with open(self.path2,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5))
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp=line.split("\t")
if(len(tmp)!=4):continue
self.phoneme_data[tmp[0]]=[tmp[1]]
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text=list(set(self.phoneme_data)&names4&names5)
tmp=self.audiopaths_sid_text
leng=len(tmp)
min_num=100
if(leng<min_num):
self.audiopaths_sid_text=[]
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
self.max_wav_value = hparams.max_wav_value
@ -69,20 +72,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audiopaths_sid_text_new = []
lengths = []
skipped_phone = 0
skipped_dur = 0
skipped_phone = 0
skipped_dur = 0
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
skipped_phone += 1
skipped_phone += 1
continue
size=os.path.getsize("%s/%s"%(self.path5,audiopath))
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
duration = size / self.sampling_rate / 2
if (54 > duration > 0.6 or self.val):
if 54 > duration > 0.6 or self.val:
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
lengths.append(size // (2 * self.hop_length))
else:
@ -90,7 +93,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
continue
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
print("total left: ", len(audiopaths_sid_text_new))
assert len(audiopaths_sid_text_new)>1#至少能凑够batch size这里todo
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
@ -98,30 +101,41 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids)
try:
spec, wav = self.get_audio("%s/%s"%(self.path5,audiopath))
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt"%(self.path4,audiopath),map_location="cpu")
if(ssl.shape[-1]!=spec.shape[-1]):
typee=ssl.dtype
ssl=F.pad(ssl.float(),(0,1),mode="replicate").to(typee)
ssl.requires_grad=False
ssl = torch.load(
"%s/%s.pt" % (self.path4, audiopath), map_location="cpu"
)
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
except:
traceback.print_exc()
spec = torch.zeros(1025, 100)
wav = torch.zeros(1, 100*self.hop_length)
ssl=torch.zeros(1,768,100)
text=text[-1:]
wav = torch.zeros(1, 100 * self.hop_length)
ssl = torch.zeros(1, 768, 100)
text = text[-1:]
print("load audio or ssl error!!!!!!", audiopath)
# print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
return (ssl, spec, wav, text)
def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio_array = load_audio(
filename, self.sampling_rate
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
# print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
audio=torch.FloatTensor(audio_array)#/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,self.sampling_rate, self.hop_length, self.win_length,center=False)
spec = spectrogram_torch(
audio_norm,
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0)
return spec, audio_norm
@ -131,39 +145,51 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
def __getitem__(self, index):
# with torch.no_grad():
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
def __len__(self):
return len(self.audiopaths_sid_text)
def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1]- wav.shape[-1]//self.hop_length) < 3, ("first", ssl.shape, wav.shape)
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first",
ssl.shape,
wav.shape,
)
len_mel = mel.shape[1]
if self.val:
reference_mel = mel[:, :len_mel//3]
reference_mel = mel[:, : len_mel // 3]
return reference_mel, ssl, wav, mel
dir = random.randint(0, 1)
sep_point = random.randint(int(len_mel//3), int(len_mel//3*2))
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
if dir == 0:
reference_mel = mel[:, :sep_point]
ssl = ssl[:, :, sep_point:]
wav2 = wav[:, sep_point*self.hop_length:]
wav2 = wav[:, sep_point * self.hop_length :]
mel = mel[:, sep_point:]
else:
reference_mel = mel[:, sep_point:]
ssl = ssl[:, :, :sep_point]
wav2 = wav[:, :sep_point*self.hop_length]
wav2 = wav[:, : sep_point * self.hop_length]
mel = mel[:, :sep_point]
assert abs(ssl.shape[-1]- wav2.shape[-1]//self.hop_length) < 3, (ssl.shape, wav.shape,wav2.shape, mel.shape, sep_point,self.hop_length, sep_point*self.hop_length, dir)
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape,
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
"""
class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
@ -176,8 +202,8 @@ class TextAudioSpeakerCollate():
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@ -194,7 +220,7 @@ class TextAudioSpeakerCollate():
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
spec_padded.zero_()
wav_padded.zero_()
@ -205,23 +231,31 @@ class TextAudioSpeakerCollate():
row = batch[ids_sorted_decreasing[i]]
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
text = row[3]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
return (
ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
)
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
@ -234,7 +268,15 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
def __init__(
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths
# print(233333333333333,self.lengths,dir(dataset))
@ -254,7 +296,7 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
buckets[idx_bucket].append(i)
for i in range(len(buckets) - 1, 0, -1):
# for i in range(len(buckets) - 1, -1, -1):
# for i in range(len(buckets) - 1, -1, -1):
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i + 1)
@ -263,7 +305,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
rem = (
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
@ -289,14 +333,23 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
ids_bucket = (
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample
ids_bucket = ids_bucket[self.rank::self.num_replicas]
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
batch = [
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch)
if self.shuffle:

View File

@ -5,64 +5,69 @@ from torch.nn import functional as F
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
gl = gl.float()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1-dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
dr = dr.float()
dg = dg.float()
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l
loss = 0
gen_losses = []
for dg in disc_outputs:
dg = dg.float()
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses
return loss, gen_losses
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
"""
z_p, logs_q: [b, h, t_t]
m_p, logs_p: [b, h, t_t]
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
def mle_loss(z, m, logs, logdet, mask):
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l
l = torch.sum(logs) + 0.5 * torch.sum(
torch.exp(-2 * logs) * ((z - m) ** 2)
) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum(
torch.ones_like(z) * mask
) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l

View File

@ -49,21 +49,37 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
@ -71,37 +87,63 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + '_' + str(spec.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
fmax_dtype_device = str(fmax) + '_' + dtype_device
wnsize_dtype_device = str(win_size) + '_' + dtype_device
dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

View File

@ -12,12 +12,21 @@ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
from module.quantize import ResidualVectorQuantizer
from text import symbols
from torch.cuda.amp import autocast
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
@ -31,21 +40,29 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@ -66,7 +83,10 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@ -74,8 +94,13 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
@ -84,12 +109,18 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
@ -98,7 +129,9 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
@ -108,9 +141,13 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
@ -135,15 +172,17 @@ class DurationPredictor(nn.Module):
class TextEncoder(nn.Module):
def __init__(self,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
latent_channels=192):
def __init__(
self,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
latent_channels=192,
):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
@ -160,17 +199,14 @@ class TextEncoder(nn.Module):
hidden_channels,
filter_channels,
n_heads,
n_layers//2,
n_layers // 2,
kernel_size,
p_dropout)
p_dropout,
)
self.encoder_text = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
@ -179,21 +215,25 @@ class TextEncoder(nn.Module):
hidden_channels,
filter_channels,
n_heads,
n_layers//2,
n_layers // 2,
kernel_size,
p_dropout)
p_dropout,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
if test == 1 :
text_mask = torch.unsqueeze(
commons.sequence_mask(text_lengths, text.size(1)), 1
).to(y.dtype)
if test == 1:
text[:, :] = 0
text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask)
@ -208,9 +248,9 @@ class TextEncoder(nn.Module):
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0,1)
def decode_latent(self, codes, y_mask, refer,refer_mask, ge):
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
@ -224,15 +264,18 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
@ -245,8 +288,16 @@ class ResidualCouplingBlock(nn.Module):
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(
modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
gin_channels=gin_channels, mean_only=True))
modules.ResidualCouplingLayer(
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
@ -260,14 +311,16 @@ class ResidualCouplingBlock(nn.Module):
class PosteriorEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -278,13 +331,21 @@ class PosteriorEncoder(nn.Module):
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
if(g!=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
@ -294,14 +355,16 @@ class PosteriorEncoder(nn.Module):
class WNEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -312,11 +375,20 @@ class WNEncoder(nn.Module):
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask
@ -325,24 +397,45 @@ class WNEncoder(nn.Module):
class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2)))
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
@ -373,7 +466,7 @@ class Generator(torch.nn.Module):
return x
def remove_weight_norm(self):
print('Removing weight norm...')
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
@ -386,13 +479,55 @@ class DiscriminatorP(torch.nn.Module):
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
@ -421,14 +556,16 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
@ -451,7 +588,9 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
@ -469,31 +608,40 @@ class MultiPeriodDiscriminator(torch.nn.Module):
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class ReferenceEncoder(nn.Module):
'''
"""
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
'''
"""
def __init__(self, spec_channels, gin_channels=0):
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [weight_norm(nn.Conv2d(in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1))) for i in range(K)]
convs = [
weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, gin_channels)
def forward(self, inputs):
@ -527,23 +675,31 @@ class Quantizer_module(torch.nn.Module):
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
def forward(self, x):
d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) - 2 * torch.matmul(x, self.embedding.weight.T)
d = (
torch.sum(x**2, 1, keepdim=True)
+ torch.sum(self.embedding.weight**2, 1)
- 2 * torch.matmul(x, self.embedding.weight.T)
)
min_indicies = torch.argmin(d, 1)
z_q = self.embedding(min_indicies)
return z_q, min_indicies
class Quantizer(torch.nn.Module):
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList([
Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)
])
self.quantizer_modules = nn.ModuleList(
[
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
)
self.n_code_groups = n_code_groups
self.embed_dim = embed_dim
def forward(self, xin):
#B, C, T
# B, C, T
B, C, T = xin.shape
xin = xin.transpose(1, 2)
x = xin.reshape(-1, self.embed_dim)
@ -553,38 +709,41 @@ class Quantizer(torch.nn.Module):
for _x, m in zip(x, self.quantizer_modules):
_z_q, _min_indicies = m(_x)
z_q.append(_z_q)
min_indicies.append(_min_indicies) #B * T,
min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
return z_q, loss, codes.transpose(1, 2)
def embed(self, x):
#idx: N, 4, T
x=x.transpose(1, 2)
# idx: N, 4, T
x = x.transpose(1, 2)
x = torch.split(x, 1, 2)
ret = []
for q, embed in zip(x, self.quantizer_modules):
q = embed.embedding(q.squeeze(-1))
ret.append(q)
ret = torch.cat(ret, -1)
return ret.transpose(1, 2) #N, C, T
return ret.transpose(1, 2) # N, C, T
class CodePredictor(nn.Module):
def __init__(self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
n_q=8,
dims=1024,
ssl_dim=768
):
def __init__(
self,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
n_q=8,
dims=1024,
ssl_dim=768,
):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
@ -594,19 +753,18 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.out_proj = nn.Conv1d(hidden_channels, (n_q-1) * dims, 1)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q
self.dims = dims
def forward(self, x, x_mask, refer, codes, infer=False):
x = x.detach()
x = self.vq_proj(x * x_mask) * x_mask
@ -614,7 +772,9 @@ class CodePredictor(nn.Module):
x = x + g
x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
target = codes[1:].transpose(0, 1)
if not infer:
logits = logits.reshape(-1, self.dims)
@ -626,44 +786,44 @@ class CodePredictor(nn.Module):
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
print('Top-10 Accuracy:', top3_acc, "%")
print("Top-10 Accuracy:", top3_acc, "%")
pred_codes = torch.argmax(logits, dim=-1)
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
print('Top-1 Accuracy:', acc, "%")
print("Top-1 Accuracy:", acc, "%")
return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs):
Synthesizer for Training
"""
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
**kwargs
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@ -685,34 +845,50 @@ class SynthesizerTrn(nn.Module):
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
)
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
self.ref_enc = modules.MelStyleEncoder(
spec_channels, style_vector_dim=gin_channels
)
ssl_dim = 768
assert semantic_frame_rate in ['25hz', "50hz"]
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == '25hz':
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(
dimension=ssl_dim,
n_q=1,
bins=1024
)
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
if freeze_quantizer:
self.ssl_proj.requires_grad_(False)
self.quantizer.requires_grad_(False)
@ -721,56 +897,85 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.mrte.requires_grad_(False)
def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False):
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]
)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
z_slice, ids_slice = commons.rand_slice_segments(
z, y_lengths, self.segment_size
)
o = self.dec(z_slice, g=ge)
return o, commit_loss, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
return (
o,
commit_loss,
ids_slice,
y_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
quantized,
)
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype
)
ge = self.ref_enc(y * y_mask, y_mask)
ssl = self.ssl_proj(ssl)
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge, test=test
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
o = self.dec((z * y_mask)[:, :, :], g=ge)
return o,y_mask, (z, z_p, m_p, logs_p)
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes,text, refer, noise_scale=0.5):
def decode(self, codes, text, refer, noise_scale=0.5):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask)
y_lengths = torch.LongTensor([codes.size(2)*2]).to(codes.device)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == '25hz':
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -779,6 +984,6 @@ class SynthesizerTrn(nn.Module):
return o
def extract_latent(self, x):
ssl = self.ssl_proj(x)
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0,1)
return codes.transpose(0, 1)

File diff suppressed because it is too large Load Diff

View File

@ -5,46 +5,74 @@ from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention
class MRTE(nn.Module):
def __init__(self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer = 2
):
def __init__(
self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer=2,
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size,hidden_size,n_heads)
self.c_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size,out_channels, 1)
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
if(ge==None):ge=0
if ge == None:
ge = 0
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
if test != None:
if test == 0:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
elif test == 1:
x = ssl_enc + ge
elif test ==2:
x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
elif test == 2:
x = (
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
else:
raise ValueError("test should be 0,1,2")
else:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask)
return x
class SpeakerEncoder(torch.nn.Module):
def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256):
def __init__(
self,
mel_n_channels=80,
model_num_layers=2,
model_hidden_size=256,
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.lstm = nn.LSTM(
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
@ -56,13 +84,15 @@ class SpeakerEncoder(torch.nn.Module):
class MELEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -81,80 +111,82 @@ class MELEncoder(nn.Module):
x = self.enc(x)
x = self.proj(x)
return x
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
self.hidden_channels =hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
super(WN, self).__init__()
assert kernel_size % 2 == 1
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer)
for i in range(n_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(
hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
def forward(self, x):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
acts = fused_add_tanh_sigmoid_multiply(
x_in,
n_channels_tensor)
acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:]
x = (x + res_acts)
output = output + res_skip_acts[:,self.hidden_channels:,:]
else:
output = output + res_skip_acts
return output
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_channels, :]
x = x + res_acts
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output
def remove_weight_norm(self):
for l in self.in_layers:
remove_weight_norm(l)
for l in self.res_skip_layers:
remove_weight_norm(l)
def remove_weight_norm(self):
for l in self.in_layers:
remove_weight_norm(l)
for l in self.res_skip_layers:
remove_weight_norm(l)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input, n_channels):
n_channels_int = n_channels[0]
t_act = torch.tanh(input[:, :n_channels_int, :])
s_act = torch.sigmoid(input[:, n_channels_int:, :])
acts = t_act * s_act
return acts
n_channels_int = n_channels[0]
t_act = torch.tanh(input[:, :n_channels_int, :])
s_act = torch.sigmoid(input[:, n_channels_int:, :])
acts = t_act * s_act
return acts
if __name__ == '__main__':
content_enc = torch.randn(3,192,100)
content_mask = torch.ones(3,1,100)
ref_mel = torch.randn(3,128,30)
ref_mask = torch.ones(3,1,30)
if __name__ == "__main__":
content_enc = torch.randn(3, 192, 100)
content_mask = torch.ones(3, 1, 100)
ref_mel = torch.randn(3, 128, 30)
ref_mask = torch.ones(3, 1, 30)
model = MRTE()
out = model(content_enc,content_mask,ref_mel,ref_mask)
print(out.shape)
out = model(content_enc, content_mask, ref_mel, ref_mask)
print(out.shape)

View File

@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module):
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult:
def forward(
self,
x: torch.Tensor,
n_q: tp.Optional[int] = None,
layers: tp.Optional[list] = None,
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module):
"""
n_q = n_q if n_q else self.n_q
if layers and max(layers) >= n_q:
raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.')
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
@ -105,4 +116,4 @@ class ResidualVectorQuantizer(nn.Module):
st (int): Start to decode input codes from which layers. Default: 0.
"""
quantized = self.vq.decode(codes, st=st)
return quantized
return quantized

View File

@ -9,66 +9,63 @@ DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {
'tails': tails,
'tail_bound': tail_bound
}
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(
inputs[..., None] >= bin_locations,
dim=-1
) - 1
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails='linear',
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == 'linear':
if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
@ -77,45 +74,57 @@ def unconstrained_rational_quadratic_spline(inputs,
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError('{} tails are not implemented.'.format(tails))
raise RuntimeError("{} tails are not implemented.".format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative
min_derivative=min_derivative,
)
return outputs, logabsdet
def rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0., right=1., bottom=0., top=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE,
):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError('Input to a transform is not within its domain')
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError('Minimal bin width too large for the number of bins')
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError('Minimal bin height too large for the number of bins')
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
@ -126,7 +135,7 @@ def rational_quadratic_spline(inputs,
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
@ -150,15 +159,13 @@ def rational_quadratic_spline(inputs,
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (((inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta)
+ input_heights * (input_delta - input_derivatives)))
b = (input_heights * input_derivatives
- (inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta))
c = - input_delta * (inputs - input_cumheights)
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
@ -167,11 +174,15 @@ def rational_quadratic_spline(inputs,
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2))
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
@ -179,15 +190,20 @@ def rational_quadratic_spline(inputs,
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2)
+ input_derivatives * theta_one_minus_theta)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2))
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2)
)
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

View File

@ -1,50 +1,81 @@
import os,torch,sys
import os, torch, sys
from subprocess import Popen
now_dir = os.getcwd()
sys.path.append(now_dir)
from config import text_path,wav_dir,n_card,n_process_per_card,exp_name,n_parts,exp_dir
os.makedirs("%s/logs_s1"%exp_dir,exist_ok=True)
os.makedirs("%s/logs_s2"%exp_dir,exist_ok=True)
from config import (
text_path,
wav_dir,
n_card,
exp_name,
n_parts,
exp_dir,
)
os.makedirs("%s/logs_s1" % exp_dir, exist_ok=True)
os.makedirs("%s/logs_s2" % exp_dir, exist_ok=True)
##############step1
ps=[]
ps = []
for i_part in range(n_parts):
cmd="python prepare/1-get-text.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card)
cmd = "python prepare/1-get-text.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd)
p = Popen(cmd, shell=True)
ps.append(p)
for p in ps:
p.wait()
opt=[]
opt = []
for i_part in range(n_parts):
txt_path = "%s/2-name2text-%s.txt" % (exp_dir, i_part)
with open(txt_path,"r")as f:
opt+=f.read().strip("\n").split("\n")
with open(txt_path, "r") as f:
opt += f.read().strip("\n").split("\n")
os.remove(txt_path)
with open("%s/2-name2text.txt"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n")
with open("%s/2-name2text.txt" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")
############step2
ps=[]
ps = []
for i_part in range(n_parts):
cmd="python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s"%(text_path,wav_dir,exp_name,i_part,n_parts,i_part%n_card)
cmd = "python prepare/2-get-hubert-wav32k.py %s %s %s %s %s %s" % (
text_path,
wav_dir,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd)
p = Popen(cmd, shell=True)
ps.append(p)
for p in ps:
p.wait()
#############step3
ps=[]
ps = []
for i_part in range(n_parts):
cmd="python prepare/3-get-semantic.py %s %s %s %s %s"%(text_path,exp_name,i_part,n_parts,i_part%n_card)
cmd = "python prepare/3-get-semantic.py %s %s %s %s %s" % (
text_path,
exp_name,
i_part,
n_parts,
i_part % n_card,
)
print(cmd)
p = Popen(cmd, shell=True)
ps.append(p)
for p in ps:
p.wait()
opt=["item_name semantic_audio"]
opt = ["item_name semantic_audio"]
for i_part in range(n_parts):
semantic_path = "%s/6-name2semantic-%s.tsv" % (exp_dir, i_part)
with open(semantic_path,"r")as f:
opt+=f.read().strip("\n").split("\n")
with open(semantic_path, "r") as f:
opt += f.read().strip("\n").split("\n")
os.remove(semantic_path)
with open("%s/6-name2semantic.tsv"%exp_dir,"w")as f:f.write("\n".join(opt)+"\n")
with open("%s/6-name2semantic.tsv" % exp_dir, "w") as f:
f.write("\n".join(opt) + "\n")

View File

@ -2,16 +2,16 @@
import os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir= os.environ.get("opt_dir")
bert_pretrained_dir= os.environ.get("bert_pretrained_dir")
is_half=eval(os.environ.get("is_half","True"))
import sys,numpy as np,traceback,pdb
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
is_half = eval(os.environ.get("is_half", "True"))
import sys, numpy as np, traceback, pdb
import os.path
from glob import glob
from tqdm import tqdm
@ -31,25 +31,29 @@ import numpy as np
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part)
if(os.path.exists(txt_path)==False):
bert_dir="%s/3-bert"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(bert_dir,exist_ok=True)
device="cuda:0"
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(bert_dir, exist_ok=True)
device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model=AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if (is_half == True):
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
@ -67,51 +71,55 @@ if(os.path.exists(txt_path)==False):
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def process(data,res):
for name,text,lan in data:
def process(data, res):
for name, text, lan in data:
try:
name=os.path.basename(name)
phones, word2ph, norm_text=clean_text(text.replace("%", '-').replace('', ','),lan)
path_bert="%s/%s.pt"%(bert_dir,name)
if (os.path.exists(path_bert) == False and lan == "zh"):
name = os.path.basename(name)
phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan
)
path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones)
# torch.save(bert_feature, path_bert)
my_save(bert_feature, path_bert)
phones = " ".join(phones)
# res.append([name,phones])
res.append([name,phones, word2ph, norm_text])
res.append([name, phones, word2ph, norm_text])
except:
print(name, text, traceback.format_exc())
todo=[]
res=[]
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
todo = []
res = []
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
language_v1_to_language_v2={
"ZH":"zh",
"zh":"zh",
"JP":"ja",
"jp":"ja",
"JA":"ja",
"ja":"ja",
"EN":"en",
"en":"en",
"En":"en",
language_v1_to_language_v2 = {
"ZH": "zh",
"zh": "zh",
"JP": "ja",
"jp": "ja",
"JA": "ja",
"ja": "ja",
"EN": "en",
"en": "en",
"En": "en",
}
for line in lines[int(i_part)::int(all_parts)]:
for line in lines[int(i_part) :: int(all_parts)]:
try:
wav_name,spk_name,language,text=line.split("|")
wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"])
todo.append([wav_name,text,language_v1_to_language_v2.get(language,language)])
todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
except:
print(line,traceback.format_exc())
process(todo,res)
opt=[]
for name,phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s"%(name,phones, word2ph, norm_text))
with open(txt_path,"w",encoding="utf8")as f:
f.write("\n".join(opt)+"\n")
print(line, traceback.format_exc())
process(todo, res)
opt = []
for name, phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
with open(txt_path, "w", encoding="utf8") as f:
f.write("\n".join(opt) + "\n")

View File

@ -1,20 +1,23 @@
# -*- coding: utf-8 -*-
import sys,os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
from feature_extractor import cnhubert
opt_dir= os.environ.get("opt_dir")
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
is_half=eval(os.environ.get("is_half","True"))
import sys, os
import pdb,traceback,numpy as np,logging
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
is_half = eval(os.environ.get("is_half", "True"))
import pdb, traceback, numpy as np, logging
from scipy.io import wavfile
import librosa,torch
import librosa, torch
now_dir = os.getcwd()
sys.path.append(now_dir)
from my_utils import load_audio
@ -32,63 +35,75 @@ from my_utils import load_audio
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
hubert_dir="%s/4-cnhubert"%(opt_dir)
wav32dir="%s/5-wav32k"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(hubert_dir,exist_ok=True)
os.makedirs(wav32dir,exist_ok=True)
maxx=0.95
alpha=0.5
device="cuda:0"
model=cnhubert.get_model()
if(is_half==True):
model=model.half().to(device)
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s/%s%s.pth" % (dir, ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
device = "cuda:0"
model = cnhubert.get_model()
if is_half == True:
model = model.half().to(device)
else:
model = model.to(device)
def name2go(wav_name):
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
if(os.path.exists(hubert_path)):return
wav_path="%s/%s"%(inp_wav_dir,wav_name)
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if os.path.exists(hubert_path):
return
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2:
print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
tmp_audio = librosa.resample(
tmp_audio32, orig_sr=32000, target_sr=16000
)
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + (
(1 - alpha) * 32768
) * tmp_audio
tmp_audio = librosa.resample(tmp_audio32, orig_sr=32000, target_sr=16000)
tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True):
tensor_wav16=tensor_wav16.half().to(device)
if is_half == True:
tensor_wav16 = tensor_wav16.half().to(device)
else:
tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:return
ssl = (
model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"]
.transpose(1, 2)
.cpu()
) # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
return
wavfile.write(
"%s/%s"%(wav32dir,wav_name),
"%s/%s" % (wav32dir, wav_name),
32000,
tmp_audio32.astype("int16"),
)
# torch.save(ssl,hubert_path )
my_save(ssl,hubert_path )
my_save(ssl, hubert_path)
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]:
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=os.path.basename(wav_name)
wav_name = os.path.basename(wav_name)
name2go(wav_name)
except:
print(line,traceback.format_exc())
print(line, traceback.format_exc())

View File

@ -1,24 +1,27 @@
import os
inp_text= os.environ.get("inp_text")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"]= os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir= os.environ.get("opt_dir")
pretrained_s2G= os.environ.get("pretrained_s2G")
s2config_path= os.environ.get("s2config_path")
is_half=eval(os.environ.get("is_half","True"))
import math,traceback
inp_text = os.environ.get("inp_text")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("_CUDA_VISIBLE_DEVICES")
opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path")
is_half = eval(os.environ.get("is_half", "True"))
import math, traceback
import multiprocessing
import sys,pdb
import sys, pdb
now_dir = os.getcwd()
sys.path.append(now_dir)
from random import shuffle
import torch.multiprocessing as mp
from glob import glob
from tqdm import tqdm
import logging,librosa,utils,torch
import logging, librosa, utils, torch
from module.models import SynthesizerTrn
logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G
@ -30,52 +33,58 @@ logging.getLogger("numba").setLevel(logging.WARNING)
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
hubert_dir="%s/4-cnhubert"%(opt_dir)
semantic_path="%s/6-name2semantic-%s.tsv"%(opt_dir,i_part)
if(os.path.exists(semantic_path)==False):
os.makedirs(opt_dir,exist_ok=True)
hubert_dir = "%s/4-cnhubert" % (opt_dir)
semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
if os.path.exists(semantic_path) == False:
os.makedirs(opt_dir, exist_ok=True)
device="cuda:0"
device = "cuda:0"
hps = utils.get_hparams_from_file(s2config_path)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
if(is_half==True):
vq_model=vq_model.half().to(device)
**hps.model
)
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
# utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
print(vq_model.load_state_dict(torch.load(pretrained_s2G,map_location="cpu")["weight"], strict=False))
print(
vq_model.load_state_dict(
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
)
)
def name2go(wav_name,lines):
def name2go(wav_name, lines):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if(os.path.exists(hubert_path)==False):return
if os.path.exists(hubert_path) == False:
return
ssl_content = torch.load(hubert_path, map_location="cpu")
if(is_half==True):
ssl_content=ssl_content.half().to(device)
if is_half == True:
ssl_content = ssl_content.half().to(device)
else:
ssl_content = ssl_content.to(device)
codes = vq_model.extract_latent(ssl_content)
semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
lines.append("%s\t%s"%(wav_name,semantic))
lines.append("%s\t%s" % (wav_name, semantic))
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
lines1=[]
for line in lines[int(i_part)::int(all_parts)]:
lines1 = []
for line in lines[int(i_part) :: int(all_parts)]:
# print(line)
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=os.path.basename(wav_name)
wav_name = os.path.basename(wav_name)
# name2go(name,lines1)
name2go(wav_name,lines1)
name2go(wav_name, lines1)
except:
print(line,traceback.format_exc())
with open(semantic_path,"w",encoding="utf8")as f:f.write("\n".join(lines1))
print(line, traceback.format_exc())
with open(semantic_path, "w", encoding="utf8") as f:
f.write("\n".join(lines1))

View File

@ -1,11 +1,12 @@
import os
import sys
import traceback
from collections import OrderedDict
import torch
from i18n.i18n import I18nAuto
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
def savee(ckpt, name, epoch, steps, hps):
try:
opt = OrderedDict()
@ -15,8 +16,8 @@ def savee(ckpt, name, epoch, steps, hps):
continue
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch,steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir,name))
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()

View File

@ -2,56 +2,84 @@
import os
import pdb
if("_CUDA_VISIBLE_DEVICES"in os.environ):
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
from pathlib import Path
import torch,platform
import torch, platform
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger#WandbLogger
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
from AR.utils import get_newest_ckpt
from collections import OrderedDict
class my_model_ckpt(ModelCheckpoint):
def __init__(self,config,if_save_latest,if_save_every_weights,half_weights_save_dir,exp_name,**kwargs):
def __init__(
self,
config,
if_save_latest,
if_save_every_weights,
half_weights_save_dir,
exp_name,
**kwargs
):
super().__init__(**kwargs)
self.if_save_latest=if_save_latest
self.if_save_every_weights=if_save_every_weights
self.half_weights_save_dir=half_weights_save_dir
self.exp_name=exp_name
self.config=config
self.if_save_latest = if_save_latest
self.if_save_every_weights = if_save_every_weights
self.half_weights_save_dir = half_weights_save_dir
self.exp_name = exp_name
self.config = config
def on_train_epoch_end(self, trainer, pl_module):
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
if not self._should_skip_saving_checkpoint(
trainer
) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
if(self.if_save_latest==True):####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean=list(os.listdir(self.dirpath))
if (
self._every_n_epochs >= 1
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if (
self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates)
if (self.if_save_latest == True):
if self.if_save_latest == True:
for name in to_clean:
try:
os.remove("%s/%s"%(self.dirpath,name))
except:pass
if(self.if_save_every_weights==True):
to_save_od=OrderedDict()
to_save_od["weight"]=OrderedDict()
dictt=trainer.strategy._lightning_module.state_dict()
for key in dictt:to_save_od["weight"][key]=dictt[key].half()
to_save_od["config"]=self.config
to_save_od["info"]="GPT-e%s"%(trainer.current_epoch+1)
torch.save(to_save_od,"%s/%s-e%s.ckpt"%(self.half_weights_save_dir,self.exp_name,trainer.current_epoch+1))
os.remove("%s/%s" % (self.dirpath, name))
except:
pass
if self.if_save_every_weights == True:
to_save_od = OrderedDict()
to_save_od["weight"] = OrderedDict()
dictt = trainer.strategy._lightning_module.state_dict()
for key in dictt:
to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
torch.save(
to_save_od,
"%s/%s-e%s.ckpt"
% (
self.half_weights_save_dir,
self.exp_name,
trainer.current_epoch + 1,
),
)
self._save_last_checkpoint(trainer, monitor_candidates)
@ -61,41 +89,45 @@ def main(args):
output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt'
ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True)
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = my_model_ckpt(
config=config,
if_save_latest=config["train"]["if_save_latest"], if_save_every_weights=config["train"]["if_save_every_weights"], half_weights_save_dir=config["train"]["half_weights_save_dir"], exp_name=config["train"]["exp_name"],
if_save_latest=config["train"]["if_save_latest"],
if_save_every_weights=config["train"]["if_save_every_weights"],
half_weights_save_dir=config["train"]["half_weights_save_dir"],
exp_name=config["train"]["exp_name"],
save_top_k=-1,
monitor='top_3_acc',
mode='max',
monitor="top_3_acc",
mode="max",
save_on_train_epoch_end=True,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir,
)
logger = TensorBoardLogger(
name=output_dir.stem,
save_dir=output_dir
)
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator='gpu',
accelerator="gpu",
# val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None,
limit_val_batches=0,
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl"if platform.system()!="Windows"else "gloo"),
strategy=DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
),
precision=config["train"]["precision"],
logger=logger,num_sanity_val_steps=0,
callbacks=[ckpt_callback])
logger=logger,
num_sanity_val_steps=0,
callbacks=[ckpt_callback],
)
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir)
config, output_dir
)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
@ -116,14 +148,15 @@ def main(args):
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'-c',
'--config_file',
"-c",
"--config_file",
type=str,
default='configs/s1longer.yaml',
help='path of config file')
default="configs/s1longer.yaml",
help="path of config file",
)
# args for dataset
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')

View File

@ -1,4 +1,5 @@
import utils,os
import utils, os
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch
@ -6,11 +7,12 @@ from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist,traceback
import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging,traceback
import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
@ -20,37 +22,42 @@ from module import commons
from module.data_utils import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import (
generator_loss,
discriminator_loss,
feature_loss,
kl_loss
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('medium')#最低精度但最快(也就快一丁点),对于结果造成不了影响
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
def main():
"""Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed."
n_gpus = torch.cuda.device_count()
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(randint(20000, 55555))
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
def run(rank, n_gpus, hps):
@ -62,21 +69,54 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus,rank=rank)
dist.init_process_group(
backend="gloo" if os.name == "nt" else "nccl",
init_method="env://",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data)########
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[32, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
[
32,
300,
400,
500,
600,
700,
800,
900,
1000,
1100,
1200,
1300,
1400,
1500,
1600,
1700,
1800,
1900,
],
num_replicas=n_gpus,
rank=rank,
shuffle=True)
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True,
collate_fn=collate_fn, batch_sampler=train_sampler,persistent_workers=True,prefetch_factor=16)
train_loader = DataLoader(
train_dataset,
num_workers=6,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=16,
)
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
@ -87,17 +127,21 @@ def run(rank, n_gpus, hps):
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model).cuda(rank)
**hps.model,
).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
for name, param in net_g.named_parameters():
if not param.requires_grad:
print(name,"not requires_grad")
print(name, "not requires_grad")
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
base_params = filter(lambda p: id(p) not in te_p+et_p+mrte_p and p.requires_grad, net_g.parameters())
base_params = filter(
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
net_g.parameters(),
)
# te_p=net_g.enc_p.text_embedding.parameters()
# et_p=net_g.enc_p.encoder_text.parameters()
@ -106,31 +150,46 @@ def run(rank, n_gpus, hps):
optim_g = torch.optim.AdamW(
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
[
{"params":base_params,"lr":hps.train.learning_rate},
{"params":net_g.enc_p.text_embedding.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
{"params":net_g.enc_p.encoder_text.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
{"params":net_g.enc_p.mrte.parameters(),"lr":hps.train.learning_rate*hps.train.text_low_lr_rate},
{"params": base_params, "lr": hps.train.learning_rate},
{
"params": net_g.enc_p.text_embedding.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.encoder_text.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
{
"params": net_g.enc_p.mrte.parameters(),
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
},
],
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
eps=hps.train.eps,
)
optim_d = torch.optim.AdamW(
net_d.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
net_g = DDP(net_g, device_ids=[rank],find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank],find_unused_parameters=True)
eps=hps.train.eps,
)
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "D_*.pth"), net_d, optim_d
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
net_d,
optim_d,
) # D多半加载没事
if rank == 0:
logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2"%hps.data.exp_dir, "G_*.pth"), net_g, optim_g
utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
net_g,
optim_g,
)
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
@ -144,7 +203,8 @@ def run(rank, n_gpus, hps):
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print(
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],strict=False
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
) ##测试不加载优化器
if hps.train.pretrained_s2D != "":
@ -159,8 +219,12 @@ def run(rank, n_gpus, hps):
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=-1)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
)
for _ in range(epoch_str):
scheduler_g.step()
scheduler_d.step()
@ -169,17 +233,39 @@ def run(rank, n_gpus, hps):
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None], logger, [writer, writer_eval])
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None],
logger,
[writer, writer_eval],
)
else:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
[train_loader, None], None, None)
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
)
scheduler_g.step()
scheduler_d.step()
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@ -192,17 +278,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train()
net_d.train()
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in tqdm(enumerate(train_loader)):
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
for batch_idx, (
ssl,
ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in tqdm(enumerate(train_loader)):
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad=False
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
)
with autocast(enabled=hps.train.fp16_run):
y_hat, kl_ssl, ids_slice, x_mask, z_mask, \
(z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl = net_g(ssl, spec, spec_lengths, text, text_lengths)
(
y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch(
spec,
@ -210,8 +318,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax)
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
hps.data.mel_fmax,
)
y_mel = commons.slice_segments(
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
hps.data.filter_length,
@ -220,15 +331,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
hps.data.mel_fmax,
)
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
y = commons.slice_segments(
y, ids_slice * hps.data.hop_length, hps.train.segment_size
) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r, y_d_hat_g
)
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
@ -256,32 +371,54 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info(
"Train Epoch: {} [{:.0f}%]".format(
epoch, 100.0 * batch_idx / len(train_loader)
)
)
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
"grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
scalar_dict = {
"loss/g/total": loss_gen_all,
"loss/d/total": loss_disc_all,
"learning_rate": lr,
"grad_norm_d": grad_norm_d,
"grad_norm_g": grad_norm_g,
}
scalar_dict.update(
{"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl_ssl": kl_ssl, "loss/g/kl": loss_kl})
{
"loss/g/fm": loss_fm,
"loss/g/mel": loss_mel,
"loss/g/kl_ssl": kl_ssl,
"loss/g/kl": loss_kl,
}
)
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy()
),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].data.cpu().numpy()
),
"all/mel": utils.plot_spectrogram_to_numpy(
mel[0].data.cpu().numpy()
),
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
stats_ssl[0].data.cpu().numpy()
),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict)
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
@ -290,14 +427,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(global_step)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(global_step)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
),
)
else:
utils.save_checkpoint(
@ -305,14 +446,18 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
optim_g,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "G_{}.pth".format(233333333333)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
),
)
utils.save_checkpoint(
net_d,
optim_d,
hps.train.learning_rate,
epoch,
os.path.join("%s/logs_s2"%hps.data.exp_dir, "D_{}.pth".format(233333333333)),
os.path.join(
"%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
@ -334,11 +479,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
)
)
if rank == 0:
logger.info('====> Epoch: {}'.format(epoch))
logger.info("====> Epoch: {}".format(epoch))
def evaluate(hps, generator, eval_loader, writer_eval):
@ -347,15 +489,25 @@ def evaluate(hps, generator, eval_loader, writer_eval):
audio_dict = {}
print("Evaluating ...")
with torch.no_grad():
for batch_idx, (ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths) in enumerate(eval_loader):
for batch_idx, (
ssl,
ssl_lengths,
spec,
spec_lengths,
y,
y_lengths,
text,
text_lengths,
) in enumerate(eval_loader):
print(111)
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda()
for test in [0, 1]:
y_hat, mask, *_ = generator.module.infer(ssl,spec, spec_lengths,text, text_lengths, test=test)
y_hat, mask, *_ = generator.module.infer(
ssl, spec, spec_lengths, text, text_lengths, test=test
)
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel_torch(
@ -364,7 +516,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.mel_fmin,
hps.data.mel_fmax)
hps.data.mel_fmax,
)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(),
hps.data.filter_length,
@ -373,16 +526,26 @@ def evaluate(hps, generator, eval_loader, writer_eval):
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax
hps.data.mel_fmax,
)
image_dict.update({
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
})
audio_dict.update({
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, :y_hat_lengths[0]]
})
image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :y_lengths[0]]})
image_dict.update(
{
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
y_hat_mel[0].cpu().numpy()
)
}
)
audio_dict.update(
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
)
image_dict.update(
{
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
mel[0].cpu().numpy()
)
}
)
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
# audio_dict.update({
@ -394,9 +557,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
global_step=global_step,
images=image_dict,
audios=audio_dict,
audio_sampling_rate=hps.data.sampling_rate
audio_sampling_rate=hps.data.sampling_rate,
)
generator.train()
if __name__ == "__main__":
main()

View File

@ -6,49 +6,56 @@ import cn2an
from pypinyin import lazy_pinyin, Style
import sys
sys.path.append("/data/docker/liujing04/gpt-vits/gpt-vits-master")
from text.symbols import punctuation
from text.tone_sandhi import ToneSandhi
current_file_path = os.path.dirname(__file__)
pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in
open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()}
pinyin_to_symbol_map = {
line.split("\t")[0]: line.strip().split("\t")[1]
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
}
import jieba.posseg as psg
rep_map = {
'': ',',
'': ',',
'': ',',
'': '.',
'': '!',
'': '?',
'\n': '.',
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
'': ",",
'...': '',
'$': '.',
'/': ',',
'': "-"
"": ",",
"...": "",
"$": ".",
"/": ",",
"": "-",
}
tone_modifier = ToneSandhi()
def replace_punctuation(text):
text = text.replace("", "").replace("","")
pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys()))
text = text.replace("", "").replace("", "")
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def g2p(text):
pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip()!='']
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
phones, word2ph = _g2p(sentences)
return phones, word2ph
@ -56,10 +63,10 @@ def g2p(text):
def _get_initials_finals(word):
initials = []
finals = []
orig_initials = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
orig_finals = lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
)
for c, v in zip(orig_initials, orig_finals):
initials.append(c)
finals.append(v)
@ -72,17 +79,16 @@ def _g2p(segments):
for seg in segments:
pinyins = []
# Replace all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg)
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
initials = []
finals = []
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
for word, pos in seg_cut:
if pos == 'eng':
if pos == "eng":
continue
sub_initials, sub_finals = _get_initials_finals(word)
sub_finals = tone_modifier.modified_tone(word, pos,
sub_finals)
sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
initials.append(sub_initials)
finals.append(sub_finals)
@ -91,7 +97,7 @@ def _g2p(segments):
finals = sum(finals, [])
#
for c, v in zip(initials, finals):
raw_pinyin = c+v
raw_pinyin = c + v
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c == v:
@ -102,40 +108,40 @@ def _g2p(segments):
v_without_tone = v[:-1]
tone = v[-1]
pinyin = c+v_without_tone
assert tone in '12345'
pinyin = c + v_without_tone
assert tone in "12345"
if c:
# 多音节
v_rep_map = {
"uei": 'ui',
'iou': 'iu',
'uen': 'un',
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c+v_rep_map[v_without_tone]
pinyin = c + v_rep_map[v_without_tone]
else:
# 单音节
pinyin_rep_map = {
'ing': 'ying',
'i': 'yi',
'in': 'yin',
'u': 'wu',
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
'v': 'yu',
'e': 'e',
'i': 'y',
'u': 'w',
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]]+pinyin[1:]
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
new_c, new_v = pinyin_to_symbol_map[pinyin].split(' ')
new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ")
new_v = new_v + tone
phone = [new_c, new_v]
word2ph.append(len(phone))
@ -144,9 +150,8 @@ def _g2p(segments):
return phones_list, word2ph
def text_normalize(text):
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
numbers = re.findall(r"\d+(?:\.?\d+)?", text)
for number in numbers:
text = text.replace(number, cn2an.an2cn(number), 1)
text = replace_punctuation(text)
@ -154,7 +159,7 @@ def text_normalize(text):
return text
if __name__ == '__main__':
if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?"
text = "你好"

View File

@ -1,29 +1,27 @@
from text import chinese, japanese, cleaned_text_to_sequence, symbols, english
language_module_map = {
'zh': chinese,
"ja": japanese,
'en': english
}
language_module_map = {"zh": chinese, "ja": japanese, "en": english}
special = [
('%', 'zh', "SP"),
('', 'zh', "SP2"),
('^', 'zh', "SP3"),
("%", "zh", "SP"),
("", "zh", "SP2"),
("^", "zh", "SP3"),
# ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧
]
def clean_text(text, language):
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
return clean_special(text, language, special_s, target_symbol)
language_module = language_module_map[language]
norm_text = language_module.text_normalize(text)
if(language=="zh"):
if language == "zh":
phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph)
else:
phones = language_module.g2p(norm_text)
word2ph=None
word2ph = None
for ph in phones:
assert ph in symbols
@ -41,17 +39,17 @@ def clean_special(text, language, special_s, target_symbol):
new_ph = []
for ph in phones:
assert ph in symbols
if ph == ',':
if ph == ",":
new_ph.append(target_symbol)
else:
new_ph.append(ph)
return new_ph
def text_to_sequence(text, language):
phones = clean_text(text)
return cleaned_text_to_sequence(phones)
if __name__ == '__main__':
print(clean_text("你好%啊啊啊额、还是到付红四方。", 'zh'))
if __name__ == "__main__":
print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh"))

View File

@ -8,20 +8,87 @@ from string import punctuation
from text import symbols
current_file_path = os.path.dirname(__file__)
CMU_DICT_PATH = os.path.join(current_file_path, 'cmudict.rep')
CACHE_PATH = os.path.join(current_file_path, 'cmudict_cache.pickle')
CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
_g2p = G2p()
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'}
arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
def replace_phs(phs):
rep_map = {
';': ',',
':': ',',
'\'': '-',
'"': '-'
}
rep_map = {";": ",", ":": ",", "'": "-", '"': "-"}
phs_new = []
for ph in phs:
if ph in symbols:
@ -29,9 +96,10 @@ def replace_phs(phs):
elif ph in rep_map.keys():
phs_new.append(rep_map[ph])
else:
print('ph not in symbols: ', ph)
print("ph not in symbols: ", ph)
return phs_new
def read_dict():
g2p_dict = {}
start_line = 49
@ -41,13 +109,13 @@ def read_dict():
while line:
if line_index >= start_line:
line = line.strip()
word_split = line.split(' ')
word_split = line.split(" ")
word = word_split[0]
syllable_split = word_split[1].split(' - ')
syllable_split = word_split[1].split(" - ")
g2p_dict[word] = []
for syllable in syllable_split:
phone_split = syllable.split(' ')
phone_split = syllable.split(" ")
g2p_dict[word].append(phone_split)
line_index = line_index + 1
@ -57,13 +125,13 @@ def read_dict():
def cache_dict(g2p_dict, file_path):
with open(file_path, 'wb') as pickle_file:
with open(file_path, "wb") as pickle_file:
pickle.dump(g2p_dict, pickle_file)
def get_dict():
if os.path.exists(CACHE_PATH):
with open(CACHE_PATH, 'rb') as pickle_file:
with open(CACHE_PATH, "rb") as pickle_file:
g2p_dict = pickle.load(pickle_file)
else:
g2p_dict = read_dict()
@ -71,6 +139,7 @@ def get_dict():
return g2p_dict
eng_dict = get_dict()
@ -78,8 +147,8 @@ def text_normalize(text):
# todo: eng text normalize
return text.replace(";", ",")
def g2p(text):
def g2p(text):
phones = []
words = re.split(r"([,;.\-\?\!\s+])", text)
for w in words:
@ -97,6 +166,7 @@ def g2p(text):
return replace_phs(phones)
if __name__ == "__main__":
# print(get_dict())
print(g2p("hello"))
@ -106,4 +176,4 @@ if __name__ == "__main__":
# for group in syllables:
# for ph in group:
# all_phones.add(ph)
# print(all_phones)
# print(all_phones)

View File

@ -8,57 +8,63 @@ from text import symbols
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile(
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# List of (symbol, Japanese) pairs for marks:
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
('', 'パーセント')
]]
_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("", "パーセント")]]
# List of (consonant, sokuon) pairs:
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
(r'Q([↑↓]*[kg])', r'k#\1'),
(r'Q([↑↓]*[tdjʧ])', r't#\1'),
(r'Q([↑↓]*[sʃ])', r's\1'),
(r'Q([↑↓]*[pb])', r'p#\1')
]]
_real_sokuon = [
(re.compile("%s" % x[0]), x[1])
for x in [
(r"Q([↑↓]*[kg])", r"k#\1"),
(r"Q([↑↓]*[tdjʧ])", r"t#\1"),
(r"Q([↑↓]*[sʃ])", r"s\1"),
(r"Q([↑↓]*[pb])", r"p#\1"),
]
]
# List of (consonant, hatsuon) pairs:
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
(r'N([↑↓]*[pbm])', r'm\1'),
(r'N([↑↓]*[ʧʥj])', r'n^\1'),
(r'N([↑↓]*[tdn])', r'n\1'),
(r'N([↑↓]*[kg])', r'ŋ\1')
]]
_real_hatsuon = [
(re.compile("%s" % x[0]), x[1])
for x in [
(r"N([↑↓]*[pbm])", r"m\1"),
(r"N([↑↓]*[ʧʥj])", r"n^\1"),
(r"N([↑↓]*[tdn])", r"n\1"),
(r"N([↑↓]*[kg])", r"ŋ\1"),
]
]
def post_replace_ph(ph):
rep_map = {
'': ',',
'': ',',
'': ',',
'': '.',
'': '!',
'': '?',
'\n': '.',
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
'': ",",
'...': ''
"": ",",
"...": "",
}
if ph in rep_map.keys():
ph = rep_map[ph]
if ph in symbols:
return ph
if ph not in symbols:
ph = 'UNK'
ph = "UNK"
return ph
def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text)
@ -66,7 +72,7 @@ def symbols_to_japanese(text):
def preprocess_jap(text):
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
"""Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
text = symbols_to_japanese(text)
sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text)
@ -77,13 +83,15 @@ def preprocess_jap(text):
text += p.split(" ")
if i < len(marks):
text += [marks[i].replace(' ', '')]
text += [marks[i].replace(" ", "")]
return text
def text_normalize(text):
# todo: jap text normalize
return text
def g2p(norm_text):
phones = preprocess_jap(norm_text)
phones = [post_replace_ph(i) for i in phones]
@ -91,7 +99,7 @@ def g2p(norm_text):
return phones
if __name__ == '__main__':
if __name__ == "__main__":
for line in open("../../../Downloads/transcript_utf8.txt").readlines():
text = line.split(":")[1]
phones = g2p(text)

View File

@ -1,24 +1,397 @@
import os
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
punctuation = ['!', '?', '', ",", "."]#@是SP停顿
punctuation = ["!", "?", "", ",", "."] # @是SP停顿
punctuation.append("-")
pu_symbols = punctuation + ["SP", 'SP2', 'SP3', "UNK"]
pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"]
# pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"]
pad = '_'
pad = "_"
c = ['AA', 'EE', 'OO', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'w', 'x', 'y', 'z', 'zh']
v = ['E1', 'En1', 'a1', 'ai1', 'an1', 'ang1', 'ao1', 'e1', 'ei1', 'en1', 'eng1', 'er1', 'i1', 'i01', 'ia1', 'ian1', 'iang1', 'iao1', 'ie1', 'in1', 'ing1', 'iong1', 'ir1', 'iu1', 'o1', 'ong1', 'ou1', 'u1', 'ua1', 'uai1', 'uan1', 'uang1', 'ui1', 'un1', 'uo1', 'v1', 'van1', 've1', 'vn1', 'E2', 'En2', 'a2', 'ai2', 'an2', 'ang2', 'ao2', 'e2', 'ei2', 'en2', 'eng2', 'er2', 'i2', 'i02', 'ia2', 'ian2', 'iang2', 'iao2', 'ie2', 'in2', 'ing2', 'iong2', 'ir2', 'iu2', 'o2', 'ong2', 'ou2', 'u2', 'ua2', 'uai2', 'uan2', 'uang2', 'ui2', 'un2', 'uo2', 'v2', 'van2', 've2', 'vn2', 'E3', 'En3', 'a3', 'ai3', 'an3', 'ang3', 'ao3', 'e3', 'ei3', 'en3', 'eng3', 'er3', 'i3', 'i03', 'ia3', 'ian3', 'iang3', 'iao3', 'ie3', 'in3', 'ing3', 'iong3', 'ir3', 'iu3', 'o3', 'ong3', 'ou3', 'u3', 'ua3', 'uai3', 'uan3', 'uang3', 'ui3', 'un3', 'uo3', 'v3', 'van3', 've3', 'vn3', 'E4', 'En4', 'a4', 'ai4', 'an4', 'ang4', 'ao4', 'e4', 'ei4', 'en4', 'eng4', 'er4', 'i4', 'i04', 'ia4', 'ian4', 'iang4', 'iao4', 'ie4', 'in4', 'ing4', 'iong4', 'ir4', 'iu4', 'o4', 'ong4', 'ou4', 'u4', 'ua4', 'uai4', 'uan4', 'uang4', 'ui4', 'un4', 'uo4', 'v4', 'van4', 've4', 'vn4', 'E5', 'En5', 'a5', 'ai5', 'an5', 'ang5', 'ao5', 'e5', 'ei5', 'en5', 'eng5', 'er5', 'i5', 'i05', 'ia5', 'ian5', 'iang5', 'iao5', 'ie5', 'in5', 'ing5', 'iong5', 'ir5', 'iu5', 'o5', 'ong5', 'ou5', 'u5', 'ua5', 'uai5', 'uan5', 'uang5', 'ui5', 'un5', 'uo5', 'v5', 'van5', 've5', 'vn5']
c = [
"AA",
"EE",
"OO",
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"w",
"x",
"y",
"z",
"zh",
]
v = [
"E1",
"En1",
"a1",
"ai1",
"an1",
"ang1",
"ao1",
"e1",
"ei1",
"en1",
"eng1",
"er1",
"i1",
"i01",
"ia1",
"ian1",
"iang1",
"iao1",
"ie1",
"in1",
"ing1",
"iong1",
"ir1",
"iu1",
"o1",
"ong1",
"ou1",
"u1",
"ua1",
"uai1",
"uan1",
"uang1",
"ui1",
"un1",
"uo1",
"v1",
"van1",
"ve1",
"vn1",
"E2",
"En2",
"a2",
"ai2",
"an2",
"ang2",
"ao2",
"e2",
"ei2",
"en2",
"eng2",
"er2",
"i2",
"i02",
"ia2",
"ian2",
"iang2",
"iao2",
"ie2",
"in2",
"ing2",
"iong2",
"ir2",
"iu2",
"o2",
"ong2",
"ou2",
"u2",
"ua2",
"uai2",
"uan2",
"uang2",
"ui2",
"un2",
"uo2",
"v2",
"van2",
"ve2",
"vn2",
"E3",
"En3",
"a3",
"ai3",
"an3",
"ang3",
"ao3",
"e3",
"ei3",
"en3",
"eng3",
"er3",
"i3",
"i03",
"ia3",
"ian3",
"iang3",
"iao3",
"ie3",
"in3",
"ing3",
"iong3",
"ir3",
"iu3",
"o3",
"ong3",
"ou3",
"u3",
"ua3",
"uai3",
"uan3",
"uang3",
"ui3",
"un3",
"uo3",
"v3",
"van3",
"ve3",
"vn3",
"E4",
"En4",
"a4",
"ai4",
"an4",
"ang4",
"ao4",
"e4",
"ei4",
"en4",
"eng4",
"er4",
"i4",
"i04",
"ia4",
"ian4",
"iang4",
"iao4",
"ie4",
"in4",
"ing4",
"iong4",
"ir4",
"iu4",
"o4",
"ong4",
"ou4",
"u4",
"ua4",
"uai4",
"uan4",
"uang4",
"ui4",
"un4",
"uo4",
"v4",
"van4",
"ve4",
"vn4",
"E5",
"En5",
"a5",
"ai5",
"an5",
"ang5",
"ao5",
"e5",
"ei5",
"en5",
"eng5",
"er5",
"i5",
"i05",
"ia5",
"ian5",
"iang5",
"iao5",
"ie5",
"in5",
"ing5",
"iong5",
"ir5",
"iu5",
"o5",
"ong5",
"ou5",
"u5",
"ua5",
"uai5",
"uan5",
"uang5",
"ui5",
"un5",
"uo5",
"v5",
"van5",
"ve5",
"vn5",
]
v_without_tone = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn']
v_without_tone = [
"E",
"En",
"a",
"ai",
"an",
"ang",
"ao",
"e",
"ei",
"en",
"eng",
"er",
"i",
"i0",
"ia",
"ian",
"iang",
"iao",
"ie",
"in",
"ing",
"iong",
"ir",
"iu",
"o",
"ong",
"ou",
"u",
"ua",
"uai",
"uan",
"uang",
"ui",
"un",
"uo",
"v",
"van",
"ve",
"vn",
]
# japanese
ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky',
'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'v', 'w', 'y', 'z']
ja_symbols = [
"I",
"N",
"U",
"a",
"b",
"by",
"ch",
"cl",
"d",
"dy",
"e",
"f",
"g",
"gy",
"h",
"hy",
"i",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"p",
"py",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"u",
"v",
"w",
"y",
"z",
]
arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'}
arpa = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
symbols = sorted(set(symbols))
if __name__ == '__main__':
print(len(symbols))
if __name__ == "__main__":
print(len(symbols))

View File

@ -19,51 +19,442 @@ from pypinyin import lazy_pinyin
from pypinyin import Style
class ToneSandhi():
class ToneSandhi:
def __init__(self):
self.must_neural_tone_words = {
'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝',
'难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊',
'里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去',
'软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号',
'认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当',
'蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻',
'舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂',
'胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆',
'老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂',
'精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿',
'窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台',
'码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算',
'白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨',
'琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快',
'爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜',
'溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔',
'棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事',
'木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾',
'收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼',
'抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实',
'扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头',
'念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼',
'干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数',
'屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气',
'实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈',
'姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方',
'大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴',
'嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦',
'咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝',
'叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹',
'功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息',
'凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤',
'佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家',
'交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故',
'不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨',
'父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅',
'幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱',
'凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱',
'扫把', '惦记'
"麻烦",
"麻利",
"鸳鸯",
"高粱",
"骨头",
"骆驼",
"马虎",
"首饰",
"馒头",
"馄饨",
"风筝",
"难为",
"队伍",
"阔气",
"闺女",
"门道",
"锄头",
"铺盖",
"铃铛",
"铁匠",
"钥匙",
"里脊",
"里头",
"部分",
"那么",
"道士",
"造化",
"迷糊",
"连累",
"这么",
"这个",
"运气",
"过去",
"软和",
"转悠",
"踏实",
"跳蚤",
"跟头",
"趔趄",
"财主",
"豆腐",
"讲究",
"记性",
"记号",
"认识",
"规矩",
"见识",
"裁缝",
"补丁",
"衣裳",
"衣服",
"衙门",
"街坊",
"行李",
"行当",
"蛤蟆",
"蘑菇",
"薄荷",
"葫芦",
"葡萄",
"萝卜",
"荸荠",
"苗条",
"苗头",
"苍蝇",
"芝麻",
"舒服",
"舒坦",
"舌头",
"自在",
"膏药",
"脾气",
"脑袋",
"脊梁",
"能耐",
"胳膊",
"胭脂",
"胡萝",
"胡琴",
"胡同",
"聪明",
"耽误",
"耽搁",
"耷拉",
"耳朵",
"老爷",
"老实",
"老婆",
"老头",
"老太",
"翻腾",
"罗嗦",
"罐头",
"编辑",
"结实",
"红火",
"累赘",
"糨糊",
"糊涂",
"精神",
"粮食",
"簸箕",
"篱笆",
"算计",
"算盘",
"答应",
"笤帚",
"笑语",
"笑话",
"窟窿",
"窝囊",
"窗户",
"稳当",
"稀罕",
"称呼",
"秧歌",
"秀气",
"秀才",
"福气",
"祖宗",
"砚台",
"码头",
"石榴",
"石头",
"石匠",
"知识",
"眼睛",
"眯缝",
"眨巴",
"眉毛",
"相声",
"盘算",
"白净",
"痢疾",
"痛快",
"疟疾",
"疙瘩",
"疏忽",
"畜生",
"生意",
"甘蔗",
"琵琶",
"琢磨",
"琉璃",
"玻璃",
"玫瑰",
"玄乎",
"狐狸",
"状元",
"特务",
"牲口",
"牙碜",
"牌楼",
"爽快",
"爱人",
"热闹",
"烧饼",
"烟筒",
"烂糊",
"点心",
"炊帚",
"灯笼",
"火候",
"漂亮",
"滑溜",
"溜达",
"温和",
"清楚",
"消息",
"浪头",
"活泼",
"比方",
"正经",
"欺负",
"模糊",
"槟榔",
"棺材",
"棒槌",
"棉花",
"核桃",
"栅栏",
"柴火",
"架势",
"枕头",
"枇杷",
"机灵",
"本事",
"木头",
"木匠",
"朋友",
"月饼",
"月亮",
"暖和",
"明白",
"时候",
"新鲜",
"故事",
"收拾",
"收成",
"提防",
"挖苦",
"挑剔",
"指甲",
"指头",
"拾掇",
"拳头",
"拨弄",
"招牌",
"招呼",
"抬举",
"护士",
"折腾",
"扫帚",
"打量",
"打算",
"打点",
"打扮",
"打听",
"打发",
"扎实",
"扁担",
"戒指",
"懒得",
"意识",
"意思",
"情形",
"悟性",
"怪物",
"思量",
"怎么",
"念头",
"念叨",
"快活",
"忙活",
"志气",
"心思",
"得罪",
"张罗",
"弟兄",
"开通",
"应酬",
"庄稼",
"干事",
"帮手",
"帐篷",
"希罕",
"师父",
"师傅",
"巴结",
"巴掌",
"差事",
"工夫",
"岁数",
"屁股",
"尾巴",
"少爷",
"小气",
"小伙",
"将就",
"对头",
"对付",
"寡妇",
"家伙",
"客气",
"实在",
"官司",
"学问",
"学生",
"字号",
"嫁妆",
"媳妇",
"媒人",
"婆家",
"娘家",
"委屈",
"姑娘",
"姐夫",
"妯娌",
"妥当",
"妖精",
"奴才",
"女婿",
"头发",
"太阳",
"大爷",
"大方",
"大意",
"大夫",
"多少",
"多么",
"外甥",
"壮实",
"地道",
"地方",
"在乎",
"困难",
"嘴巴",
"嘱咐",
"嘟囔",
"嘀咕",
"喜欢",
"喇嘛",
"喇叭",
"商量",
"唾沫",
"哑巴",
"哈欠",
"哆嗦",
"咳嗽",
"和尚",
"告诉",
"告示",
"含糊",
"吓唬",
"后头",
"名字",
"名堂",
"合同",
"吆喝",
"叫唤",
"口袋",
"厚道",
"厉害",
"千斤",
"包袱",
"包涵",
"匀称",
"勤快",
"动静",
"动弹",
"功夫",
"力气",
"前头",
"刺猬",
"刺激",
"别扭",
"利落",
"利索",
"利害",
"分析",
"出息",
"凑合",
"凉快",
"冷战",
"冤枉",
"冒失",
"养活",
"关系",
"先生",
"兄弟",
"便宜",
"使唤",
"佩服",
"作坊",
"体面",
"位置",
"似的",
"伙计",
"休息",
"什么",
"人家",
"亲戚",
"亲家",
"交情",
"云彩",
"事情",
"买卖",
"主意",
"丫头",
"丧气",
"两口",
"东西",
"东家",
"世故",
"不由",
"不在",
"下水",
"下巴",
"上头",
"上司",
"丈夫",
"丈人",
"一辈",
"那个",
"菩萨",
"父亲",
"母亲",
"咕噜",
"邋遢",
"费用",
"冤家",
"甜头",
"介绍",
"荒唐",
"大人",
"泥鳅",
"幸福",
"熟悉",
"计划",
"扑腾",
"蜡烛",
"姥爷",
"照顾",
"喉咙",
"吉他",
"弄堂",
"蚂蚱",
"凤凰",
"拖沓",
"寒碜",
"糟蹋",
"倒腾",
"报复",
"逻辑",
"盘缠",
"喽啰",
"牢骚",
"咖喱",
"扫把",
"惦记",
}
self.must_not_neural_tone_words = {
"男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子", "人人", "虎虎"
"男子",
"女子",
"分子",
"原子",
"量子",
"莲子",
"石子",
"瓜子",
"电子",
"人人",
"虎虎",
}
self.punc = ":,;。?!“”‘’':,;.?!"
@ -72,14 +463,15 @@ class ToneSandhi():
# word: "家里"
# pos: "s"
# finals: ['ia1', 'i3']
def _neural_sandhi(self, word: str, pos: str,
finals: List[str]) -> List[str]:
def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
for j, item in enumerate(word):
if j - 1 >= 0 and item == word[j - 1] and pos[0] in {
"n", "v", "a"
} and word not in self.must_not_neural_tone_words:
if (
j - 1 >= 0
and item == word[j - 1]
and pos[0] in {"n", "v", "a"}
and word not in self.must_not_neural_tone_words
):
finals[j] = finals[j][:-1] + "5"
ge_idx = word.find("")
if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
@ -89,9 +481,12 @@ class ToneSandhi():
# e.g. 走了, 看着, 去过
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
finals[-1] = finals[-1][:-1] + "5"
elif len(word) > 1 and word[-1] in "们子" and pos in {
"r", "n"
} and word not in self.must_not_neural_tone_words:
elif (
len(word) > 1
and word[-1] in "们子"
and pos in {"r", "n"}
and word not in self.must_not_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
@ -100,21 +495,26 @@ class ToneSandhi():
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
finals[-1] = finals[-1][:-1] + "5"
# 个做量词
elif (ge_idx >= 1 and
(word[ge_idx - 1].isnumeric() or
word[ge_idx - 1] in "几有两半多各整每做是")) or word == '':
elif (
ge_idx >= 1
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
) or word == "":
finals[ge_idx] = finals[ge_idx][:-1] + "5"
else:
if word in self.must_neural_tone_words or word[
-2:] in self.must_neural_tone_words:
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals[-1] = finals[-1][:-1] + "5"
word_list = self._split_word(word)
finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]]
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
for i, word in enumerate(word_list):
# conventional neural in Chinese
if word in self.must_neural_tone_words or word[
-2:] in self.must_neural_tone_words:
if (
word in self.must_neural_tone_words
or word[-2:] in self.must_neural_tone_words
):
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, [])
return finals
@ -126,15 +526,15 @@ class ToneSandhi():
else:
for i, char in enumerate(word):
# "不" before tone4 should be bu2, e.g. 不怕
if char == "" and i + 1 < len(word) and finals[i +
1][-1] == "4":
if char == "" and i + 1 < len(word) and finals[i + 1][-1] == "4":
finals[i] = finals[i][:-1] + "2"
return finals
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
# "一" in number sequences, e.g. 一零零, 二一零
if word.find("") != -1 and all(
[item.isnumeric() for item in word if item != ""]):
[item.isnumeric() for item in word if item != ""]
):
return finals
# "一" between reduplication words shold be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
@ -161,10 +561,10 @@ class ToneSandhi():
first_subword = word_list[0]
first_begin_idx = word.find(first_subword)
if first_begin_idx == 0:
second_subword = word[len(first_subword):]
second_subword = word[len(first_subword) :]
new_word_list = [first_subword, second_subword]
else:
second_subword = word[:-len(first_subword)]
second_subword = word[: -len(first_subword)]
new_word_list = [second_subword, first_subword]
return new_word_list
@ -182,18 +582,19 @@ class ToneSandhi():
elif len(word_list[0]) == 1:
finals[1] = finals[1][:-1] + "2"
else:
finals_list = [
finals[:len(word_list[0])], finals[len(word_list[0]):]
]
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
if len(finals_list) == 2:
for i, sub in enumerate(finals_list):
# e.g. 所有/人
if self._all_tone_three(sub) and len(sub) == 2:
finals_list[i][0] = finals_list[i][0][:-1] + "2"
# e.g. 好/喜欢
elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \
finals_list[0][-1][-1] == "3":
elif (
i == 1
and not self._all_tone_three(sub)
and finals_list[i][0][-1] == "3"
and finals_list[0][-1][-1] == "3"
):
finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
finals = sum(finals_list, [])
# split idiom into two words who's length is 2
@ -222,7 +623,7 @@ class ToneSandhi():
new_seg.append((word, pos))
last_word = word[:]
if last_word == "":
new_seg.append((last_word, 'd'))
new_seg.append((last_word, "d"))
last_word = ""
return new_seg
@ -236,12 +637,21 @@ class ToneSandhi():
new_seg = []
# function 1
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and i + 1 < len(seg) and seg[i - 1][
0] == seg[i + 1][0] and seg[i - 1][1] == "v":
if (
i - 1 >= 0
and word == ""
and i + 1 < len(seg)
and seg[i - 1][0] == seg[i + 1][0]
and seg[i - 1][1] == "v"
):
new_seg[i - 1][0] = new_seg[i - 1][0] + "" + new_seg[i - 1][0]
else:
if i - 2 >= 0 and seg[i - 1][0] == "" and seg[i - 2][
0] == word and pos == "v":
if (
i - 2 >= 0
and seg[i - 1][0] == ""
and seg[i - 2][0] == word
and pos == "v"
):
continue
else:
new_seg.append([word, pos])
@ -257,22 +667,27 @@ class ToneSandhi():
# the first and the second words are all_tone_three
def _merge_continuous_three_tones(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and self._all_tone_three(
sub_finals_list[i - 1]) and self._all_tone_three(
sub_finals_list[i]) and not merge_last[i - 1]:
if (
i - 1 >= 0
and self._all_tone_three(sub_finals_list[i - 1])
and self._all_tone_three(sub_finals_list[i])
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len(
seg[i - 1][0]) + len(seg[i][0]) <= 3:
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
@ -287,21 +702,27 @@ class ToneSandhi():
# the last char of first word and the first char of second word is tone_three
def _merge_continuous_three_tones_2(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
self, seg: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_seg = []
sub_finals_list = [
lazy_pinyin(
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for (word, pos) in seg
]
assert len(sub_finals_list) == len(seg)
merge_last = [False] * len(seg)
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \
merge_last[i - 1]:
if (
i - 1 >= 0
and sub_finals_list[i - 1][-1][-1] == "3"
and sub_finals_list[i][0][-1] == "3"
and not merge_last[i - 1]
):
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
if not self._is_reduplication(seg[i - 1][0]) and len(
seg[i - 1][0]) + len(seg[i][0]) <= 3:
if (
not self._is_reduplication(seg[i - 1][0])
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
):
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
merge_last[i] = True
else:
@ -313,14 +734,13 @@ class ToneSandhi():
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if i - 1 >= 0 and word == "" and seg[i-1][0] != "#":
if i - 1 >= 0 and word == "" and seg[i - 1][0] != "#":
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
else:
new_seg.append([word, pos])
return new_seg
def _merge_reduplication(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
new_seg = []
for i, (word, pos) in enumerate(seg):
if new_seg and word == new_seg[-1][0]:
@ -329,8 +749,7 @@ class ToneSandhi():
new_seg.append([word, pos])
return new_seg
def pre_merge_for_modify(
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
seg = self._merge_bu(seg)
try:
seg = self._merge_yi(seg)
@ -349,8 +768,7 @@ class ToneSandhi():
seg = self._merge_er(seg)
return seg
def modified_tone(self, word: str, pos: str,
finals: List[str]) -> List[str]:
def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals)

View File

@ -12,8 +12,9 @@ import numpy as np
from scipy.io.wavfile import read
import torch
import logging
logging.getLogger('numba').setLevel(logging.ERROR)
logging.getLogger('matplotlib').setLevel(logging.ERROR)
logging.getLogger("numba").setLevel(logging.ERROR)
logging.getLogger("matplotlib").setLevel(logging.ERROR)
MATPLOTLIB_FLAG = False
@ -23,13 +24,17 @@ logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
learning_rate = checkpoint_dict['learning_rate']
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if (
optimizer is not None
and not skip_optimizer
and checkpoint_dict["optimizer"] is not None
):
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
@ -39,41 +44,63 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
# assert "quantizer" not in k
# print("load", k)
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except:
traceback.print_exc()
print("error, %s is not in the checkpoint" % k)#shape不对也会比如text_embedding当cleaner修改时
print(
"error, %s is not in the checkpoint" % k
) # shape不对也会比如text_embedding当cleaner修改时
new_state_dict[k] = v
if hasattr(model, 'module'):
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
print("load ")
logger.info("Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path, iteration))
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
)
return model, optimizer, learning_rate, iteration
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info("Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path))
if hasattr(model, 'module'):
logger.info(
"Saving model and optimizer state at iteration {} to {}".format(
iteration, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save({'model': state_dict,
'iteration': iteration,
'optimizer': optimizer.state_dict(),
'learning_rate': learning_rate}, checkpoint_path)
torch.save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats='HWC')
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
@ -90,23 +117,23 @@ def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
@ -116,26 +143,28 @@ def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
interpolation='none')
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
xlabel = "Decoder timestep"
if info is not None:
xlabel += '\n\n' + info
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
@ -147,16 +176,31 @@ def load_wav_to_torch(full_path):
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True, stage=1):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/s2.json",help='JSON file for configuration')
parser.add_argument('-p', '--pretrain', type=str, required=False,default=None,help='pretrain dir')
parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None,help='resume step')
parser.add_argument(
"-c",
"--config",
type=str,
default="./configs/s2.json",
help="JSON file for configuration",
)
parser.add_argument(
"-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
)
parser.add_argument(
"-rs",
"--resume_step",
type=int,
required=False,
default=None,
help="resume step",
)
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
@ -172,7 +216,7 @@ def get_hparams(init=True, stage=1):
hparams.pretrain = args.pretrain
hparams.resume_step = args.resume_step
# hparams.data.exp_dir = args.exp_dir
if stage ==1:
if stage == 1:
model_dir = hparams.s1_ckpt_dir
else:
model_dir = hparams.s2_ckpt_dir
@ -186,29 +230,38 @@ def get_hparams(init=True, stage=1):
return hparams
def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
import re
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
ckpts_files = [
f
for f in os.listdir(path_to_models)
if os.path.isfile(os.path.join(path_to_models, f))
]
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
key=sort_key)
to_del = [os.path.join(path_to_models, fn) for fn in
(x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
x_sorted = lambda _x: sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn)
for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
@ -228,12 +281,15 @@ def get_hparams_from_file(config_path):
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
))
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
@ -242,8 +298,11 @@ def check_git_hash(model_dir):
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]))
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]
)
)
else:
open(path, "w").write(cur_hash)
@ -263,7 +322,7 @@ def get_logger(model_dir, filename="train.log"):
return logger
class HParams():
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
@ -294,5 +353,10 @@ class HParams():
def __repr__(self):
return self.__dict__.__repr__()
if __name__ == '__main__':
print(load_wav_to_torch('/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac'))
if __name__ == "__main__":
print(
load_wav_to_torch(
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
)
)

View File

@ -71,6 +71,8 @@ conda install ffmpeg
```bash
sudo apt install ffmpeg
sudo apt install libsox-dev
conda install -c conda-forge 'ffmpeg<7'
```
##### MacOS Users

View File

@ -18,3 +18,4 @@ modelscope
sentencepiece
transformers
chardet
PyYAML

View File

@ -1,7 +1,7 @@
import os
import traceback,gradio as gr
import logging
from i18n.i18n import I18nAuto
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
logger = logging.getLogger(__name__)

1321
webui.py

File diff suppressed because it is too large Load Diff