mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Merge 77917ef6e25d37553c89de60718e4c8626fdcb9d into 9da7e17efe05041e31d3c3f42c8730ae890397f2
This commit is contained in:
commit
fcc0d24825
@ -1,5 +1,8 @@
|
|||||||
# Download moda ASR related models
|
# Download moda ASR related models
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
|
|
||||||
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
|
model_dir = snapshot_download(
|
||||||
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
|
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4"
|
||||||
|
)
|
||||||
|
model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4")
|
||||||
|
model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4")
|
||||||
|
@ -4,14 +4,11 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
from typing import Iterator
|
from typing import Iterator, Optional, TypeVar
|
||||||
from typing import Optional
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset, Sampler
|
||||||
from torch.utils.data import Sampler
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DistributedBucketSampler",
|
"DistributedBucketSampler",
|
||||||
@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
if rank >= num_replicas or rank < 0:
|
if rank >= num_replicas or rank < 0:
|
||||||
raise ValueError(
|
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||||
"Invalid rank {}, rank should be in the interval"
|
|
||||||
" [0, {}]".format(rank, num_replicas - 1)
|
|
||||||
)
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
# If the dataset length is evenly divisible by # of replicas, then there
|
# If the dataset length is evenly divisible by # of replicas, then there
|
||||||
# is no need to drop any data, since the dataset will be split equally.
|
# is no need to drop any data, since the dataset will be split equally.
|
||||||
if (
|
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||||
self.drop_last and len(self.dataset) % self.num_replicas != 0
|
|
||||||
): # type: ignore[arg-type]
|
|
||||||
# Split to nearest available length that is evenly divisible.
|
# Split to nearest available length that is evenly divisible.
|
||||||
# This is to ensure each rank receives the same amount of data when
|
# This is to ensure each rank receives the same amount of data when
|
||||||
# using this Sampler.
|
# using this Sampler.
|
||||||
self.num_samples = math.ceil(
|
self.num_samples = math.ceil(
|
||||||
(len(self.dataset) - self.num_replicas)
|
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
|
||||||
/ self.num_replicas # type: ignore[arg-type]
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.num_samples = math.ceil(
|
self.num_samples = math.ceil(
|
||||||
len(self.dataset) / self.num_replicas
|
len(self.dataset) / self.num_replicas,
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
grouped_batch_size = self.batch_size * self.num_replicas
|
grouped_batch_size = self.batch_size * self.num_replicas
|
||||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||||
batches = [
|
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
|
||||||
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
|
|
||||||
for b in range(n_batch)
|
|
||||||
]
|
|
||||||
shuffle(batches)
|
shuffle(batches)
|
||||||
indices = list(itertools.chain(*batches))
|
indices = list(itertools.chain(*batches))
|
||||||
else:
|
else:
|
||||||
@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
if padding_size <= len(indices):
|
if padding_size <= len(indices):
|
||||||
indices += indices[:padding_size]
|
indices += indices[:padding_size]
|
||||||
else:
|
else:
|
||||||
indices += (indices * math.ceil(padding_size / len(indices)))[
|
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||||
:padding_size
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
# remove tail of data to make it evenly divisible.
|
# remove tail of data to make it evenly divisible.
|
||||||
indices = indices[: self.total_size]
|
indices = indices[: self.total_size]
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
from pytorch_lightning import LightningDataModule
|
from pytorch_lightning import LightningDataModule
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||||
from AR.data.dataset import Text2SemanticDataset
|
from AR.data.dataset import Text2SemanticDataset
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticDataModule(LightningDataModule):
|
class Text2SemanticDataModule(LightningDataModule):
|
||||||
@ -42,7 +43,11 @@ class Text2SemanticDataModule(LightningDataModule):
|
|||||||
# pad_val=self.config['data']['pad_val'])
|
# pad_val=self.config['data']['pad_val'])
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
batch_size = (
|
||||||
|
self.config["train"]["batch_size"] // 2
|
||||||
|
if self.config["train"].get("if_dpo", False) is True
|
||||||
|
else self.config["train"]["batch_size"]
|
||||||
|
)
|
||||||
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
||||||
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
|
@ -1,21 +1,17 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import pdb
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
||||||
import traceback, os
|
import os
|
||||||
from typing import Dict
|
import traceback
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch, json
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
version = os.environ.get('version',None)
|
version = os.environ.get("version", None)
|
||||||
|
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
|
|
||||||
@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
|
|||||||
|
|
||||||
padded_sequences = []
|
padded_sequences = []
|
||||||
for seq, length in zip(sequences, seq_lengths):
|
for seq, length in zip(sequences, seq_lengths):
|
||||||
padding = (
|
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
|
||||||
)
|
|
||||||
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
||||||
padded_sequences.append(padded_seq)
|
padded_sequences.append(padded_seq)
|
||||||
batch = np.stack(padded_sequences)
|
batch = np.stack(padded_sequences)
|
||||||
@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.semantic_data = pd.read_csv(
|
self.semantic_data = pd.read_csv(
|
||||||
semantic_path, delimiter="\t", encoding="utf-8"
|
semantic_path,
|
||||||
|
delimiter="\t",
|
||||||
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
# get dict
|
# get dict
|
||||||
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
||||||
self.path3 = "%s/3-bert" % (
|
self.path3 = "%s/3-bert" % (
|
||||||
os.path.dirname(phoneme_path)
|
os.path.dirname(
|
||||||
|
phoneme_path,
|
||||||
|
)
|
||||||
) # "%s/3-bert"%exp_dir#bert_dir
|
) # "%s/3-bert"%exp_dir#bert_dir
|
||||||
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||||
assert os.path.exists(self.path2)
|
assert os.path.exists(self.path2)
|
||||||
@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
|
|||||||
num_not_in += 1
|
num_not_in += 1
|
||||||
continue
|
continue
|
||||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||||
if (
|
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
|
||||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
|
||||||
): ###########2:改为恒定限制为semantic/2.5就行
|
|
||||||
num_deleted_ps += 1
|
num_deleted_ps += 1
|
||||||
continue
|
continue
|
||||||
# if len(semantic_ids) > 1000:###########3
|
# if len(semantic_ids) > 1000:###########3
|
||||||
@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
|
|||||||
|
|
||||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||||
|
|
||||||
if (
|
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
|
||||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
|
||||||
): ##########4#3~25#每秒多少个phone
|
|
||||||
num_deleted_ps += 1
|
num_deleted_ps += 1
|
||||||
# print(item_name)
|
# print(item_name)
|
||||||
continue
|
continue
|
||||||
@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
|
|||||||
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
||||||
if num_deleted_bigger > 0:
|
if num_deleted_bigger > 0:
|
||||||
print(
|
print(
|
||||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
|
||||||
)
|
)
|
||||||
if num_deleted_ps > 0:
|
if num_deleted_ps > 0:
|
||||||
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
||||||
print(
|
print(
|
||||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
there are 31 semantic datas not in phoneme datas
|
there are 31 semantic datas not in phoneme datas
|
||||||
@ -306,7 +300,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
batch_size = 12
|
batch_size = 12
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
collate_fn=dataset.collate,
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
for i, batch in enumerate(dataloader):
|
for i, batch in enumerate(dataloader):
|
||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -8,10 +9,12 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import LightningModule
|
from pytorch_lightning import LightningModule
|
||||||
|
|
||||||
from AR.models.t2s_model import Text2SemanticDecoder
|
from AR.models.t2s_model import Text2SemanticDecoder
|
||||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
from AR.modules.optim import ScaledAdam
|
from AR.modules.optim import ScaledAdam
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticLightningModule(LightningModule):
|
class Text2SemanticLightningModule(LightningModule):
|
||||||
def __init__(self, config, output_dir, is_train=True):
|
def __init__(self, config, output_dir, is_train=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -23,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||||
print(
|
print(
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
torch.load(
|
||||||
|
pretrained_s1,
|
||||||
|
map_location="cpu",
|
||||||
|
)["weight"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if is_train:
|
if is_train:
|
||||||
@ -113,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
model_parameters = self.model.parameters()
|
model_parameters = self.model.parameters()
|
||||||
parameters_names = []
|
parameters_names = []
|
||||||
parameters_names.append(
|
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
|
||||||
)
|
|
||||||
lm_opt = ScaledAdam(
|
lm_opt = ScaledAdam(
|
||||||
model_parameters,
|
model_parameters,
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -8,6 +9,7 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import LightningModule
|
from pytorch_lightning import LightningModule
|
||||||
|
|
||||||
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
from AR.modules.optim import ScaledAdam
|
from AR.modules.optim import ScaledAdam
|
||||||
@ -24,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||||
print(
|
print(
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
torch.load(
|
||||||
)
|
pretrained_s1,
|
||||||
|
map_location="cpu",
|
||||||
|
)["weight"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if is_train:
|
if is_train:
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization = False
|
||||||
@ -79,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
model_parameters = self.model.parameters()
|
model_parameters = self.model.parameters()
|
||||||
parameters_names = []
|
parameters_names = []
|
||||||
parameters_names.append(
|
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
|
||||||
)
|
|
||||||
lm_opt = ScaledAdam(
|
lm_opt = ScaledAdam(
|
||||||
model_parameters,
|
model_parameters,
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
|
@ -2,27 +2,24 @@
|
|||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from AR.models.utils import make_pad_mask, make_pad_mask_left
|
import torch
|
||||||
from AR.models.utils import (
|
|
||||||
topk_sampling,
|
|
||||||
sample,
|
|
||||||
logits_to_probs,
|
|
||||||
multinomial_sample_one_no_sync,
|
|
||||||
dpo_loss,
|
|
||||||
make_reject_y,
|
|
||||||
get_batch_logps
|
|
||||||
)
|
|
||||||
from AR.modules.embedding import SinePositionalEmbedding
|
|
||||||
from AR.modules.embedding import TokenEmbedding
|
|
||||||
from AR.modules.transformer import LayerNorm
|
|
||||||
from AR.modules.transformer import TransformerEncoder
|
|
||||||
from AR.modules.transformer import TransformerEncoderLayer
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from AR.models.utils import (
|
||||||
|
dpo_loss,
|
||||||
|
get_batch_logps,
|
||||||
|
make_pad_mask,
|
||||||
|
make_pad_mask_left,
|
||||||
|
make_reject_y,
|
||||||
|
sample,
|
||||||
|
topk_sampling,
|
||||||
|
)
|
||||||
|
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||||
|
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
default_config = {
|
default_config = {
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
@ -36,9 +33,16 @@ default_config = {
|
|||||||
"EOS": 1024,
|
"EOS": 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
||||||
# Efficient implementation equivalent to the following:
|
# Efficient implementation equivalent to the following:
|
||||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
def scaled_dot_product_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||||
@ -65,6 +69,7 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
|
|||||||
|
|
||||||
return attn_weight @ value
|
return attn_weight @ value
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class T2SMLP:
|
class T2SMLP:
|
||||||
def __init__(self, w1, b1, w2, b2):
|
def __init__(self, w1, b1, w2, b2):
|
||||||
@ -114,7 +119,11 @@ class T2SBlock:
|
|||||||
self.false = torch.tensor(False, dtype=torch.bool)
|
self.false = torch.tensor(False, dtype=torch.bool)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
def to_mask(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
if padding_mask is None:
|
if padding_mask is None:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -123,9 +132,13 @@ class T2SBlock:
|
|||||||
else:
|
else:
|
||||||
return x * padding_mask
|
return x * padding_mask
|
||||||
|
|
||||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
def process_prompt(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
torch_sdpa: bool = True,
|
||||||
|
):
|
||||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||||
|
|
||||||
batch_size = q.shape[0]
|
batch_size = q.shape[0]
|
||||||
@ -149,9 +162,7 @@ class T2SBlock:
|
|||||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||||
|
|
||||||
x = x + attn
|
x = x + attn
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
|
||||||
)
|
|
||||||
x = x + self.mlp.forward(x)
|
x = x + self.mlp.forward(x)
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(
|
||||||
x,
|
x,
|
||||||
@ -162,7 +173,14 @@ class T2SBlock:
|
|||||||
)
|
)
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
|
def decode_next_token(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
torch_sdpa: bool = True,
|
||||||
|
):
|
||||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||||
|
|
||||||
k_cache = torch.cat([k_cache, k], dim=1)
|
k_cache = torch.cat([k_cache, k], dim=1)
|
||||||
@ -176,7 +194,6 @@ class T2SBlock:
|
|||||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
if torch_sdpa:
|
if torch_sdpa:
|
||||||
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
||||||
else:
|
else:
|
||||||
@ -187,7 +204,11 @@ class T2SBlock:
|
|||||||
|
|
||||||
x = x + attn
|
x = x + attn
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(
|
||||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
x,
|
||||||
|
[self.hidden_dim],
|
||||||
|
self.norm_w1,
|
||||||
|
self.norm_b1,
|
||||||
|
self.norm_eps1,
|
||||||
)
|
)
|
||||||
x = x + self.mlp.forward(x)
|
x = x + self.mlp.forward(x)
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(
|
||||||
@ -207,9 +228,11 @@ class T2STransformer:
|
|||||||
self.blocks = blocks
|
self.blocks = blocks
|
||||||
|
|
||||||
def process_prompt(
|
def process_prompt(
|
||||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
torch_sdpa:bool=True
|
torch_sdpa: bool = True,
|
||||||
):
|
):
|
||||||
k_cache: List[torch.Tensor] = []
|
k_cache: List[torch.Tensor] = []
|
||||||
v_cache: List[torch.Tensor] = []
|
v_cache: List[torch.Tensor] = []
|
||||||
@ -220,14 +243,17 @@ class T2STransformer:
|
|||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(
|
def decode_next_token(
|
||||||
self, x:torch.Tensor,
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
k_cache: List[torch.Tensor],
|
k_cache: List[torch.Tensor],
|
||||||
v_cache: List[torch.Tensor],
|
v_cache: List[torch.Tensor],
|
||||||
attn_mask: torch.Tensor = None,
|
attn_mask: torch.Tensor = None,
|
||||||
torch_sdpa:bool=True
|
torch_sdpa: bool = True,
|
||||||
):
|
):
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
|
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
|
||||||
|
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
|
||||||
|
)
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# assert self.EOS == 1024
|
# assert self.EOS == 1024
|
||||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||||
self.ar_text_embedding = TokenEmbedding(
|
self.ar_text_embedding = TokenEmbedding(
|
||||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
self.embedding_dim,
|
||||||
|
self.phoneme_vocab_size,
|
||||||
|
self.p_dropout,
|
||||||
)
|
)
|
||||||
self.ar_text_position = SinePositionalEmbedding(
|
self.ar_text_position = SinePositionalEmbedding(
|
||||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
self.embedding_dim,
|
||||||
|
dropout=0.1,
|
||||||
|
scale=False,
|
||||||
|
alpha=True,
|
||||||
)
|
)
|
||||||
self.ar_audio_embedding = TokenEmbedding(
|
self.ar_audio_embedding = TokenEmbedding(
|
||||||
self.embedding_dim, self.vocab_size, self.p_dropout
|
self.embedding_dim,
|
||||||
|
self.vocab_size,
|
||||||
|
self.p_dropout,
|
||||||
)
|
)
|
||||||
self.ar_audio_position = SinePositionalEmbedding(
|
self.ar_audio_position = SinePositionalEmbedding(
|
||||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
self.embedding_dim,
|
||||||
|
dropout=0.1,
|
||||||
|
scale=False,
|
||||||
|
alpha=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.h = TransformerEncoder(
|
self.h = TransformerEncoder(
|
||||||
@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
layer.linear1.weight,
|
layer.linear1.weight,
|
||||||
layer.linear1.bias,
|
layer.linear1.bias,
|
||||||
layer.linear2.weight,
|
layer.linear2.weight,
|
||||||
layer.linear2.bias
|
layer.linear2.bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
block = T2SBlock(
|
block = T2SBlock(
|
||||||
@ -309,7 +345,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
layer.norm1.eps,
|
layer.norm1.eps,
|
||||||
layer.norm2.weight,
|
layer.norm2.weight,
|
||||||
layer.norm2.bias,
|
layer.norm2.bias,
|
||||||
layer.norm2.eps
|
layer.norm2.eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||||
|
|
||||||
###### DPO #############
|
###### DPO #############
|
||||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||||
|
x, x_lens, reject_y, reject_y_lens, bert_feature
|
||||||
|
)
|
||||||
|
|
||||||
reject_xy_dec, _ = self.h(
|
reject_xy_dec, _ = self.h(
|
||||||
(reject_xy_pos, None),
|
(reject_xy_pos, None),
|
||||||
@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
||||||
y.device
|
|
||||||
)
|
|
||||||
|
|
||||||
xy_dec, _ = self.h(
|
xy_dec, _ = self.h(
|
||||||
(xy_pos, None),
|
(xy_pos, None),
|
||||||
mask=xy_attn_mask,
|
mask=xy_attn_mask,
|
||||||
)
|
)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = topk_sampling(
|
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||||
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
print("use early stop num:", early_stop_num)
|
print("use early stop num:", early_stop_num)
|
||||||
@ -542,9 +576,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
||||||
y_mask_int, (0, 1), value=1
|
|
||||||
)
|
|
||||||
# 错位
|
# 错位
|
||||||
return targets[:, :-1], targets[:, 1:]
|
return targets[:, :-1], targets[:, 1:]
|
||||||
|
|
||||||
@ -563,8 +595,17 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
):
|
):
|
||||||
if prompts is None:
|
if prompts is None:
|
||||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||||
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
return self.infer_panel_naive_batched(
|
||||||
|
x,
|
||||||
|
x_lens,
|
||||||
|
prompts,
|
||||||
|
bert_feature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
early_stop_num=early_stop_num,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
max_len = kwargs.get("max_len", x_lens.max())
|
max_len = kwargs.get("max_len", x_lens.max())
|
||||||
x_list = []
|
x_list = []
|
||||||
@ -574,11 +615,12 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||||
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
||||||
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
|
x_item = (
|
||||||
|
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
|
||||||
|
) ### padding left
|
||||||
x_list.append(x_item)
|
x_list.append(x_item)
|
||||||
x: torch.Tensor = torch.stack(x_list, dim=0)
|
x: torch.Tensor = torch.stack(x_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
# AR Decoder
|
# AR Decoder
|
||||||
y = prompts
|
y = prompts
|
||||||
|
|
||||||
@ -598,8 +640,6 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##### create mask #####
|
##### create mask #####
|
||||||
bsz = x.shape[0]
|
bsz = x.shape[0]
|
||||||
src_len = x_len + y_len
|
src_len = x_len + y_len
|
||||||
@ -642,7 +682,6 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||||
|
|
||||||
|
|
||||||
# 正确的attn_mask应该是这样的:
|
# 正确的attn_mask应该是这样的:
|
||||||
# | pad_len | x_len | y_len |
|
# | pad_len | x_len | y_len |
|
||||||
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||||
@ -655,7 +694,6 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||||
|
|
||||||
|
|
||||||
###### decode #####
|
###### decode #####
|
||||||
y_list = [None] * y.shape[0]
|
y_list = [None] * y.shape[0]
|
||||||
batch_idx_map = list(range(y.shape[0]))
|
batch_idx_map = list(range(y.shape[0]))
|
||||||
@ -665,9 +703,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||||
else:
|
else:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||||
logits = self.ar_predict_layer(
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
xy_dec[:, -1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||||
@ -684,8 +720,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||||
tokens = torch.argmax(logits, dim=-1)
|
tokens = torch.argmax(logits, dim=-1)
|
||||||
reserved_idx_of_batch_for_y = None
|
reserved_idx_of_batch_for_y = None
|
||||||
if (self.EOS in samples[:, 0]) or \
|
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
|
||||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
|
||||||
l1 = samples[:, 0] == self.EOS
|
l1 = samples[:, 0] == self.EOS
|
||||||
l2 = tokens == self.EOS
|
l2 = tokens == self.EOS
|
||||||
l = l1.logical_or(l2)
|
l = l1.logical_or(l2)
|
||||||
@ -709,7 +744,6 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||||
|
|
||||||
|
|
||||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
||||||
print("use early stop num:", early_stop_num)
|
print("use early stop num:", early_stop_num)
|
||||||
stop = True
|
stop = True
|
||||||
@ -718,7 +752,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
idx_list[batch_index] = idx
|
idx_list[batch_index] = idx
|
||||||
y_list[batch_index] = y[i, :-1]
|
y_list[batch_index] = y[i, :-1]
|
||||||
|
|
||||||
if not (None in idx_list):
|
if None not in idx_list:
|
||||||
stop = True
|
stop = True
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
@ -730,9 +764,11 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
if (None in idx_list):
|
if None in idx_list:
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
if idx_list[i] is None:
|
if idx_list[i] is None:
|
||||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||||
@ -742,7 +778,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# print(idx_list)
|
# print(idx_list)
|
||||||
return y_list, idx_list
|
return y_list, idx_list
|
||||||
|
|
||||||
def infer_panel_naive_batched(self,
|
def infer_panel_naive_batched(
|
||||||
|
self,
|
||||||
x: List[torch.LongTensor], #####全部文本token
|
x: List[torch.LongTensor], #####全部文本token
|
||||||
x_lens: torch.LongTensor,
|
x_lens: torch.LongTensor,
|
||||||
prompts: torch.LongTensor, ####参考音频token
|
prompts: torch.LongTensor, ####参考音频token
|
||||||
@ -752,12 +789,13 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
y_list = []
|
y_list = []
|
||||||
idx_list = []
|
idx_list = []
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
y, idx = self.infer_panel_naive(
|
||||||
|
x[i].unsqueeze(0),
|
||||||
x_lens[i],
|
x_lens[i],
|
||||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||||
bert_feature[i].unsqueeze(0),
|
bert_feature[i].unsqueeze(0),
|
||||||
@ -766,7 +804,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
early_stop_num,
|
early_stop_num,
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
y_list.append(y[0])
|
y_list.append(y[0])
|
||||||
idx_list.append(idx)
|
idx_list.append(idx)
|
||||||
|
|
||||||
@ -783,7 +822,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
@ -828,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
xy_attn_mask = (
|
||||||
.unsqueeze(0)\
|
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
.expand(bsz*self.num_head, -1, -1)\
|
.unsqueeze(0)
|
||||||
.view(bsz, self.num_head, src_len, src_len)\
|
.expand(bsz * self.num_head, -1, -1)
|
||||||
|
.view(bsz, self.num_head, src_len, src_len)
|
||||||
.to(device=x.device, dtype=torch.bool)
|
.to(device=x.device, dtype=torch.bool)
|
||||||
|
)
|
||||||
|
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
if xy_attn_mask is not None:
|
if xy_attn_mask is not None:
|
||||||
@ -840,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||||
|
|
||||||
logits = self.ar_predict_layer(
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
xy_dec[:, -1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
xy_attn_mask = None
|
xy_attn_mask = None
|
||||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||||
logits = logits[:, :-1]
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
samples = sample(
|
samples = sample(
|
||||||
@ -870,13 +909,14 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
if ref_free:
|
if ref_free:
|
||||||
return y[:, :-1], 0
|
return y[:, :-1], 0
|
||||||
return y[:, :-1], idx
|
return y[:, :-1], idx
|
||||||
|
|
||||||
|
|
||||||
def infer_panel(
|
def infer_panel(
|
||||||
self,
|
self,
|
||||||
x: torch.LongTensor, #####全部文本token
|
x: torch.LongTensor, #####全部文本token
|
||||||
@ -888,6 +928,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
return self.infer_panel_naive(
|
||||||
|
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||||
|
)
|
||||||
|
@ -1,17 +1,13 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from AR.modules.embedding_onnx import SinePositionalEmbedding
|
|
||||||
from AR.modules.embedding_onnx import TokenEmbedding
|
|
||||||
from AR.modules.transformer_onnx import LayerNorm
|
|
||||||
from AR.modules.transformer_onnx import TransformerEncoder
|
|
||||||
from AR.modules.transformer_onnx import TransformerEncoderLayer
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
|
||||||
|
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||||
|
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
default_config = {
|
default_config = {
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
"hidden_dim": 512,
|
"hidden_dim": 512,
|
||||||
@ -26,6 +22,7 @@ default_config = {
|
|||||||
|
|
||||||
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
||||||
|
|
||||||
|
|
||||||
def logits_to_probs(
|
def logits_to_probs(
|
||||||
logits,
|
logits,
|
||||||
previous_tokens=None,
|
previous_tokens=None,
|
||||||
@ -39,19 +36,27 @@ def logits_to_probs(
|
|||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(
|
||||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
score < 0,
|
||||||
|
score * repetition_penalty,
|
||||||
|
score / repetition_penalty,
|
||||||
)
|
)
|
||||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cum_probs = torch.cumsum(
|
cum_probs = torch.cumsum(
|
||||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
torch.nn.functional.softmax(
|
||||||
|
sorted_logits,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
)
|
)
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
sorted_indices_to_remove[0] = False # keep at least one option
|
sorted_indices_to_remove[0] = False # keep at least one option
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
dim=0,
|
||||||
|
index=sorted_indices,
|
||||||
|
src=sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
@ -67,7 +72,7 @@ def logits_to_probs(
|
|||||||
|
|
||||||
|
|
||||||
def multinomial_sample_one_no_sync(
|
def multinomial_sample_one_no_sync(
|
||||||
probs_sort
|
probs_sort,
|
||||||
): # Does multinomial sampling without a cuda synchronization
|
): # Does multinomial sampling without a cuda synchronization
|
||||||
q = torch.randn_like(probs_sort)
|
q = torch.randn_like(probs_sort)
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
@ -79,7 +84,9 @@ def sample(
|
|||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
probs = logits_to_probs(
|
probs = logits_to_probs(
|
||||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
logits=logits,
|
||||||
|
previous_tokens=previous_tokens,
|
||||||
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
return idx_next, probs
|
return idx_next, probs
|
||||||
@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class T2SFirstStageDecoder(nn.Module):
|
class T2SFirstStageDecoder(nn.Module):
|
||||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
def __init__(
|
||||||
top_k, early_stop_num, num_layers):
|
self,
|
||||||
|
ar_audio_embedding,
|
||||||
|
ar_audio_position,
|
||||||
|
h,
|
||||||
|
ar_predict_layer,
|
||||||
|
loss_fct,
|
||||||
|
ar_accuracy_metric,
|
||||||
|
top_k,
|
||||||
|
early_stop_num,
|
||||||
|
num_layers,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ar_audio_embedding = ar_audio_embedding
|
self.ar_audio_embedding = ar_audio_embedding
|
||||||
self.ar_audio_position = ar_audio_position
|
self.ar_audio_position = ar_audio_position
|
||||||
@ -136,7 +153,11 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
||||||
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||||
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
|
torch.ones_like(
|
||||||
|
y_example.transpose(0, 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
)
|
)
|
||||||
y_attn_mask = y_attn_mask > 0
|
y_attn_mask = y_attn_mask > 0
|
||||||
|
|
||||||
@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
cache["k"] = (
|
||||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||||
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
.unsqueeze(1)
|
||||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
.repeat(self.num_layers, 1, 1, 1)
|
||||||
|
)
|
||||||
|
cache["v"] = (
|
||||||
|
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||||
|
.unsqueeze(1)
|
||||||
|
.repeat(self.num_layers, 1, 1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class T2SStageDecoder(nn.Module):
|
class T2SStageDecoder(nn.Module):
|
||||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
def __init__(
|
||||||
top_k, early_stop_num, num_layers):
|
self,
|
||||||
|
ar_audio_embedding,
|
||||||
|
ar_audio_position,
|
||||||
|
h,
|
||||||
|
ar_predict_layer,
|
||||||
|
loss_fct,
|
||||||
|
ar_accuracy_metric,
|
||||||
|
top_k,
|
||||||
|
early_stop_num,
|
||||||
|
num_layers,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ar_audio_embedding = ar_audio_embedding
|
self.ar_audio_embedding = ar_audio_embedding
|
||||||
self.ar_audio_position = ar_audio_position
|
self.ar_audio_position = ar_audio_position
|
||||||
@ -184,7 +221,11 @@ class T2SStageDecoder(nn.Module):
|
|||||||
}
|
}
|
||||||
|
|
||||||
y_emb = torch.cat(
|
y_emb = torch.cat(
|
||||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
[
|
||||||
|
cache["y_emb"],
|
||||||
|
self.ar_audio_embedding(y[:, -1:]),
|
||||||
|
],
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
cache["y_emb"] = y_emb
|
cache["y_emb"] = y_emb
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
def init_onnx(self):
|
def init_onnx(self):
|
||||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||||
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
self.first_stage_decoder = T2SFirstStageDecoder(
|
||||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
self.ar_audio_embedding,
|
||||||
self.num_layers)
|
self.ar_audio_position,
|
||||||
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
self.h,
|
||||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
self.ar_predict_layer,
|
||||||
self.num_layers)
|
self.loss_fct,
|
||||||
|
self.ar_accuracy_metric,
|
||||||
|
self.top_k,
|
||||||
|
self.early_stop_num,
|
||||||
|
self.num_layers,
|
||||||
|
)
|
||||||
|
self.stage_decoder = T2SStageDecoder(
|
||||||
|
self.ar_audio_embedding,
|
||||||
|
self.ar_audio_position,
|
||||||
|
self.h,
|
||||||
|
self.ar_predict_layer,
|
||||||
|
self.loss_fct,
|
||||||
|
self.ar_accuracy_metric,
|
||||||
|
self.top_k,
|
||||||
|
self.early_stop_num,
|
||||||
|
self.num_layers,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, prompts, bert_feature):
|
def forward(self, x, prompts, bert_feature):
|
||||||
early_stop_num = self.early_stop_num
|
early_stop_num = self.early_stop_num
|
||||||
@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
if cache["first_infer"] == 1:
|
if cache["first_infer"] == 1:
|
||||||
y_emb = self.ar_audio_embedding(y)
|
y_emb = self.ar_audio_embedding(y)
|
||||||
else:
|
else:
|
||||||
y_emb = torch.cat(
|
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
|
||||||
)
|
|
||||||
cache["y_emb"] = y_emb
|
cache["y_emb"] = y_emb
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
if cache["first_infer"] == 1:
|
if cache["first_infer"] == 1:
|
||||||
@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||||
y_attn_mask = F.pad(
|
y_attn_mask = F.pad(
|
||||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||||
(x_len, 0), value=False
|
(x_len, 0),
|
||||||
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
else:
|
else:
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
def sequence_mask(length, max_length=None):
|
def sequence_mask(length, max_length=None):
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
@ -74,7 +76,11 @@ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|||||||
|
|
||||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||||
def top_k_top_p_filtering(
|
def top_k_top_p_filtering(
|
||||||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
logits,
|
||||||
|
top_k=0,
|
||||||
|
top_p=1.0,
|
||||||
|
filter_value=-float("Inf"),
|
||||||
|
min_tokens_to_keep=1,
|
||||||
):
|
):
|
||||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
Args:
|
Args:
|
||||||
@ -105,9 +111,7 @@ def top_k_top_p_filtering(
|
|||||||
sorted_indices_to_remove[..., 0] = 0
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@ -130,7 +134,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def multinomial_sample_one_no_sync(
|
def multinomial_sample_one_no_sync(
|
||||||
@ -156,19 +160,21 @@ def logits_to_probs(
|
|||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(
|
||||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
score < 0,
|
||||||
|
score * repetition_penalty,
|
||||||
|
score / repetition_penalty,
|
||||||
)
|
)
|
||||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cum_probs = torch.cumsum(
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
|
||||||
)
|
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
dim=1,
|
||||||
|
index=sorted_indices,
|
||||||
|
src=sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
@ -188,18 +194,19 @@ def sample(
|
|||||||
previous_tokens: Optional[torch.Tensor] = None,
|
previous_tokens: Optional[torch.Tensor] = None,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
probs = logits_to_probs(
|
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
||||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
|
||||||
)
|
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
return idx_next, probs
|
return idx_next, probs
|
||||||
|
|
||||||
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
|
||||||
|
def dpo_loss(
|
||||||
|
policy_chosen_logps: torch.FloatTensor,
|
||||||
policy_rejected_logps: torch.FloatTensor,
|
policy_rejected_logps: torch.FloatTensor,
|
||||||
reference_chosen_logps: torch.FloatTensor,
|
reference_chosen_logps: torch.FloatTensor,
|
||||||
reference_rejected_logps: torch.FloatTensor,
|
reference_rejected_logps: torch.FloatTensor,
|
||||||
beta: float,
|
beta: float,
|
||||||
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
reference_free: bool = False,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||||
|
|
||||||
@ -214,15 +221,26 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
|||||||
|
|
||||||
return losses.mean(), chosen_rewards, rejected_rewards
|
return losses.mean(), chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
||||||
|
|
||||||
|
def get_batch_logps(
|
||||||
|
logits_target: torch.FloatTensor,
|
||||||
|
logits_reject: torch.FloatTensor,
|
||||||
|
labels_target: torch.LongTensor,
|
||||||
|
labels_reject: torch.LongTensor,
|
||||||
|
average_log_prob: bool = False,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
# dummy token; we'll ignore the losses on these tokens later
|
# dummy token; we'll ignore the losses on these tokens later
|
||||||
|
|
||||||
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
|
per_token_logps_target = torch.gather(
|
||||||
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
|
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
|
||||||
|
).squeeze(2)
|
||||||
|
per_token_logps_reject = torch.gather(
|
||||||
|
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
|
||||||
|
).squeeze(2)
|
||||||
|
|
||||||
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
||||||
|
|
||||||
|
|
||||||
def make_reject_y(y_o, y_lens):
|
def make_reject_y(y_o, y_lens):
|
||||||
def repeat_P(y):
|
def repeat_P(y):
|
||||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||||
@ -231,6 +249,7 @@ def make_reject_y(y_o, y_lens):
|
|||||||
range_text = y[range_idx[0] : range_idx[1]]
|
range_text = y[range_idx[0] : range_idx[1]]
|
||||||
new_y = torch.cat([pre, range_text, range_text, shf])
|
new_y = torch.cat([pre, range_text, range_text, shf])
|
||||||
return new_y
|
return new_y
|
||||||
|
|
||||||
def lost_P(y):
|
def lost_P(y):
|
||||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||||
pre = y[: range_idx[0]]
|
pre = y[: range_idx[0]]
|
||||||
@ -238,6 +257,7 @@ def make_reject_y(y_o, y_lens):
|
|||||||
range_text = y[range_idx[0] : range_idx[1]]
|
range_text = y[range_idx[0] : range_idx[1]]
|
||||||
new_y = torch.cat([pre, shf])
|
new_y = torch.cat([pre, shf])
|
||||||
return new_y
|
return new_y
|
||||||
|
|
||||||
bs = len(y_lens)
|
bs = len(y_lens)
|
||||||
reject_y = []
|
reject_y = []
|
||||||
reject_y_lens = []
|
reject_y_lens = []
|
||||||
|
@ -1,17 +1,14 @@
|
|||||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
from typing import Tuple
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Linear
|
from torch.nn import Linear, Module
|
||||||
from torch.nn import Module
|
from torch.nn import functional as F
|
||||||
from torch.nn.init import constant_
|
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||||
from torch.nn.init import xavier_normal_
|
|
||||||
from torch.nn.init import xavier_uniform_
|
|
||||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||||
|
|
||||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||||
@ -73,6 +70,7 @@ class MultiheadAttention(Module):
|
|||||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["batch_first"]
|
__constants__ = ["batch_first"]
|
||||||
bias_k: Optional[torch.Tensor]
|
bias_k: Optional[torch.Tensor]
|
||||||
bias_v: Optional[torch.Tensor]
|
bias_v: Optional[torch.Tensor]
|
||||||
@ -104,9 +102,7 @@ class MultiheadAttention(Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
self.head_dim * num_heads == self.embed_dim
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
|
|
||||||
if add_bias_kv:
|
if add_bias_kv:
|
||||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
@ -117,31 +113,32 @@ class MultiheadAttention(Module):
|
|||||||
if linear1_cls == Linear:
|
if linear1_cls == Linear:
|
||||||
if not self._qkv_same_embed_dim:
|
if not self._qkv_same_embed_dim:
|
||||||
self.q_proj_weight = Parameter(
|
self.q_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
torch.empty((embed_dim, embed_dim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.k_proj_weight = Parameter(
|
self.k_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
torch.empty((embed_dim, self.kdim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.v_proj_weight = Parameter(
|
self.v_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
torch.empty((embed_dim, self.vdim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.register_parameter("in_proj_weight", None)
|
self.register_parameter("in_proj_weight", None)
|
||||||
else:
|
else:
|
||||||
self.in_proj_weight = Parameter(
|
self.in_proj_weight = Parameter(
|
||||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.register_parameter("q_proj_weight", None)
|
self.register_parameter("q_proj_weight", None)
|
||||||
self.register_parameter("k_proj_weight", None)
|
self.register_parameter("k_proj_weight", None)
|
||||||
self.register_parameter("v_proj_weight", None)
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.in_proj_bias = Parameter(
|
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
@ -150,7 +147,10 @@ class MultiheadAttention(Module):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
self.in_proj_linear = linear1_cls(
|
self.in_proj_linear = linear1_cls(
|
||||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
3 * embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.in_proj_weight = self.in_proj_linear.weight
|
self.in_proj_weight = self.in_proj_linear.weight
|
||||||
|
|
||||||
@ -164,7 +164,10 @@ class MultiheadAttention(Module):
|
|||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
self.out_proj = linear2_cls(
|
self.out_proj = linear2_cls(
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.bias_k is not None:
|
if self.bias_k is not None:
|
||||||
@ -261,28 +264,26 @@ class MultiheadAttention(Module):
|
|||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
_kpm_dtype = key_padding_mask.dtype
|
_kpm_dtype = key_padding_mask.dtype
|
||||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||||
key_padding_mask
|
key_padding_mask,
|
||||||
):
|
):
|
||||||
raise AssertionError(
|
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||||
"only bool and floating types of key_padding_mask are supported"
|
|
||||||
)
|
|
||||||
why_not_fast_path = ""
|
why_not_fast_path = ""
|
||||||
if not is_batched:
|
if not is_batched:
|
||||||
why_not_fast_path = (
|
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
|
||||||
)
|
|
||||||
elif query is not key or key is not value:
|
elif query is not key or key is not value:
|
||||||
# When lifting this restriction, don't forget to either
|
# When lifting this restriction, don't forget to either
|
||||||
# enforce that the dtypes all match or test cases where
|
# enforce that the dtypes all match or test cases where
|
||||||
# they don't!
|
# they don't!
|
||||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
why_not_fast_path = (
|
||||||
elif (
|
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
)
|
||||||
):
|
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
|
||||||
# this case will fail anyway, but at least they'll get a useful error message.
|
# this case will fail anyway, but at least they'll get a useful error message.
|
||||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
why_not_fast_path = (
|
||||||
|
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||||
|
)
|
||||||
elif self.training:
|
elif self.training:
|
||||||
why_not_fast_path = "training is enabled"
|
why_not_fast_path = "training is enabled"
|
||||||
elif not self.batch_first:
|
elif not self.batch_first:
|
||||||
@ -300,9 +301,7 @@ class MultiheadAttention(Module):
|
|||||||
elif attn_mask is not None:
|
elif attn_mask is not None:
|
||||||
why_not_fast_path = "attn_mask was not None"
|
why_not_fast_path = "attn_mask was not None"
|
||||||
elif query.is_nested and key_padding_mask is not None:
|
elif query.is_nested and key_padding_mask is not None:
|
||||||
why_not_fast_path = (
|
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
|
||||||
"key_padding_mask is not supported with NestedTensor input"
|
|
||||||
)
|
|
||||||
elif self.num_heads % 2 == 1:
|
elif self.num_heads % 2 == 1:
|
||||||
why_not_fast_path = "num_heads is odd"
|
why_not_fast_path = "num_heads is odd"
|
||||||
elif torch.is_autocast_enabled():
|
elif torch.is_autocast_enabled():
|
||||||
@ -322,20 +321,10 @@ class MultiheadAttention(Module):
|
|||||||
# generator expressions.
|
# generator expressions.
|
||||||
if torch.overrides.has_torch_function(tensor_args):
|
if torch.overrides.has_torch_function(tensor_args):
|
||||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||||
elif not all(
|
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
|
||||||
[
|
|
||||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
|
||||||
for x in tensor_args
|
|
||||||
]
|
|
||||||
):
|
|
||||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||||
elif torch.is_grad_enabled() and any(
|
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
|
||||||
[x is not None and x.requires_grad for x in tensor_args]
|
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
|
||||||
):
|
|
||||||
why_not_fast_path = (
|
|
||||||
"grad is enabled and at least one of query or the "
|
|
||||||
"input/output projection weights or biases requires_grad"
|
|
||||||
)
|
|
||||||
if not why_not_fast_path:
|
if not why_not_fast_path:
|
||||||
return torch._native_multi_head_attention(
|
return torch._native_multi_head_attention(
|
||||||
query,
|
query,
|
||||||
@ -350,11 +339,7 @@ class MultiheadAttention(Module):
|
|||||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||||
need_weights,
|
need_weights,
|
||||||
average_attn_weights,
|
average_attn_weights,
|
||||||
1
|
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
|
||||||
if key_padding_mask is not None
|
|
||||||
else 0
|
|
||||||
if attn_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||||
|
@ -1,17 +1,13 @@
|
|||||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
from typing import Tuple
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Linear
|
from torch.nn import Linear, Module
|
||||||
from torch.nn import Module
|
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||||
from torch.nn.init import constant_
|
|
||||||
from torch.nn.init import xavier_normal_
|
|
||||||
from torch.nn.init import xavier_uniform_
|
|
||||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||||
|
|
||||||
|
|
||||||
@ -47,9 +43,7 @@ class MultiheadAttention(Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
self.head_dim * num_heads == self.embed_dim
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
|
|
||||||
if add_bias_kv:
|
if add_bias_kv:
|
||||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
@ -60,18 +54,30 @@ class MultiheadAttention(Module):
|
|||||||
if linear1_cls == Linear:
|
if linear1_cls == Linear:
|
||||||
if not self._qkv_same_embed_dim:
|
if not self._qkv_same_embed_dim:
|
||||||
self.q_proj_weight = Parameter(
|
self.q_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(embed_dim, embed_dim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.k_proj_weight = Parameter(
|
self.k_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(embed_dim, self.kdim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.v_proj_weight = Parameter(
|
self.v_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(embed_dim, self.vdim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.register_parameter("in_proj_weight", None)
|
self.register_parameter("in_proj_weight", None)
|
||||||
else:
|
else:
|
||||||
self.in_proj_weight = Parameter(
|
self.in_proj_weight = Parameter(
|
||||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(3 * embed_dim, embed_dim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.register_parameter("q_proj_weight", None)
|
self.register_parameter("q_proj_weight", None)
|
||||||
self.register_parameter("k_proj_weight", None)
|
self.register_parameter("k_proj_weight", None)
|
||||||
@ -79,13 +85,11 @@ class MultiheadAttention(Module):
|
|||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.in_proj_bias = Parameter(
|
self.in_proj_bias = Parameter(
|
||||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
torch.empty(3 * embed_dim, **factory_kwargs),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
else:
|
else:
|
||||||
@ -93,7 +97,10 @@ class MultiheadAttention(Module):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
self.in_proj_linear = linear1_cls(
|
self.in_proj_linear = linear1_cls(
|
||||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
3 * embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.in_proj_weight = self.in_proj_linear.weight
|
self.in_proj_weight = self.in_proj_linear.weight
|
||||||
|
|
||||||
@ -107,7 +114,10 @@ class MultiheadAttention(Module):
|
|||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
self.out_proj = linear2_cls(
|
self.out_proj = linear2_cls(
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.bias_k is not None:
|
if self.bias_k is not None:
|
||||||
|
@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
|
|||||||
return
|
return
|
||||||
pe = torch.zeros(x.size(1), self.embedding_dim)
|
pe = torch.zeros(x.size(1), self.embedding_dim)
|
||||||
if self.reverse:
|
if self.reverse:
|
||||||
position = torch.arange(
|
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
|
||||||
).unsqueeze(1)
|
|
||||||
else:
|
else:
|
||||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||||
div_term = torch.exp(
|
div_term = torch.exp(
|
||||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||||
* -(math.log(10000.0) / self.embedding_dim)
|
|
||||||
)
|
)
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
|||||||
lr = self.end_lr
|
lr = self.end_lr
|
||||||
|
|
||||||
else:
|
else:
|
||||||
decay_ratio = (self._current_step - self.warmup_steps) / (
|
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
||||||
self.total_steps - self.warmup_steps
|
|
||||||
)
|
|
||||||
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
|
||||||
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
|
|
||||||
)
|
|
||||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||||
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
||||||
|
|
||||||
@ -70,7 +66,13 @@ if __name__ == "__main__":
|
|||||||
m = nn.Linear(10, 10)
|
m = nn.Linear(10, 10)
|
||||||
opt = Adam(m.parameters(), lr=1e-4)
|
opt = Adam(m.parameters(), lr=1e-4)
|
||||||
s = WarmupCosineLRSchedule(
|
s = WarmupCosineLRSchedule(
|
||||||
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
|
opt,
|
||||||
|
1e-6,
|
||||||
|
2e-4,
|
||||||
|
1e-6,
|
||||||
|
warmup_steps=2000,
|
||||||
|
total_steps=20000,
|
||||||
|
current_step=0,
|
||||||
)
|
)
|
||||||
lrs = []
|
lrs = []
|
||||||
for i in range(25000):
|
for i in range(25000):
|
||||||
|
@ -16,8 +16,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
|
|||||||
group_params_names: name for each parameter in group,
|
group_params_names: name for each parameter in group,
|
||||||
which is List[str].
|
which is List[str].
|
||||||
"""
|
"""
|
||||||
batches = defaultdict(
|
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||||
list
|
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
|
||||||
batches_names = defaultdict(
|
|
||||||
list
|
|
||||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
|
||||||
|
|
||||||
assert len(param_group) == len(group_params_names)
|
assert len(param_group) == len(group_params_names)
|
||||||
for p, named_p in zip(param_group, group_params_names):
|
for p, named_p in zip(param_group, group_params_names):
|
||||||
@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
|
|||||||
batches_names[key].append(named_p)
|
batches_names[key].append(named_p)
|
||||||
|
|
||||||
batches_names_keys = list(batches_names.keys())
|
batches_names_keys = list(batches_names.keys())
|
||||||
sorted_idx = sorted(
|
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||||
range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
batches_names = [
|
|
||||||
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
|
||||||
]
|
|
||||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
|
|
||||||
stacked_params_dict = dict()
|
stacked_params_dict = dict()
|
||||||
@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
|
|||||||
# group. class Optimizer will take care of saving/loading state.
|
# group. class Optimizer will take care of saving/loading state.
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
p_stacked = torch.stack(batch)
|
p_stacked = torch.stack(batch)
|
||||||
grad = torch.stack([
|
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
|
||||||
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
|
|
||||||
])
|
|
||||||
p_stacked.grad = grad
|
p_stacked.grad = grad
|
||||||
stacked_params_dict[key] = p_stacked
|
stacked_params_dict[key] = p_stacked
|
||||||
tuples.append((p_stacked, state, batch_names))
|
tuples.append((p_stacked, state, batch_names))
|
||||||
|
|
||||||
yield tuples # <-- calling code will do the actual optimization here!
|
yield tuples # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||||
for i, p in enumerate(batch): # batch is list of Parameter
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
@ -177,12 +167,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
parameters_names=None,
|
parameters_names=None,
|
||||||
show_dominant_parameters=True, ):
|
show_dominant_parameters=True,
|
||||||
|
):
|
||||||
assert parameters_names is not None, (
|
assert parameters_names is not None, (
|
||||||
"Please prepare parameters_names,"
|
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
|
||||||
"which is a List[List[str]]. Each List[str] is for a group"
|
)
|
||||||
"and each str is for a parameter")
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
param_max_rms=param_max_rms,
|
param_max_rms=param_max_rms,
|
||||||
scalar_max=scalar_max,
|
scalar_max=scalar_max,
|
||||||
size_update_period=size_update_period,
|
size_update_period=size_update_period,
|
||||||
clipping_update_period=clipping_update_period, )
|
clipping_update_period=clipping_update_period,
|
||||||
|
)
|
||||||
|
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(params, defaults)
|
||||||
assert len(self.param_groups) == len(parameters_names)
|
assert len(self.param_groups) == len(parameters_names)
|
||||||
@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
batch = True
|
batch = True
|
||||||
|
|
||||||
for group, group_params_names in zip(self.param_groups,
|
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||||
self.parameters_names):
|
with self.batched_params(group["params"], group_params_names) as batches:
|
||||||
|
|
||||||
with self.batched_params(group["params"],
|
|
||||||
group_params_names) as batches:
|
|
||||||
|
|
||||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
# a stacking dim, it is not a real dim.
|
# a stacking dim, it is not a real dim.
|
||||||
|
|
||||||
if (len(batches[0][1]) ==
|
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||||
0): # if len(first state) == 0: not yet initialized
|
|
||||||
clipping_scale = 1
|
clipping_scale = 1
|
||||||
else:
|
else:
|
||||||
clipping_scale = self._get_clipping_scale(group, batches)
|
clipping_scale = self._get_clipping_scale(group, batches)
|
||||||
@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# grad is not going to be None, we handled that when creating the batches.
|
# grad is not going to be None, we handled that when creating the batches.
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||||
"ScaledAdam optimizer does not support sparse gradients"
|
|
||||||
)
|
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
self._init_state(group, p, state)
|
self._init_state(group, p, state)
|
||||||
@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# parameter-change "delta", which combines all forms of
|
# parameter-change "delta", which combines all forms of
|
||||||
# update. this is equivalent to how it's done in Adam,
|
# update. this is equivalent to how it's done in Adam,
|
||||||
# except for the first few steps.
|
# except for the first few steps.
|
||||||
state["delta"] = torch.zeros_like(
|
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
p, memory_format=torch.preserve_format)
|
|
||||||
|
|
||||||
batch_size = p.shape[0]
|
batch_size = p.shape[0]
|
||||||
numel = p.numel() // batch_size
|
numel = p.numel() // batch_size
|
||||||
@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
# the parameter tensor.
|
# the parameter tensor.
|
||||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||||
param_rms = (
|
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
|
||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
|
|
||||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||||
state["scale_grads"] = torch.zeros(size_update_period,
|
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
|
||||||
*param_rms.shape, **kwargs)
|
|
||||||
|
|
||||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||||
state["exp_avg_sq"] = torch.zeros_like(
|
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
p, memory_format=torch.preserve_format)
|
|
||||||
|
|
||||||
def _get_clipping_scale(self,
|
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
|
||||||
group: dict,
|
|
||||||
tuples: List[Tuple[Tensor, dict, List[str]]]
|
|
||||||
) -> float:
|
|
||||||
"""
|
"""
|
||||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||||
by this amount before applying the rest of the update.
|
by this amount before applying the rest of the update.
|
||||||
@ -325,11 +301,10 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for (p, state, param_names) in tuples:
|
for p, state, param_names in tuples:
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||||
"ScaledAdam optimizer does not support sparse gradients")
|
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||||
else:
|
else:
|
||||||
@ -337,8 +312,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
tot_norm = tot_sumsq.sqrt()
|
tot_norm = tot_sumsq.sqrt()
|
||||||
if "model_norms" not in first_state:
|
if "model_norms" not in first_state:
|
||||||
first_state["model_norms"] = torch.zeros(
|
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
|
||||||
clipping_update_period, device=p.device)
|
|
||||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||||
|
|
||||||
if step % clipping_update_period == 0:
|
if step % clipping_update_period == 0:
|
||||||
@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
for n in range(0, 5):
|
for n in range(0, 5):
|
||||||
index = min(
|
index = min(
|
||||||
clipping_update_period - 1,
|
clipping_update_period - 1,
|
||||||
(clipping_update_period // 4) * n, )
|
(clipping_update_period // 4) * n,
|
||||||
|
)
|
||||||
quartiles.append(sorted_norms[index].item())
|
quartiles.append(sorted_norms[index].item())
|
||||||
|
|
||||||
median = quartiles[2]
|
median = quartiles[2]
|
||||||
threshold = clipping_scale * median
|
threshold = clipping_scale * median
|
||||||
first_state["model_norm_threshold"] = threshold
|
first_state["model_norm_threshold"] = threshold
|
||||||
percent_clipped = (first_state["num_clipped"] * 100.0 /
|
percent_clipped = (
|
||||||
clipping_update_period
|
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
|
||||||
if "num_clipped" in first_state else 0.0)
|
)
|
||||||
first_state["num_clipped"] = 0
|
first_state["num_clipped"] = 0
|
||||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if step < clipping_update_period:
|
if step < clipping_update_period:
|
||||||
@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
model_norm_threshold = first_state["model_norm_threshold"]
|
model_norm_threshold = first_state["model_norm_threshold"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Warning: model_norm_threshold not in state: possibly "
|
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
|
||||||
"you changed config when restarting, adding clipping_scale option?"
|
|
||||||
)
|
)
|
||||||
return 1.0
|
return 1.0
|
||||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||||
if ans < 1.0:
|
if ans < 1.0:
|
||||||
first_state["num_clipped"] += 1
|
first_state["num_clipped"] += 1
|
||||||
if ans < 0.1:
|
if ans < 0.1:
|
||||||
logging.warn(
|
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
|
||||||
)
|
|
||||||
if self.show_dominant_parameters:
|
if self.show_dominant_parameters:
|
||||||
assert p.shape[0] == len(param_names)
|
assert p.shape[0] == len(param_names)
|
||||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def _show_gradient_dominating_parameter(
|
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
||||||
self, tuples: List[Tuple[Tensor, dict, List[str]]],
|
|
||||||
tot_sumsq: Tensor):
|
|
||||||
"""
|
"""
|
||||||
Show information of parameter wihch dominanting tot_sumsq.
|
Show information of parameter wihch dominanting tot_sumsq.
|
||||||
|
|
||||||
@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
from tuples, we still pass it to save some time.
|
from tuples, we still pass it to save some time.
|
||||||
"""
|
"""
|
||||||
all_sumsq_orig = {}
|
all_sumsq_orig = {}
|
||||||
for (p, state, batch_param_names) in tuples:
|
for p, state, batch_param_names in tuples:
|
||||||
# p is a stacked batch parameters.
|
# p is a stacked batch parameters.
|
||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch_rms_orig = torch.ones(p.shape[0])
|
batch_rms_orig = torch.ones(p.shape[0])
|
||||||
else:
|
else:
|
||||||
batch_rms_orig = state["param_rms"]
|
batch_rms_orig = state["param_rms"]
|
||||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
|
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
|
||||||
dim=list(range(1, batch_grad.ndim)))
|
|
||||||
|
|
||||||
for name, sumsq_orig, rms, grad in zip(batch_param_names,
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
|
batch_param_names,
|
||||||
batch_sumsq_orig,
|
batch_sumsq_orig,
|
||||||
batch_rms_orig, batch_grad):
|
batch_rms_orig,
|
||||||
|
batch_grad,
|
||||||
|
):
|
||||||
proportion_orig = sumsq_orig / tot_sumsq
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
assert torch.isclose(
|
assert torch.isclose(
|
||||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||||
torch.tensor(1.0), )
|
torch.tensor(1.0),
|
||||||
|
)
|
||||||
sorted_by_proportion = {
|
sorted_by_proportion = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in sorted(
|
for k, v in sorted(
|
||||||
all_sumsq_orig.items(),
|
all_sumsq_orig.items(),
|
||||||
key=lambda item: item[1][0],
|
key=lambda item: item[1][0],
|
||||||
reverse=True, )
|
reverse=True,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
dominant_param_name = next(iter(sorted_by_proportion))
|
dominant_param_name = next(iter(sorted_by_proportion))
|
||||||
(dominant_proportion, dominant_sumsq, dominant_rms,
|
(
|
||||||
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
|
dominant_proportion,
|
||||||
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
dominant_sumsq,
|
||||||
|
dominant_rms,
|
||||||
|
dominant_grad,
|
||||||
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
|
logging.info(
|
||||||
|
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||||
f" with proportion {dominant_proportion:.2f},"
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
f"={dominant_sumsq:.3e},"
|
f"={dominant_sumsq:.3e},"
|
||||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
|
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _step_one_batch(self,
|
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
|
||||||
group: dict,
|
|
||||||
p: Tensor,
|
|
||||||
state: dict,
|
|
||||||
clipping_scale: float):
|
|
||||||
"""
|
"""
|
||||||
Do the step for one parameter, which is actually going to be a batch of
|
Do the step for one parameter, which is actually going to be a batch of
|
||||||
`real` parameters, with dim 0 as the batch dim.
|
`real` parameters, with dim 0 as the batch dim.
|
||||||
@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
if numel > 1:
|
if numel > 1:
|
||||||
# Update the size/scale of p, and set param_rms
|
# Update the size/scale of p, and set param_rms
|
||||||
scale_grads = state["scale_grads"]
|
scale_grads = state["scale_grads"]
|
||||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
|
||||||
dim=list(range(1, p.ndim)), keepdim=True)
|
|
||||||
if step % size_update_period == size_update_period - 1:
|
if step % size_update_period == size_update_period - 1:
|
||||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||||
param_rms.copy_((p**2)
|
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||||
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
|
||||||
.sqrt())
|
|
||||||
if step > 0:
|
if step > 0:
|
||||||
# self._size_update() learns the overall scale on the
|
# self._size_update() learns the overall scale on the
|
||||||
# parameter, by shrinking or expanding it.
|
# parameter, by shrinking or expanding it.
|
||||||
@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
state["step"] = step + 1
|
state["step"] = step + 1
|
||||||
|
|
||||||
def _size_update(self,
|
def _size_update(
|
||||||
|
self,
|
||||||
group: dict,
|
group: dict,
|
||||||
scale_grads: Tensor,
|
scale_grads: Tensor,
|
||||||
p: Tensor,
|
p: Tensor,
|
||||||
state: dict) -> None:
|
state: dict,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||||
@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# faster decay at this level.
|
# faster decay at this level.
|
||||||
beta2_corr = beta2**size_update_period
|
beta2_corr = beta2**size_update_period
|
||||||
|
|
||||||
scale_exp_avg_sq = state[
|
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||||
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||||
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
|
alpha=1 - beta2_corr,
|
||||||
|
) # shape is (batch_size, 1, 1, ...)
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
# The 1st time we reach here is when size_step == 1.
|
||||||
size_step = (step + 1) // size_update_period
|
size_step = (step + 1) // size_update_period
|
||||||
@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
denom = scale_exp_avg_sq.sqrt() + eps
|
denom = scale_exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
scale_step = (-size_lr * (bias_correction2**0.5) *
|
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||||
scale_grads.sum(dim=0) / denom)
|
|
||||||
|
|
||||||
is_too_small = param_rms < param_min_rms
|
is_too_small = param_rms < param_min_rms
|
||||||
is_too_large = param_rms > param_max_rms
|
is_too_large = param_rms > param_max_rms
|
||||||
@ -580,8 +552,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||||
|
|
||||||
this_step = state["step"] - (state["zero_step"]
|
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||||
if "zero_step" in state else 0)
|
|
||||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||||
if bias_correction2 < 0.99:
|
if bias_correction2 < 0.99:
|
||||||
# note: not in-place.
|
# note: not in-place.
|
||||||
|
@ -5,7 +5,6 @@ from torch.nn.functional import (
|
|||||||
_none_or_dtype,
|
_none_or_dtype,
|
||||||
_in_projection_packed,
|
_in_projection_packed,
|
||||||
)
|
)
|
||||||
from torch.nn import functional as F
|
|
||||||
import torch
|
import torch
|
||||||
# Tensor = torch.Tensor
|
# Tensor = torch.Tensor
|
||||||
# from typing import Callable, List, Optional, Tuple, Union
|
# from typing import Callable, List, Optional, Tuple, Union
|
||||||
@ -156,9 +155,7 @@ def multi_head_attention_forward_patched(
|
|||||||
cache=cache,
|
cache=cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_batched = _mha_shape_check(
|
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
||||||
query, key, value, key_padding_mask, attn_mask, num_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||||
# is batched, run the computation and before returning squeeze the
|
# is batched, run the computation and before returning squeeze the
|
||||||
@ -211,45 +208,33 @@ def multi_head_attention_forward_patched(
|
|||||||
# longer causal.
|
# longer causal.
|
||||||
is_causal = False
|
is_causal = False
|
||||||
|
|
||||||
assert (
|
assert embed_dim == embed_dim_to_check, (
|
||||||
embed_dim == embed_dim_to_check
|
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||||
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
)
|
||||||
if isinstance(embed_dim, torch.Tensor):
|
if isinstance(embed_dim, torch.Tensor):
|
||||||
# embed_dim can be a tensor when JIT tracing
|
# embed_dim can be a tensor when JIT tracing
|
||||||
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
||||||
else:
|
else:
|
||||||
head_dim = embed_dim // num_heads
|
head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||||
head_dim * num_heads == embed_dim
|
|
||||||
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
|
||||||
if use_separate_proj_weight:
|
if use_separate_proj_weight:
|
||||||
# allow MHA to have different embedding dimensions when separate projection weights are used
|
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||||
assert (
|
assert key.shape[:2] == value.shape[:2], (
|
||||||
key.shape[:2] == value.shape[:2]
|
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||||
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
||||||
key.shape == value.shape
|
|
||||||
), f"key shape {key.shape} does not match value shape {value.shape}"
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# compute in-projection
|
# compute in-projection
|
||||||
#
|
#
|
||||||
if not use_separate_proj_weight:
|
if not use_separate_proj_weight:
|
||||||
assert (
|
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
||||||
in_proj_weight is not None
|
|
||||||
), "use_separate_proj_weight is False but in_proj_weight is None"
|
|
||||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
||||||
q_proj_weight is not None
|
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
||||||
), "use_separate_proj_weight is True but q_proj_weight is None"
|
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
||||||
assert (
|
|
||||||
k_proj_weight is not None
|
|
||||||
), "use_separate_proj_weight is True but k_proj_weight is None"
|
|
||||||
assert (
|
|
||||||
v_proj_weight is not None
|
|
||||||
), "use_separate_proj_weight is True but v_proj_weight is None"
|
|
||||||
if in_proj_bias is None:
|
if in_proj_bias is None:
|
||||||
b_q = b_k = b_v = None
|
b_q = b_k = b_v = None
|
||||||
else:
|
else:
|
||||||
@ -312,9 +297,7 @@ def multi_head_attention_forward_patched(
|
|||||||
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||||
f"attn_mask's dimension {attn_mask.dim()} is not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
# add bias along batch dimension (currently second)
|
# add bias along batch dimension (currently second)
|
||||||
if bias_k is not None and bias_v is not None:
|
if bias_k is not None and bias_v is not None:
|
||||||
@ -338,34 +321,26 @@ def multi_head_attention_forward_patched(
|
|||||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
assert (
|
assert static_k.size(0) == bsz * num_heads, (
|
||||||
static_k.size(0) == bsz * num_heads
|
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||||
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
)
|
||||||
assert (
|
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||||
static_k.size(2) == head_dim
|
|
||||||
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
|
||||||
k = static_k
|
k = static_k
|
||||||
if static_v is None:
|
if static_v is None:
|
||||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
assert (
|
assert static_v.size(0) == bsz * num_heads, (
|
||||||
static_v.size(0) == bsz * num_heads
|
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||||
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
)
|
||||||
assert (
|
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||||
static_v.size(2) == head_dim
|
|
||||||
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
|
||||||
v = static_v
|
v = static_v
|
||||||
|
|
||||||
# add zero attention along batch dimension (now first)
|
# add zero attention along batch dimension (now first)
|
||||||
if add_zero_attn:
|
if add_zero_attn:
|
||||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||||
k = torch.cat(
|
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||||
)
|
|
||||||
v = torch.cat(
|
|
||||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
|
||||||
)
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask = pad(attn_mask, (0, 1))
|
attn_mask = pad(attn_mask, (0, 1))
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
@ -381,9 +356,7 @@ def multi_head_attention_forward_patched(
|
|||||||
src_len,
|
src_len,
|
||||||
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||||
key_padding_mask = (
|
key_padding_mask = (
|
||||||
key_padding_mask.view(bsz, 1, 1, src_len)
|
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||||
.expand(-1, num_heads, -1, -1)
|
|
||||||
.reshape(bsz * num_heads, 1, src_len)
|
|
||||||
)
|
)
|
||||||
if attn_mask is None:
|
if attn_mask is None:
|
||||||
attn_mask = key_padding_mask
|
attn_mask = key_padding_mask
|
||||||
@ -402,14 +375,10 @@ def multi_head_attention_forward_patched(
|
|||||||
B, Nt, E = q.shape
|
B, Nt, E = q.shape
|
||||||
q_scaled = q / math.sqrt(E)
|
q_scaled = q / math.sqrt(E)
|
||||||
|
|
||||||
assert not (
|
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
||||||
is_causal and attn_mask is None
|
|
||||||
), "FIXME: is_causal not implemented for need_weights"
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_output_weights = torch.baddbmm(
|
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||||
attn_mask, q_scaled, k.transpose(-2, -1)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||||
@ -418,9 +387,7 @@ def multi_head_attention_forward_patched(
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
|
||||||
)
|
|
||||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||||
|
|
||||||
@ -449,13 +416,9 @@ def multi_head_attention_forward_patched(
|
|||||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||||
|
|
||||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||||
q, k, v, attn_mask, dropout_p, is_causal
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from torch.nn.functional import *
|
from torch.nn.functional import *
|
||||||
from torch.nn.functional import (
|
from torch.nn.functional import (
|
||||||
_mha_shape_check,
|
|
||||||
_canonical_mask,
|
_canonical_mask,
|
||||||
_none_or_dtype,
|
|
||||||
_in_projection_packed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def multi_head_attention_forward_patched(
|
def multi_head_attention_forward_patched(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -34,7 +32,6 @@ def multi_head_attention_forward_patched(
|
|||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
|
||||||
# set up shape vars
|
# set up shape vars
|
||||||
_, _, embed_dim = query.shape
|
_, _, embed_dim = query.shape
|
||||||
attn_mask = _canonical_mask(
|
attn_mask = _canonical_mask(
|
||||||
@ -80,12 +77,8 @@ def multi_head_attention_forward_patched(
|
|||||||
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||||
q, k, v, attn_mask, dropout_p, is_causal
|
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||||
)
|
|
||||||
attn_output = (
|
|
||||||
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
|
||||||
)
|
|
||||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||||
|
|
||||||
|
@ -13,12 +13,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -61,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
# floors), should be expectation-preserving.
|
# floors), should be expectation-preserving.
|
||||||
floor = -0.043637
|
floor = -0.043637
|
||||||
ceil = 1.2
|
ceil = 1.2
|
||||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
|
||||||
deriv
|
|
||||||
)
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# for self-testing only.
|
# for self-testing only.
|
||||||
assert d_scaled.min() >= 0.0
|
assert d_scaled.min() >= 0.0
|
||||||
@ -153,13 +148,9 @@ def _compute_scale_factor(
|
|||||||
else:
|
else:
|
||||||
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
||||||
# x_abs)_mean , min_abs.
|
# x_abs)_mean , min_abs.
|
||||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
|
||||||
min=0, max=max_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
||||||
min=0, max=max_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
return below_threshold - above_threshold
|
return below_threshold - above_threshold
|
||||||
|
|
||||||
@ -181,18 +172,16 @@ def _compute_sign_factor(
|
|||||||
else:
|
else:
|
||||||
# 0 if proportion_positive >= min_positive, else can be
|
# 0 if proportion_positive >= min_positive, else can be
|
||||||
# as large as max_factor.
|
# as large as max_factor.
|
||||||
factor1 = (
|
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
|
||||||
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
|
||||||
).clamp_(min=0, max=max_factor)
|
|
||||||
|
|
||||||
if max_positive == 1.0:
|
if max_positive == 1.0:
|
||||||
factor2 = 0.0
|
factor2 = 0.0
|
||||||
else:
|
else:
|
||||||
# 0 if self.proportion_positive <= max_positive, else can be
|
# 0 if self.proportion_positive <= max_positive, else can be
|
||||||
# as large as -max_factor.
|
# as large as -max_factor.
|
||||||
factor2 = (
|
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
|
||||||
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
min=0, max=max_factor
|
||||||
).clamp_(min=0, max=max_factor)
|
)
|
||||||
sign_factor = factor1 - factor2
|
sign_factor = factor1 - factor2
|
||||||
# require min_positive != 0 or max_positive != 1:
|
# require min_positive != 0 or max_positive != 1:
|
||||||
assert not isinstance(sign_factor, float)
|
assert not isinstance(sign_factor, float)
|
||||||
@ -320,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|
||||||
|
|
||||||
def BalancedDoubleSwish(
|
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
|
||||||
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
|
||||||
) -> nn.Sequential:
|
|
||||||
"""
|
"""
|
||||||
ActivationBalancer -> DoubleSwish
|
ActivationBalancer -> DoubleSwish
|
||||||
"""
|
"""
|
||||||
balancer = ActivationBalancer(
|
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
|
||||||
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
|
||||||
)
|
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
balancer,
|
balancer,
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.elementwise_affine = elementwise_affine
|
self.elementwise_affine = elementwise_affine
|
||||||
if self.elementwise_affine:
|
if self.elementwise_affine:
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
)
|
|
||||||
self.bias = nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert embedding is None
|
assert embedding is None
|
||||||
return F.layer_norm(
|
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return (
|
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||||
"{normalized_shape}, eps={eps}, "
|
|
||||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityNorm(nn.Module):
|
class IdentityNorm(nn.Module):
|
||||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = transformer_encoder(src)
|
>>> out = transformer_encoder(src)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["norm"]
|
__constants__ = ["norm"]
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||||
@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Implementation of Feedforward model
|
# Implementation of Feedforward model
|
||||||
self.linear1 = linear1_feedforward_cls(
|
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||||
d_model, dim_feedforward, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.linear2 = linear2_feedforward_cls(
|
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||||
dim_feedforward, d_model, **factory_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm_first = norm_first
|
self.norm_first = norm_first
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
_skpm_dtype = src_key_padding_mask.dtype
|
_skpm_dtype = src_key_padding_mask.dtype
|
||||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
|
||||||
src_key_padding_mask
|
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||||
):
|
|
||||||
raise AssertionError(
|
|
||||||
"only bool and floating types of key_padding_mask are supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.norm_first:
|
if self.norm_first:
|
||||||
x = x + self._sa_block(
|
x = x + self._sa_block(
|
||||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.elementwise_affine = elementwise_affine
|
self.elementwise_affine = elementwise_affine
|
||||||
if self.elementwise_affine:
|
if self.elementwise_affine:
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
)
|
|
||||||
self.bias = nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert embedding is None
|
assert embedding is None
|
||||||
return F.layer_norm(
|
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return (
|
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||||
"{normalized_shape}, eps={eps}, "
|
|
||||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityNorm(nn.Module):
|
class IdentityNorm(nn.Module):
|
||||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = transformer_encoder(src)
|
>>> out = transformer_encoder(src)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["norm"]
|
__constants__ = ["norm"]
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||||
@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
__constants__ = ["batch_first", "norm_first"]
|
__constants__ = ["batch_first", "norm_first"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
d_model: int,
|
d_model: int,
|
||||||
@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
linear2_cls=linear2_self_attention_cls,
|
linear2_cls=linear2_self_attention_cls,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.linear1 = linear1_feedforward_cls(
|
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||||
d_model, dim_feedforward, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.linear2 = linear2_feedforward_cls(
|
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||||
dim_feedforward, d_model, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.norm_first = norm_first
|
self.norm_first = norm_first
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
@ -30,9 +30,7 @@ class GruutPhonemizer:
|
|||||||
"«": "«",
|
"«": "«",
|
||||||
"»": "»",
|
"»": "»",
|
||||||
}
|
}
|
||||||
self._punctuation_regexp: str = (
|
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||||
rf"([{''.join(self._special_cases_dict.keys())}])"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _normalize_punctuation(self, text: str) -> str:
|
def _normalize_punctuation(self, text: str) -> str:
|
||||||
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
||||||
@ -53,13 +51,8 @@ class GruutPhonemizer:
|
|||||||
|
|
||||||
def phonemize(self, text: str, espeak: bool = False) -> str:
|
def phonemize(self, text: str, espeak: bool = False) -> str:
|
||||||
text_to_phonemize: str = self._normalize_punctuation(text)
|
text_to_phonemize: str = self._normalize_punctuation(text)
|
||||||
sents: List[Sentence] = [
|
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
||||||
sent
|
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
||||||
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
|
|
||||||
]
|
|
||||||
words: List[str] = [
|
|
||||||
self._convert_punctuation(word) for word in itertools.chain(*sents)
|
|
||||||
]
|
|
||||||
return " ".join(words)
|
return " ".join(words)
|
||||||
|
|
||||||
def transform(self, phonemes):
|
def transform(self, phonemes):
|
||||||
|
@ -3,7 +3,9 @@
|
|||||||
PAD = "_"
|
PAD = "_"
|
||||||
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
||||||
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
IPA_LETTERS = (
|
||||||
|
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||||
|
)
|
||||||
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
||||||
SPACE_ID = SYMBOLS.index(" ")
|
SPACE_ID = SYMBOLS.index(" ")
|
||||||
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
||||||
|
@ -2,12 +2,12 @@ import re
|
|||||||
|
|
||||||
|
|
||||||
def str2bool(str):
|
def str2bool(str):
|
||||||
return True if str.lower() == 'true' else False
|
return True if str.lower() == "true" else False
|
||||||
|
|
||||||
|
|
||||||
def get_newest_ckpt(string_list):
|
def get_newest_ckpt(string_list):
|
||||||
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
||||||
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
|
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
|
||||||
|
|
||||||
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
||||||
extracted_info = []
|
extracted_info = []
|
||||||
@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
|
|||||||
step = int(match.group(2))
|
step = int(match.group(2))
|
||||||
extracted_info.append((epoch, step, string))
|
extracted_info.append((epoch, step, string))
|
||||||
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
||||||
sorted_info = sorted(
|
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||||
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
|
||||||
# 获取最新的 ckpt 文件名
|
# 获取最新的 ckpt 文件名
|
||||||
newest_ckpt = sorted_info[0][2]
|
newest_ckpt = sorted_info[0][2]
|
||||||
return newest_ckpt
|
return newest_ckpt
|
||||||
@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
|
|||||||
# 文本存在且不为空时 return True
|
# 文本存在且不为空时 return True
|
||||||
def check_txt_file(file_path):
|
def check_txt_file(file_path):
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, "r") as file:
|
||||||
text = file.readline().strip()
|
text = file.readline().strip()
|
||||||
assert text.strip() != ''
|
assert text.strip() != ""
|
||||||
return text
|
return text
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Initialize modules for espnet2 neural networks."""
|
"""Initialize modules for espnet2 neural networks."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typeguard import check_argument_types
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
|
|||||||
|
|
||||||
|
|
||||||
def write_args(args, path):
|
def write_args(args, path):
|
||||||
args_dict = dict(
|
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
|
||||||
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
|
|
||||||
)
|
|
||||||
with open(path, "a") as args_file:
|
with open(path, "a") as args_file:
|
||||||
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
||||||
args_file.write(
|
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
|
||||||
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
|
|
||||||
)
|
|
||||||
args_file.write("==> Cmd:\n")
|
args_file.write("==> Cmd:\n")
|
||||||
args_file.write(str(sys.argv))
|
args_file.write(str(sys.argv))
|
||||||
args_file.write("\n==> args:\n")
|
args_file.write("\n==> args:\n")
|
||||||
|
@ -23,9 +23,7 @@ class Snake(nn.Module):
|
|||||||
>>> x = a1(x)
|
>>> x = a1(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialization.
|
Initialization.
|
||||||
INPUT:
|
INPUT:
|
||||||
@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
|
|||||||
>>> x = a1(x)
|
>>> x = a1(x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialization.
|
Initialization.
|
||||||
INPUT:
|
INPUT:
|
||||||
|
@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||||
activation_results = anti_alias_activation_cuda.forward(
|
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
|
||||||
inputs, up_ftr, down_ftr, alpha, beta
|
|
||||||
)
|
|
||||||
|
|
||||||
return activation_results
|
return activation_results
|
||||||
|
|
||||||
@ -61,17 +59,11 @@ class Activation1d(nn.Module):
|
|||||||
if self.act.__class__.__name__ == "Snake":
|
if self.act.__class__.__name__ == "Snake":
|
||||||
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||||
else:
|
else:
|
||||||
beta = (
|
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
|
||||||
self.act.beta.data
|
|
||||||
) # Snakebeta uses different params for alpha and beta
|
|
||||||
alpha = self.act.alpha.data
|
alpha = self.act.alpha.data
|
||||||
if (
|
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
|
||||||
not self.act.alpha_logscale
|
|
||||||
): # Exp baked into cuda kernel, cancel it out with a log
|
|
||||||
alpha = torch.log(alpha)
|
alpha = torch.log(alpha)
|
||||||
beta = torch.log(beta)
|
beta = torch.log(beta)
|
||||||
|
|
||||||
x = FusedAntiAliasActivation.apply(
|
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
|
||||||
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
|
||||||
)
|
|
||||||
return x
|
return x
|
||||||
|
@ -58,17 +58,13 @@ def load():
|
|||||||
srcpath / "anti_alias_activation.cpp",
|
srcpath / "anti_alias_activation.cpp",
|
||||||
srcpath / "anti_alias_activation_cuda.cu",
|
srcpath / "anti_alias_activation_cuda.cu",
|
||||||
]
|
]
|
||||||
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
|
||||||
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
|
||||||
)
|
|
||||||
|
|
||||||
return anti_alias_activation_cuda
|
return anti_alias_activation_cuda
|
||||||
|
|
||||||
|
|
||||||
def _get_cuda_bare_metal_version(cuda_dir):
|
def _get_cuda_bare_metal_version(cuda_dir):
|
||||||
raw_output = subprocess.check_output(
|
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
|
||||||
)
|
|
||||||
output = raw_output.split()
|
output = raw_output.split()
|
||||||
release_idx = output.index("release") + 1
|
release_idx = output.index("release") + 1
|
||||||
release = output[release_idx].split(".")
|
release = output[release_idx].split(".")
|
||||||
|
@ -27,9 +27,7 @@ else:
|
|||||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||||
# LICENSE is in incl_licenses directory.
|
# LICENSE is in incl_licenses directory.
|
||||||
def kaiser_sinc_filter1d(
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||||
cutoff, half_width, kernel_size
|
|
||||||
): # return filter [1,1,kernel_size]
|
|
||||||
even = kernel_size % 2 == 0
|
even = kernel_size % 2 == 0
|
||||||
half_size = kernel_size // 2
|
half_size = kernel_size // 2
|
||||||
|
|
||||||
|
@ -11,18 +11,12 @@ class UpSample1d(nn.Module):
|
|||||||
def __init__(self, ratio=2, kernel_size=None):
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
self.kernel_size = (
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
|
||||||
)
|
|
||||||
self.stride = ratio
|
self.stride = ratio
|
||||||
self.pad = self.kernel_size // ratio - 1
|
self.pad = self.kernel_size // ratio - 1
|
||||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
self.pad_right = (
|
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
||||||
)
|
|
||||||
filter = kaiser_sinc_filter1d(
|
|
||||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
|
||||||
)
|
|
||||||
self.register_buffer("filter", filter)
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
# x: [B, C, T]
|
# x: [B, C, T]
|
||||||
@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
|
|||||||
_, C, _ = x.shape
|
_, C, _ = x.shape
|
||||||
|
|
||||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||||
x = self.ratio * F.conv_transpose1d(
|
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
|
||||||
)
|
|
||||||
x = x[..., self.pad_left : -self.pad_right]
|
x = x[..., self.pad_left : -self.pad_right]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
|
|||||||
def __init__(self, ratio=2, kernel_size=None):
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
self.kernel_size = (
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
|
||||||
)
|
|
||||||
self.lowpass = LowPassFilter1d(
|
self.lowpass = LowPassFilter1d(
|
||||||
cutoff=0.5 / ratio,
|
cutoff=0.5 / ratio,
|
||||||
half_width=0.6 / ratio,
|
half_width=0.6 / ratio,
|
||||||
|
@ -87,9 +87,7 @@ class AMPBlock1(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.convs2.apply(init_weights)
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
self.num_layers = len(self.convs1) + len(
|
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||||
self.convs2
|
|
||||||
) # Total number of conv layers
|
|
||||||
|
|
||||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
if self.h.get("use_cuda_kernel", False):
|
if self.h.get("use_cuda_kernel", False):
|
||||||
@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
|
|||||||
if activation == "snake":
|
if activation == "snake":
|
||||||
self.activations = nn.ModuleList(
|
self.activations = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Activation1d(
|
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
activation=activations.Snake(
|
|
||||||
channels, alpha_logscale=h.snake_logscale
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
elif activation == "snakebeta":
|
elif activation == "snakebeta":
|
||||||
self.activations = nn.ModuleList(
|
self.activations = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Activation1d(
|
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
activation=activations.SnakeBeta(
|
|
||||||
channels, alpha_logscale=h.snake_logscale
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
|
|||||||
if activation == "snake":
|
if activation == "snake":
|
||||||
self.activations = nn.ModuleList(
|
self.activations = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Activation1d(
|
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
activation=activations.Snake(
|
|
||||||
channels, alpha_logscale=h.snake_logscale
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
elif activation == "snakebeta":
|
elif activation == "snakebeta":
|
||||||
self.activations = nn.ModuleList(
|
self.activations = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Activation1d(
|
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
activation=activations.SnakeBeta(
|
|
||||||
channels, alpha_logscale=h.snake_logscale
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for _ in range(self.num_layers)
|
for _ in range(self.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -283,9 +265,7 @@ class BigVGAN(
|
|||||||
self.num_upsamples = len(h.upsample_rates)
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
|
|
||||||
# Pre-conv
|
# Pre-conv
|
||||||
self.conv_pre = weight_norm(
|
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||||
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||||
if h.resblock == "1":
|
if h.resblock == "1":
|
||||||
@ -293,9 +273,7 @@ class BigVGAN(
|
|||||||
elif h.resblock == "2":
|
elif h.resblock == "2":
|
||||||
resblock_class = AMPBlock2
|
resblock_class = AMPBlock2
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||||
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
@ -320,22 +298,14 @@ class BigVGAN(
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||||
for j, (k, d) in enumerate(
|
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||||
):
|
|
||||||
self.resblocks.append(
|
|
||||||
resblock_class(h, ch, k, d, activation=h.activation)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Post-conv
|
# Post-conv
|
||||||
activation_post = (
|
activation_post = (
|
||||||
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||||
if h.activation == "snake"
|
if h.activation == "snake"
|
||||||
else (
|
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
|
||||||
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
|
||||||
if h.activation == "snakebeta"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if activation_post is None:
|
if activation_post is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -346,9 +316,7 @@ class BigVGAN(
|
|||||||
|
|
||||||
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||||
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||||
self.conv_post = weight_norm(
|
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||||
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Weight initialization
|
# Weight initialization
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
@ -451,13 +419,13 @@ class BigVGAN(
|
|||||||
# instantiate BigVGAN using h
|
# instantiate BigVGAN using h
|
||||||
if use_cuda_kernel:
|
if use_cuda_kernel:
|
||||||
print(
|
print(
|
||||||
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||||
)
|
)
|
||||||
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||||
|
|
||||||
@ -485,7 +453,7 @@ class BigVGAN(
|
|||||||
model.load_state_dict(checkpoint_dict["generator"])
|
model.load_state_dict(checkpoint_dict["generator"])
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print(
|
print(
|
||||||
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||||
)
|
)
|
||||||
model.remove_weight_norm()
|
model.remove_weight_norm()
|
||||||
model.load_state_dict(checkpoint_dict["generator"])
|
model.load_state_dict(checkpoint_dict["generator"])
|
||||||
|
@ -15,7 +15,7 @@ from torchaudio.transforms import Spectrogram, Resample
|
|||||||
from env import AttrDict
|
from env import AttrDict
|
||||||
from utils import get_padding
|
from utils import get_padding
|
||||||
import typing
|
import typing
|
||||||
from typing import Optional, List, Union, Dict, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorP(torch.nn.Module):
|
class DiscriminatorP(torch.nn.Module):
|
||||||
@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.conv_post = norm_f(
|
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||||
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
fmap = []
|
fmap = []
|
||||||
@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|||||||
self.mpd_reshapes = h.mpd_reshapes
|
self.mpd_reshapes = h.mpd_reshapes
|
||||||
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[
|
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
||||||
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
|
|
||||||
for rs in self.mpd_reshapes
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
assert (
|
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
|
||||||
len(self.resolution) == 3
|
|
||||||
), f"MRD layer requires list with len=3, got {self.resolution}"
|
|
||||||
self.lrelu_slope = 0.1
|
self.lrelu_slope = 0.1
|
||||||
|
|
||||||
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||||
if hasattr(cfg, "mrd_use_spectral_norm"):
|
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||||
print(
|
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
|
||||||
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
|
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||||
)
|
|
||||||
norm_f = (
|
|
||||||
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
|
||||||
)
|
|
||||||
self.d_mult = cfg.discriminator_channel_mult
|
self.d_mult = cfg.discriminator_channel_mult
|
||||||
if hasattr(cfg, "mrd_channel_mult"):
|
if hasattr(cfg, "mrd_channel_mult"):
|
||||||
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||||
@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.conv_post = norm_f(
|
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||||
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
fmap = []
|
fmap = []
|
||||||
@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
|
|||||||
def __init__(self, cfg, debug=False):
|
def __init__(self, cfg, debug=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.resolutions = cfg.resolutions
|
self.resolutions = cfg.resolutions
|
||||||
assert (
|
assert len(self.resolutions) == 3, (
|
||||||
len(self.resolutions) == 3
|
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||||
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
|
||||||
self.discriminators = nn.ModuleList(
|
|
||||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
|
||||||
)
|
)
|
||||||
|
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
|
|||||||
convs = lambda: nn.ModuleList(
|
convs = lambda: nn.ModuleList(
|
||||||
[
|
[
|
||||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||||
weight_norm(
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
),
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
weight_norm(
|
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
|
||||||
),
|
|
||||||
weight_norm(
|
|
||||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
|
||||||
),
|
|
||||||
weight_norm(
|
|
||||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||||
|
|
||||||
self.conv_post = weight_norm(
|
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||||
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
# Remove DC offset
|
# Remove DC offset
|
||||||
@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||||
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
|
||||||
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
|
|
||||||
y_d_rs = []
|
y_d_rs = []
|
||||||
y_d_gs = []
|
y_d_gs = []
|
||||||
fmap_rs = []
|
fmap_rs = []
|
||||||
@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
|
|||||||
|
|
||||||
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||||
for i, dilation in enumerate(self.dilations):
|
for i, dilation in enumerate(self.dilations):
|
||||||
out_chs = min(
|
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
|
||||||
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
|
||||||
)
|
|
||||||
self.convs.append(
|
self.convs.append(
|
||||||
weight_norm(
|
weight_norm(
|
||||||
nn.Conv2d(
|
nn.Conv2d(
|
||||||
@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
|
|||||||
in_chs,
|
in_chs,
|
||||||
out_chs,
|
out_chs,
|
||||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||||
padding=self.get_2d_padding(
|
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||||
(self.kernel_size[0], self.kernel_size[0])
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -508,7 +482,7 @@ class DiscriminatorCQT(nn.Module):
|
|||||||
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
||||||
if self.cqtd_normalize_volume:
|
if self.cqtd_normalize_volume:
|
||||||
print(
|
print(
|
||||||
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_2d_padding(
|
def get_2d_padding(
|
||||||
@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
|||||||
# Multi-scale params to loop over
|
# Multi-scale params to loop over
|
||||||
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||||
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
|
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
|
||||||
"cqtd_bins_per_octaves", [24, 36, 48]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.discriminators = nn.ModuleList(
|
self.discriminators = nn.ModuleList(
|
||||||
[
|
[
|
||||||
@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
|
|
||||||
y_d_rs = []
|
y_d_rs = []
|
||||||
y_d_gs = []
|
y_d_gs = []
|
||||||
fmap_rs = []
|
fmap_rs = []
|
||||||
@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.discrimiantor = nn.ModuleList(list_discriminator)
|
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
List[List[torch.Tensor]],
|
List[List[torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
|
|
||||||
y_d_rs = []
|
y_d_rs = []
|
||||||
y_d_gs = []
|
y_d_gs = []
|
||||||
fmap_rs = []
|
fmap_rs = []
|
||||||
|
@ -35,9 +35,7 @@ def inference(a, h):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, filname in enumerate(filelist):
|
for i, filname in enumerate(filelist):
|
||||||
# Load the ground truth audio and resample if necessary
|
# Load the ground truth audio and resample if necessary
|
||||||
wav, sr = librosa.load(
|
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
|
||||||
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
|
|
||||||
)
|
|
||||||
wav = torch.FloatTensor(wav).to(device)
|
wav = torch.FloatTensor(wav).to(device)
|
||||||
# Compute mel spectrogram from the ground truth audio
|
# Compute mel spectrogram from the ground truth audio
|
||||||
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||||
@ -48,9 +46,7 @@ def inference(a, h):
|
|||||||
audio = audio * MAX_WAV_VALUE
|
audio = audio * MAX_WAV_VALUE
|
||||||
audio = audio.cpu().numpy().astype("int16")
|
audio = audio.cpu().numpy().astype("int16")
|
||||||
|
|
||||||
output_file = os.path.join(
|
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
|
||||||
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
|
|
||||||
)
|
|
||||||
write(output_file, h.sampling_rate, audio)
|
write(output_file, h.sampling_rate, audio)
|
||||||
print(output_file)
|
print(output_file)
|
||||||
|
|
||||||
|
@ -61,9 +61,7 @@ def inference(a, h):
|
|||||||
audio = audio * MAX_WAV_VALUE
|
audio = audio * MAX_WAV_VALUE
|
||||||
audio = audio.cpu().numpy().astype("int16")
|
audio = audio.cpu().numpy().astype("int16")
|
||||||
|
|
||||||
output_file = os.path.join(
|
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
|
||||||
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
|
|
||||||
)
|
|
||||||
write(output_file, h.sampling_rate, audio)
|
write(output_file, h.sampling_rate, audio)
|
||||||
print(output_file)
|
print(output_file)
|
||||||
|
|
||||||
|
@ -6,13 +6,12 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from librosa.filters import mel as librosa_mel_fn
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
from typing import Optional, List, Union, Dict, Tuple
|
from typing import List, Tuple
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import math
|
import math
|
||||||
import functools
|
import functools
|
||||||
@ -123,9 +122,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
|||||||
B, C, T = wav.shape
|
B, C, T = wav.shape
|
||||||
|
|
||||||
if match_stride:
|
if match_stride:
|
||||||
assert (
|
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||||
hop_length == window_length // 4
|
|
||||||
), "For match_stride, hop must equal n_fft // 4"
|
|
||||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||||
pad = (window_length - hop_length) // 2
|
pad = (window_length - hop_length) // 2
|
||||||
else:
|
else:
|
||||||
@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
|||||||
magnitude = torch.abs(stft)
|
magnitude = torch.abs(stft)
|
||||||
|
|
||||||
nf = magnitude.shape[2]
|
nf = magnitude.shape[2]
|
||||||
mel_basis = self.get_mel_filters(
|
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||||
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
|
||||||
)
|
|
||||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||||
@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
for n_mels, fmin, fmax, s in zip(
|
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||||
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
|
||||||
):
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"n_mels": n_mels,
|
"n_mels": n_mels,
|
||||||
"fmin": fmin,
|
"fmin": fmin,
|
||||||
@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
|||||||
|
|
||||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||||
x_logmels = torch.log(
|
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||||
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||||
) / torch.log(torch.tensor(10.0))
|
|
||||||
y_logmels = torch.log(
|
|
||||||
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
|
||||||
) / torch.log(torch.tensor(10.0))
|
|
||||||
|
|
||||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||||
@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
# Loss functions
|
# Loss functions
|
||||||
def feature_loss(
|
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||||
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
for dr, dg in zip(fmap_r, fmap_g):
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
for rl, gl in zip(dr, dg):
|
for rl, gl in zip(dr, dg):
|
||||||
@ -226,7 +212,6 @@ def feature_loss(
|
|||||||
def discriminator_loss(
|
def discriminator_loss(
|
||||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
r_losses = []
|
r_losses = []
|
||||||
g_losses = []
|
g_losses = []
|
||||||
@ -243,7 +228,6 @@ def discriminator_loss(
|
|||||||
def generator_loss(
|
def generator_loss(
|
||||||
disc_outputs: List[torch.Tensor],
|
disc_outputs: List[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
gen_losses = []
|
gen_losses = []
|
||||||
for dg in disc_outputs:
|
for dg in disc_outputs:
|
||||||
|
@ -86,9 +86,7 @@ def mel_spectrogram(
|
|||||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||||
|
|
||||||
if key not in mel_basis_cache:
|
if key not in mel_basis_cache:
|
||||||
mel = librosa_mel_fn(
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
|
||||||
)
|
|
||||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||||
|
|
||||||
@ -96,9 +94,7 @@ def mel_spectrogram(
|
|||||||
hann_window = hann_window_cache[key]
|
hann_window = hann_window_cache[key]
|
||||||
|
|
||||||
padding = (n_fft - hop_size) // 2
|
padding = (n_fft - hop_size) // 2
|
||||||
y = torch.nn.functional.pad(
|
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||||
y.unsqueeze(1), (padding, padding), mode="reflect"
|
|
||||||
).squeeze(1)
|
|
||||||
|
|
||||||
spec = torch.stft(
|
spec = torch.stft(
|
||||||
y,
|
y,
|
||||||
@ -150,17 +146,13 @@ def get_dataset_filelist(a):
|
|||||||
|
|
||||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||||
training_files = [
|
training_files = [
|
||||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||||
for x in fi.read().split("\n")
|
|
||||||
if len(x) > 0
|
|
||||||
]
|
]
|
||||||
print(f"first training file: {training_files[0]}")
|
print(f"first training file: {training_files[0]}")
|
||||||
|
|
||||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||||
validation_files = [
|
validation_files = [
|
||||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||||
for x in fi.read().split("\n")
|
|
||||||
if len(x) > 0
|
|
||||||
]
|
]
|
||||||
print(f"first validation file: {validation_files[0]}")
|
print(f"first validation file: {validation_files[0]}")
|
||||||
|
|
||||||
@ -171,9 +163,7 @@ def get_dataset_filelist(a):
|
|||||||
for x in fi.read().split("\n")
|
for x in fi.read().split("\n")
|
||||||
if len(x) > 0
|
if len(x) > 0
|
||||||
]
|
]
|
||||||
print(
|
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
|
||||||
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
|
|
||||||
)
|
|
||||||
list_unseen_validation_files.append(unseen_validation_files)
|
list_unseen_validation_files.append(unseen_validation_files)
|
||||||
|
|
||||||
return training_files, validation_files, list_unseen_validation_files
|
return training_files, validation_files, list_unseen_validation_files
|
||||||
@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
print("[INFO] checking dataset integrity...")
|
print("[INFO] checking dataset integrity...")
|
||||||
for i in tqdm(range(len(self.audio_files))):
|
for i in tqdm(range(len(self.audio_files))):
|
||||||
assert os.path.exists(
|
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
|
||||||
self.audio_files[i]
|
|
||||||
), f"{self.audio_files[i]} not found"
|
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||||
self, index: int
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
|
||||||
try:
|
try:
|
||||||
filename = self.audio_files[index]
|
filename = self.audio_files[index]
|
||||||
|
|
||||||
@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
# Obtain randomized audio chunk
|
# Obtain randomized audio chunk
|
||||||
if source_sampling_rate != self.sampling_rate:
|
if source_sampling_rate != self.sampling_rate:
|
||||||
# Adjust segment size to crop if the source sr is different
|
# Adjust segment size to crop if the source sr is different
|
||||||
target_segment_size = math.ceil(
|
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
|
||||||
self.segment_size
|
|
||||||
* (source_sampling_rate / self.sampling_rate)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
target_segment_size = self.segment_size
|
target_segment_size = self.segment_size
|
||||||
|
|
||||||
# Compute upper bound index for the random chunk
|
# Compute upper bound index for the random chunk
|
||||||
random_chunk_upper_bound = max(
|
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
|
||||||
0, audio.shape[0] - target_segment_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Crop or pad audio to obtain random chunk with target_segment_size
|
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||||
if audio.shape[0] >= target_segment_size:
|
if audio.shape[0] >= target_segment_size:
|
||||||
@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||||
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||||
assert (
|
assert source_sampling_rate == self.sampling_rate, (
|
||||||
source_sampling_rate == self.sampling_rate
|
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||||
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
)
|
||||||
|
|
||||||
# Cast ndarray to torch tensor
|
# Cast ndarray to torch tensor
|
||||||
audio = torch.FloatTensor(audio)
|
audio = torch.FloatTensor(audio)
|
||||||
@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||||
audio = audio[
|
audio = audio[
|
||||||
:,
|
:,
|
||||||
mel_start
|
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
|
||||||
* self.hop_size : (mel_start + frames_per_seg)
|
|
||||||
* self.hop_size,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||||
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||||
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||||
mel = torch.nn.functional.pad(
|
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||||
)
|
|
||||||
audio = torch.nn.functional.pad(
|
|
||||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||||
mel_loss = mel_spectrogram(
|
mel_loss = mel_spectrogram(
|
||||||
@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# Shape sanity checks
|
# Shape sanity checks
|
||||||
assert (
|
assert (
|
||||||
audio.shape[1] == mel.shape[2] * self.hop_size
|
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||||
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
), (
|
||||||
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||||
|
|
||||||
@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
|
|||||||
if self.fine_tuning:
|
if self.fine_tuning:
|
||||||
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||||
else:
|
else:
|
||||||
print(
|
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
|
||||||
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
|
|
||||||
)
|
|
||||||
return self[random.randrange(len(self))]
|
return self[random.randrange(len(self))]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# to import modules from parent_dir
|
# to import modules from parent_dir
|
||||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
|||||||
data = torch.rand((10, 10, 200), device="cuda")
|
data = torch.rand((10, 10, 200), device="cuda")
|
||||||
|
|
||||||
# Check activations.Snake cuda vs. torch
|
# Check activations.Snake cuda vs. torch
|
||||||
fused_anti_alias_activation = activation1d.Activation1d(
|
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
||||||
activation=Snake(10), fused=True
|
|
||||||
).cuda()
|
|
||||||
fused_activation_output = fused_anti_alias_activation(data)
|
fused_activation_output = fused_anti_alias_activation(data)
|
||||||
|
|
||||||
torch_anti_alias_activation = activation1d.Activation1d(
|
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
||||||
activation=Snake(10), fused=False
|
|
||||||
).cuda()
|
|
||||||
torch_activation_output = torch_anti_alias_activation(data)
|
torch_activation_output = torch_anti_alias_activation(data)
|
||||||
|
|
||||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# to import modules from parent_dir
|
# to import modules from parent_dir
|
||||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
|||||||
data = torch.rand((10, 10, 200), device="cuda")
|
data = torch.rand((10, 10, 200), device="cuda")
|
||||||
|
|
||||||
# Check activations, Snake CUDA vs. Torch
|
# Check activations, Snake CUDA vs. Torch
|
||||||
fused_anti_alias_activation = activation1d.Activation1d(
|
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
||||||
activation=SnakeBeta(10), fused=True
|
|
||||||
).cuda()
|
|
||||||
fused_activation_output = fused_anti_alias_activation(data)
|
fused_activation_output = fused_anti_alias_activation(data)
|
||||||
|
|
||||||
torch_anti_alias_activation = activation1d.Activation1d(
|
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
||||||
activation=SnakeBeta(10), fused=False
|
|
||||||
).cuda()
|
|
||||||
torch_activation_output = torch_anti_alias_activation(data)
|
torch_activation_output = torch_anti_alias_activation(data)
|
||||||
|
|
||||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||||
@ -57,7 +54,6 @@ def test_anti_alias_activation():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from alias_free_activation.cuda import load
|
from alias_free_activation.cuda import load
|
||||||
|
|
||||||
|
@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
|
|||||||
|
|
||||||
|
|
||||||
def get_mel(x, h):
|
def get_mel(x, h):
|
||||||
return mel_spectrogram(
|
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
||||||
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(filepath, device):
|
def load_checkpoint(filepath, device):
|
||||||
@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
||||||
description="Test script to check CUDA kernel correctness."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint_file",
|
"--checkpoint_file",
|
||||||
type=str,
|
type=str,
|
||||||
@ -109,9 +105,7 @@ if __name__ == "__main__":
|
|||||||
diff += test_result.mean(dim=-1).item()
|
diff += test_result.mean(dim=-1).item()
|
||||||
|
|
||||||
diff /= num_sample
|
diff /= num_sample
|
||||||
if (
|
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||||
diff <= 2e-3
|
|
||||||
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
|
||||||
print(
|
print(
|
||||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
||||||
f"\n > mean_difference={diff}"
|
f"\n > mean_difference={diff}"
|
||||||
|
@ -77,24 +77,18 @@ def train(rank, a, h):
|
|||||||
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||||
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||||
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||||
print(
|
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||||
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
|
||||||
)
|
|
||||||
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||||
mrd = MultiBandDiscriminator(h).to(device)
|
mrd = MultiBandDiscriminator(h).to(device)
|
||||||
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||||
print(
|
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||||
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
|
||||||
)
|
|
||||||
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||||
else: # Fallback to original MRD in BigVGAN-v1
|
else: # Fallback to original MRD in BigVGAN-v1
|
||||||
mrd = MultiResolutionDiscriminator(h).to(device)
|
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||||
|
|
||||||
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||||
if h.get("use_multiscale_melloss", False):
|
if h.get("use_multiscale_melloss", False):
|
||||||
print(
|
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
|
||||||
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
|
|
||||||
)
|
|
||||||
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||||
sampling_rate=h.sampling_rate
|
sampling_rate=h.sampling_rate
|
||||||
) # NOTE: accepts waveform as input
|
) # NOTE: accepts waveform as input
|
||||||
@ -114,9 +108,7 @@ def train(rank, a, h):
|
|||||||
|
|
||||||
if os.path.isdir(a.checkpoint_path):
|
if os.path.isdir(a.checkpoint_path):
|
||||||
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||||
cp_g = scan_checkpoint(
|
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
|
||||||
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
|
|
||||||
)
|
|
||||||
cp_do = scan_checkpoint(
|
cp_do = scan_checkpoint(
|
||||||
a.checkpoint_path,
|
a.checkpoint_path,
|
||||||
prefix="do_",
|
prefix="do_",
|
||||||
@ -143,9 +135,7 @@ def train(rank, a, h):
|
|||||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||||
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||||
|
|
||||||
optim_g = torch.optim.AdamW(
|
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||||
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
|
|
||||||
)
|
|
||||||
optim_d = torch.optim.AdamW(
|
optim_d = torch.optim.AdamW(
|
||||||
itertools.chain(mrd.parameters(), mpd.parameters()),
|
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||||
h.learning_rate,
|
h.learning_rate,
|
||||||
@ -156,12 +146,8 @@ def train(rank, a, h):
|
|||||||
optim_g.load_state_dict(state_dict_do["optim_g"])
|
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||||
optim_d.load_state_dict(state_dict_do["optim_d"])
|
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||||
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||||
)
|
|
||||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
|
||||||
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define training and validation datasets
|
# Define training and validation datasets
|
||||||
|
|
||||||
@ -169,9 +155,7 @@ def train(rank, a, h):
|
|||||||
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||||
Example: trained on LibriTTS, validate on VCTK
|
Example: trained on LibriTTS, validate on VCTK
|
||||||
"""
|
"""
|
||||||
training_filelist, validation_filelist, list_unseen_validation_filelist = (
|
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
|
||||||
get_dataset_filelist(a)
|
|
||||||
)
|
|
||||||
|
|
||||||
trainset = MelDataset(
|
trainset = MelDataset(
|
||||||
training_filelist,
|
training_filelist,
|
||||||
@ -327,17 +311,12 @@ def train(rank, a, h):
|
|||||||
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
||||||
|
|
||||||
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||||
if (
|
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||||
not "nonspeech" in mode
|
|
||||||
): # Skips if the name of dataset (in mode string) contains "nonspeech"
|
|
||||||
|
|
||||||
# Resample to 16000 for pesq
|
# Resample to 16000 for pesq
|
||||||
y_16k = pesq_resampler(y)
|
y_16k = pesq_resampler(y)
|
||||||
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||||
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||||
y_g_hat_int_16k = (
|
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||||
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
|
||||||
)
|
|
||||||
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||||
|
|
||||||
# MRSTFT calculation
|
# MRSTFT calculation
|
||||||
@ -348,9 +327,7 @@ def train(rank, a, h):
|
|||||||
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||||
if steps >= 0:
|
if steps >= 0:
|
||||||
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||||
if (
|
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||||
a.save_audio
|
|
||||||
): # Also save audio to disk if --save_audio is set to True
|
|
||||||
save_audio(
|
save_audio(
|
||||||
y[0],
|
y[0],
|
||||||
os.path.join(
|
os.path.join(
|
||||||
@ -373,9 +350,7 @@ def train(rank, a, h):
|
|||||||
steps,
|
steps,
|
||||||
h.sampling_rate,
|
h.sampling_rate,
|
||||||
)
|
)
|
||||||
if (
|
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||||
a.save_audio
|
|
||||||
): # Also save audio to disk if --save_audio is set to True
|
|
||||||
save_audio(
|
save_audio(
|
||||||
y_g_hat[0, 0],
|
y_g_hat[0, 0],
|
||||||
os.path.join(
|
os.path.join(
|
||||||
@ -487,15 +462,11 @@ def train(rank, a, h):
|
|||||||
|
|
||||||
# MPD
|
# MPD
|
||||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||||
y_df_hat_r, y_df_hat_g
|
|
||||||
)
|
|
||||||
|
|
||||||
# MRD
|
# MRD
|
||||||
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||||
y_ds_hat_r, y_ds_hat_g
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_disc_all = loss_disc_s + loss_disc_f
|
loss_disc_all = loss_disc_s + loss_disc_f
|
||||||
|
|
||||||
@ -505,17 +476,11 @@ def train(rank, a, h):
|
|||||||
# Whether to freeze D for initial training steps
|
# Whether to freeze D for initial training steps
|
||||||
if steps >= a.freeze_step:
|
if steps >= a.freeze_step:
|
||||||
loss_disc_all.backward()
|
loss_disc_all.backward()
|
||||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
|
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
|
||||||
mpd.parameters(), clip_grad_norm
|
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
|
||||||
)
|
|
||||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
|
|
||||||
mrd.parameters(), clip_grad_norm
|
|
||||||
)
|
|
||||||
optim_d.step()
|
optim_d.step()
|
||||||
else:
|
else:
|
||||||
print(
|
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
|
||||||
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
|
|
||||||
)
|
|
||||||
grad_norm_mpd = 0.0
|
grad_norm_mpd = 0.0
|
||||||
grad_norm_mrd = 0.0
|
grad_norm_mrd = 0.0
|
||||||
|
|
||||||
@ -523,9 +488,7 @@ def train(rank, a, h):
|
|||||||
optim_g.zero_grad()
|
optim_g.zero_grad()
|
||||||
|
|
||||||
# L1 Mel-Spectrogram Loss
|
# L1 Mel-Spectrogram Loss
|
||||||
lambda_melloss = h.get(
|
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
|
||||||
"lambda_melloss", 45.0
|
|
||||||
) # Defaults to 45 in BigVGAN-v1 if not set
|
|
||||||
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||||
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||||
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||||
@ -542,27 +505,19 @@ def train(rank, a, h):
|
|||||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||||
|
|
||||||
if steps >= a.freeze_step:
|
if steps >= a.freeze_step:
|
||||||
loss_gen_all = (
|
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||||
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(
|
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
|
||||||
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
|
|
||||||
)
|
|
||||||
loss_gen_all = loss_mel
|
loss_gen_all = loss_mel
|
||||||
|
|
||||||
loss_gen_all.backward()
|
loss_gen_all.backward()
|
||||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(
|
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
|
||||||
generator.parameters(), clip_grad_norm
|
|
||||||
)
|
|
||||||
optim_g.step()
|
optim_g.step()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# STDOUT logging
|
# STDOUT logging
|
||||||
if steps % a.stdout_interval == 0:
|
if steps % a.stdout_interval == 0:
|
||||||
mel_error = (
|
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
|
||||||
loss_mel.item() / lambda_melloss
|
|
||||||
) # Log training mel regression loss to stdout
|
|
||||||
print(
|
print(
|
||||||
f"Steps: {steps:d}, "
|
f"Steps: {steps:d}, "
|
||||||
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||||
@ -577,11 +532,7 @@ def train(rank, a, h):
|
|||||||
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
{
|
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
|
||||||
"generator": (
|
|
||||||
generator.module if h.num_gpus > 1 else generator
|
|
||||||
).state_dict()
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
@ -598,9 +549,7 @@ def train(rank, a, h):
|
|||||||
|
|
||||||
# Tensorboard summary logging
|
# Tensorboard summary logging
|
||||||
if steps % a.summary_interval == 0:
|
if steps % a.summary_interval == 0:
|
||||||
mel_error = (
|
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
|
||||||
loss_mel.item() / lambda_melloss
|
|
||||||
) # Log training mel regression loss to tensorboard
|
|
||||||
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||||
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||||
@ -612,12 +561,8 @@ def train(rank, a, h):
|
|||||||
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||||
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||||
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||||
sw.add_scalar(
|
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
|
||||||
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
|
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
|
||||||
)
|
|
||||||
sw.add_scalar(
|
|
||||||
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
|
|
||||||
)
|
|
||||||
sw.add_scalar("training/epoch", epoch + 1, steps)
|
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
@ -660,9 +605,7 @@ def train(rank, a, h):
|
|||||||
scheduler_d.step()
|
scheduler_d.step()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
|
||||||
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -674,12 +617,8 @@ def main():
|
|||||||
|
|
||||||
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||||
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||||
parser.add_argument(
|
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
|
||||||
"--input_training_file", default="tests/LibriTTS/train-full.txt"
|
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list_input_unseen_wavs_dir",
|
"--list_input_unseen_wavs_dir",
|
||||||
|
@ -1,49 +1,61 @@
|
|||||||
from copy import deepcopy
|
import gc
|
||||||
import math
|
import math
|
||||||
import os, sys, gc
|
import os
|
||||||
import random
|
import random
|
||||||
import traceback
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
import ffmpeg
|
|
||||||
import os
|
import os
|
||||||
from typing import Generator, List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import ffmpeg
|
||||||
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import yaml
|
import yaml
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
||||||
from tools.audio_sr import AP_BWE
|
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
|
from BigVGAN.bigvgan import BigVGAN
|
||||||
from feature_extractor.cnhubert import CNHubert
|
from feature_extractor.cnhubert import CNHubert
|
||||||
|
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
import librosa
|
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||||
from time import time as ttime
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
|
|
||||||
|
from tools.audio_sr import AP_BWE
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
from tools.my_utils import load_audio
|
from tools.my_utils import load_audio
|
||||||
from module.mel_processing import spectrogram_torch
|
|
||||||
from TTS_infer_pack.text_segmentation_method import splits
|
from TTS_infer_pack.text_segmentation_method import splits
|
||||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||||
from BigVGAN.bigvgan import BigVGAN
|
|
||||||
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
|
|
||||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
i18n = I18nAuto(language=language)
|
i18n = I18nAuto(language=language)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
spec_min = -12
|
spec_min = -12
|
||||||
spec_max = 2
|
spec_max = 2
|
||||||
|
|
||||||
|
|
||||||
def norm_spec(x):
|
def norm_spec(x):
|
||||||
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
||||||
|
|
||||||
|
|
||||||
def denorm_spec(x):
|
def denorm_spec(x):
|
||||||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||||
mel_fn=lambda x: mel_spectrogram_torch(x, **{
|
|
||||||
|
|
||||||
|
mel_fn = lambda x: mel_spectrogram_torch(
|
||||||
|
x,
|
||||||
|
**{
|
||||||
"n_fft": 1024,
|
"n_fft": 1024,
|
||||||
"win_size": 1024,
|
"win_size": 1024,
|
||||||
"hop_size": 256,
|
"hop_size": 256,
|
||||||
@ -51,8 +63,9 @@ mel_fn=lambda x: mel_spectrogram_torch(x, **{
|
|||||||
"sampling_rate": 24000,
|
"sampling_rate": 24000,
|
||||||
"fmin": 0,
|
"fmin": 0,
|
||||||
"fmax": None,
|
"fmax": None,
|
||||||
"center": False
|
"center": False,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
||||||
@ -60,15 +73,14 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
|||||||
raw_audio = input_audio.astype(np.int16).tobytes()
|
raw_audio = input_audio.astype(np.int16).tobytes()
|
||||||
|
|
||||||
# 设置 ffmpeg 输入流
|
# 设置 ffmpeg 输入流
|
||||||
input_stream = ffmpeg.input('pipe:', format='s16le', acodec='pcm_s16le', ar=str(sr), ac=1)
|
input_stream = ffmpeg.input("pipe:", format="s16le", acodec="pcm_s16le", ar=str(sr), ac=1)
|
||||||
|
|
||||||
# 变速处理
|
# 变速处理
|
||||||
output_stream = input_stream.filter('atempo', speed)
|
output_stream = input_stream.filter("atempo", speed)
|
||||||
|
|
||||||
# 输出流到管道
|
# 输出流到管道
|
||||||
out, _ = (
|
out, _ = output_stream.output("pipe:", format="s16le", acodec="pcm_s16le").run(
|
||||||
output_stream.output('pipe:', format='s16le', acodec='pcm_s16le')
|
input=raw_audio, capture_stdout=True, capture_stderr=True
|
||||||
.run(input=raw_audio, capture_stdout=True, capture_stderr=True)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 将管道输出解码为 NumPy 数组
|
# 将管道输出解码为 NumPy 数组
|
||||||
@ -77,14 +89,13 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
|||||||
return processed_audio
|
return processed_audio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
resample_transform_dict = {}
|
resample_transform_dict = {}
|
||||||
|
|
||||||
|
|
||||||
def resample(audio_tensor, sr0, device):
|
def resample(audio_tensor, sr0, device):
|
||||||
global resample_transform_dict
|
global resample_transform_dict
|
||||||
if sr0 not in resample_transform_dict:
|
if sr0 not in resample_transform_dict:
|
||||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(
|
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
|
||||||
sr0, 24000
|
|
||||||
).to(device)
|
|
||||||
return resample_transform_dict[sr0](audio_tensor)
|
return resample_transform_dict[sr0](audio_tensor)
|
||||||
|
|
||||||
|
|
||||||
@ -156,11 +167,12 @@ default_v3:
|
|||||||
version: v3
|
version: v3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
seed = int(seed)
|
seed = int(seed)
|
||||||
seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
||||||
print(f"Set seed to {seed}")
|
print(f"Set seed to {seed}")
|
||||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@ -178,6 +190,7 @@ def set_seed(seed:int):
|
|||||||
pass
|
pass
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
|
||||||
class TTS_Config:
|
class TTS_Config:
|
||||||
default_configs = {
|
default_configs = {
|
||||||
"v1": {
|
"v1": {
|
||||||
@ -225,7 +238,6 @@ class TTS_Config:
|
|||||||
# "auto_yue",#多语种启动切分识别语种
|
# "auto_yue",#多语种启动切分识别语种
|
||||||
|
|
||||||
def __init__(self, configs: Union[dict, str] = None):
|
def __init__(self, configs: Union[dict, str] = None):
|
||||||
|
|
||||||
# 设置默认配置文件路径
|
# 设置默认配置文件路径
|
||||||
configs_base_path: str = "GPT_SoVITS/configs/"
|
configs_base_path: str = "GPT_SoVITS/configs/"
|
||||||
os.makedirs(configs_base_path, exist_ok=True)
|
os.makedirs(configs_base_path, exist_ok=True)
|
||||||
@ -247,10 +259,9 @@ class TTS_Config:
|
|||||||
self.default_configs[version] = configs.get(version, self.default_configs[version])
|
self.default_configs[version] = configs.get(version, self.default_configs[version])
|
||||||
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
|
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
|
||||||
|
|
||||||
|
|
||||||
self.device = self.configs.get("device", torch.device("cpu"))
|
self.device = self.configs.get("device", torch.device("cpu"))
|
||||||
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
||||||
print(f"Warning: CUDA is not available, set device to CPU.")
|
print("Warning: CUDA is not available, set device to CPU.")
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
self.is_half = self.configs.get("is_half", False)
|
self.is_half = self.configs.get("is_half", False)
|
||||||
@ -267,22 +278,20 @@ class TTS_Config:
|
|||||||
|
|
||||||
self.is_v3_synthesizer: bool = False
|
self.is_v3_synthesizer: bool = False
|
||||||
|
|
||||||
|
|
||||||
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)):
|
||||||
self.t2s_weights_path = self.default_configs[version]['t2s_weights_path']
|
self.t2s_weights_path = self.default_configs[version]["t2s_weights_path"]
|
||||||
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}")
|
||||||
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)):
|
||||||
self.vits_weights_path = self.default_configs[version]['vits_weights_path']
|
self.vits_weights_path = self.default_configs[version]["vits_weights_path"]
|
||||||
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
print(f"fall back to default vits_weights_path: {self.vits_weights_path}")
|
||||||
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)):
|
||||||
self.bert_base_path = self.default_configs[version]['bert_base_path']
|
self.bert_base_path = self.default_configs[version]["bert_base_path"]
|
||||||
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
print(f"fall back to default bert_base_path: {self.bert_base_path}")
|
||||||
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)):
|
||||||
self.cnhuhbert_base_path = self.default_configs[version]['cnhuhbert_base_path']
|
self.cnhuhbert_base_path = self.default_configs[version]["cnhuhbert_base_path"]
|
||||||
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}")
|
||||||
self.update_configs()
|
self.update_configs()
|
||||||
|
|
||||||
|
|
||||||
self.max_sec = None
|
self.max_sec = None
|
||||||
self.hz: int = 50
|
self.hz: int = 50
|
||||||
self.semantic_frame_rate: str = "25hz"
|
self.semantic_frame_rate: str = "25hz"
|
||||||
@ -293,15 +302,13 @@ class TTS_Config:
|
|||||||
self.win_length: int = 2048
|
self.win_length: int = 2048
|
||||||
self.n_speakers: int = 300
|
self.n_speakers: int = 300
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _load_configs(self, configs_path: str) -> dict:
|
def _load_configs(self, configs_path: str) -> dict:
|
||||||
if os.path.exists(configs_path):
|
if os.path.exists(configs_path):
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
print(i18n("路径不存在,使用默认配置"))
|
print(i18n("路径不存在,使用默认配置"))
|
||||||
self.save_configs(configs_path)
|
self.save_configs(configs_path)
|
||||||
with open(configs_path, 'r', encoding='utf-8') as f:
|
with open(configs_path, "r", encoding="utf-8") as f:
|
||||||
configs = yaml.load(f, Loader=yaml.FullLoader)
|
configs = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
@ -313,7 +320,7 @@ class TTS_Config:
|
|||||||
|
|
||||||
if configs_path is None:
|
if configs_path is None:
|
||||||
configs_path = self.configs_path
|
configs_path = self.configs_path
|
||||||
with open(configs_path, 'w') as f:
|
with open(configs_path, "w") as f:
|
||||||
yaml.dump(configs, f)
|
yaml.dump(configs, f)
|
||||||
|
|
||||||
def update_configs(self):
|
def update_configs(self):
|
||||||
@ -334,10 +341,10 @@ class TTS_Config:
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
self.configs = self.update_configs()
|
self.configs = self.update_configs()
|
||||||
string = "TTS Config".center(100, '-') + '\n'
|
string = "TTS Config".center(100, "-") + "\n"
|
||||||
for k, v in self.configs.items():
|
for k, v in self.configs.items():
|
||||||
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
string += f"{str(k).ljust(20)}: {str(v)}\n"
|
||||||
string += "-" * 100 + '\n'
|
string += "-" * 100 + "\n"
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -368,11 +375,9 @@ class TTS:
|
|||||||
|
|
||||||
self._init_models()
|
self._init_models()
|
||||||
|
|
||||||
self.text_preprocessor:TextPreprocessor = \
|
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
|
||||||
TextPreprocessor(self.bert_model,
|
self.bert_model, self.bert_tokenizer, self.configs.device
|
||||||
self.bert_tokenizer,
|
)
|
||||||
self.configs.device)
|
|
||||||
|
|
||||||
|
|
||||||
self.prompt_cache: dict = {
|
self.prompt_cache: dict = {
|
||||||
"ref_audio_path": None,
|
"ref_audio_path": None,
|
||||||
@ -386,19 +391,18 @@ class TTS:
|
|||||||
"aux_ref_audio_paths": [],
|
"aux_ref_audio_paths": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
self.stop_flag: bool = False
|
self.stop_flag: bool = False
|
||||||
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||||
|
|
||||||
def _init_models(self,):
|
def _init_models(
|
||||||
|
self,
|
||||||
|
):
|
||||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||||
self.init_vits_weights(self.configs.vits_weights_path)
|
self.init_vits_weights(self.configs.vits_weights_path)
|
||||||
self.init_bert_weights(self.configs.bert_base_path)
|
self.init_bert_weights(self.configs.bert_base_path)
|
||||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||||
# self.enable_half_precision(self.configs.is_half)
|
# self.enable_half_precision(self.configs.is_half)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_cnhuhbert_weights(self, base_path: str):
|
def init_cnhuhbert_weights(self, base_path: str):
|
||||||
print(f"Loading CNHuBERT weights from {base_path}")
|
print(f"Loading CNHuBERT weights from {base_path}")
|
||||||
self.cnhuhbert_model = CNHubert(base_path)
|
self.cnhuhbert_model = CNHubert(base_path)
|
||||||
@ -407,8 +411,6 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_bert_weights(self, base_path: str):
|
def init_bert_weights(self, base_path: str):
|
||||||
print(f"Loading BERT weights from {base_path}")
|
print(f"Loading BERT weights from {base_path}")
|
||||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||||
@ -419,7 +421,6 @@ class TTS:
|
|||||||
self.bert_model = self.bert_model.half()
|
self.bert_model = self.bert_model.half()
|
||||||
|
|
||||||
def init_vits_weights(self, weights_path: str):
|
def init_vits_weights(self, weights_path: str):
|
||||||
|
|
||||||
self.configs.vits_weights_path = weights_path
|
self.configs.vits_weights_path = weights_path
|
||||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
||||||
path_sovits_v3 = self.configs.default_configs["v3"]["vits_weights_path"]
|
path_sovits_v3 = self.configs.default_configs["v3"]["vits_weights_path"]
|
||||||
@ -433,9 +434,9 @@ class TTS:
|
|||||||
hps = dict_s2["config"]
|
hps = dict_s2["config"]
|
||||||
|
|
||||||
hps["model"]["semantic_frame_rate"] = "25hz"
|
hps["model"]["semantic_frame_rate"] = "25hz"
|
||||||
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
|
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
||||||
hps["model"]["version"] = "v2" # v3model,v2sybomls
|
hps["model"]["version"] = "v2" # v3model,v2sybomls
|
||||||
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||||
hps["model"]["version"] = "v1"
|
hps["model"]["version"] = "v1"
|
||||||
else:
|
else:
|
||||||
hps["model"]["version"] = "v2"
|
hps["model"]["version"] = "v2"
|
||||||
@ -460,7 +461,7 @@ class TTS:
|
|||||||
self.configs.filter_length // 2 + 1,
|
self.configs.filter_length // 2 + 1,
|
||||||
self.configs.segment_size // self.configs.hop_length,
|
self.configs.segment_size // self.configs.hop_length,
|
||||||
n_speakers=self.configs.n_speakers,
|
n_speakers=self.configs.n_speakers,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.configs.is_v3_synthesizer = False
|
self.configs.is_v3_synthesizer = False
|
||||||
else:
|
else:
|
||||||
@ -468,7 +469,7 @@ class TTS:
|
|||||||
self.configs.filter_length // 2 + 1,
|
self.configs.filter_length // 2 + 1,
|
||||||
self.configs.segment_size // self.configs.hop_length,
|
self.configs.segment_size // self.configs.hop_length,
|
||||||
n_speakers=self.configs.n_speakers,
|
n_speakers=self.configs.n_speakers,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.configs.is_v3_synthesizer = True
|
self.configs.is_v3_synthesizer = True
|
||||||
self.init_bigvgan()
|
self.init_bigvgan()
|
||||||
@ -476,9 +477,13 @@ class TTS:
|
|||||||
del vits_model.enc_q
|
del vits_model.enc_q
|
||||||
|
|
||||||
if if_lora_v3 == False:
|
if if_lora_v3 == False:
|
||||||
print(f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
|
print(
|
||||||
|
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}")
|
print(
|
||||||
|
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}"
|
||||||
|
)
|
||||||
lora_rank = dict_s2["lora_rank"]
|
lora_rank = dict_s2["lora_rank"]
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||||
@ -487,11 +492,12 @@ class TTS:
|
|||||||
init_lora_weights=True,
|
init_lora_weights=True,
|
||||||
)
|
)
|
||||||
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
|
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
|
||||||
print(f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}")
|
print(
|
||||||
|
f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
vits_model.cfm = vits_model.cfm.merge_and_unload()
|
vits_model.cfm = vits_model.cfm.merge_and_unload()
|
||||||
|
|
||||||
|
|
||||||
vits_model = vits_model.to(self.configs.device)
|
vits_model = vits_model.to(self.configs.device)
|
||||||
vits_model = vits_model.eval()
|
vits_model = vits_model.eval()
|
||||||
|
|
||||||
@ -499,7 +505,6 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||||
self.vits_model = self.vits_model.half()
|
self.vits_model = self.vits_model.half()
|
||||||
|
|
||||||
|
|
||||||
def init_t2s_weights(self, weights_path: str):
|
def init_t2s_weights(self, weights_path: str):
|
||||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||||
self.configs.t2s_weights_path = weights_path
|
self.configs.t2s_weights_path = weights_path
|
||||||
@ -516,11 +521,13 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
|
|
||||||
def init_bigvgan(self):
|
def init_bigvgan(self):
|
||||||
if self.bigvgan_model is not None:
|
if self.bigvgan_model is not None:
|
||||||
return
|
return
|
||||||
self.bigvgan_model = BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
|
self.bigvgan_model = BigVGAN.from_pretrained(
|
||||||
|
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||||
|
use_cuda_kernel=False,
|
||||||
|
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||||
# remove weight norm in the model and set to eval mode
|
# remove weight norm in the model and set to eval mode
|
||||||
self.bigvgan_model.remove_weight_norm()
|
self.bigvgan_model.remove_weight_norm()
|
||||||
self.bigvgan_model = self.bigvgan_model.eval()
|
self.bigvgan_model = self.bigvgan_model.eval()
|
||||||
@ -539,14 +546,13 @@ class TTS:
|
|||||||
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
|
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
|
||||||
self.sr_model_not_exist = True
|
self.sr_model_not_exist = True
|
||||||
|
|
||||||
|
|
||||||
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||||
'''
|
"""
|
||||||
To enable half precision for the TTS model.
|
To enable half precision for the TTS model.
|
||||||
Args:
|
Args:
|
||||||
enable: bool, whether to enable half precision.
|
enable: bool, whether to enable half precision.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
if str(self.configs.device) == "cpu" and enable:
|
if str(self.configs.device) == "cpu" and enable:
|
||||||
print("Half precision is not supported on CPU.")
|
print("Half precision is not supported on CPU.")
|
||||||
return
|
return
|
||||||
@ -579,11 +585,11 @@ class TTS:
|
|||||||
self.bigvgan_model = self.bigvgan_model.float()
|
self.bigvgan_model = self.bigvgan_model.float()
|
||||||
|
|
||||||
def set_device(self, device: torch.device, save: bool = True):
|
def set_device(self, device: torch.device, save: bool = True):
|
||||||
'''
|
"""
|
||||||
To set the device for all models.
|
To set the device for all models.
|
||||||
Args:
|
Args:
|
||||||
device: torch.device, the device to use for all models.
|
device: torch.device, the device to use for all models.
|
||||||
'''
|
"""
|
||||||
self.configs.device = device
|
self.configs.device = device
|
||||||
if save:
|
if save:
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
@ -600,14 +606,13 @@ class TTS:
|
|||||||
if self.sr_model is not None:
|
if self.sr_model is not None:
|
||||||
self.sr_model = self.sr_model.to(device)
|
self.sr_model = self.sr_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
def set_ref_audio(self, ref_audio_path: str):
|
def set_ref_audio(self, ref_audio_path: str):
|
||||||
'''
|
"""
|
||||||
To set the reference audio for the TTS model,
|
To set the reference audio for the TTS model,
|
||||||
including the prompt_semantic and refer_spepc.
|
including the prompt_semantic and refer_spepc.
|
||||||
Args:
|
Args:
|
||||||
ref_audio_path: str, the path of the reference audio.
|
ref_audio_path: str, the path of the reference audio.
|
||||||
'''
|
"""
|
||||||
self._set_prompt_semantic(ref_audio_path)
|
self._set_prompt_semantic(ref_audio_path)
|
||||||
self._set_ref_spec(ref_audio_path)
|
self._set_ref_spec(ref_audio_path)
|
||||||
self._set_ref_audio_path(ref_audio_path)
|
self._set_ref_audio_path(ref_audio_path)
|
||||||
@ -631,7 +636,8 @@ class TTS:
|
|||||||
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
|
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
|
||||||
audio = torch.FloatTensor(audio)
|
audio = torch.FloatTensor(audio)
|
||||||
maxx = audio.abs().max()
|
maxx = audio.abs().max()
|
||||||
if(maxx>1):audio/=min(2,maxx)
|
if maxx > 1:
|
||||||
|
audio /= min(2, maxx)
|
||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
spec = spectrogram_torch(
|
spec = spectrogram_torch(
|
||||||
@ -654,7 +660,7 @@ class TTS:
|
|||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||||
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
|
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
||||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||||
wav16k = torch.from_numpy(wav16k)
|
wav16k = torch.from_numpy(wav16k)
|
||||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||||
@ -665,9 +671,7 @@ class TTS:
|
|||||||
zero_wav_torch = zero_wav_torch.half()
|
zero_wav_torch = zero_wav_torch.half()
|
||||||
|
|
||||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||||
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))[
|
hubert_feature = self.cnhuhbert_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(
|
||||||
"last_hidden_state"
|
|
||||||
].transpose(
|
|
||||||
1, 2
|
1, 2
|
||||||
) # .float()
|
) # .float()
|
||||||
codes = self.vits_model.extract_latent(hubert_feature)
|
codes = self.vits_model.extract_latent(hubert_feature)
|
||||||
@ -696,7 +700,9 @@ class TTS:
|
|||||||
batch = torch.stack(padded_sequences)
|
batch = torch.stack(padded_sequences)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def to_batch(self, data:list,
|
def to_batch(
|
||||||
|
self,
|
||||||
|
data: list,
|
||||||
prompt_data: dict = None,
|
prompt_data: dict = None,
|
||||||
batch_size: int = 5,
|
batch_size: int = 5,
|
||||||
threshold: float = 0.75,
|
threshold: float = 0.75,
|
||||||
@ -739,7 +745,6 @@ class TTS:
|
|||||||
batch_index_list.append([])
|
batch_index_list.append([])
|
||||||
batch_index_list[-1].append(i)
|
batch_index_list[-1].append(i)
|
||||||
|
|
||||||
|
|
||||||
for batch_idx, index_list in enumerate(batch_index_list):
|
for batch_idx, index_list in enumerate(batch_index_list):
|
||||||
item_list = [data[idx] for idx in index_list]
|
item_list = [data[idx] for idx in index_list]
|
||||||
phones_list = []
|
phones_list = []
|
||||||
@ -753,14 +758,14 @@ class TTS:
|
|||||||
all_phones_max_len = 0
|
all_phones_max_len = 0
|
||||||
for item in item_list:
|
for item in item_list:
|
||||||
if prompt_data is not None:
|
if prompt_data is not None:
|
||||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1).to(
|
||||||
.to(dtype=precision, device=device)
|
dtype=precision, device=device
|
||||||
|
)
|
||||||
all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]).to(device)
|
all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]).to(device)
|
||||||
phones = torch.LongTensor(item["phones"]).to(device)
|
phones = torch.LongTensor(item["phones"]).to(device)
|
||||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||||
else:
|
else:
|
||||||
all_bert_features = item["bert_features"]\
|
all_bert_features = item["bert_features"].to(dtype=precision, device=device)
|
||||||
.to(dtype=precision, device=device)
|
|
||||||
phones = torch.LongTensor(item["phones"]).to(device)
|
phones = torch.LongTensor(item["phones"]).to(device)
|
||||||
all_phones = phones
|
all_phones = phones
|
||||||
# norm_text = item["norm_text"]
|
# norm_text = item["norm_text"]
|
||||||
@ -779,7 +784,6 @@ class TTS:
|
|||||||
all_phones_batch = all_phones_list
|
all_phones_batch = all_phones_list
|
||||||
all_bert_features_batch = all_bert_features_list
|
all_bert_features_batch = all_bert_features_list
|
||||||
|
|
||||||
|
|
||||||
max_len = max(all_bert_max_len, all_phones_max_len)
|
max_len = max(all_bert_max_len, all_phones_max_len)
|
||||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
#### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略)
|
||||||
@ -812,7 +816,7 @@ class TTS:
|
|||||||
return _data, batch_index_list
|
return _data, batch_index_list
|
||||||
|
|
||||||
def recovery_order(self, data: list, batch_index_list: list) -> list:
|
def recovery_order(self, data: list, batch_index_list: list) -> list:
|
||||||
'''
|
"""
|
||||||
Recovery the order of the audio according to the batch_index_list.
|
Recovery the order of the audio according to the batch_index_list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -821,7 +825,7 @@ class TTS:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list (List[torch.Tensor]): the data in the original order.
|
list (List[torch.Tensor]): the data in the original order.
|
||||||
'''
|
"""
|
||||||
length = len(sum(batch_index_list, []))
|
length = len(sum(batch_index_list, []))
|
||||||
_data = [None] * length
|
_data = [None] * length
|
||||||
for i, index_list in enumerate(batch_index_list):
|
for i, index_list in enumerate(batch_index_list):
|
||||||
@ -829,10 +833,12 @@ class TTS:
|
|||||||
_data[index] = data[i][j]
|
_data[index] = data[i][j]
|
||||||
return _data
|
return _data
|
||||||
|
|
||||||
def stop(self,):
|
def stop(
|
||||||
'''
|
self,
|
||||||
|
):
|
||||||
|
"""
|
||||||
Stop the inference process.
|
Stop the inference process.
|
||||||
'''
|
"""
|
||||||
self.stop_flag = True
|
self.stop_flag = True
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -933,9 +939,12 @@ class TTS:
|
|||||||
if no_prompt_text and self.configs.is_v3_synthesizer:
|
if no_prompt_text and self.configs.is_v3_synthesizer:
|
||||||
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
|
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
|
||||||
|
|
||||||
if ref_audio_path in [None, ""] and \
|
if ref_audio_path in [None, ""] and (
|
||||||
((self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])):
|
(self.prompt_cache["prompt_semantic"] is None) or (self.prompt_cache["refer_spec"] in [None, []])
|
||||||
raise ValueError("ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()")
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio()"
|
||||||
|
)
|
||||||
|
|
||||||
###### setting reference audio and prompt text preprocessing ########
|
###### setting reference audio and prompt text preprocessing ########
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
@ -959,23 +968,19 @@ class TTS:
|
|||||||
|
|
||||||
if not no_prompt_text:
|
if not no_prompt_text:
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_lang != "en" else "."
|
if prompt_text[-1] not in splits:
|
||||||
|
prompt_text += "。" if prompt_lang != "en" else "."
|
||||||
print(i18n("实际输入的参考文本:"), prompt_text)
|
print(i18n("实际输入的参考文本:"), prompt_text)
|
||||||
if self.prompt_cache["prompt_text"] != prompt_text:
|
if self.prompt_cache["prompt_text"] != prompt_text:
|
||||||
phones, bert_features, norm_text = \
|
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||||
self.text_preprocessor.segment_and_extract_feature_for_text(
|
prompt_text, prompt_lang, self.configs.version
|
||||||
prompt_text,
|
)
|
||||||
prompt_lang,
|
|
||||||
self.configs.version)
|
|
||||||
self.prompt_cache["prompt_text"] = prompt_text
|
self.prompt_cache["prompt_text"] = prompt_text
|
||||||
self.prompt_cache["prompt_lang"] = prompt_lang
|
self.prompt_cache["prompt_lang"] = prompt_lang
|
||||||
self.prompt_cache["phones"] = phones
|
self.prompt_cache["phones"] = phones
|
||||||
self.prompt_cache["bert_features"] = bert_features
|
self.prompt_cache["bert_features"] = bert_features
|
||||||
self.prompt_cache["norm_text"] = norm_text
|
self.prompt_cache["norm_text"] = norm_text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
###### text preprocessing ########
|
###### text preprocessing ########
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
data: list = None
|
data: list = None
|
||||||
@ -986,16 +991,17 @@ class TTS:
|
|||||||
return
|
return
|
||||||
|
|
||||||
batch_index_list: list = None
|
batch_index_list: list = None
|
||||||
data, batch_index_list = self.to_batch(data,
|
data, batch_index_list = self.to_batch(
|
||||||
|
data,
|
||||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
threshold=batch_threshold,
|
threshold=batch_threshold,
|
||||||
split_bucket=split_bucket,
|
split_bucket=split_bucket,
|
||||||
device=self.configs.device,
|
device=self.configs.device,
|
||||||
precision=self.precision
|
precision=self.precision,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f'############ {i18n("切分文本")} ############')
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
|
texts = self.text_preprocessor.pre_seg_text(text, text_lang, text_split_method)
|
||||||
data = []
|
data = []
|
||||||
for i in range(len(texts)):
|
for i in range(len(texts)):
|
||||||
@ -1005,9 +1011,11 @@ class TTS:
|
|||||||
|
|
||||||
def make_batch(batch_texts):
|
def make_batch(batch_texts):
|
||||||
batch_data = []
|
batch_data = []
|
||||||
print(f'############ {i18n("提取文本Bert特征")} ############')
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
for text in tqdm(batch_texts):
|
for text in tqdm(batch_texts):
|
||||||
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(text, text_lang, self.configs.version)
|
phones, bert_features, norm_text = self.text_preprocessor.segment_and_extract_feature_for_text(
|
||||||
|
text, text_lang, self.configs.version
|
||||||
|
)
|
||||||
if phones is None:
|
if phones is None:
|
||||||
continue
|
continue
|
||||||
res = {
|
res = {
|
||||||
@ -1018,17 +1026,17 @@ class TTS:
|
|||||||
batch_data.append(res)
|
batch_data.append(res)
|
||||||
if len(batch_data) == 0:
|
if len(batch_data) == 0:
|
||||||
return None
|
return None
|
||||||
batch, _ = self.to_batch(batch_data,
|
batch, _ = self.to_batch(
|
||||||
|
batch_data,
|
||||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
threshold=batch_threshold,
|
threshold=batch_threshold,
|
||||||
split_bucket=False,
|
split_bucket=False,
|
||||||
device=self.configs.device,
|
device=self.configs.device,
|
||||||
precision=self.precision
|
precision=self.precision,
|
||||||
)
|
)
|
||||||
return batch[0]
|
return batch[0]
|
||||||
|
|
||||||
|
|
||||||
t2 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
print("############ 推理 ############")
|
print("############ 推理 ############")
|
||||||
@ -1057,7 +1065,9 @@ class TTS:
|
|||||||
if no_prompt_text:
|
if no_prompt_text:
|
||||||
prompt = None
|
prompt = None
|
||||||
else:
|
else:
|
||||||
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
prompt = (
|
||||||
|
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||||
|
)
|
||||||
|
|
||||||
print(f"############ {i18n('预测语义Token')} ############")
|
print(f"############ {i18n('预测语义Token')} ############")
|
||||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||||
@ -1076,8 +1086,10 @@ class TTS:
|
|||||||
t4 = time.perf_counter()
|
t4 = time.perf_counter()
|
||||||
t_34 += t4 - t3
|
t_34 += t4 - t3
|
||||||
|
|
||||||
refer_audio_spec:torch.Tensor = [item.to(dtype=self.precision, device=self.configs.device) for item in self.prompt_cache["refer_spec"]]
|
refer_audio_spec: torch.Tensor = [
|
||||||
|
item.to(dtype=self.precision, device=self.configs.device)
|
||||||
|
for item in self.prompt_cache["refer_spec"]
|
||||||
|
]
|
||||||
|
|
||||||
batch_audio_fragment = []
|
batch_audio_fragment = []
|
||||||
|
|
||||||
@ -1100,59 +1112,64 @@ class TTS:
|
|||||||
# ## vits并行推理 method 2
|
# ## vits并行推理 method 2
|
||||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||||
audio_frag_idx = [pred_semantic_list[i].shape[0]*2*upsample_rate for i in range(0, len(pred_semantic_list))]
|
audio_frag_idx = [
|
||||||
|
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||||||
|
for i in range(0, len(pred_semantic_list))
|
||||||
|
]
|
||||||
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
audio_frag_end_idx = [sum(audio_frag_idx[: i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||||
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
all_pred_semantic = (
|
||||||
|
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||||
|
)
|
||||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||||
_batch_audio_fragment = (self.vits_model.decode(
|
_batch_audio_fragment = self.vits_model.decode(
|
||||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||||
).detach()[0, 0, :])
|
).detach()[0, 0, :]
|
||||||
audio_frag_end_idx.insert(0, 0)
|
audio_frag_end_idx.insert(0, 0)
|
||||||
batch_audio_fragment= [_batch_audio_fragment[audio_frag_end_idx[i-1]:audio_frag_end_idx[i]] for i in range(1, len(audio_frag_end_idx))]
|
batch_audio_fragment = [
|
||||||
|
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||||
|
for i in range(1, len(audio_frag_end_idx))
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
# ## vits串行推理
|
# ## vits串行推理
|
||||||
for i, idx in enumerate(tqdm(idx_list)):
|
for i, idx in enumerate(tqdm(idx_list)):
|
||||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
_pred_semantic = (
|
||||||
audio_fragment =(self.vits_model.decode(
|
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||||
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
|
audio_fragment = self.vits_model.decode(
|
||||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||||
).detach()[0, 0, :])
|
).detach()[0, 0, :]
|
||||||
batch_audio_fragment.append(
|
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||||
audio_fragment
|
|
||||||
) ###试试重建不带上prompt部分
|
|
||||||
else:
|
else:
|
||||||
if parallel_infer:
|
if parallel_infer:
|
||||||
print(f"{i18n('并行合成中')}...")
|
print(f"{i18n('并行合成中')}...")
|
||||||
audio_fragments = self.v3_synthesis_batched_infer(
|
audio_fragments = self.v3_synthesis_batched_infer(
|
||||||
idx_list,
|
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
||||||
pred_semantic_list,
|
|
||||||
batch_phones,
|
|
||||||
speed=speed_factor,
|
|
||||||
sample_steps=sample_steps
|
|
||||||
)
|
)
|
||||||
batch_audio_fragment.extend(audio_fragments)
|
batch_audio_fragment.extend(audio_fragments)
|
||||||
else:
|
else:
|
||||||
for i, idx in enumerate(tqdm(idx_list)):
|
for i, idx in enumerate(tqdm(idx_list)):
|
||||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||||
_pred_semantic = (pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)) # .unsqueeze(0)#mq要多unsqueeze一次
|
_pred_semantic = (
|
||||||
|
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||||
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
audio_fragment = self.v3_synthesis(
|
audio_fragment = self.v3_synthesis(
|
||||||
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
||||||
)
|
)
|
||||||
batch_audio_fragment.append(
|
batch_audio_fragment.append(audio_fragment)
|
||||||
audio_fragment
|
|
||||||
)
|
|
||||||
|
|
||||||
t5 = time.perf_counter()
|
t5 = time.perf_counter()
|
||||||
t_45 += t5 - t4
|
t_45 += t5 - t4
|
||||||
if return_fragment:
|
if return_fragment:
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||||
yield self.audio_postprocess([batch_audio_fragment],
|
yield self.audio_postprocess(
|
||||||
|
[batch_audio_fragment],
|
||||||
output_sr,
|
output_sr,
|
||||||
None,
|
None,
|
||||||
speed_factor,
|
speed_factor,
|
||||||
False,
|
False,
|
||||||
fragment_interval,
|
fragment_interval,
|
||||||
super_sampling if self.configs.is_v3_synthesizer else False
|
super_sampling if self.configs.is_v3_synthesizer else False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
audio.append(batch_audio_fragment)
|
audio.append(batch_audio_fragment)
|
||||||
@ -1166,13 +1183,14 @@ class TTS:
|
|||||||
if len(audio) == 0:
|
if len(audio) == 0:
|
||||||
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
yield 16000, np.zeros(int(16000), dtype=np.int16)
|
||||||
return
|
return
|
||||||
yield self.audio_postprocess(audio,
|
yield self.audio_postprocess(
|
||||||
|
audio,
|
||||||
output_sr,
|
output_sr,
|
||||||
batch_index_list,
|
batch_index_list,
|
||||||
speed_factor,
|
speed_factor,
|
||||||
split_bucket,
|
split_bucket,
|
||||||
fragment_interval,
|
fragment_interval,
|
||||||
super_sampling if self.configs.is_v3_synthesizer else False
|
super_sampling if self.configs.is_v3_synthesizer else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1200,7 +1218,8 @@ class TTS:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def audio_postprocess(self,
|
def audio_postprocess(
|
||||||
|
self,
|
||||||
audio: List[torch.Tensor],
|
audio: List[torch.Tensor],
|
||||||
sr: int,
|
sr: int,
|
||||||
batch_index_list: list = None,
|
batch_index_list: list = None,
|
||||||
@ -1210,19 +1229,17 @@ class TTS:
|
|||||||
super_sampling: bool = False,
|
super_sampling: bool = False,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> Tuple[int, np.ndarray]:
|
||||||
zero_wav = torch.zeros(
|
zero_wav = torch.zeros(
|
||||||
int(self.configs.sampling_rate * fragment_interval),
|
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||||||
dtype=self.precision,
|
|
||||||
device=self.configs.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, batch in enumerate(audio):
|
for i, batch in enumerate(audio):
|
||||||
for j, audio_fragment in enumerate(batch):
|
for j, audio_fragment in enumerate(batch):
|
||||||
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音
|
||||||
if max_audio>1: audio_fragment/=max_audio
|
if max_audio > 1:
|
||||||
|
audio_fragment /= max_audio
|
||||||
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
|
||||||
audio[i][j] = audio_fragment
|
audio[i][j] = audio_fragment
|
||||||
|
|
||||||
|
|
||||||
if split_bucket:
|
if split_bucket:
|
||||||
audio = self.recovery_order(audio, batch_index_list)
|
audio = self.recovery_order(audio, batch_index_list)
|
||||||
else:
|
else:
|
||||||
@ -1238,7 +1255,8 @@ class TTS:
|
|||||||
if not self.sr_model_not_exist:
|
if not self.sr_model_not_exist:
|
||||||
audio, sr = self.sr_model(audio.unsqueeze(0), sr)
|
audio, sr = self.sr_model(audio.unsqueeze(0), sr)
|
||||||
max_audio = np.abs(audio).max()
|
max_audio = np.abs(audio).max()
|
||||||
if max_audio > 1: audio /= max_audio
|
if max_audio > 1:
|
||||||
|
audio /= max_audio
|
||||||
t2 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
print(f"超采样用时:{t2 - t1:.3f}s")
|
print(f"超采样用时:{t2 - t1:.3f}s")
|
||||||
else:
|
else:
|
||||||
@ -1254,14 +1272,9 @@ class TTS:
|
|||||||
|
|
||||||
return sr, audio
|
return sr, audio
|
||||||
|
|
||||||
|
def v3_synthesis(
|
||||||
def v3_synthesis(self,
|
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
|
||||||
semantic_tokens:torch.Tensor,
|
|
||||||
phones:torch.Tensor,
|
|
||||||
speed:float=1.0,
|
|
||||||
sample_steps:int=32
|
|
||||||
):
|
):
|
||||||
|
|
||||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||||
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
|
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
|
||||||
@ -1270,7 +1283,7 @@ class TTS:
|
|||||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||||
ref_sr = self.prompt_cache["raw_sr"]
|
ref_sr = self.prompt_cache["raw_sr"]
|
||||||
ref_audio = ref_audio.to(self.configs.device).float()
|
ref_audio = ref_audio.to(self.configs.device).float()
|
||||||
if (ref_audio.shape[0] == 2):
|
if ref_audio.shape[0] == 2:
|
||||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||||
if ref_sr != 24000:
|
if ref_sr != 24000:
|
||||||
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
|
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
|
||||||
@ -1280,7 +1293,7 @@ class TTS:
|
|||||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
mel2 = mel2[:, :, :T_min]
|
mel2 = mel2[:, :, :T_min]
|
||||||
fea_ref = fea_ref[:, :, :T_min]
|
fea_ref = fea_ref[:, :, :T_min]
|
||||||
if (T_min > 468):
|
if T_min > 468:
|
||||||
mel2 = mel2[:, :, -468:]
|
mel2 = mel2[:, :, -468:]
|
||||||
fea_ref = fea_ref[:, :, -468:]
|
fea_ref = fea_ref[:, :, -468:]
|
||||||
T_min = 468
|
T_min = 468
|
||||||
@ -1291,13 +1304,16 @@ class TTS:
|
|||||||
|
|
||||||
cfm_resss = []
|
cfm_resss = []
|
||||||
idx = 0
|
idx = 0
|
||||||
while (1):
|
while 1:
|
||||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||||
if (fea_todo_chunk.shape[-1] == 0): break
|
if fea_todo_chunk.shape[-1] == 0:
|
||||||
|
break
|
||||||
idx += chunk_len
|
idx += chunk_len
|
||||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||||
|
|
||||||
cfm_res = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
|
cfm_res = self.vits_model.cfm.inference(
|
||||||
|
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||||
|
)
|
||||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||||
|
|
||||||
mel2 = cfm_res[:, :, -T_min:]
|
mel2 = cfm_res[:, :, -T_min:]
|
||||||
@ -1307,23 +1323,20 @@ class TTS:
|
|||||||
cfm_res = torch.cat(cfm_resss, 2)
|
cfm_res = torch.cat(cfm_resss, 2)
|
||||||
cfm_res = denorm_spec(cfm_res)
|
cfm_res = denorm_spec(cfm_res)
|
||||||
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
wav_gen = self.bigvgan_model(cfm_res)
|
wav_gen = self.bigvgan_model(cfm_res)
|
||||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
def v3_synthesis_batched_infer(
|
||||||
|
self,
|
||||||
def v3_synthesis_batched_infer(self,
|
|
||||||
idx_list: List[int],
|
idx_list: List[int],
|
||||||
semantic_tokens_list: List[torch.Tensor],
|
semantic_tokens_list: List[torch.Tensor],
|
||||||
batch_phones: List[torch.Tensor],
|
batch_phones: List[torch.Tensor],
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
sample_steps:int=32
|
sample_steps: int = 32,
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
|
|
||||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||||
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
|
refer_audio_spec = self.prompt_cache["refer_spec"][0].to(dtype=self.precision, device=self.configs.device)
|
||||||
@ -1332,7 +1345,7 @@ class TTS:
|
|||||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||||
ref_sr = self.prompt_cache["raw_sr"]
|
ref_sr = self.prompt_cache["raw_sr"]
|
||||||
ref_audio = ref_audio.to(self.configs.device).float()
|
ref_audio = ref_audio.to(self.configs.device).float()
|
||||||
if (ref_audio.shape[0] == 2):
|
if ref_audio.shape[0] == 2:
|
||||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||||
if ref_sr != 24000:
|
if ref_sr != 24000:
|
||||||
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
|
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
|
||||||
@ -1342,7 +1355,7 @@ class TTS:
|
|||||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
mel2 = mel2[:, :, :T_min]
|
mel2 = mel2[:, :, :T_min]
|
||||||
fea_ref = fea_ref[:, :, :T_min]
|
fea_ref = fea_ref[:, :, :T_min]
|
||||||
if (T_min > 468):
|
if T_min > 468:
|
||||||
mel2 = mel2[:, :, -468:]
|
mel2 = mel2[:, :, -468:]
|
||||||
fea_ref = fea_ref[:, :, -468:]
|
fea_ref = fea_ref[:, :, -468:]
|
||||||
T_min = 468
|
T_min = 468
|
||||||
@ -1350,7 +1363,6 @@ class TTS:
|
|||||||
|
|
||||||
mel2 = mel2.to(self.precision)
|
mel2 = mel2.to(self.precision)
|
||||||
|
|
||||||
|
|
||||||
# #### batched inference
|
# #### batched inference
|
||||||
overlapped_len = 12
|
overlapped_len = 12
|
||||||
feat_chunks = []
|
feat_chunks = []
|
||||||
@ -1359,7 +1371,9 @@ class TTS:
|
|||||||
|
|
||||||
for i, idx in enumerate(idx_list):
|
for i, idx in enumerate(idx_list):
|
||||||
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
phones = batch_phones[i].unsqueeze(0).to(self.configs.device)
|
||||||
semantic_tokens = semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
semantic_tokens = (
|
||||||
|
semantic_tokens_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||||
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
feat, _ = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
||||||
feat_list.append(feat)
|
feat_list.append(feat)
|
||||||
feat_lens.append(feat.shape[2])
|
feat_lens.append(feat.shape[2])
|
||||||
@ -1375,7 +1389,8 @@ class TTS:
|
|||||||
pos = pos - overlapped_len
|
pos = pos - overlapped_len
|
||||||
chunk = feats_padded[:, :, pos : pos + chunk_len]
|
chunk = feats_padded[:, :, pos : pos + chunk_len]
|
||||||
pos += chunk_len
|
pos += chunk_len
|
||||||
if (chunk.shape[-1] == 0): break
|
if chunk.shape[-1] == 0:
|
||||||
|
break
|
||||||
|
|
||||||
# padding for the last chunk
|
# padding for the last chunk
|
||||||
padding_len = chunk_len - chunk.shape[2]
|
padding_len = chunk_len - chunk.shape[2]
|
||||||
@ -1383,26 +1398,24 @@ class TTS:
|
|||||||
chunk = F.pad(chunk, (0, padding_len), "constant", 0)
|
chunk = F.pad(chunk, (0, padding_len), "constant", 0)
|
||||||
feat_chunks.append(chunk)
|
feat_chunks.append(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
feat_chunks = torch.cat(feat_chunks, 0)
|
feat_chunks = torch.cat(feat_chunks, 0)
|
||||||
bs = feat_chunks.shape[0]
|
bs = feat_chunks.shape[0]
|
||||||
fea_ref = fea_ref.repeat(bs, 1, 1)
|
fea_ref = fea_ref.repeat(bs, 1, 1)
|
||||||
fea = torch.cat([fea_ref, feat_chunks], 2).transpose(2, 1)
|
fea = torch.cat([fea_ref, feat_chunks], 2).transpose(2, 1)
|
||||||
pred_spec = self.vits_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
|
pred_spec = self.vits_model.cfm.inference(
|
||||||
|
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||||
|
)
|
||||||
pred_spec = pred_spec[:, :, -chunk_len:]
|
pred_spec = pred_spec[:, :, -chunk_len:]
|
||||||
dd = pred_spec.shape[1]
|
dd = pred_spec.shape[1]
|
||||||
pred_spec = pred_spec.permute(1, 0, 2).contiguous().view(dd, -1).unsqueeze(0)
|
pred_spec = pred_spec.permute(1, 0, 2).contiguous().view(dd, -1).unsqueeze(0)
|
||||||
# pred_spec = pred_spec[..., :-padding_len]
|
# pred_spec = pred_spec[..., :-padding_len]
|
||||||
|
|
||||||
|
|
||||||
pred_spec = denorm_spec(pred_spec)
|
pred_spec = denorm_spec(pred_spec)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav_gen = self.bigvgan_model(pred_spec)
|
wav_gen = self.bigvgan_model(pred_spec)
|
||||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
audio_fragments = []
|
audio_fragments = []
|
||||||
upsample_rate = 256
|
upsample_rate = 256
|
||||||
pos = 0
|
pos = 0
|
||||||
@ -1421,16 +1434,13 @@ class TTS:
|
|||||||
audio_fragments.append(audio_fragment)
|
audio_fragments.append(audio_fragment)
|
||||||
audio = audio[feat_len * upsample_rate :]
|
audio = audio[feat_len * upsample_rate :]
|
||||||
|
|
||||||
|
|
||||||
return audio_fragments
|
return audio_fragments
|
||||||
|
|
||||||
|
def sola_algorithm(
|
||||||
|
self,
|
||||||
def sola_algorithm(self,
|
|
||||||
audio_fragments: List[torch.Tensor],
|
audio_fragments: List[torch.Tensor],
|
||||||
overlap_len: int,
|
overlap_len: int,
|
||||||
):
|
):
|
||||||
|
|
||||||
for i in range(len(audio_fragments) - 1):
|
for i in range(len(audio_fragments) - 1):
|
||||||
f1 = audio_fragments[i]
|
f1 = audio_fragments[i]
|
||||||
f2 = audio_fragments[i + 1]
|
f2 = audio_fragments[i + 1]
|
||||||
@ -1444,11 +1454,10 @@ class TTS:
|
|||||||
|
|
||||||
f2_ = f2[idx:]
|
f2_ = f2[idx:]
|
||||||
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
|
window = torch.hann_window((overlap_len - idx) * 2, device=f1.device, dtype=f1.dtype)
|
||||||
f2_[:(overlap_len-idx)] = window[:(overlap_len-idx)]*f2_[:(overlap_len-idx)] + window[(overlap_len-idx):]*f1[-(overlap_len-idx):]
|
f2_[: (overlap_len - idx)] = (
|
||||||
|
window[: (overlap_len - idx)] * f2_[: (overlap_len - idx)]
|
||||||
|
+ window[(overlap_len - idx) :] * f1[-(overlap_len - idx) :]
|
||||||
|
)
|
||||||
audio_fragments[i + 1] = f2_
|
audio_fragments[i + 1] = f2_
|
||||||
|
|
||||||
|
|
||||||
return torch.cat(audio_fragments, 0)
|
return torch.cat(audio_fragments, 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
|
import os
|
||||||
import os, sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
|
|
||||||
@ -21,13 +22,15 @@ from tools.i18n.i18n import I18nAuto, scan_language_list
|
|||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
i18n = I18nAuto(language=language)
|
i18n = I18nAuto(language=language)
|
||||||
punctuation = set(['!', '?', '…', ',', '.', '-'])
|
punctuation = set(["!", "?", "…", ",", ".", "-"])
|
||||||
|
|
||||||
|
|
||||||
def get_first(text: str) -> str:
|
def get_first(text: str) -> str:
|
||||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||||
text = re.split(pattern, text)[0].strip()
|
text = re.split(pattern, text)[0].strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||||
if (len(texts)) < 2:
|
if (len(texts)) < 2:
|
||||||
return texts
|
return texts
|
||||||
@ -38,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
|||||||
if len(text) >= threshold:
|
if len(text) >= threshold:
|
||||||
result.append(text)
|
result.append(text)
|
||||||
text = ""
|
text = ""
|
||||||
if (len(text) > 0):
|
if len(text) > 0:
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
result.append(text)
|
result.append(text)
|
||||||
else:
|
else:
|
||||||
@ -46,23 +49,19 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextPreprocessor:
|
class TextPreprocessor:
|
||||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||||
tokenizer:AutoTokenizer, device:torch.device):
|
|
||||||
self.bert_model = bert_model
|
self.bert_model = bert_model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = device
|
self.device = device
|
||||||
self.bert_lock = threading.RLock()
|
self.bert_lock = threading.RLock()
|
||||||
|
|
||||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
print(f'############ {i18n("切分文本")} ############')
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
text = self.replace_consecutive_punctuation(text)
|
text = self.replace_consecutive_punctuation(text)
|
||||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
result = []
|
result = []
|
||||||
print(f'############ {i18n("提取文本Bert特征")} ############')
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
for text in tqdm(texts):
|
for text in tqdm(texts):
|
||||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
if phones is None or norm_text == "":
|
if phones is None or norm_text == "":
|
||||||
@ -79,7 +78,7 @@ class TextPreprocessor:
|
|||||||
text = text.strip("\n")
|
text = text.strip("\n")
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
return []
|
return []
|
||||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
if text[0] not in splits and len(get_first(text)) < 4:
|
||||||
text = "。" + text if lang != "en" else "." + text
|
text = "。" + text if lang != "en" else "." + text
|
||||||
print(i18n("实际输入的目标文本:"))
|
print(i18n("实际输入的目标文本:"))
|
||||||
print(text)
|
print(text)
|
||||||
@ -95,18 +94,18 @@ class TextPreprocessor:
|
|||||||
_texts = merge_short_text_in_array(_texts, 5)
|
_texts = merge_short_text_in_array(_texts, 5)
|
||||||
texts = []
|
texts = []
|
||||||
|
|
||||||
|
|
||||||
for text in _texts:
|
for text in _texts:
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if (len(text.strip()) == 0):
|
if len(text.strip()) == 0:
|
||||||
continue
|
continue
|
||||||
if not re.sub("\W+", "", text):
|
if not re.sub("\W+", "", text):
|
||||||
# 检测一下,如果是纯符号,就跳过。
|
# 检测一下,如果是纯符号,就跳过。
|
||||||
continue
|
continue
|
||||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
if text[-1] not in splits:
|
||||||
|
text += "。" if lang != "en" else "."
|
||||||
|
|
||||||
# 解决句子过长导致Bert报错的问题
|
# 解决句子过长导致Bert报错的问题
|
||||||
if (len(text) > 510):
|
if len(text) > 510:
|
||||||
texts.extend(split_big_text(text))
|
texts.extend(split_big_text(text))
|
||||||
else:
|
else:
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
@ -115,7 +114,9 @@ class TextPreprocessor:
|
|||||||
print(texts)
|
print(texts)
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
|
def segment_and_extract_feature_for_text(
|
||||||
|
self, text: str, language: str, version: str = "v1"
|
||||||
|
) -> Tuple[list, torch.Tensor, str]:
|
||||||
return self.get_phones_and_bert(text, language, version)
|
return self.get_phones_and_bert(text, language, version)
|
||||||
|
|
||||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||||
@ -126,15 +127,15 @@ class TextPreprocessor:
|
|||||||
while " " in formattext:
|
while " " in formattext:
|
||||||
formattext = formattext.replace(" ", " ")
|
formattext = formattext.replace(" ", " ")
|
||||||
if language == "all_zh":
|
if language == "all_zh":
|
||||||
if re.search(r'[A-Za-z]', formattext):
|
if re.search(r"[A-Za-z]", formattext):
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
return self.get_phones_and_bert(formattext, "zh", version)
|
return self.get_phones_and_bert(formattext, "zh", version)
|
||||||
else:
|
else:
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
return self.get_phones_and_bert(formattext, "yue", version)
|
return self.get_phones_and_bert(formattext, "yue", version)
|
||||||
else:
|
else:
|
||||||
@ -178,14 +179,13 @@ class TextPreprocessor:
|
|||||||
bert_list.append(bert)
|
bert_list.append(bert)
|
||||||
bert = torch.cat(bert_list, dim=1)
|
bert = torch.cat(bert_list, dim=1)
|
||||||
phones = sum(phones_list, [])
|
phones = sum(phones_list, [])
|
||||||
norm_text = ''.join(norm_text_list)
|
norm_text = "".join(norm_text_list)
|
||||||
|
|
||||||
if not final and len(phones) < 6:
|
if not final and len(phones) < 6:
|
||||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||||
|
|
||||||
return phones, bert, norm_text
|
return phones, bert, norm_text
|
||||||
|
|
||||||
|
|
||||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = self.tokenizer(text, return_tensors="pt")
|
inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
@ -219,7 +219,6 @@ class TextPreprocessor:
|
|||||||
|
|
||||||
return feature
|
return feature
|
||||||
|
|
||||||
|
|
||||||
def filter_text(self, texts):
|
def filter_text(self, texts):
|
||||||
_text = []
|
_text = []
|
||||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||||
@ -231,9 +230,8 @@ class TextPreprocessor:
|
|||||||
_text.append(text)
|
_text.append(text)
|
||||||
return _text
|
return _text
|
||||||
|
|
||||||
|
|
||||||
def replace_consecutive_punctuation(self, text):
|
def replace_consecutive_punctuation(self, text):
|
||||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||||
result = re.sub(pattern, r'\1', text)
|
result = re.sub(pattern, r"\1", text)
|
||||||
return result
|
return result
|
||||||
|
@ -1,40 +1,56 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
||||||
METHODS = dict()
|
METHODS = dict()
|
||||||
|
|
||||||
|
|
||||||
def get_method(name: str) -> Callable:
|
def get_method(name: str) -> Callable:
|
||||||
method = METHODS.get(name, None)
|
method = METHODS.get(name, None)
|
||||||
if method is None:
|
if method is None:
|
||||||
raise ValueError(f"Method {name} not found")
|
raise ValueError(f"Method {name} not found")
|
||||||
return method
|
return method
|
||||||
|
|
||||||
|
|
||||||
def get_method_names() -> list:
|
def get_method_names() -> list:
|
||||||
return list(METHODS.keys())
|
return list(METHODS.keys())
|
||||||
|
|
||||||
|
|
||||||
def register_method(name):
|
def register_method(name):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
METHODS[name] = func
|
METHODS[name] = func
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
|
||||||
|
splits = {
|
||||||
|
",",
|
||||||
|
"。",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
",",
|
||||||
|
".",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
"~",
|
||||||
|
":",
|
||||||
|
":",
|
||||||
|
"—",
|
||||||
|
"…",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def split_big_text(text, max_len=510):
|
def split_big_text(text, max_len=510):
|
||||||
# 定义全角和半角标点符号
|
# 定义全角和半角标点符号
|
||||||
punctuation = "".join(splits)
|
punctuation = "".join(splits)
|
||||||
|
|
||||||
# 切割文本
|
# 切割文本
|
||||||
segments = re.split('([' + punctuation + '])', text)
|
segments = re.split("([" + punctuation + "])", text)
|
||||||
|
|
||||||
# 初始化结果列表和当前片段
|
# 初始化结果列表和当前片段
|
||||||
result = []
|
result = []
|
||||||
current_segment = ''
|
current_segment = ""
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||||
@ -51,7 +67,6 @@ def split_big_text(text, max_len=510):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||||
if todo_text[-1] not in splits:
|
if todo_text[-1] not in splits:
|
||||||
@ -123,6 +138,7 @@ def cut2(inp):
|
|||||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
# 按中文句号。切
|
# 按中文句号。切
|
||||||
@register_method("cut3")
|
@register_method("cut3")
|
||||||
def cut3(inp):
|
def cut3(inp):
|
||||||
@ -131,26 +147,28 @@ def cut3(inp):
|
|||||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
# 按英文句号.切
|
# 按英文句号.切
|
||||||
@register_method("cut4")
|
@register_method("cut4")
|
||||||
def cut4(inp):
|
def cut4(inp):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
|
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
||||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
# 按标点符号切
|
# 按标点符号切
|
||||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||||
@register_method("cut5")
|
@register_method("cut5")
|
||||||
def cut5(inp):
|
def cut5(inp):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
|
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||||
mergeitems = []
|
mergeitems = []
|
||||||
items = []
|
items = []
|
||||||
|
|
||||||
for i, char in enumerate(inp):
|
for i, char in enumerate(inp):
|
||||||
if char in punds:
|
if char in punds:
|
||||||
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||||
items.append(char)
|
items.append(char)
|
||||||
else:
|
else:
|
||||||
items.append(char)
|
items.append(char)
|
||||||
@ -166,8 +184,6 @@ def cut5(inp):
|
|||||||
return "\n".join(opt)
|
return "\n".join(opt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__ == '__main__':
|
|
||||||
method = get_method("cut5")
|
method = get_method("cut5")
|
||||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||||
|
|
||||||
|
@ -1,5 +1,13 @@
|
|||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.insert(0, now_dir)
|
sys.path.insert(0, now_dir)
|
||||||
from text.g2pw import G2PWPinyin
|
from text.g2pw import G2PWPinyin
|
||||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
|
||||||
|
g2pw = G2PWPinyin(
|
||||||
|
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||||
|
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||||
|
v_to_u=False,
|
||||||
|
neutral_tone_with_five=True,
|
||||||
|
)
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from my_utils import load_audio
|
from my_utils import load_audio
|
||||||
from text import cleaned_text_to_sequence
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
@ -33,6 +32,7 @@ default_config = {
|
|||||||
"EOS": 1024,
|
"EOS": 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||||
config = dict_s1["config"]
|
config = dict_s1["config"]
|
||||||
config["model"]["dropout"] = float(config["model"]["dropout"])
|
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||||||
@ -41,6 +41,7 @@ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
|||||||
t2s_model = t2s_model.eval()
|
t2s_model = t2s_model.eval()
|
||||||
return t2s_model
|
return t2s_model
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def logits_to_probs(
|
def logits_to_probs(
|
||||||
logits,
|
logits,
|
||||||
@ -57,21 +58,15 @@ def logits_to_probs(
|
|||||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
||||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
|
||||||
)
|
|
||||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cum_probs = torch.cumsum(
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
|
||||||
)
|
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
logits = logits / max(temperature, 1e-5)
|
logits = logits / max(temperature, 1e-5)
|
||||||
@ -84,12 +79,14 @@ def logits_to_probs(
|
|||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def multinomial_sample_one_no_sync(probs_sort):
|
def multinomial_sample_one_no_sync(probs_sort):
|
||||||
# Does multinomial sampling without a cuda synchronization
|
# Does multinomial sampling without a cuda synchronization
|
||||||
q = torch.randn_like(probs_sort)
|
q = torch.randn_like(probs_sort)
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def sample(
|
def sample(
|
||||||
logits,
|
logits,
|
||||||
@ -100,7 +97,12 @@ def sample(
|
|||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
):
|
):
|
||||||
probs = logits_to_probs(
|
probs = logits_to_probs(
|
||||||
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty
|
logits=logits,
|
||||||
|
previous_tokens=previous_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
)
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
return idx_next, probs
|
return idx_next, probs
|
||||||
@ -158,6 +160,7 @@ class DictToAttrRecursive(dict):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise AttributeError(f"Attribute {item} not found")
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class T2SMLP:
|
class T2SMLP:
|
||||||
def __init__(self, w1, b1, w2, b2):
|
def __init__(self, w1, b1, w2, b2):
|
||||||
@ -171,6 +174,7 @@ class T2SMLP:
|
|||||||
x = F.linear(x, self.w2, self.b2)
|
x = F.linear(x, self.w2, self.b2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class T2SBlock:
|
class T2SBlock:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -245,9 +249,7 @@ class T2SBlock:
|
|||||||
x_item = x[i, idx, :].unsqueeze(0)
|
x_item = x[i, idx, :].unsqueeze(0)
|
||||||
attn_item = attn[i, idx, :].unsqueeze(0)
|
attn_item = attn[i, idx, :].unsqueeze(0)
|
||||||
x_item = x_item + attn_item
|
x_item = x_item + attn_item
|
||||||
x_item = F.layer_norm(
|
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
|
||||||
)
|
|
||||||
x_item = x_item + self.mlp.forward(x_item)
|
x_item = x_item + self.mlp.forward(x_item)
|
||||||
x_item = F.layer_norm(
|
x_item = F.layer_norm(
|
||||||
x_item,
|
x_item,
|
||||||
@ -260,9 +262,7 @@ class T2SBlock:
|
|||||||
x = self.to_mask(x, padding_mask)
|
x = self.to_mask(x, padding_mask)
|
||||||
else:
|
else:
|
||||||
x = x + attn
|
x = x + attn
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
|
||||||
)
|
|
||||||
x = x + self.mlp.forward(x)
|
x = x + self.mlp.forward(x)
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(
|
||||||
x,
|
x,
|
||||||
@ -294,9 +294,7 @@ class T2SBlock:
|
|||||||
attn = F.linear(attn, self.out_w, self.out_b)
|
attn = F.linear(attn, self.out_w, self.out_b)
|
||||||
|
|
||||||
x = x + attn
|
x = x + attn
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
|
||||||
)
|
|
||||||
x = x + self.mlp.forward(x)
|
x = x + self.mlp.forward(x)
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(
|
||||||
x,
|
x,
|
||||||
@ -307,14 +305,14 @@ class T2SBlock:
|
|||||||
)
|
)
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class T2STransformer:
|
class T2STransformer:
|
||||||
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
||||||
self.num_blocks: int = num_blocks
|
self.num_blocks: int = num_blocks
|
||||||
self.blocks = blocks
|
self.blocks = blocks
|
||||||
|
|
||||||
def process_prompt(
|
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||||
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
|
|
||||||
k_cache: list[torch.Tensor] = []
|
k_cache: list[torch.Tensor] = []
|
||||||
v_cache: list[torch.Tensor] = []
|
v_cache: list[torch.Tensor] = []
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
@ -323,21 +321,19 @@ class T2STransformer:
|
|||||||
v_cache.append(v_cache_)
|
v_cache.append(v_cache_)
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(
|
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
|
||||||
self, x:torch.Tensor,
|
|
||||||
k_cache: list[torch.Tensor],
|
|
||||||
v_cache: list[torch.Tensor]):
|
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
class VitsModel(nn.Module):
|
class VitsModel(nn.Module):
|
||||||
def __init__(self, vits_path):
|
def __init__(self, vits_path):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||||
dict_s2 = torch.load(vits_path)
|
dict_s2 = torch.load(vits_path)
|
||||||
self.hps = dict_s2["config"]
|
self.hps = dict_s2["config"]
|
||||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||||
self.hps["model"]["version"] = "v1"
|
self.hps["model"]["version"] = "v1"
|
||||||
else:
|
else:
|
||||||
self.hps["model"]["version"] = "v2"
|
self.hps["model"]["version"] = "v2"
|
||||||
@ -348,7 +344,7 @@ class VitsModel(nn.Module):
|
|||||||
self.hps.data.filter_length // 2 + 1,
|
self.hps.data.filter_length // 2 + 1,
|
||||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||||
n_speakers=self.hps.data.n_speakers,
|
n_speakers=self.hps.data.n_speakers,
|
||||||
**self.hps.model
|
**self.hps.model,
|
||||||
)
|
)
|
||||||
self.vq_model.eval()
|
self.vq_model.eval()
|
||||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
@ -360,10 +356,11 @@ class VitsModel(nn.Module):
|
|||||||
self.hps.data.sampling_rate,
|
self.hps.data.sampling_rate,
|
||||||
self.hps.data.hop_length,
|
self.hps.data.hop_length,
|
||||||
self.hps.data.win_length,
|
self.hps.data.win_length,
|
||||||
center=False
|
center=False,
|
||||||
)
|
)
|
||||||
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
class T2SModel(nn.Module):
|
class T2SModel(nn.Module):
|
||||||
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
||||||
super(T2SModel, self).__init__()
|
super(T2SModel, self).__init__()
|
||||||
@ -393,12 +390,7 @@ class T2SModel(nn.Module):
|
|||||||
|
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
layer = h.layers[i]
|
layer = h.layers[i]
|
||||||
t2smlp = T2SMLP(
|
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
|
||||||
layer.linear1.weight,
|
|
||||||
layer.linear1.bias,
|
|
||||||
layer.linear2.weight,
|
|
||||||
layer.linear2.bias
|
|
||||||
)
|
|
||||||
|
|
||||||
block = T2SBlock(
|
block = T2SBlock(
|
||||||
self.num_head,
|
self.num_head,
|
||||||
@ -413,7 +405,7 @@ class T2SModel(nn.Module):
|
|||||||
layer.norm1.eps,
|
layer.norm1.eps,
|
||||||
layer.norm2.weight,
|
layer.norm2.weight,
|
||||||
layer.norm2.bias,
|
layer.norm2.bias,
|
||||||
layer.norm2.eps
|
layer.norm2.eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
@ -427,7 +419,15 @@ class T2SModel(nn.Module):
|
|||||||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||||
|
|
||||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor,top_k:LongTensor):
|
def forward(
|
||||||
|
self,
|
||||||
|
prompts: LongTensor,
|
||||||
|
ref_seq: LongTensor,
|
||||||
|
text_seq: LongTensor,
|
||||||
|
ref_bert: torch.Tensor,
|
||||||
|
text_bert: torch.Tensor,
|
||||||
|
top_k: LongTensor,
|
||||||
|
):
|
||||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
@ -438,7 +438,6 @@ class T2SModel(nn.Module):
|
|||||||
|
|
||||||
early_stop_num = self.early_stop_num
|
early_stop_num = self.early_stop_num
|
||||||
|
|
||||||
|
|
||||||
# [1,N,512] [1,N]
|
# [1,N,512] [1,N]
|
||||||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
y = prompts
|
y = prompts
|
||||||
@ -465,11 +464,13 @@ class T2SModel(nn.Module):
|
|||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
xy_attn_mask = (
|
||||||
.unsqueeze(0)\
|
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
.expand(bsz*self.num_head, -1, -1)\
|
.unsqueeze(0)
|
||||||
.view(bsz, self.num_head, src_len, src_len)\
|
.expand(bsz * self.num_head, -1, -1)
|
||||||
|
.view(bsz, self.num_head, src_len, src_len)
|
||||||
.to(device=x.device, dtype=torch.bool)
|
.to(device=x.device, dtype=torch.bool)
|
||||||
|
)
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
top_k = int(top_k)
|
top_k = int(top_k)
|
||||||
@ -481,7 +482,9 @@ class T2SModel(nn.Module):
|
|||||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
stop = False
|
stop = False
|
||||||
# for idx in range(1, 50):
|
# for idx in range(1, 50):
|
||||||
@ -491,7 +494,7 @@ class T2SModel(nn.Module):
|
|||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
|
||||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||||
logits = logits[:, :-1]
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||||
@ -508,18 +511,20 @@ class T2SModel(nn.Module):
|
|||||||
break
|
break
|
||||||
|
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
y[0, -1] = 0
|
y[0, -1] = 0
|
||||||
|
|
||||||
return y[:, -idx:].unsqueeze(0)
|
return y[:, -idx:].unsqueeze(0)
|
||||||
|
|
||||||
bert_path = os.environ.get(
|
|
||||||
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
||||||
)
|
|
||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
||||||
phone_level_feature = []
|
phone_level_feature = []
|
||||||
@ -530,17 +535,21 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
|||||||
# [sum(word2ph), 1024]
|
# [sum(word2ph), 1024]
|
||||||
return phone_level_feature
|
return phone_level_feature
|
||||||
|
|
||||||
|
|
||||||
class MyBertModel(torch.nn.Module):
|
class MyBertModel(torch.nn.Module):
|
||||||
def __init__(self, bert_model):
|
def __init__(self, bert_model):
|
||||||
super(MyBertModel, self).__init__()
|
super(MyBertModel, self).__init__()
|
||||||
self.bert = bert_model
|
self.bert = bert_model
|
||||||
|
|
||||||
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
|
def forward(
|
||||||
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
|
||||||
|
):
|
||||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||||
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||||||
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
||||||
return build_phone_level_feature(res, word2ph)
|
return build_phone_level_feature(res, word2ph)
|
||||||
|
|
||||||
|
|
||||||
class SSLModel(torch.nn.Module):
|
class SSLModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -550,6 +559,7 @@ class SSLModel(torch.nn.Module):
|
|||||||
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
return ssl_content
|
return ssl_content
|
||||||
|
|
||||||
|
|
||||||
class ExportSSLModel(torch.nn.Module):
|
class ExportSSLModel(torch.nn.Module):
|
||||||
def __init__(self, ssl: SSLModel):
|
def __init__(self, ssl: SSLModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -563,6 +573,7 @@ class ExportSSLModel(torch.nn.Module):
|
|||||||
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
def export_bert(output_path):
|
def export_bert(output_path):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
|
|
||||||
@ -570,33 +581,34 @@ def export_bert(output_path):
|
|||||||
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
||||||
word2ph = []
|
word2ph = []
|
||||||
for c in text:
|
for c in text:
|
||||||
if c in [',','。',':','?',",",".","?"]:
|
if c in [",", "。", ":", "?", ",", ".", "?"]:
|
||||||
word2ph.append(1)
|
word2ph.append(1)
|
||||||
else:
|
else:
|
||||||
word2ph.append(2)
|
word2ph.append(2)
|
||||||
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int()
|
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
|
||||||
|
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
||||||
my_bert_model = MyBertModel(bert_model)
|
my_bert_model = MyBertModel(bert_model)
|
||||||
|
|
||||||
ref_bert_inputs = {
|
ref_bert_inputs = {
|
||||||
'input_ids': ref_bert_inputs['input_ids'],
|
"input_ids": ref_bert_inputs["input_ids"],
|
||||||
'attention_mask': ref_bert_inputs['attention_mask'],
|
"attention_mask": ref_bert_inputs["attention_mask"],
|
||||||
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
"token_type_ids": ref_bert_inputs["token_type_ids"],
|
||||||
'word2ph': ref_bert_inputs['word2ph']
|
"word2ph": ref_bert_inputs["word2ph"],
|
||||||
}
|
}
|
||||||
|
|
||||||
torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1)
|
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
|
||||||
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1)
|
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
|
||||||
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1)
|
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
|
||||||
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0)
|
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
|
||||||
|
|
||||||
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
||||||
output_path = os.path.join(output_path, "bert_model.pt")
|
output_path = os.path.join(output_path, "bert_model.pt")
|
||||||
my_bert_model.save(output_path)
|
my_bert_model.save(output_path)
|
||||||
print('#### exported bert ####')
|
print("#### exported bert ####")
|
||||||
|
|
||||||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
|
|
||||||
|
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
|
||||||
if not os.path.exists(output_path):
|
if not os.path.exists(output_path):
|
||||||
os.makedirs(output_path)
|
os.makedirs(output_path)
|
||||||
print(f"目录已创建: {output_path}")
|
print(f"目录已创建: {output_path}")
|
||||||
@ -609,18 +621,19 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
|||||||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||||
torch.jit.script(s).save(ssl_path)
|
torch.jit.script(s).save(ssl_path)
|
||||||
print('#### exported ssl ####')
|
print("#### exported ssl ####")
|
||||||
export_bert(output_path)
|
export_bert(output_path)
|
||||||
else:
|
else:
|
||||||
s = ExportSSLModel(ssl)
|
s = ExportSSLModel(ssl)
|
||||||
|
|
||||||
print(f"device: {device}")
|
print(f"device: {device}")
|
||||||
|
|
||||||
|
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
|
||||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||||
|
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||||
|
)
|
||||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||||
text_bert = text_bert_T.T.to(text_seq.device)
|
text_bert = text_bert_T.T.to(text_seq.device)
|
||||||
|
|
||||||
@ -634,12 +647,12 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
|||||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||||
dict_s1 = torch.load(gpt_path)
|
dict_s1 = torch.load(gpt_path)
|
||||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
print('#### get_raw_t2s_model ####')
|
print("#### get_raw_t2s_model ####")
|
||||||
print(raw_t2s.config)
|
print(raw_t2s.config)
|
||||||
t2s_m = T2SModel(raw_t2s)
|
t2s_m = T2SModel(raw_t2s)
|
||||||
t2s_m.eval()
|
t2s_m.eval()
|
||||||
t2s = torch.jit.script(t2s_m).to(device)
|
t2s = torch.jit.script(t2s_m).to(device)
|
||||||
print('#### script t2s_m ####')
|
print("#### script t2s_m ####")
|
||||||
|
|
||||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||||
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
||||||
@ -658,19 +671,13 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gpt_sovits_export = torch.jit.trace(
|
gpt_sovits_export = torch.jit.trace(
|
||||||
gpt_sovits,
|
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
example_inputs=(
|
)
|
||||||
ssl_content,
|
|
||||||
ref_audio_sr,
|
|
||||||
ref_seq,
|
|
||||||
text_seq,
|
|
||||||
ref_bert,
|
|
||||||
text_bert,
|
|
||||||
top_k))
|
|
||||||
|
|
||||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||||
gpt_sovits_export.save(gpt_sovits_path)
|
gpt_sovits_export.save(gpt_sovits_path)
|
||||||
print('#### exported gpt_sovits ####')
|
print("#### exported gpt_sovits ####")
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def parse_audio(ref_audio):
|
def parse_audio(ref_audio):
|
||||||
@ -678,10 +685,12 @@ def parse_audio(ref_audio):
|
|||||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
||||||
return ref_audio_16k, ref_audio_sr
|
return ref_audio_16k, ref_audio_sr
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||||
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
||||||
|
|
||||||
|
|
||||||
class GPT_SoVITS(nn.Module):
|
class GPT_SoVITS(nn.Module):
|
||||||
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -710,12 +719,11 @@ class GPT_SoVITS(nn.Module):
|
|||||||
|
|
||||||
def test():
|
def test():
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
gpt_path = args.gpt_model
|
gpt_path = args.gpt_model
|
||||||
@ -726,7 +734,7 @@ def test():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||||
# bert = MyBertModel(bert_model)
|
# bert = MyBertModel(bert_model)
|
||||||
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
|
||||||
|
|
||||||
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
||||||
# raw_t2s = get_raw_t2s_model(dict_s1)
|
# raw_t2s = get_raw_t2s_model(dict_s1)
|
||||||
@ -740,95 +748,97 @@ def test():
|
|||||||
|
|
||||||
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
||||||
# ssl.eval()
|
# ssl.eval()
|
||||||
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda')
|
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
|
||||||
|
|
||||||
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda')
|
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
|
||||||
|
|
||||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||||
ref_seq = torch.LongTensor([ref_seq_id])
|
ref_seq = torch.LongTensor([ref_seq_id])
|
||||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||||
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
||||||
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
||||||
|
|
||||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2')
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
|
||||||
|
|
||||||
test_bert = tokenizer(text, return_tensors="pt")
|
test_bert = tokenizer(text, return_tensors="pt")
|
||||||
word2ph = []
|
word2ph = []
|
||||||
for c in text:
|
for c in text:
|
||||||
if c in [',','。',':','?',"?",",","."]:
|
if c in [",", "。", ":", "?", "?", ",", "."]:
|
||||||
word2ph.append(1)
|
word2ph.append(1)
|
||||||
else:
|
else:
|
||||||
word2ph.append(2)
|
word2ph.append(2)
|
||||||
test_bert['word2ph'] = torch.Tensor(word2ph).int()
|
test_bert["word2ph"] = torch.Tensor(word2ph).int()
|
||||||
|
|
||||||
test_bert = my_bert(
|
test_bert = my_bert(
|
||||||
test_bert['input_ids'].to('cuda'),
|
test_bert["input_ids"].to("cuda"),
|
||||||
test_bert['attention_mask'].to('cuda'),
|
test_bert["attention_mask"].to("cuda"),
|
||||||
test_bert['token_type_ids'].to('cuda'),
|
test_bert["token_type_ids"].to("cuda"),
|
||||||
test_bert['word2ph'].to('cuda')
|
test_bert["word2ph"].to("cuda"),
|
||||||
)
|
)
|
||||||
|
|
||||||
text_seq = torch.LongTensor([text_seq_id])
|
text_seq = torch.LongTensor([text_seq_id])
|
||||||
text_bert = text_bert_T.T.to(text_seq.device)
|
text_bert = text_bert_T.T.to(text_seq.device)
|
||||||
|
|
||||||
print('text_bert:',text_bert.shape,text_bert)
|
print("text_bert:", text_bert.shape, text_bert)
|
||||||
print('test_bert:',test_bert.shape,test_bert)
|
print("test_bert:", test_bert.shape, test_bert)
|
||||||
print(torch.allclose(text_bert.to('cuda'),test_bert))
|
print(torch.allclose(text_bert.to("cuda"), test_bert))
|
||||||
|
|
||||||
print('text_seq:',text_seq.shape)
|
print("text_seq:", text_seq.shape)
|
||||||
print('text_bert:',text_bert.shape,text_bert.type())
|
print("text_bert:", text_bert.shape, text_bert.type())
|
||||||
|
|
||||||
# [1,N]
|
# [1,N]
|
||||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda')
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
|
||||||
print('ref_audio:',ref_audio.shape)
|
print("ref_audio:", ref_audio.shape)
|
||||||
|
|
||||||
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
||||||
print('start ssl')
|
print("start ssl")
|
||||||
ssl_content = ssl(ref_audio)
|
ssl_content = ssl(ref_audio)
|
||||||
|
|
||||||
print('start gpt_sovits:')
|
print("start gpt_sovits:")
|
||||||
print('ssl_content:',ssl_content.shape)
|
print("ssl_content:", ssl_content.shape)
|
||||||
print('ref_audio_sr:',ref_audio_sr.shape)
|
print("ref_audio_sr:", ref_audio_sr.shape)
|
||||||
print('ref_seq:',ref_seq.shape)
|
print("ref_seq:", ref_seq.shape)
|
||||||
ref_seq=ref_seq.to('cuda')
|
ref_seq = ref_seq.to("cuda")
|
||||||
print('text_seq:',text_seq.shape)
|
print("text_seq:", text_seq.shape)
|
||||||
text_seq=text_seq.to('cuda')
|
text_seq = text_seq.to("cuda")
|
||||||
print('ref_bert:',ref_bert.shape)
|
print("ref_bert:", ref_bert.shape)
|
||||||
ref_bert=ref_bert.to('cuda')
|
ref_bert = ref_bert.to("cuda")
|
||||||
print('text_bert:',text_bert.shape)
|
print("text_bert:", text_bert.shape)
|
||||||
text_bert=text_bert.to('cuda')
|
text_bert = text_bert.to("cuda")
|
||||||
|
|
||||||
top_k = torch.LongTensor([5]).to('cuda')
|
top_k = torch.LongTensor([5]).to("cuda")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||||
print('start write wav')
|
print("start write wav")
|
||||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
|
||||||
import text
|
import text
|
||||||
import json
|
import json
|
||||||
|
|
||||||
def export_symbel(version='v2'):
|
|
||||||
if version=='v1':
|
def export_symbel(version="v2"):
|
||||||
|
if version == "v1":
|
||||||
symbols = text._symbol_to_id_v1
|
symbols = text._symbol_to_id_v1
|
||||||
with open(f"onnx/symbols_v1.json", "w") as file:
|
with open("onnx/symbols_v1.json", "w") as file:
|
||||||
json.dump(symbols, file, indent=4)
|
json.dump(symbols, file, indent=4)
|
||||||
else:
|
else:
|
||||||
symbols = text._symbol_to_id_v2
|
symbols = text._symbol_to_id_v2
|
||||||
with open(f"onnx/symbols_v2.json", "w") as file:
|
with open("onnx/symbols_v2.json", "w") as file:
|
||||||
json.dump(symbols, file, indent=4)
|
json.dump(symbols, file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||||
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model")
|
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||||
parser.add_argument('--device', help="Device to use")
|
parser.add_argument("--device", help="Device to use")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
export(
|
export(
|
||||||
@ -841,7 +851,9 @@ def main():
|
|||||||
export_bert_and_ssl=args.export_common_model,
|
export_bert_and_ssl=args.export_common_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
import inference_webui
|
import inference_webui
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
inference_webui.is_half = False
|
inference_webui.is_half = False
|
||||||
inference_webui.dtype = torch.float32
|
inference_webui.dtype = torch.float32
|
||||||
|
@ -6,16 +6,16 @@ from export_torch_script import (
|
|||||||
spectrogram_torch,
|
spectrogram_torch,
|
||||||
)
|
)
|
||||||
from f5_tts.model.backbones.dit import DiT
|
from f5_tts.model.backbones.dit import DiT
|
||||||
from feature_extractor import cnhubert
|
|
||||||
from inference_webui import get_phones_and_bert
|
from inference_webui import get_phones_and_bert
|
||||||
import librosa
|
import librosa
|
||||||
from module import commons
|
from module import commons
|
||||||
from module.mel_processing import mel_spectrogram_torch, spectral_normalize_torch
|
from module.mel_processing import mel_spectrogram_torch
|
||||||
from module.models_onnx import CFM, SynthesizerTrnV3
|
from module.models_onnx import CFM, SynthesizerTrnV3
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch._dynamo.config
|
import torch._dynamo.config
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import logging, uvicorn
|
import logging
|
||||||
|
import uvicorn
|
||||||
import torch
|
import torch
|
||||||
import soundfile
|
import soundfile
|
||||||
from librosa.filters import mel as librosa_mel_fn
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
@ -32,7 +32,6 @@ now_dir = os.getcwd()
|
|||||||
|
|
||||||
|
|
||||||
class MelSpectrgram(torch.nn.Module):
|
class MelSpectrgram(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dtype,
|
dtype,
|
||||||
@ -48,9 +47,7 @@ class MelSpectrgram(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
||||||
mel = librosa_mel_fn(
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
|
||||||
)
|
|
||||||
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
||||||
self.n_fft: int = n_fft
|
self.n_fft: int = n_fft
|
||||||
self.hop_size: int = hop_size
|
self.hop_size: int = hop_size
|
||||||
@ -172,9 +169,7 @@ class ExportCFM(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
T_min = fea_ref.size(2)
|
T_min = fea_ref.size(2)
|
||||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||||
cfm_res = self.cfm(
|
cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
|
||||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps
|
|
||||||
)
|
|
||||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||||
mel2 = cfm_res[:, :, -T_min:]
|
mel2 = cfm_res[:, :, -T_min:]
|
||||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||||
@ -198,6 +193,7 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
|||||||
spec_min = -12
|
spec_min = -12
|
||||||
spec_max = 2
|
spec_max = 2
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def norm_spec(x):
|
def norm_spec(x):
|
||||||
spec_min = -12
|
spec_min = -12
|
||||||
@ -212,7 +208,6 @@ def denorm_spec(x):
|
|||||||
|
|
||||||
|
|
||||||
class ExportGPTSovitsHalf(torch.nn.Module):
|
class ExportGPTSovitsHalf(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hps = hps
|
self.hps = hps
|
||||||
@ -255,18 +250,14 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
|||||||
center=False,
|
center=False,
|
||||||
).to(ssl_content.dtype)
|
).to(ssl_content.dtype)
|
||||||
|
|
||||||
|
|
||||||
codes = self.vq_model.extract_latent(ssl_content)
|
codes = self.vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
pred_semantic = self.t2s_m(
|
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
||||||
prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
|
||||||
)
|
|
||||||
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
|
|
||||||
ge = self.vq_model.create_ge(refer)
|
ge = self.vq_model.create_ge(refer)
|
||||||
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
|
|
||||||
@ -293,6 +284,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
|||||||
|
|
||||||
return fea_ref, fea_todo, mel2
|
return fea_ref, fea_todo, mel2
|
||||||
|
|
||||||
|
|
||||||
class GPTSoVITSV3(torch.nn.Module):
|
class GPTSoVITSV3(torch.nn.Module):
|
||||||
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -313,7 +305,9 @@ class GPTSoVITSV3(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
# current_time = datetime.now()
|
# current_time = datetime.now()
|
||||||
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
||||||
|
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||||
|
)
|
||||||
chunk_len = 934 - fea_ref.shape[2]
|
chunk_len = 934 - fea_ref.shape[2]
|
||||||
wav_gen_list = []
|
wav_gen_list = []
|
||||||
idx = 0
|
idx = 0
|
||||||
@ -331,7 +325,13 @@ class GPTSoVITSV3(torch.nn.Module):
|
|||||||
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
||||||
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
||||||
if complete_len != 0:
|
if complete_len != 0:
|
||||||
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype)], 2)
|
fea_todo_chunk = torch.cat(
|
||||||
|
[
|
||||||
|
fea_todo_chunk,
|
||||||
|
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
||||||
|
],
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||||
idx += chunk_len
|
idx += chunk_len
|
||||||
@ -343,13 +343,13 @@ class GPTSoVITSV3(torch.nn.Module):
|
|||||||
wav_gen = torch.cat(wav_gen_list, 2)
|
wav_gen = torch.cat(wav_gen_list, 2)
|
||||||
return wav_gen[0][0][:wav_gen_length]
|
return wav_gen[0][0][:wav_gen_length]
|
||||||
|
|
||||||
|
|
||||||
def init_bigvgan():
|
def init_bigvgan():
|
||||||
global bigvgan_model
|
global bigvgan_model
|
||||||
from BigVGAN import bigvgan
|
from BigVGAN import bigvgan
|
||||||
|
|
||||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||||
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x"
|
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
||||||
% (now_dir,),
|
|
||||||
use_cuda_kernel=False,
|
use_cuda_kernel=False,
|
||||||
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
||||||
# remove weight norm in the model and set to eval mode
|
# remove weight norm in the model and set to eval mode
|
||||||
@ -467,10 +467,7 @@ def export_cfm(
|
|||||||
cfm = e_cfm.cfm
|
cfm = e_cfm.cfm
|
||||||
|
|
||||||
B, T = mu.size(0), mu.size(1)
|
B, T = mu.size(0), mu.size(1)
|
||||||
x = (
|
x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
||||||
torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype)
|
|
||||||
* temperature
|
|
||||||
)
|
|
||||||
print("x:", x.shape, x.dtype)
|
print("x:", x.shape, x.dtype)
|
||||||
prompt_len = prompt.size(-1)
|
prompt_len = prompt.size(-1)
|
||||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||||
@ -565,11 +562,7 @@ def export():
|
|||||||
wav16k = wav16k.to(device)
|
wav16k = wav16k.to(device)
|
||||||
zero_wav_torch = zero_wav_torch.to(device)
|
zero_wav_torch = zero_wav_torch.to(device)
|
||||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||||
"last_hidden_state"
|
|
||||||
].transpose(
|
|
||||||
1, 2
|
|
||||||
) # .float()
|
|
||||||
codes = sovits.vq_model.extract_latent(ssl_content)
|
codes = sovits.vq_model.extract_latent(ssl_content)
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
@ -626,10 +619,7 @@ def export():
|
|||||||
"create_ge": refer,
|
"create_ge": refer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
|
||||||
trace_vq_model = torch.jit.trace_module(
|
|
||||||
sovits.vq_model, inputs, optimize=True
|
|
||||||
)
|
|
||||||
trace_vq_model.save("onnx/ad/vq_model.pt")
|
trace_vq_model.save("onnx/ad/vq_model.pt")
|
||||||
|
|
||||||
print(fea_ref.shape, fea_ref.dtype, ge.shape)
|
print(fea_ref.shape, fea_ref.dtype, ge.shape)
|
||||||
@ -714,9 +704,7 @@ def export():
|
|||||||
|
|
||||||
idx += chunk_len
|
idx += chunk_len
|
||||||
|
|
||||||
cfm_res, fea_ref, mel2 = export_cfm_(
|
cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
||||||
fea_ref, fea_todo_chunk, mel2, sample_steps
|
|
||||||
)
|
|
||||||
cfm_resss.append(cfm_res)
|
cfm_resss.append(cfm_res)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -726,9 +714,7 @@ def export():
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
||||||
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
||||||
bigvgan_model_ = torch.jit.trace(
|
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||||
bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)
|
|
||||||
)
|
|
||||||
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
||||||
wav_gen = bigvgan_model(cmf_res)
|
wav_gen = bigvgan_model(cmf_res)
|
||||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||||
@ -748,7 +734,6 @@ def test_export(
|
|||||||
bigvgan,
|
bigvgan,
|
||||||
output,
|
output,
|
||||||
):
|
):
|
||||||
|
|
||||||
# hps = sovits.hps
|
# hps = sovits.hps
|
||||||
ref_wav_path = "onnx/ad/ref.wav"
|
ref_wav_path = "onnx/ad/ref.wav"
|
||||||
speed = 1.0
|
speed = 1.0
|
||||||
@ -773,11 +758,7 @@ def test_export(
|
|||||||
wav16k = wav16k.to(device)
|
wav16k = wav16k.to(device)
|
||||||
zero_wav_torch = zero_wav_torch.to(device)
|
zero_wav_torch = zero_wav_torch.to(device)
|
||||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||||
"last_hidden_state"
|
|
||||||
].transpose(
|
|
||||||
1, 2
|
|
||||||
) # .float()
|
|
||||||
|
|
||||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||||
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
||||||
@ -799,8 +780,18 @@ def test_export(
|
|||||||
|
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
logger.info("start inference %s", current_time)
|
logger.info("start inference %s", current_time)
|
||||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
print(
|
||||||
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
ssl_content.shape,
|
||||||
|
ref_audio_32k.shape,
|
||||||
|
phoneme_ids0.shape,
|
||||||
|
phoneme_ids1.shape,
|
||||||
|
bert1.shape,
|
||||||
|
bert2.shape,
|
||||||
|
top_k.shape,
|
||||||
|
)
|
||||||
|
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
|
||||||
|
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
||||||
|
)
|
||||||
chunk_len = 934 - fea_ref.shape[2]
|
chunk_len = 934 - fea_ref.shape[2]
|
||||||
print(fea_ref.shape, fea_todo.shape, mel2.shape)
|
print(fea_ref.shape, fea_todo.shape, mel2.shape)
|
||||||
|
|
||||||
@ -812,7 +803,6 @@ def test_export(
|
|||||||
wav_gen_length = fea_todo.shape[2] * 256
|
wav_gen_length = fea_todo.shape[2] * 256
|
||||||
|
|
||||||
while 1:
|
while 1:
|
||||||
|
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
||||||
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
||||||
@ -861,7 +851,6 @@ def test_export1(
|
|||||||
gpt_sovits_v3,
|
gpt_sovits_v3,
|
||||||
output,
|
output,
|
||||||
):
|
):
|
||||||
|
|
||||||
# hps = sovits.hps
|
# hps = sovits.hps
|
||||||
ref_wav_path = "onnx/ad/ref.wav"
|
ref_wav_path = "onnx/ad/ref.wav"
|
||||||
speed = 1.0
|
speed = 1.0
|
||||||
@ -886,11 +875,7 @@ def test_export1(
|
|||||||
wav16k = wav16k.to(device)
|
wav16k = wav16k.to(device)
|
||||||
zero_wav_torch = zero_wav_torch.to(device)
|
zero_wav_torch = zero_wav_torch.to(device)
|
||||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||||
"last_hidden_state"
|
|
||||||
].transpose(
|
|
||||||
1, 2
|
|
||||||
) # .float()
|
|
||||||
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
|
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
|
||||||
|
|
||||||
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
||||||
@ -913,7 +898,15 @@ def test_export1(
|
|||||||
|
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
logger.info("start inference %s", current_time)
|
logger.info("start inference %s", current_time)
|
||||||
print(ssl_content.shape, ref_audio_32k.shape, phoneme_ids0.shape, phoneme_ids1.shape, bert1.shape, bert2.shape, top_k.shape)
|
print(
|
||||||
|
ssl_content.shape,
|
||||||
|
ref_audio_32k.shape,
|
||||||
|
phoneme_ids0.shape,
|
||||||
|
phoneme_ids1.shape,
|
||||||
|
bert1.shape,
|
||||||
|
bert2.shape,
|
||||||
|
top_k.shape,
|
||||||
|
)
|
||||||
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
||||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||||
|
|
||||||
@ -929,7 +922,6 @@ import time
|
|||||||
|
|
||||||
|
|
||||||
def test_():
|
def test_():
|
||||||
|
|
||||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||||
|
|
||||||
# cfm = ExportCFM(sovits.cfm)
|
# cfm = ExportCFM(sovits.cfm)
|
||||||
@ -942,7 +934,7 @@ def test_():
|
|||||||
|
|
||||||
cfm.eval()
|
cfm.eval()
|
||||||
|
|
||||||
logger.info(f"cfm ok")
|
logger.info("cfm ok")
|
||||||
|
|
||||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||||
# v2 的 gpt 也可以用
|
# v2 的 gpt 也可以用
|
||||||
@ -957,17 +949,14 @@ def test_():
|
|||||||
t2s_m = torch.jit.script(t2s_m)
|
t2s_m = torch.jit.script(t2s_m)
|
||||||
t2s_m.eval()
|
t2s_m.eval()
|
||||||
# t2s_m.top_k = 15
|
# t2s_m.top_k = 15
|
||||||
logger.info(f"t2s_m ok")
|
logger.info("t2s_m ok")
|
||||||
|
|
||||||
|
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
|
||||||
vq_model: torch.jit.ScriptModule = torch.jit.load(
|
|
||||||
"onnx/ad/vq_model.pt", map_location=device
|
|
||||||
)
|
|
||||||
# vq_model = torch.jit.optimize_for_inference(vq_model)
|
# vq_model = torch.jit.optimize_for_inference(vq_model)
|
||||||
# vq_model = vq_model.half().to(device)
|
# vq_model = vq_model.half().to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
# vq_model = sovits.vq_model
|
# vq_model = sovits.vq_model
|
||||||
logger.info(f"vq_model ok")
|
logger.info("vq_model ok")
|
||||||
|
|
||||||
# gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
|
# gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
|
||||||
# gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
|
# gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
|
||||||
@ -975,7 +964,7 @@ def test_():
|
|||||||
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
||||||
# gpt_sovits_v3_half.eval()
|
# gpt_sovits_v3_half.eval()
|
||||||
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
||||||
logger.info(f"gpt_sovits_v3_half ok")
|
logger.info("gpt_sovits_v3_half ok")
|
||||||
|
|
||||||
# init_bigvgan()
|
# init_bigvgan()
|
||||||
# global bigvgan_model
|
# global bigvgan_model
|
||||||
@ -985,7 +974,7 @@ def test_():
|
|||||||
bigvgan_model = bigvgan_model.cuda()
|
bigvgan_model = bigvgan_model.cuda()
|
||||||
bigvgan_model.eval()
|
bigvgan_model.eval()
|
||||||
|
|
||||||
logger.info(f"bigvgan ok")
|
logger.info("bigvgan ok")
|
||||||
|
|
||||||
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
||||||
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
||||||
@ -1020,6 +1009,7 @@ def test_():
|
|||||||
# "out2.wav",
|
# "out2.wav",
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
def test_export_gpt_sovits_v3():
|
def test_export_gpt_sovits_v3():
|
||||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
||||||
# test_export1(
|
# test_export1(
|
||||||
|
@ -11,7 +11,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from x_transformers.x_transformers import RotaryEmbedding
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
@ -28,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
|
|||||||
|
|
||||||
from module.commons import sequence_mask
|
from module.commons import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class TextEmbedding(nn.Module):
|
class TextEmbedding(nn.Module):
|
||||||
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -143,9 +143,7 @@ class DiT(nn.Module):
|
|||||||
drop_audio_cond=False, # cfg for cond audio
|
drop_audio_cond=False, # cfg for cond audio
|
||||||
drop_text=False, # cfg for text
|
drop_text=False, # cfg for text
|
||||||
# mask: bool["b n"] | None = None, # noqa: F722
|
# mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
|
||||||
):
|
):
|
||||||
|
|
||||||
x = x0.transpose(2, 1)
|
x = x0.transpose(2, 1)
|
||||||
cond = cond0.transpose(2, 1)
|
cond = cond0.transpose(2, 1)
|
||||||
text = text0.transpose(2, 1)
|
text = text0.transpose(2, 1)
|
||||||
|
@ -391,6 +391,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# Attention processor
|
# Attention processor
|
||||||
|
|
||||||
|
|
||||||
# from torch.nn.attention import SDPBackend
|
# from torch.nn.attention import SDPBackend
|
||||||
# torch.backends.cuda.enable_flash_sdp(True)
|
# torch.backends.cuda.enable_flash_sdp(True)
|
||||||
class AttnProcessor:
|
class AttnProcessor:
|
||||||
@ -545,6 +546,7 @@ class JointAttnProcessor:
|
|||||||
|
|
||||||
# DiT Block
|
# DiT Block
|
||||||
|
|
||||||
|
|
||||||
class DiTBlock(nn.Module):
|
class DiTBlock(nn.Module):
|
||||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
from . import cnhubert, whisper_enc
|
from . import cnhubert, whisper_enc
|
||||||
|
|
||||||
content_module_map = {
|
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}
|
||||||
'cnhubert': cnhubert,
|
|
||||||
'whisper': whisper_enc
|
|
||||||
}
|
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import soundfile as sf
|
|
||||||
import os
|
import os
|
||||||
from transformers import logging as tf_logging
|
from transformers import logging as tf_logging
|
||||||
|
|
||||||
tf_logging.set_verbosity_error()
|
tf_logging.set_verbosity_error()
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -27,17 +24,15 @@ class CNHubert(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
if base_path is None:
|
if base_path is None:
|
||||||
base_path = cnhubert_base_path
|
base_path = cnhubert_base_path
|
||||||
if os.path.exists(base_path):...
|
if os.path.exists(base_path):
|
||||||
else:raise FileNotFoundError(base_path)
|
...
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(base_path)
|
||||||
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
||||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
|
||||||
base_path, local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
input_values = self.feature_extractor(
|
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||||
x, return_tensors="pt", sampling_rate=16000
|
|
||||||
).input_values.to(x.device)
|
|
||||||
feats = self.model(input_values)["last_hidden_state"]
|
feats = self.model(input_values)["last_hidden_state"]
|
||||||
return feats
|
return feats
|
||||||
|
|
||||||
|
@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
|
|||||||
feature_len = mel.shape[-1] // 2
|
feature_len = mel.shape[-1] // 2
|
||||||
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
|
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
|
||||||
:1, :feature_len, :
|
|
||||||
].transpose(1, 2)
|
|
||||||
return feature
|
return feature
|
||||||
|
@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
|||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
|
|
||||||
|
def synthesize(
|
||||||
|
GPT_model_path,
|
||||||
|
SoVITS_model_path,
|
||||||
|
ref_audio_path,
|
||||||
|
ref_text_path,
|
||||||
|
ref_language,
|
||||||
|
target_text_path,
|
||||||
|
target_language,
|
||||||
|
output_path,
|
||||||
|
):
|
||||||
# Read reference text
|
# Read reference text
|
||||||
with open(ref_text_path, 'r', encoding='utf-8') as file:
|
with open(ref_text_path, "r", encoding="utf-8") as file:
|
||||||
ref_text = file.read()
|
ref_text = file.read()
|
||||||
|
|
||||||
# Read target text
|
# Read target text
|
||||||
with open(target_text_path, 'r', encoding='utf-8') as file:
|
with open(target_text_path, "r", encoding="utf-8") as file:
|
||||||
target_text = file.read()
|
target_text = file.read()
|
||||||
|
|
||||||
# Change model weights
|
# Change model weights
|
||||||
@ -21,11 +31,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
|||||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||||
|
|
||||||
# Synthesize audio
|
# Synthesize audio
|
||||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
synthesis_result = get_tts_wav(
|
||||||
|
ref_wav_path=ref_audio_path,
|
||||||
prompt_text=ref_text,
|
prompt_text=ref_text,
|
||||||
prompt_language=i18n(ref_language),
|
prompt_language=i18n(ref_language),
|
||||||
text=target_text,
|
text=target_text,
|
||||||
text_language=i18n(target_language), top_p=1, temperature=1)
|
text_language=i18n(target_language),
|
||||||
|
top_p=1,
|
||||||
|
temperature=1,
|
||||||
|
)
|
||||||
|
|
||||||
result_list = list(synthesis_result)
|
result_list = list(synthesis_result)
|
||||||
|
|
||||||
@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
|||||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||||
print(f"Audio saved to {output_wav_path}")
|
print(f"Audio saved to {output_wav_path}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||||
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
|
parser.add_argument(
|
||||||
parser.add_argument('--target_text', required=True, help="Path to the target text file")
|
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
||||||
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
|
)
|
||||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_language",
|
||||||
|
required=True,
|
||||||
|
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
||||||
|
help="Language of the target text",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
|
synthesize(
|
||||||
|
args.gpt_model,
|
||||||
|
args.sovits_model,
|
||||||
|
args.ref_audio,
|
||||||
|
args.ref_text,
|
||||||
|
args.ref_language,
|
||||||
|
args.target_text,
|
||||||
|
args.target_language,
|
||||||
|
args.output_path,
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||||
@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.setWindowTitle('GPT-SoVITS GUI')
|
self.setWindowTitle("GPT-SoVITS GUI")
|
||||||
self.setGeometry(800, 450, 950, 850)
|
self.setGeometry(800, 450, 950, 850)
|
||||||
|
|
||||||
self.setStyleSheet("""
|
self.setStyleSheet("""
|
||||||
@ -65,7 +66,8 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
|
|
||||||
license_text = (
|
license_text = (
|
||||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
||||||
|
)
|
||||||
license_label = QLabel(license_text)
|
license_label = QLabel(license_text)
|
||||||
license_label.setWordWrap(True)
|
license_label.setWordWrap(True)
|
||||||
|
|
||||||
@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
self.output_text = QTextEdit()
|
self.output_text = QTextEdit()
|
||||||
self.output_text.setReadOnly(True)
|
self.output_text.setReadOnly(True)
|
||||||
|
|
||||||
self.add_drag_drop_events([
|
self.add_drag_drop_events(
|
||||||
|
[
|
||||||
self.GPT_model_input,
|
self.GPT_model_input,
|
||||||
self.SoVITS_model_input,
|
self.SoVITS_model_input,
|
||||||
self.ref_audio_input,
|
self.ref_audio_input,
|
||||||
self.ref_text_input,
|
self.ref_text_input,
|
||||||
self.target_text_input,
|
self.target_text_input,
|
||||||
self.output_input,
|
self.output_input,
|
||||||
])
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.synthesize_button = QPushButton("合成")
|
self.synthesize_button = QPushButton("合成")
|
||||||
self.synthesize_button.clicked.connect(self.synthesize)
|
self.synthesize_button.clicked.connect(self.synthesize)
|
||||||
@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
def upload_ref_text(self):
|
def upload_ref_text(self):
|
||||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||||
if file_path:
|
if file_path:
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
self.ref_text_input.setText(content)
|
self.ref_text_input.setText(content)
|
||||||
|
|
||||||
def upload_target_text(self):
|
def upload_target_text(self):
|
||||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||||
if file_path:
|
if file_path:
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
self.target_text_input.setText(content)
|
self.target_text_input.setText(content)
|
||||||
|
|
||||||
@ -284,11 +288,13 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||||
self.SoVITS_Path = SoVITS_model_path
|
self.SoVITS_Path = SoVITS_model_path
|
||||||
|
|
||||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
synthesis_result = get_tts_wav(
|
||||||
|
ref_wav_path=ref_audio_path,
|
||||||
prompt_text=ref_text,
|
prompt_text=ref_text,
|
||||||
prompt_language=language_combobox,
|
prompt_language=language_combobox,
|
||||||
text=target_text,
|
text=target_text,
|
||||||
text_language=target_language_combobox)
|
text_language=target_language_combobox,
|
||||||
|
)
|
||||||
|
|
||||||
result_list = list(synthesis_result)
|
result_list = list(synthesis_result)
|
||||||
|
|
||||||
@ -303,7 +309,7 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
self.output_text.append("处理结果:\n" + result)
|
self.output_text.append("处理结果:\n" + result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
app = QApplication(sys.argv)
|
app = QApplication(sys.argv)
|
||||||
mainWin = GPTSoVITSGUI()
|
mainWin = GPTSoVITSGUI()
|
||||||
mainWin.show()
|
mainWin.show()
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,14 +1,19 @@
|
|||||||
'''
|
"""
|
||||||
按中英混合识别
|
按中英混合识别
|
||||||
按日英混合识别
|
按日英混合识别
|
||||||
多语种启动切分识别语种
|
多语种启动切分识别语种
|
||||||
全部按中文识别
|
全部按中文识别
|
||||||
全部按英文识别
|
全部按英文识别
|
||||||
全部按日文识别
|
全部按日文识别
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import os, re, logging, json
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||||
@ -20,13 +25,14 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
|
|||||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||||
import pdb
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import gradio.analytics as analytics
|
import gradio.analytics as analytics
|
||||||
|
|
||||||
analytics.version_check = lambda: None
|
analytics.version_check = lambda: None
|
||||||
except:...
|
except:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||||
@ -44,10 +50,10 @@ bert_path = os.environ.get("bert_path", None)
|
|||||||
version = model_version = os.environ.get("version", "v2")
|
version = model_version = os.environ.get("version", "v2")
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from TTS_infer_pack.TTS import TTS, TTS_Config, NO_PROMPT_ERROR
|
|
||||||
from TTS_infer_pack.text_segmentation_method import get_method
|
from TTS_infer_pack.text_segmentation_method import get_method
|
||||||
|
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
from inference_webui import DictToAttrRecursive
|
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
@ -87,7 +93,7 @@ dict_language_v2 = {
|
|||||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||||
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
||||||
}
|
}
|
||||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||||
|
|
||||||
cut_method = {
|
cut_method = {
|
||||||
i18n("不切"): "cut0",
|
i18n("不切"): "cut0",
|
||||||
@ -117,19 +123,30 @@ gpt_path = tts_config.t2s_weights_path
|
|||||||
sovits_path = tts_config.vits_weights_path
|
sovits_path = tts_config.vits_weights_path
|
||||||
version = tts_config.version
|
version = tts_config.version
|
||||||
|
|
||||||
def inference(text, text_lang,
|
|
||||||
|
def inference(
|
||||||
|
text,
|
||||||
|
text_lang,
|
||||||
ref_audio_path,
|
ref_audio_path,
|
||||||
aux_ref_audio_paths,
|
aux_ref_audio_paths,
|
||||||
prompt_text,
|
prompt_text,
|
||||||
prompt_lang, top_k,
|
prompt_lang,
|
||||||
top_p, temperature,
|
top_k,
|
||||||
text_split_method, batch_size,
|
top_p,
|
||||||
speed_factor, ref_text_free,
|
temperature,
|
||||||
split_bucket,fragment_interval,
|
text_split_method,
|
||||||
seed, keep_random, parallel_infer,
|
batch_size,
|
||||||
repetition_penalty, sample_steps, super_sampling,
|
speed_factor,
|
||||||
|
ref_text_free,
|
||||||
|
split_bucket,
|
||||||
|
fragment_interval,
|
||||||
|
seed,
|
||||||
|
keep_random,
|
||||||
|
parallel_infer,
|
||||||
|
repetition_penalty,
|
||||||
|
sample_steps,
|
||||||
|
super_sampling,
|
||||||
):
|
):
|
||||||
|
|
||||||
seed = -1 if keep_random else seed
|
seed = -1 if keep_random else seed
|
||||||
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
||||||
inputs = {
|
inputs = {
|
||||||
@ -158,11 +175,12 @@ def inference(text, text_lang,
|
|||||||
for item in tts_pipeline.run(inputs):
|
for item in tts_pipeline.run(inputs):
|
||||||
yield item, actual_seed
|
yield item, actual_seed
|
||||||
except NO_PROMPT_ERROR:
|
except NO_PROMPT_ERROR:
|
||||||
gr.Warning(i18n('V3不支持无参考文本模式,请填写参考文本!'))
|
gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
|
||||||
|
|
||||||
|
|
||||||
def custom_sort_key(s):
|
def custom_sort_key(s):
|
||||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||||
parts = re.split('(\d+)', s)
|
parts = re.split("(\d+)", s)
|
||||||
# 将数字部分转换为整数,非数字部分保持不变
|
# 将数字部分转换为整数,非数字部分保持不变
|
||||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||||
return parts
|
return parts
|
||||||
@ -170,59 +188,76 @@ def custom_sort_key(s):
|
|||||||
|
|
||||||
def change_choices():
|
def change_choices():
|
||||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
|
||||||
|
"choices": sorted(GPT_names, key=custom_sort_key),
|
||||||
|
"__type__": "update",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
|
pretrained_sovits_name = [
|
||||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
|
"GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||||
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||||
|
path_sovits_v3,
|
||||||
|
]
|
||||||
|
pretrained_gpt_name = [
|
||||||
|
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||||
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||||
|
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
|
]
|
||||||
|
|
||||||
_ = [[], []]
|
_ = [[], []]
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
|
if os.path.exists(pretrained_gpt_name[i]):
|
||||||
if os.path.exists(pretrained_sovits_name[i]):_[-1].append(pretrained_sovits_name[i])
|
_[0].append(pretrained_gpt_name[i])
|
||||||
|
if os.path.exists(pretrained_sovits_name[i]):
|
||||||
|
_[-1].append(pretrained_sovits_name[i])
|
||||||
pretrained_gpt_name, pretrained_sovits_name = _
|
pretrained_gpt_name, pretrained_sovits_name = _
|
||||||
|
|
||||||
|
|
||||||
if os.path.exists(f"./weight.json"):
|
if os.path.exists("./weight.json"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file)
|
with open("./weight.json", "w", encoding="utf-8") as file:
|
||||||
|
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
||||||
|
|
||||||
with open(f"./weight.json", 'r', encoding="utf-8") as file:
|
with open("./weight.json", "r", encoding="utf-8") as file:
|
||||||
weight_data = file.read()
|
weight_data = file.read()
|
||||||
weight_data = json.loads(weight_data)
|
weight_data = json.loads(weight_data)
|
||||||
gpt_path = os.environ.get(
|
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
|
||||||
"gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name))
|
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
|
||||||
sovits_path = os.environ.get(
|
|
||||||
"sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name))
|
|
||||||
if isinstance(gpt_path, list):
|
if isinstance(gpt_path, list):
|
||||||
gpt_path = gpt_path[0]
|
gpt_path = gpt_path[0]
|
||||||
if isinstance(sovits_path, list):
|
if isinstance(sovits_path, list):
|
||||||
sovits_path = sovits_path[0]
|
sovits_path = sovits_path[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
|
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
|
||||||
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
|
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
|
||||||
for path in SoVITS_weight_root + GPT_weight_root:
|
for path in SoVITS_weight_root + GPT_weight_root:
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||||
SoVITS_names = [i for i in pretrained_sovits_name]
|
SoVITS_names = [i for i in pretrained_sovits_name]
|
||||||
for path in SoVITS_weight_root:
|
for path in SoVITS_weight_root:
|
||||||
for name in os.listdir(path):
|
for name in os.listdir(path):
|
||||||
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
|
if name.endswith(".pth"):
|
||||||
|
SoVITS_names.append("%s/%s" % (path, name))
|
||||||
GPT_names = [i for i in pretrained_gpt_name]
|
GPT_names = [i for i in pretrained_gpt_name]
|
||||||
for path in GPT_weight_root:
|
for path in GPT_weight_root:
|
||||||
for name in os.listdir(path):
|
for name in os.listdir(path):
|
||||||
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
|
if name.endswith(".ckpt"):
|
||||||
|
GPT_names.append("%s/%s" % (path, name))
|
||||||
return SoVITS_names, GPT_names
|
return SoVITS_names, GPT_names
|
||||||
|
|
||||||
|
|
||||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||||
|
|
||||||
|
|
||||||
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
|
from process_ckpt import get_sovits_version_from_path_fast
|
||||||
|
|
||||||
|
|
||||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||||
global version, model_version, dict_language, if_lora_v3
|
global version, model_version, dict_language, if_lora_v3
|
||||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||||
@ -231,18 +266,21 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
|||||||
info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||||
gr.Warning(info)
|
gr.Warning(info)
|
||||||
raise FileExistsError(info)
|
raise FileExistsError(info)
|
||||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||||
if prompt_language is not None and text_language is not None:
|
if prompt_language is not None and text_language is not None:
|
||||||
if prompt_language in list(dict_language.keys()):
|
if prompt_language in list(dict_language.keys()):
|
||||||
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
|
prompt_text_update, prompt_language_update = (
|
||||||
|
{"__type__": "update"},
|
||||||
|
{"__type__": "update", "value": prompt_language},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_text_update = {'__type__':'update', 'value':''}
|
prompt_text_update = {"__type__": "update", "value": ""}
|
||||||
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
|
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||||
if text_language in list(dict_language.keys()):
|
if text_language in list(dict_language.keys()):
|
||||||
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
|
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
||||||
else:
|
else:
|
||||||
text_update = {'__type__':'update', 'value':''}
|
text_update = {"__type__": "update", "value": ""}
|
||||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||||
if model_version == "v3":
|
if model_version == "v3":
|
||||||
visible_sample_steps = True
|
visible_sample_steps = True
|
||||||
visible_inp_refs = False
|
visible_inp_refs = False
|
||||||
@ -250,45 +288,93 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
|||||||
visible_sample_steps = False
|
visible_sample_steps = False
|
||||||
visible_inp_refs = True
|
visible_inp_refs = True
|
||||||
# prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free,
|
# prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free,
|
||||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "interactive": visible_sample_steps,"value":32},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "interactive": True if model_version!="v3"else False},{"__type__": "update", "value":i18n("模型加载中,请等待"),"interactive":False}
|
yield (
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
prompt_text_update,
|
||||||
|
prompt_language_update,
|
||||||
|
text_update,
|
||||||
|
text_language_update,
|
||||||
|
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||||
|
{"__type__": "update", "visible": visible_inp_refs},
|
||||||
|
{"__type__": "update", "interactive": True if model_version != "v3" else False},
|
||||||
|
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
||||||
|
)
|
||||||
|
|
||||||
tts_pipeline.init_vits_weights(sovits_path)
|
tts_pipeline.init_vits_weights(sovits_path)
|
||||||
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "interactive": visible_sample_steps,"value":32},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "interactive": True if model_version!="v3"else False},{"__type__": "update", "value":i18n("合成语音"),"interactive":True}
|
yield (
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
prompt_text_update,
|
||||||
|
prompt_language_update,
|
||||||
|
text_update,
|
||||||
|
text_language_update,
|
||||||
|
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||||
|
{"__type__": "update", "visible": visible_inp_refs},
|
||||||
|
{"__type__": "update", "interactive": True if model_version != "v3" else False},
|
||||||
|
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
||||||
|
)
|
||||||
with open("./weight.json") as f:
|
with open("./weight.json") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
data["SoVITS"][version] = sovits_path
|
data["SoVITS"][version] = sovits_path
|
||||||
with open("./weight.json","w")as f:f.write(json.dumps(data))
|
with open("./weight.json", "w") as f:
|
||||||
|
f.write(json.dumps(data))
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
||||||
|
+ "<br>"
|
||||||
|
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
# with gr.Group():
|
# with gr.Group():
|
||||||
gr.Markdown(value=i18n("模型切换"))
|
gr.Markdown(value=i18n("模型切换"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
GPT_dropdown = gr.Dropdown(
|
||||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
label=i18n("GPT模型列表"),
|
||||||
|
choices=sorted(GPT_names, key=custom_sort_key),
|
||||||
|
value=gpt_path,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
SoVITS_dropdown = gr.Dropdown(
|
||||||
|
label=i18n("SoVITS模型列表"),
|
||||||
|
choices=sorted(SoVITS_names, key=custom_sort_key),
|
||||||
|
value=sovits_path,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||||
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||||
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple", visible=True if model_version!="v3"else False)
|
inp_refs = gr.File(
|
||||||
|
label=i18n("辅参考音频(可选多个,或不选)"),
|
||||||
|
file_count="multiple",
|
||||||
|
visible=True if model_version != "v3" else False,
|
||||||
|
)
|
||||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt_language = gr.Dropdown(
|
prompt_language = gr.Dropdown(
|
||||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True if model_version!="v3"else False, show_label=True)
|
ref_text_free = gr.Checkbox(
|
||||||
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"<br>"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。"))
|
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
|
||||||
|
value=False,
|
||||||
|
interactive=True if model_version != "v3" else False,
|
||||||
|
show_label=True,
|
||||||
|
)
|
||||||
|
gr.Markdown(
|
||||||
|
i18n("使用无参考文本模式时建议使用微调的GPT")
|
||||||
|
+ "<br>"
|
||||||
|
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
||||||
@ -297,41 +383,65 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.Markdown(value=i18n("推理设置"))
|
gr.Markdown(value=i18n("推理设置"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
batch_size = gr.Slider(
|
||||||
sample_steps = gr.Radio(label=i18n("采样步数(仅对V3生效)"),value=32,choices=[4,8,16,32],visible=True)
|
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||||
|
)
|
||||||
|
sample_steps = gr.Radio(
|
||||||
|
label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
fragment_interval = gr.Slider(
|
||||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="语速",value=1.0,interactive=True)
|
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
||||||
|
)
|
||||||
|
speed_factor = gr.Slider(
|
||||||
|
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
temperature = gr.Slider(
|
||||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
||||||
|
)
|
||||||
|
repetition_penalty = gr.Slider(
|
||||||
|
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
how_to_cut = gr.Dropdown(
|
how_to_cut = gr.Dropdown(
|
||||||
label=i18n("怎么切"),
|
label=i18n("怎么切"),
|
||||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
choices=[
|
||||||
|
i18n("不切"),
|
||||||
|
i18n("凑四句一切"),
|
||||||
|
i18n("凑50字一切"),
|
||||||
|
i18n("按中文句号。切"),
|
||||||
|
i18n("按英文句号.切"),
|
||||||
|
i18n("按标点符号切"),
|
||||||
|
],
|
||||||
value=i18n("凑四句一切"),
|
value=i18n("凑四句一切"),
|
||||||
interactive=True, scale=1
|
interactive=True,
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
super_sampling = gr.Checkbox(
|
||||||
|
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
||||||
)
|
)
|
||||||
super_sampling = gr.Checkbox(label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
split_bucket = gr.Checkbox(
|
||||||
|
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||||
|
value=True,
|
||||||
|
interactive=True,
|
||||||
|
show_label=True,
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
||||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||||
|
|
||||||
@ -340,33 +450,71 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||||
|
|
||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
inference,
|
inference,
|
||||||
[
|
[
|
||||||
text,text_language, inp_ref, inp_refs,
|
text,
|
||||||
prompt_text, prompt_language,
|
text_language,
|
||||||
top_k, top_p, temperature,
|
inp_ref,
|
||||||
how_to_cut, batch_size,
|
inp_refs,
|
||||||
speed_factor, ref_text_free,
|
prompt_text,
|
||||||
split_bucket,fragment_interval,
|
prompt_language,
|
||||||
seed, keep_random, parallel_infer,
|
top_k,
|
||||||
repetition_penalty, sample_steps, super_sampling,
|
top_p,
|
||||||
|
temperature,
|
||||||
|
how_to_cut,
|
||||||
|
batch_size,
|
||||||
|
speed_factor,
|
||||||
|
ref_text_free,
|
||||||
|
split_bucket,
|
||||||
|
fragment_interval,
|
||||||
|
seed,
|
||||||
|
keep_random,
|
||||||
|
parallel_infer,
|
||||||
|
repetition_penalty,
|
||||||
|
sample_steps,
|
||||||
|
super_sampling,
|
||||||
],
|
],
|
||||||
[output, seed],
|
[output, seed],
|
||||||
)
|
)
|
||||||
stop_infer.click(tts_pipeline.stop, [], [])
|
stop_infer.click(tts_pipeline.stop, [], [])
|
||||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs,ref_text_free,inference_button])#
|
SoVITS_dropdown.change(
|
||||||
|
change_sovits_weights,
|
||||||
|
[SoVITS_dropdown, prompt_language, text_language],
|
||||||
|
[
|
||||||
|
prompt_language,
|
||||||
|
text_language,
|
||||||
|
prompt_text,
|
||||||
|
prompt_language,
|
||||||
|
text,
|
||||||
|
text_language,
|
||||||
|
sample_steps,
|
||||||
|
inp_refs,
|
||||||
|
ref_text_free,
|
||||||
|
inference_button,
|
||||||
|
],
|
||||||
|
) #
|
||||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
gr.Markdown(
|
||||||
|
value=i18n(
|
||||||
|
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||||
|
)
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
_how_to_cut = gr.Radio(
|
_how_to_cut = gr.Radio(
|
||||||
label=i18n("怎么切"),
|
label=i18n("怎么切"),
|
||||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
choices=[
|
||||||
|
i18n("不切"),
|
||||||
|
i18n("凑四句一切"),
|
||||||
|
i18n("凑50字一切"),
|
||||||
|
i18n("按中文句号。切"),
|
||||||
|
i18n("按英文句号.切"),
|
||||||
|
i18n("按标点符号切"),
|
||||||
|
],
|
||||||
value=i18n("凑四句一切"),
|
value=i18n("凑四句一切"),
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
@ -382,7 +530,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
||||||
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
app.queue().launch( # concurrency_count=511, max_size=1022
|
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||||
server_name="0.0.0.0",
|
server_name="0.0.0.0",
|
||||||
inbrowser=True,
|
inbrowser=True,
|
||||||
|
@ -18,7 +18,7 @@ class Encoder(nn.Module):
|
|||||||
p_dropout=0.0,
|
p_dropout=0.0,
|
||||||
window_size=4,
|
window_size=4,
|
||||||
isflow=False,
|
isflow=False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
@ -56,9 +56,7 @@ class Encoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||||
if isflow:
|
if isflow:
|
||||||
cond_layer = torch.nn.Conv1d(
|
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
|
||||||
)
|
|
||||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||||
self.gin_channels = kwargs["gin_channels"]
|
self.gin_channels = kwargs["gin_channels"]
|
||||||
@ -74,9 +72,7 @@ class Encoder(nn.Module):
|
|||||||
x = self.cond_pre(x)
|
x = self.cond_pre(x)
|
||||||
cond_offset = i * 2 * self.hidden_channels
|
cond_offset = i * 2 * self.hidden_channels
|
||||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
|
||||||
)
|
|
||||||
y = self.attn_layers[i](x, x, attn_mask)
|
y = self.attn_layers[i](x, x, attn_mask)
|
||||||
y = self.drop(y)
|
y = self.drop(y)
|
||||||
x = self.norm_layers_1[i](x + y)
|
x = self.norm_layers_1[i](x + y)
|
||||||
@ -99,7 +95,7 @@ class Decoder(nn.Module):
|
|||||||
p_dropout=0.0,
|
p_dropout=0.0,
|
||||||
proximal_bias=False,
|
proximal_bias=False,
|
||||||
proximal_init=True,
|
proximal_init=True,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
@ -131,9 +127,7 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||||
self.encdec_attn_layers.append(
|
self.encdec_attn_layers.append(
|
||||||
MultiHeadAttention(
|
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
|
||||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||||
self.ffn_layers.append(
|
self.ffn_layers.append(
|
||||||
@ -153,9 +147,7 @@ class Decoder(nn.Module):
|
|||||||
x: decoder input
|
x: decoder input
|
||||||
h: encoder output
|
h: encoder output
|
||||||
"""
|
"""
|
||||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||||
device=x.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||||
x = x * x_mask
|
x = x * x_mask
|
||||||
for i in range(self.n_layers):
|
for i in range(self.n_layers):
|
||||||
@ -211,14 +203,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if window_size is not None:
|
if window_size is not None:
|
||||||
n_heads_rel = 1 if heads_share else n_heads
|
n_heads_rel = 1 if heads_share else n_heads
|
||||||
rel_stddev = self.k_channels**-0.5
|
rel_stddev = self.k_channels**-0.5
|
||||||
self.emb_rel_k = nn.Parameter(
|
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||||
* rel_stddev
|
|
||||||
)
|
|
||||||
self.emb_rel_v = nn.Parameter(
|
|
||||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
|
||||||
* rel_stddev
|
|
||||||
)
|
|
||||||
|
|
||||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||||
@ -247,46 +233,28 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||||
if self.window_size is not None:
|
if self.window_size is not None:
|
||||||
assert (
|
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||||
t_s == t_t
|
|
||||||
), "Relative attention is only available for self-attention."
|
|
||||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||||
rel_logits = self._matmul_with_relative_keys(
|
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
||||||
query / math.sqrt(self.k_channels), key_relative_embeddings
|
|
||||||
)
|
|
||||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||||
scores = scores + scores_local
|
scores = scores + scores_local
|
||||||
if self.proximal_bias:
|
if self.proximal_bias:
|
||||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||||
scores = scores + self._attention_bias_proximal(t_s).to(
|
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||||
device=scores.device, dtype=scores.dtype
|
|
||||||
)
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores.masked_fill(mask == 0, -1e4)
|
scores = scores.masked_fill(mask == 0, -1e4)
|
||||||
if self.block_length is not None:
|
if self.block_length is not None:
|
||||||
assert (
|
assert t_s == t_t, "Local attention is only available for self-attention."
|
||||||
t_s == t_t
|
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
||||||
), "Local attention is only available for self-attention."
|
|
||||||
block_mask = (
|
|
||||||
torch.ones_like(scores)
|
|
||||||
.triu(-self.block_length)
|
|
||||||
.tril(self.block_length)
|
|
||||||
)
|
|
||||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||||
p_attn = self.drop(p_attn)
|
p_attn = self.drop(p_attn)
|
||||||
output = torch.matmul(p_attn, value)
|
output = torch.matmul(p_attn, value)
|
||||||
if self.window_size is not None:
|
if self.window_size is not None:
|
||||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||||
value_relative_embeddings = self._get_relative_embeddings(
|
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||||
self.emb_rel_v, t_s
|
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||||
)
|
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||||
output = output + self._matmul_with_relative_values(
|
|
||||||
relative_weights, value_relative_embeddings
|
|
||||||
)
|
|
||||||
output = (
|
|
||||||
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
|
||||||
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
|
||||||
return output, p_attn
|
return output, p_attn
|
||||||
|
|
||||||
def _matmul_with_relative_values(self, x, y):
|
def _matmul_with_relative_values(self, x, y):
|
||||||
@ -320,9 +288,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
padded_relative_embeddings = relative_embeddings
|
padded_relative_embeddings = relative_embeddings
|
||||||
used_relative_embeddings = padded_relative_embeddings[
|
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||||
:, slice_start_position:slice_end_position
|
|
||||||
]
|
|
||||||
return used_relative_embeddings
|
return used_relative_embeddings
|
||||||
|
|
||||||
def _relative_position_to_absolute_position(self, x):
|
def _relative_position_to_absolute_position(self, x):
|
||||||
@ -336,14 +302,10 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||||
x_flat = x.view([batch, heads, length * 2 * length])
|
x_flat = x.view([batch, heads, length * 2 * length])
|
||||||
x_flat = F.pad(
|
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape and slice out the padded elements.
|
# Reshape and slice out the padded elements.
|
||||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||||
:, :, :length, length - 1 :
|
|
||||||
]
|
|
||||||
return x_final
|
return x_final
|
||||||
|
|
||||||
def _absolute_position_to_relative_position(self, x):
|
def _absolute_position_to_relative_position(self, x):
|
||||||
@ -353,9 +315,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
batch, heads, length, _ = x.size()
|
batch, heads, length, _ = x.size()
|
||||||
# padd along column
|
# padd along column
|
||||||
x = F.pad(
|
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
|
||||||
)
|
|
||||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||||
# add 0's in the beginning that will skew the elements after reshape
|
# add 0's in the beginning that will skew the elements after reshape
|
||||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||||
@ -537,9 +497,7 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def weight_norm_modules(module, name="weight", dim=0):
|
def weight_norm_modules(module, name="weight", dim=0):
|
||||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||||
module, Depthwise_Separable_TransposeConv1D
|
|
||||||
):
|
|
||||||
module.weight_norm()
|
module.weight_norm()
|
||||||
return module
|
return module
|
||||||
else:
|
else:
|
||||||
@ -547,9 +505,7 @@ def weight_norm_modules(module, name="weight", dim=0):
|
|||||||
|
|
||||||
|
|
||||||
def remove_weight_norm_modules(module, name="weight"):
|
def remove_weight_norm_modules(module, name="weight"):
|
||||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||||
module, Depthwise_Separable_TransposeConv1D
|
|
||||||
):
|
|
||||||
module.remove_weight_norm()
|
module.remove_weight_norm()
|
||||||
else:
|
else:
|
||||||
remove_weight_norm(module, name)
|
remove_weight_norm(module, name)
|
||||||
@ -567,7 +523,7 @@ class FFT(nn.Module):
|
|||||||
proximal_bias=False,
|
proximal_bias=False,
|
||||||
proximal_init=True,
|
proximal_init=True,
|
||||||
isflow=False,
|
isflow=False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
@ -579,9 +535,7 @@ class FFT(nn.Module):
|
|||||||
self.proximal_bias = proximal_bias
|
self.proximal_bias = proximal_bias
|
||||||
self.proximal_init = proximal_init
|
self.proximal_init = proximal_init
|
||||||
if isflow:
|
if isflow:
|
||||||
cond_layer = torch.nn.Conv1d(
|
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
|
||||||
)
|
|
||||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||||
self.gin_channels = kwargs["gin_channels"]
|
self.gin_channels = kwargs["gin_channels"]
|
||||||
@ -622,18 +576,14 @@ class FFT(nn.Module):
|
|||||||
if g is not None:
|
if g is not None:
|
||||||
g = self.cond_layer(g)
|
g = self.cond_layer(g)
|
||||||
|
|
||||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||||
device=x.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
x = x * x_mask
|
x = x * x_mask
|
||||||
for i in range(self.n_layers):
|
for i in range(self.n_layers):
|
||||||
if g is not None:
|
if g is not None:
|
||||||
x = self.cond_pre(x)
|
x = self.cond_pre(x)
|
||||||
cond_offset = i * 2 * self.hidden_channels
|
cond_offset = i * 2 * self.hidden_channels
|
||||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
|
||||||
)
|
|
||||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||||
y = self.drop(y)
|
y = self.drop(y)
|
||||||
x = self.norm_layers_0[i](x + y)
|
x = self.norm_layers_0[i](x + y)
|
||||||
|
@ -7,6 +7,7 @@ from module import commons
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
def __init__(self, channels, eps=1e-5):
|
def __init__(self, channels, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -43,7 +44,7 @@ class Encoder(nn.Module):
|
|||||||
p_dropout=0.0,
|
p_dropout=0.0,
|
||||||
window_size=4,
|
window_size=4,
|
||||||
isflow=True,
|
isflow=True,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
@ -65,13 +66,9 @@ class Encoder(nn.Module):
|
|||||||
if self.gin_channels != 0:
|
if self.gin_channels != 0:
|
||||||
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
||||||
# vits2 says 3rd block, so idx is 2 by default
|
# vits2 says 3rd block, so idx is 2 by default
|
||||||
self.cond_layer_idx = (
|
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||||
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
|
||||||
)
|
|
||||||
logging.debug(self.gin_channels, self.cond_layer_idx)
|
logging.debug(self.gin_channels, self.cond_layer_idx)
|
||||||
assert (
|
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
|
||||||
self.cond_layer_idx < self.n_layers
|
|
||||||
), "cond_layer_idx should be less than n_layers"
|
|
||||||
self.drop = nn.Dropout(p_dropout)
|
self.drop = nn.Dropout(p_dropout)
|
||||||
self.attn_layers = nn.ModuleList()
|
self.attn_layers = nn.ModuleList()
|
||||||
self.norm_layers_1 = nn.ModuleList()
|
self.norm_layers_1 = nn.ModuleList()
|
||||||
@ -121,7 +118,9 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||||
x = x * x_mask
|
x = x * x_mask
|
||||||
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
|
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
|
||||||
|
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
|
||||||
|
):
|
||||||
y = attn_layers(x, x, attn_mask)
|
y = attn_layers(x, x, attn_mask)
|
||||||
y = self.drop(y)
|
y = self.drop(y)
|
||||||
x = norm_layers_1(x + y)
|
x = norm_layers_1(x + y)
|
||||||
@ -170,14 +169,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
if window_size is not None:
|
if window_size is not None:
|
||||||
n_heads_rel = 1 if heads_share else n_heads
|
n_heads_rel = 1 if heads_share else n_heads
|
||||||
rel_stddev = self.k_channels**-0.5
|
rel_stddev = self.k_channels**-0.5
|
||||||
self.emb_rel_k = nn.Parameter(
|
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||||
* rel_stddev
|
|
||||||
)
|
|
||||||
self.emb_rel_v = nn.Parameter(
|
|
||||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
|
||||||
* rel_stddev
|
|
||||||
)
|
|
||||||
|
|
||||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||||
@ -224,7 +217,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||||
|
|
||||||
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
|
output = output.transpose(2, 3).contiguous().view(b, d, -1)
|
||||||
return output, p_attn
|
return output, p_attn
|
||||||
|
|
||||||
def _matmul_with_relative_values(self, x, y):
|
def _matmul_with_relative_values(self, x, y):
|
||||||
@ -258,9 +251,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
relative_embeddings,
|
relative_embeddings,
|
||||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||||
)
|
)
|
||||||
used_relative_embeddings = padded_relative_embeddings[
|
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||||
:, slice_start_position:slice_end_position
|
|
||||||
]
|
|
||||||
return used_relative_embeddings
|
return used_relative_embeddings
|
||||||
|
|
||||||
def _relative_position_to_absolute_position(self, x):
|
def _relative_position_to_absolute_position(self, x):
|
||||||
@ -274,14 +265,10 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||||
x_flat = x.view([batch, heads, length * 2 * length])
|
x_flat = x.view([batch, heads, length * 2 * length])
|
||||||
x_flat = F.pad(
|
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape and slice out the padded elements.
|
# Reshape and slice out the padded elements.
|
||||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||||
:, :, :length, length - 1 :
|
|
||||||
]
|
|
||||||
return x_final
|
return x_final
|
||||||
|
|
||||||
def _absolute_position_to_relative_position(self, x):
|
def _absolute_position_to_relative_position(self, x):
|
||||||
@ -291,9 +278,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
batch, heads, length, _ = x.size()
|
batch, heads, length, _ = x.size()
|
||||||
# padd along column
|
# padd along column
|
||||||
x = F.pad(
|
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
|
||||||
)
|
|
||||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||||
# add 0's in the beginning that will skew the elements after reshape
|
# add 0's in the beginning that will skew the elements after reshape
|
||||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||||
@ -395,12 +380,6 @@ class MRTE(nn.Module):
|
|||||||
|
|
||||||
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
||||||
text_enc = self.text_pre(text * text_mask)
|
text_enc = self.text_pre(text * text_mask)
|
||||||
x = (
|
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||||
self.cross_attention(
|
|
||||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
|
||||||
)
|
|
||||||
+ ssl_enc
|
|
||||||
+ ge
|
|
||||||
)
|
|
||||||
x = self.c_post(x * ssl_mask)
|
x = self.c_post(x * ssl_mask)
|
||||||
return x
|
return x
|
||||||
|
@ -28,9 +28,7 @@ def intersperse(lst, item):
|
|||||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||||
"""KL(P||Q)"""
|
"""KL(P||Q)"""
|
||||||
kl = (logs_q - logs_p) - 0.5
|
kl = (logs_q - logs_p) - 0.5
|
||||||
kl += (
|
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||||
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
|
||||||
)
|
|
||||||
return kl
|
return kl
|
||||||
|
|
||||||
|
|
||||||
@ -67,9 +65,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
|||||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||||
position = torch.arange(length, dtype=torch.float)
|
position = torch.arange(length, dtype=torch.float)
|
||||||
num_timescales = channels // 2
|
num_timescales = channels // 2
|
||||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
||||||
num_timescales - 1
|
|
||||||
)
|
|
||||||
inv_timescales = min_timescale * torch.exp(
|
inv_timescales = min_timescale * torch.exp(
|
||||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||||
)
|
)
|
||||||
|
@ -30,6 +30,7 @@
|
|||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
"""Core vector quantization implementation."""
|
"""Core vector quantization implementation."""
|
||||||
|
|
||||||
import typing as tp
|
import typing as tp
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
@ -121,9 +122,7 @@ class EuclideanCodebook(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.decay = decay
|
self.decay = decay
|
||||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
||||||
uniform_init if not kmeans_init else torch.zeros
|
|
||||||
)
|
|
||||||
embed = init_fn(codebook_size, dim)
|
embed = init_fn(codebook_size, dim)
|
||||||
|
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
@ -151,9 +150,7 @@ class EuclideanCodebook(nn.Module):
|
|||||||
# broadcast_tensors(self.buffers())
|
# broadcast_tensors(self.buffers())
|
||||||
|
|
||||||
def replace_(self, samples, mask):
|
def replace_(self, samples, mask):
|
||||||
modified_codebook = torch.where(
|
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
||||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
|
||||||
)
|
|
||||||
self.embed.data.copy_(modified_codebook)
|
self.embed.data.copy_(modified_codebook)
|
||||||
|
|
||||||
def expire_codes_(self, batch_samples):
|
def expire_codes_(self, batch_samples):
|
||||||
@ -174,11 +171,7 @@ class EuclideanCodebook(nn.Module):
|
|||||||
|
|
||||||
def quantize(self, x):
|
def quantize(self, x):
|
||||||
embed = self.embed.t()
|
embed = self.embed.t()
|
||||||
dist = -(
|
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
||||||
x.pow(2).sum(1, keepdim=True)
|
|
||||||
- 2 * x @ embed
|
|
||||||
+ embed.pow(2).sum(0, keepdim=True)
|
|
||||||
)
|
|
||||||
embed_ind = dist.max(dim=-1).indices
|
embed_ind = dist.max(dim=-1).indices
|
||||||
return embed_ind
|
return embed_ind
|
||||||
|
|
||||||
@ -222,8 +215,7 @@ class EuclideanCodebook(nn.Module):
|
|||||||
embed_sum = x.t() @ embed_onehot
|
embed_sum = x.t() @ embed_onehot
|
||||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||||
cluster_size = (
|
cluster_size = (
|
||||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
||||||
* self.cluster_size.sum()
|
|
||||||
)
|
)
|
||||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||||
self.embed.data.copy_(embed_normalized)
|
self.embed.data.copy_(embed_normalized)
|
||||||
@ -264,12 +256,8 @@ class VectorQuantization(nn.Module):
|
|||||||
_codebook_dim: int = default(codebook_dim, dim)
|
_codebook_dim: int = default(codebook_dim, dim)
|
||||||
|
|
||||||
requires_projection = _codebook_dim != dim
|
requires_projection = _codebook_dim != dim
|
||||||
self.project_in = (
|
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||||
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||||
)
|
|
||||||
self.project_out = (
|
|
||||||
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.commitment_weight = commitment_weight
|
self.commitment_weight = commitment_weight
|
||||||
@ -330,13 +318,9 @@ class ResidualVectorQuantization(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, *, num_quantizers, **kwargs):
|
def __init__(self, *, num_quantizers, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
||||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
|
||||||
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
|
||||||
):
|
|
||||||
quantized_out = 0.0
|
quantized_out = 0.0
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
@ -359,9 +343,7 @@ class ResidualVectorQuantization(nn.Module):
|
|||||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||||
return quantized_out, out_indices, out_losses, out_quantized
|
return quantized_out, out_indices, out_losses, out_quantized
|
||||||
|
|
||||||
def encode(
|
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
residual = x
|
residual = x
|
||||||
all_indices = []
|
all_indices = []
|
||||||
n_q = n_q or len(self.layers)
|
n_q = n_q or len(self.layers)
|
||||||
|
@ -1,24 +1,18 @@
|
|||||||
import time
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from module import commons
|
|
||||||
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
from utils import load_wav_to_torch, load_filepaths_and_text
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from functools import lru_cache
|
|
||||||
import requests
|
|
||||||
from scipy.io import wavfile
|
|
||||||
from io import BytesIO
|
|
||||||
from tools.my_utils import load_audio
|
from tools.my_utils import load_audio
|
||||||
version = os.environ.get('version',None)
|
|
||||||
|
version = os.environ.get("version", None)
|
||||||
|
|
||||||
|
|
||||||
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
||||||
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
@ -43,7 +37,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
tmp = line.split("\t")
|
tmp = line.split("\t")
|
||||||
if (len(tmp) != 4):
|
if len(tmp) != 4:
|
||||||
continue
|
continue
|
||||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||||
|
|
||||||
@ -51,7 +45,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
tmp = self.audiopaths_sid_text
|
tmp = self.audiopaths_sid_text
|
||||||
leng = len(tmp)
|
leng = len(tmp)
|
||||||
min_num = 100
|
min_num = 100
|
||||||
if (leng < min_num):
|
if leng < min_num:
|
||||||
self.audiopaths_sid_text = []
|
self.audiopaths_sid_text = []
|
||||||
for _ in range(max(2, int(min_num / leng))):
|
for _ in range(max(2, int(min_num / leng))):
|
||||||
self.audiopaths_sid_text += tmp
|
self.audiopaths_sid_text += tmp
|
||||||
@ -76,7 +70,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||||
try:
|
try:
|
||||||
phoneme = self.phoneme_data[audiopath][0]
|
phoneme = self.phoneme_data[audiopath][0]
|
||||||
phoneme = phoneme.split(' ')
|
phoneme = phoneme.split(" ")
|
||||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"{audiopath} not in self.phoneme_data !")
|
print(f"{audiopath} not in self.phoneme_data !")
|
||||||
@ -111,7 +105,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||||
if (ssl.shape[-1] != spec.shape[-1]):
|
if ssl.shape[-1] != spec.shape[-1]:
|
||||||
typee = ssl.dtype
|
typee = ssl.dtype
|
||||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
@ -129,8 +123,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
audio = torch.FloatTensor(audio_array) # /32768
|
audio = torch.FloatTensor(audio_array) # /32768
|
||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
spec = spectrogram_torch(
|
||||||
center=False)
|
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||||
|
)
|
||||||
spec = torch.squeeze(spec, 0)
|
spec = torch.squeeze(spec, 0)
|
||||||
return spec, audio_norm
|
return spec, audio_norm
|
||||||
|
|
||||||
@ -146,8 +141,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
return len(self.audiopaths_sid_text)
|
return len(self.audiopaths_sid_text)
|
||||||
|
|
||||||
def random_slice(self, ssl, wav, mel):
|
def random_slice(self, ssl, wav, mel):
|
||||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
|
||||||
"first", ssl.shape, wav.shape)
|
|
||||||
|
|
||||||
len_mel = mel.shape[1]
|
len_mel = mel.shape[1]
|
||||||
if self.val:
|
if self.val:
|
||||||
@ -168,11 +162,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
mel = mel[:, :sep_point]
|
mel = mel[:, :sep_point]
|
||||||
|
|
||||||
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
||||||
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
ssl.shape,
|
||||||
|
wav.shape,
|
||||||
|
wav2.shape,
|
||||||
|
mel.shape,
|
||||||
|
sep_point,
|
||||||
|
self.hop_length,
|
||||||
|
sep_point * self.hop_length,
|
||||||
|
dir,
|
||||||
|
)
|
||||||
return reference_mel, ssl, wav2, mel
|
return reference_mel, ssl, wav2, mel
|
||||||
class TextAudioSpeakerCollate():
|
|
||||||
""" Zero-pads model inputs and targets
|
|
||||||
"""
|
class TextAudioSpeakerCollate:
|
||||||
|
"""Zero-pads model inputs and targets"""
|
||||||
|
|
||||||
def __init__(self, return_ids=False):
|
def __init__(self, return_ids=False):
|
||||||
self.return_ids = return_ids
|
self.return_ids = return_ids
|
||||||
@ -184,9 +187,7 @@ class TextAudioSpeakerCollate():
|
|||||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||||
"""
|
"""
|
||||||
# Right zero-pad all one-hot text sequences to max input length
|
# Right zero-pad all one-hot text sequences to max input length
|
||||||
_, ids_sorted_decreasing = torch.sort(
|
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
|
||||||
dim=0, descending=True)
|
|
||||||
|
|
||||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||||
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
||||||
@ -230,6 +231,8 @@ class TextAudioSpeakerCollate():
|
|||||||
text_lengths[i] = text.size(0)
|
text_lengths[i] = text.size(0)
|
||||||
|
|
||||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||||
|
|
||||||
|
|
||||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
1) loads audio, speaker_id, text pairs
|
1) loads audio, speaker_id, text pairs
|
||||||
@ -253,7 +256,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
tmp = line.split("\t")
|
tmp = line.split("\t")
|
||||||
if (len(tmp) != 4):
|
if len(tmp) != 4:
|
||||||
continue
|
continue
|
||||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||||
|
|
||||||
@ -261,7 +264,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
|||||||
tmp = self.audiopaths_sid_text
|
tmp = self.audiopaths_sid_text
|
||||||
leng = len(tmp)
|
leng = len(tmp)
|
||||||
min_num = 100
|
min_num = 100
|
||||||
if (leng < min_num):
|
if leng < min_num:
|
||||||
self.audiopaths_sid_text = []
|
self.audiopaths_sid_text = []
|
||||||
for _ in range(max(2, int(min_num / leng))):
|
for _ in range(max(2, int(min_num / leng))):
|
||||||
self.audiopaths_sid_text += tmp
|
self.audiopaths_sid_text += tmp
|
||||||
@ -286,7 +289,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
|||||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||||
try:
|
try:
|
||||||
phoneme = self.phoneme_data[audiopath][0]
|
phoneme = self.phoneme_data[audiopath][0]
|
||||||
phoneme = phoneme.split(' ')
|
phoneme = phoneme.split(" ")
|
||||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"{audiopath} not in self.phoneme_data !")
|
print(f"{audiopath} not in self.phoneme_data !")
|
||||||
@ -322,6 +325,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
|||||||
self.sampling_rate_mel = 24000
|
self.sampling_rate_mel = 24000
|
||||||
self.mel_fmin = 0
|
self.mel_fmin = 0
|
||||||
self.mel_fmax = None
|
self.mel_fmax = None
|
||||||
|
|
||||||
def norm_spec(self, x):
|
def norm_spec(self, x):
|
||||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||||
|
|
||||||
@ -332,7 +336,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
|||||||
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||||
if (ssl.shape[-1] != spec.shape[-1]):
|
if ssl.shape[-1] != spec.shape[-1]:
|
||||||
typee = ssl.dtype
|
typee = ssl.dtype
|
||||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
@ -351,19 +355,29 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
|||||||
audio = torch.FloatTensor(audio_array) # /32768
|
audio = torch.FloatTensor(audio_array) # /32768
|
||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
audio_array24 = load_audio(
|
||||||
|
filename, 24000
|
||||||
|
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||||
audio_norm24 = audio24
|
audio_norm24 = audio24
|
||||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||||
|
|
||||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
spec = spectrogram_torch(
|
||||||
self.sampling_rate, self.hop_length, self.win_length,
|
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||||
center=False)
|
)
|
||||||
spec = torch.squeeze(spec, 0)
|
spec = torch.squeeze(spec, 0)
|
||||||
|
|
||||||
|
spec1 = spectrogram_torch(
|
||||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
audio_norm24,
|
||||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
self.filter_length_mel,
|
||||||
|
self.sampling_rate_mel,
|
||||||
|
self.hop_length_mel,
|
||||||
|
self.win_length_mel,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
mel = spec_to_mel_torch(
|
||||||
|
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||||
|
)
|
||||||
mel = torch.squeeze(mel, 0)
|
mel = torch.squeeze(mel, 0)
|
||||||
mel = self.norm_spec(mel)
|
mel = self.norm_spec(mel)
|
||||||
# print(1111111,spec.shape,mel.shape)
|
# print(1111111,spec.shape,mel.shape)
|
||||||
@ -379,9 +393,10 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.audiopaths_sid_text)
|
return len(self.audiopaths_sid_text)
|
||||||
class TextAudioSpeakerCollateV3():
|
|
||||||
""" Zero-pads model inputs and targets
|
|
||||||
"""
|
class TextAudioSpeakerCollateV3:
|
||||||
|
"""Zero-pads model inputs and targets"""
|
||||||
|
|
||||||
def __init__(self, return_ids=False):
|
def __init__(self, return_ids=False):
|
||||||
self.return_ids = return_ids
|
self.return_ids = return_ids
|
||||||
@ -394,9 +409,7 @@ class TextAudioSpeakerCollateV3():
|
|||||||
"""
|
"""
|
||||||
# ssl, spec, wav,mel, text
|
# ssl, spec, wav,mel, text
|
||||||
# Right zero-pad all one-hot text sequences to max input length
|
# Right zero-pad all one-hot text sequences to max input length
|
||||||
_, ids_sorted_decreasing = torch.sort(
|
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
|
||||||
dim=0, descending=True)
|
|
||||||
# (ssl, spec,mel, text)
|
# (ssl, spec,mel, text)
|
||||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||||
|
|
||||||
@ -456,6 +469,8 @@ class TextAudioSpeakerCollateV3():
|
|||||||
|
|
||||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||||
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
||||||
|
|
||||||
|
|
||||||
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
1) loads audio, speaker_id, text pairs
|
1) loads audio, speaker_id, text pairs
|
||||||
@ -479,7 +494,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
tmp = line.split("\t")
|
tmp = line.split("\t")
|
||||||
if (len(tmp) != 4):
|
if len(tmp) != 4:
|
||||||
continue
|
continue
|
||||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||||
|
|
||||||
@ -487,7 +502,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
|||||||
tmp = self.audiopaths_sid_text
|
tmp = self.audiopaths_sid_text
|
||||||
leng = len(tmp)
|
leng = len(tmp)
|
||||||
min_num = 100
|
min_num = 100
|
||||||
if (leng < min_num):
|
if leng < min_num:
|
||||||
self.audiopaths_sid_text = []
|
self.audiopaths_sid_text = []
|
||||||
for _ in range(max(2, int(min_num / leng))):
|
for _ in range(max(2, int(min_num / leng))):
|
||||||
self.audiopaths_sid_text += tmp
|
self.audiopaths_sid_text += tmp
|
||||||
@ -512,7 +527,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
|||||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||||
try:
|
try:
|
||||||
phoneme = self.phoneme_data[audiopath][0]
|
phoneme = self.phoneme_data[audiopath][0]
|
||||||
phoneme = phoneme.split(' ')
|
phoneme = phoneme.split(" ")
|
||||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"{audiopath} not in self.phoneme_data !")
|
print(f"{audiopath} not in self.phoneme_data !")
|
||||||
@ -548,6 +563,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
|||||||
self.sampling_rate_mel = 24000
|
self.sampling_rate_mel = 24000
|
||||||
self.mel_fmin = 0
|
self.mel_fmin = 0
|
||||||
self.mel_fmax = None
|
self.mel_fmax = None
|
||||||
|
|
||||||
def norm_spec(self, x):
|
def norm_spec(self, x):
|
||||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||||
|
|
||||||
@ -558,7 +574,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
|||||||
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||||
if (ssl.shape[-1] != spec.shape[-1]):
|
if ssl.shape[-1] != spec.shape[-1]:
|
||||||
typee = ssl.dtype
|
typee = ssl.dtype
|
||||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
@ -577,19 +593,29 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
|||||||
audio = torch.FloatTensor(audio_array) # /32768
|
audio = torch.FloatTensor(audio_array) # /32768
|
||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
audio_array24 = load_audio(
|
||||||
|
filename, 24000
|
||||||
|
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||||
audio_norm24 = audio24
|
audio_norm24 = audio24
|
||||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||||
|
|
||||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
spec = spectrogram_torch(
|
||||||
self.sampling_rate, self.hop_length, self.win_length,
|
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||||
center=False)
|
)
|
||||||
spec = torch.squeeze(spec, 0)
|
spec = torch.squeeze(spec, 0)
|
||||||
|
|
||||||
|
spec1 = spectrogram_torch(
|
||||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
audio_norm24,
|
||||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
self.filter_length_mel,
|
||||||
|
self.sampling_rate_mel,
|
||||||
|
self.hop_length_mel,
|
||||||
|
self.win_length_mel,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
mel = spec_to_mel_torch(
|
||||||
|
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||||
|
)
|
||||||
mel = torch.squeeze(mel, 0)
|
mel = torch.squeeze(mel, 0)
|
||||||
mel = self.norm_spec(mel)
|
mel = self.norm_spec(mel)
|
||||||
# print(1111111,spec.shape,mel.shape)
|
# print(1111111,spec.shape,mel.shape)
|
||||||
@ -605,9 +631,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.audiopaths_sid_text)
|
return len(self.audiopaths_sid_text)
|
||||||
class TextAudioSpeakerCollateV3b():
|
|
||||||
""" Zero-pads model inputs and targets
|
|
||||||
"""
|
class TextAudioSpeakerCollateV3b:
|
||||||
|
"""Zero-pads model inputs and targets"""
|
||||||
|
|
||||||
def __init__(self, return_ids=False):
|
def __init__(self, return_ids=False):
|
||||||
self.return_ids = return_ids
|
self.return_ids = return_ids
|
||||||
@ -620,9 +647,7 @@ class TextAudioSpeakerCollateV3b():
|
|||||||
"""
|
"""
|
||||||
# ssl, spec, wav,mel, text
|
# ssl, spec, wav,mel, text
|
||||||
# Right zero-pad all one-hot text sequences to max input length
|
# Right zero-pad all one-hot text sequences to max input length
|
||||||
_, ids_sorted_decreasing = torch.sort(
|
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
|
||||||
dim=0, descending=True)
|
|
||||||
# (ssl, spec,mel, text)
|
# (ssl, spec,mel, text)
|
||||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||||
|
|
||||||
@ -679,9 +704,21 @@ class TextAudioSpeakerCollateV3b():
|
|||||||
text_padded[i, : text.size(0)] = text
|
text_padded[i, : text.size(0)] = text
|
||||||
text_lengths[i] = text.size(0)
|
text_lengths[i] = text.size(0)
|
||||||
|
|
||||||
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
return (
|
||||||
|
ssl_padded,
|
||||||
|
spec_padded,
|
||||||
|
mel_padded,
|
||||||
|
ssl_lengths,
|
||||||
|
spec_lengths,
|
||||||
|
text_padded,
|
||||||
|
text_lengths,
|
||||||
|
wav_padded,
|
||||||
|
wav_lengths,
|
||||||
|
mel_lengths,
|
||||||
|
)
|
||||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
||||||
|
|
||||||
|
|
||||||
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||||
"""
|
"""
|
||||||
Maintain similar input lengths in a batch.
|
Maintain similar input lengths in a batch.
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def feature_loss(fmap_r, fmap_g):
|
def feature_loss(fmap_r, fmap_g):
|
||||||
@ -66,8 +65,6 @@ def mle_loss(z, m, logs, logdet, mask):
|
|||||||
torch.exp(-2 * logs) * ((z - m) ** 2)
|
torch.exp(-2 * logs) * ((z - m) ** 2)
|
||||||
) # neg normal likelihood w/o the constant term
|
) # neg normal likelihood w/o the constant term
|
||||||
l = l - torch.sum(logdet) # log jacobian determinant
|
l = l - torch.sum(logdet) # log jacobian determinant
|
||||||
l = l / torch.sum(
|
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
|
||||||
torch.ones_like(z) * mask
|
|
||||||
) # averaging across batch, channel and time axes
|
|
||||||
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
||||||
return l
|
return l
|
||||||
|
@ -1,16 +1,5 @@
|
|||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import numpy as np
|
|
||||||
import librosa
|
|
||||||
import librosa.util as librosa_util
|
|
||||||
from librosa.util import normalize, pad_center, tiny
|
|
||||||
from scipy.signal import get_window
|
|
||||||
from scipy.io.wavfile import read
|
|
||||||
from librosa.filters import mel as librosa_mel_fn
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
|
||||||
MAX_WAV_VALUE = 32768.0
|
MAX_WAV_VALUE = 32768.0
|
||||||
@ -58,9 +47,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
|
|||||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||||
if wnsize_dtype_device not in hann_window:
|
if wnsize_dtype_device not in hann_window:
|
||||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||||
dtype=y.dtype, device=y.device
|
|
||||||
)
|
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
y = torch.nn.functional.pad(
|
||||||
y.unsqueeze(1),
|
y.unsqueeze(1),
|
||||||
@ -90,20 +77,14 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
|||||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||||
if fmax_dtype_device not in mel_basis:
|
if fmax_dtype_device not in mel_basis:
|
||||||
mel = librosa_mel_fn(
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||||
)
|
|
||||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
|
||||||
dtype=spec.dtype, device=spec.device
|
|
||||||
)
|
|
||||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||||
spec = spectral_normalize_torch(spec)
|
spec = spectral_normalize_torch(spec)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def mel_spectrogram_torch(
|
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
|
||||||
):
|
|
||||||
if torch.min(y) < -1.0:
|
if torch.min(y) < -1.0:
|
||||||
print("min value is ", torch.min(y))
|
print("min value is ", torch.min(y))
|
||||||
if torch.max(y) > 1.0:
|
if torch.max(y) > 1.0:
|
||||||
@ -114,16 +95,10 @@ def mel_spectrogram_torch(
|
|||||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||||
if fmax_dtype_device not in mel_basis:
|
if fmax_dtype_device not in mel_basis:
|
||||||
mel = librosa_mel_fn(
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||||
)
|
|
||||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
|
||||||
dtype=y.dtype, device=y.device
|
|
||||||
)
|
|
||||||
if wnsize_dtype_device not in hann_window:
|
if wnsize_dtype_device not in hann_window:
|
||||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||||
dtype=y.dtype, device=y.device
|
|
||||||
)
|
|
||||||
|
|
||||||
y = torch.nn.functional.pad(
|
y = torch.nn.functional.pad(
|
||||||
y.unsqueeze(1),
|
y.unsqueeze(1),
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -13,16 +11,18 @@ from module import commons
|
|||||||
from module import modules
|
from module import modules
|
||||||
from module import attentions
|
from module import attentions
|
||||||
from f5_tts.model import DiT
|
from f5_tts.model import DiT
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
from module.commons import init_weights, get_padding
|
from module.commons import init_weights, get_padding
|
||||||
from module.mrte_model import MRTE
|
from module.mrte_model import MRTE
|
||||||
from module.quantize import ResidualVectorQuantizer
|
from module.quantize import ResidualVectorQuantizer
|
||||||
|
|
||||||
# from text import symbols
|
# from text import symbols
|
||||||
from text import symbols as symbols_v1
|
from text import symbols as symbols_v1
|
||||||
from text import symbols2 as symbols_v2
|
from text import symbols2 as symbols_v2
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
import contextlib,random
|
import contextlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class StochasticDurationPredictor(nn.Module):
|
class StochasticDurationPredictor(nn.Module):
|
||||||
@ -48,29 +48,21 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
self.flows = nn.ModuleList()
|
self.flows = nn.ModuleList()
|
||||||
self.flows.append(modules.ElementwiseAffine(2))
|
self.flows.append(modules.ElementwiseAffine(2))
|
||||||
for i in range(n_flows):
|
for i in range(n_flows):
|
||||||
self.flows.append(
|
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
|
||||||
)
|
|
||||||
self.flows.append(modules.Flip())
|
self.flows.append(modules.Flip())
|
||||||
|
|
||||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||||
self.post_convs = modules.DDSConv(
|
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
|
||||||
)
|
|
||||||
self.post_flows = nn.ModuleList()
|
self.post_flows = nn.ModuleList()
|
||||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
self.post_flows.append(
|
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
|
||||||
)
|
|
||||||
self.post_flows.append(modules.Flip())
|
self.post_flows.append(modules.Flip())
|
||||||
|
|
||||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||||
self.convs = modules.DDSConv(
|
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
|
||||||
)
|
|
||||||
if gin_channels != 0:
|
if gin_channels != 0:
|
||||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||||
|
|
||||||
@ -91,10 +83,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
h_w = self.post_pre(w)
|
h_w = self.post_pre(w)
|
||||||
h_w = self.post_convs(h_w, x_mask)
|
h_w = self.post_convs(h_w, x_mask)
|
||||||
h_w = self.post_proj(h_w) * x_mask
|
h_w = self.post_proj(h_w) * x_mask
|
||||||
e_q = (
|
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
|
||||||
* x_mask
|
|
||||||
)
|
|
||||||
z_q = e_q
|
z_q = e_q
|
||||||
for flow in self.post_flows:
|
for flow in self.post_flows:
|
||||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||||
@ -102,13 +91,8 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||||
u = torch.sigmoid(z_u) * x_mask
|
u = torch.sigmoid(z_u) * x_mask
|
||||||
z0 = (w - u) * x_mask
|
z0 = (w - u) * x_mask
|
||||||
logdet_tot_q += torch.sum(
|
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||||
)
|
|
||||||
logq = (
|
|
||||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
|
||||||
- logdet_tot_q
|
|
||||||
)
|
|
||||||
|
|
||||||
logdet_tot = 0
|
logdet_tot = 0
|
||||||
z0, logdet = self.log_flow(z0, x_mask)
|
z0, logdet = self.log_flow(z0, x_mask)
|
||||||
@ -117,18 +101,12 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
for flow in flows:
|
for flow in flows:
|
||||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
logdet_tot = logdet_tot + logdet
|
logdet_tot = logdet_tot + logdet
|
||||||
nll = (
|
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
|
||||||
- logdet_tot
|
|
||||||
)
|
|
||||||
return nll + logq # [b]
|
return nll + logq # [b]
|
||||||
else:
|
else:
|
||||||
flows = list(reversed(self.flows))
|
flows = list(reversed(self.flows))
|
||||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||||
z = (
|
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
|
||||||
* noise_scale
|
|
||||||
)
|
|
||||||
for flow in flows:
|
for flow in flows:
|
||||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
z0, z1 = torch.split(z, [1, 1], 1)
|
z0, z1 = torch.split(z, [1, 1], 1)
|
||||||
@ -137,9 +115,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DurationPredictor(nn.Module):
|
class DurationPredictor(nn.Module):
|
||||||
def __init__(
|
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -149,13 +125,9 @@ class DurationPredictor(nn.Module):
|
|||||||
self.gin_channels = gin_channels
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
self.drop = nn.Dropout(p_dropout)
|
self.drop = nn.Dropout(p_dropout)
|
||||||
self.conv_1 = nn.Conv1d(
|
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
|
||||||
)
|
|
||||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||||
self.conv_2 = nn.Conv1d(
|
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
|
||||||
)
|
|
||||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||||
|
|
||||||
@ -238,24 +210,20 @@ class TextEncoder(nn.Module):
|
|||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
y.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
y = self.ssl_proj(y * y_mask) * y_mask
|
y = self.ssl_proj(y * y_mask) * y_mask
|
||||||
|
|
||||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||||
|
|
||||||
text_mask = torch.unsqueeze(
|
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
|
||||||
commons.sequence_mask(text_lengths, text.size(1)), 1
|
|
||||||
).to(y.dtype)
|
|
||||||
if test == 1:
|
if test == 1:
|
||||||
text[:, :] = 0
|
text[:, :] = 0
|
||||||
text = self.text_embedding(text).transpose(1, 2)
|
text = self.text_embedding(text).transpose(1, 2)
|
||||||
text = self.encoder_text(text * text_mask, text_mask)
|
text = self.encoder_text(text * text_mask, text_mask)
|
||||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||||
y = self.encoder2(y * y_mask, y_mask)
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
if(speed!=1):
|
if speed != 1:
|
||||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||||
stats = self.proj(y) * y_mask
|
stats = self.proj(y) * y_mask
|
||||||
@ -360,9 +328,7 @@ class PosteriorEncoder(nn.Module):
|
|||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
if g != None:
|
if g != None:
|
||||||
g = g.detach()
|
g = g.detach()
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
x.dtype
|
|
||||||
)
|
|
||||||
x = self.pre(x) * x_mask
|
x = self.pre(x) * x_mask
|
||||||
x = self.enc(x, x_mask, g=g)
|
x = self.enc(x, x_mask, g=g)
|
||||||
stats = self.proj(x) * x_mask
|
stats = self.proj(x) * x_mask
|
||||||
@ -372,14 +338,9 @@ class PosteriorEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
in_channels,
|
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||||
out_channels,
|
):
|
||||||
hidden_channels,
|
|
||||||
kernel_size,
|
|
||||||
dilation_rate,
|
|
||||||
n_layers,
|
|
||||||
gin_channels=0):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@ -394,7 +355,7 @@ class Encoder(nn.Module):
|
|||||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
if(g!=None):
|
if g != None:
|
||||||
g = g.detach()
|
g = g.detach()
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
x = self.pre(x) * x_mask
|
x = self.pre(x) * x_mask
|
||||||
@ -402,6 +363,7 @@ class Encoder(nn.Module):
|
|||||||
stats = self.proj(x) * x_mask
|
stats = self.proj(x) * x_mask
|
||||||
return stats, x_mask
|
return stats, x_mask
|
||||||
|
|
||||||
|
|
||||||
class WNEncoder(nn.Module):
|
class WNEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -434,9 +396,7 @@ class WNEncoder(nn.Module):
|
|||||||
self.norm = modules.LayerNorm(out_channels)
|
self.norm = modules.LayerNorm(out_channels)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
x.dtype
|
|
||||||
)
|
|
||||||
x = self.pre(x) * x_mask
|
x = self.pre(x) * x_mask
|
||||||
x = self.enc(x, x_mask, g=g)
|
x = self.enc(x, x_mask, g=g)
|
||||||
out = self.proj(x) * x_mask
|
out = self.proj(x) * x_mask
|
||||||
@ -459,9 +419,7 @@ class Generator(torch.nn.Module):
|
|||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
self.conv_pre = Conv1d(
|
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
|
||||||
)
|
|
||||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
@ -481,9 +439,7 @@ class Generator(torch.nn.Module):
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
for j, (k, d) in enumerate(
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
|
||||||
):
|
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
|
|
||||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||||
@ -636,9 +592,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|||||||
periods = [2, 3, 5, 7, 11]
|
periods = [2, 3, 5, 7, 11]
|
||||||
|
|
||||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||||
discs = discs + [
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
|
||||||
]
|
|
||||||
self.discriminators = nn.ModuleList(discs)
|
self.discriminators = nn.ModuleList(discs)
|
||||||
|
|
||||||
def forward(self, y, y_hat):
|
def forward(self, y, y_hat):
|
||||||
@ -738,10 +692,7 @@ class Quantizer(torch.nn.Module):
|
|||||||
super(Quantizer, self).__init__()
|
super(Quantizer, self).__init__()
|
||||||
assert embed_dim % n_code_groups == 0
|
assert embed_dim % n_code_groups == 0
|
||||||
self.quantizer_modules = nn.ModuleList(
|
self.quantizer_modules = nn.ModuleList(
|
||||||
[
|
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
||||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
|
||||||
for _ in range(n_code_groups)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
self.n_code_groups = n_code_groups
|
self.n_code_groups = n_code_groups
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -759,9 +710,7 @@ class Quantizer(torch.nn.Module):
|
|||||||
z_q.append(_z_q)
|
z_q.append(_z_q)
|
||||||
min_indicies.append(_min_indicies) # B * T,
|
min_indicies.append(_min_indicies) # B * T,
|
||||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||||
(z_q - xin.detach()) ** 2
|
|
||||||
)
|
|
||||||
z_q = xin + (z_q - xin).detach()
|
z_q = xin + (z_q - xin).detach()
|
||||||
z_q = z_q.transpose(1, 2)
|
z_q = z_q.transpose(1, 2)
|
||||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||||
@ -801,13 +750,9 @@ class CodePredictor(nn.Module):
|
|||||||
self.p_dropout = p_dropout
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||||
self.ref_enc = modules.MelStyleEncoder(
|
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||||
ssl_dim, style_vector_dim=hidden_channels
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder = attentions.Encoder(
|
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||||
self.n_q = n_q
|
self.n_q = n_q
|
||||||
@ -820,9 +765,7 @@ class CodePredictor(nn.Module):
|
|||||||
x = x + g
|
x = x + g
|
||||||
x = self.encoder(x * x_mask, x_mask)
|
x = self.encoder(x * x_mask, x_mask)
|
||||||
x = self.out_proj(x * x_mask) * x_mask
|
x = self.out_proj(x * x_mask) * x_mask
|
||||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||||
2, 3
|
|
||||||
)
|
|
||||||
target = codes[1:].transpose(0, 1)
|
target = codes[1:].transpose(0, 1)
|
||||||
if not infer:
|
if not infer:
|
||||||
logits = logits.reshape(-1, self.dims)
|
logits = logits.reshape(-1, self.dims)
|
||||||
@ -871,7 +814,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
semantic_frame_rate=None,
|
semantic_frame_rate=None,
|
||||||
freeze_quantizer=None,
|
freeze_quantizer=None,
|
||||||
version="v2",
|
version="v2",
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.spec_channels = spec_channels
|
self.spec_channels = spec_channels
|
||||||
@ -923,12 +866,10 @@ class SynthesizerTrn(nn.Module):
|
|||||||
16,
|
16,
|
||||||
gin_channels=gin_channels,
|
gin_channels=gin_channels,
|
||||||
)
|
)
|
||||||
self.flow = ResidualCouplingBlock(
|
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.version=os.environ.get("version","v1")
|
# self.version=os.environ.get("version","v1")
|
||||||
if(self.version=="v1"):
|
if self.version == "v1":
|
||||||
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
||||||
else:
|
else:
|
||||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
||||||
@ -945,10 +886,8 @@ class SynthesizerTrn(nn.Module):
|
|||||||
self.freeze_quantizer = freeze_quantizer
|
self.freeze_quantizer = freeze_quantizer
|
||||||
|
|
||||||
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
y.dtype
|
if self.version == "v1":
|
||||||
)
|
|
||||||
if(self.version=="v1"):
|
|
||||||
ge = self.ref_enc(y * y_mask, y_mask)
|
ge = self.ref_enc(y * y_mask, y_mask)
|
||||||
else:
|
else:
|
||||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||||
@ -959,24 +898,16 @@ class SynthesizerTrn(nn.Module):
|
|||||||
self.ssl_proj.eval()
|
self.ssl_proj.eval()
|
||||||
self.quantizer.eval()
|
self.quantizer.eval()
|
||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||||
ssl, layers=[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
|
||||||
)
|
|
||||||
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||||
quantized, y_lengths, text, text_lengths, ge
|
|
||||||
)
|
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||||
z_p = self.flow(z, y_mask, g=ge)
|
z_p = self.flow(z, y_mask, g=ge)
|
||||||
|
|
||||||
z_slice, ids_slice = commons.rand_slice_segments(
|
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||||
z, y_lengths, self.segment_size
|
|
||||||
)
|
|
||||||
o = self.dec(z_slice, g=ge)
|
o = self.dec(z_slice, g=ge)
|
||||||
return (
|
return (
|
||||||
o,
|
o,
|
||||||
@ -989,10 +920,8 @@ class SynthesizerTrn(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
|
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
y.dtype
|
if self.version == "v1":
|
||||||
)
|
|
||||||
if(self.version=="v1"):
|
|
||||||
ge = self.ref_enc(y * y_mask, y_mask)
|
ge = self.ref_enc(y * y_mask, y_mask)
|
||||||
else:
|
else:
|
||||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||||
@ -1000,13 +929,9 @@ class SynthesizerTrn(nn.Module):
|
|||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
|
||||||
)
|
|
||||||
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
||||||
quantized, y_lengths, text, text_lengths, ge, test=test
|
|
||||||
)
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
@ -1020,15 +945,14 @@ class SynthesizerTrn(nn.Module):
|
|||||||
ge = None
|
ge = None
|
||||||
if refer is not None:
|
if refer is not None:
|
||||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
refer_mask = torch.unsqueeze(
|
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||||
commons.sequence_mask(refer_lengths, refer.size(2)), 1
|
if self.version == "v1":
|
||||||
).to(refer.dtype)
|
|
||||||
if (self.version == "v1"):
|
|
||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
else:
|
else:
|
||||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
return ge
|
return ge
|
||||||
if(type(refer)==list):
|
|
||||||
|
if type(refer) == list:
|
||||||
ges = []
|
ges = []
|
||||||
for _refer in refer:
|
for _refer in refer:
|
||||||
ge = get_ge(_refer)
|
ge = get_ge(_refer)
|
||||||
@ -1042,12 +966,8 @@ class SynthesizerTrn(nn.Module):
|
|||||||
|
|
||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||||
)
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(
|
|
||||||
quantized, y_lengths, text, text_lengths, ge,speed
|
|
||||||
)
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
@ -1059,11 +979,10 @@ class SynthesizerTrn(nn.Module):
|
|||||||
ssl = self.ssl_proj(x)
|
ssl = self.ssl_proj(x)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
return codes.transpose(0, 1)
|
return codes.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
class CFM(torch.nn.Module):
|
class CFM(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(self, in_channels, dit):
|
||||||
self,
|
|
||||||
in_channels,dit
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sigma_min = 1e-6
|
self.sigma_min = 1e-6
|
||||||
|
|
||||||
@ -1089,14 +1008,27 @@ class CFM(torch.nn.Module):
|
|||||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||||
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
||||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
v_pred = self.estimator(
|
||||||
|
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
|
||||||
|
).transpose(2, 1)
|
||||||
if inference_cfg_rate > 1e-5:
|
if inference_cfg_rate > 1e-5:
|
||||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
neg = self.estimator(
|
||||||
|
x,
|
||||||
|
prompt_x,
|
||||||
|
x_lens,
|
||||||
|
t_tensor,
|
||||||
|
d_tensor,
|
||||||
|
mu,
|
||||||
|
use_grad_ckpt=False,
|
||||||
|
drop_audio_cond=True,
|
||||||
|
drop_text=True,
|
||||||
|
).transpose(2, 1)
|
||||||
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
|
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
|
||||||
x = x + d * v_pred
|
x = x + d * v_pred
|
||||||
t = t + d
|
t = t + d
|
||||||
x[:, :, :prompt_len] = 0
|
x[:, :, :prompt_len] = 0
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
||||||
b, _, t = x1.shape
|
b, _, t = x1.shape
|
||||||
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
|
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
|
||||||
@ -1132,16 +1064,19 @@ class CFM(torch.nn.Module):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def set_no_grad(net_g):
|
def set_no_grad(net_g):
|
||||||
for name, param in net_g.named_parameters():
|
for name, param in net_g.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrnV3(nn.Module):
|
class SynthesizerTrnV3(nn.Module):
|
||||||
"""
|
"""
|
||||||
Synthesizer for Training
|
Synthesizer for Training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
spec_channels,
|
spec_channels,
|
||||||
segment_size,
|
segment_size,
|
||||||
inter_channels,
|
inter_channels,
|
||||||
@ -1163,8 +1098,8 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
semantic_frame_rate=None,
|
semantic_frame_rate=None,
|
||||||
freeze_quantizer=None,
|
freeze_quantizer=None,
|
||||||
version="v3",
|
version="v3",
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.spec_channels = spec_channels
|
self.spec_channels = spec_channels
|
||||||
self.inter_channels = inter_channels
|
self.inter_channels = inter_channels
|
||||||
@ -1187,7 +1122,9 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
|
|
||||||
self.model_dim = 512
|
self.model_dim = 512
|
||||||
self.use_sdp = use_sdp
|
self.use_sdp = use_sdp
|
||||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
self.enc_p = TextEncoder(
|
||||||
|
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||||
|
)
|
||||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||||
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||||
@ -1196,35 +1133,32 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
# gin_channels=gin_channels)
|
# gin_channels=gin_channels)
|
||||||
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||||
|
|
||||||
|
|
||||||
ssl_dim = 768
|
ssl_dim = 768
|
||||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||||
self.semantic_frame_rate = semantic_frame_rate
|
self.semantic_frame_rate = semantic_frame_rate
|
||||||
if semantic_frame_rate == '25hz':
|
if semantic_frame_rate == "25hz":
|
||||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||||
else:
|
else:
|
||||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||||
|
|
||||||
self.quantizer = ResidualVectorQuantizer(
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||||
dimension=ssl_dim,
|
|
||||||
n_q=1,
|
|
||||||
bins=1024
|
|
||||||
)
|
|
||||||
self.freeze_quantizer = freeze_quantizer
|
self.freeze_quantizer = freeze_quantizer
|
||||||
inter_channels2 = 512
|
inter_channels2 = 512
|
||||||
self.bridge=nn.Sequential(
|
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
|
||||||
nn.LeakyReLU()
|
|
||||||
)
|
|
||||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
self.cfm = CFM(
|
||||||
|
100,
|
||||||
|
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||||
|
) # text_dim is condition feature dim
|
||||||
if self.freeze_quantizer == True:
|
if self.freeze_quantizer == True:
|
||||||
set_no_grad(self.ssl_proj)
|
set_no_grad(self.ssl_proj)
|
||||||
set_no_grad(self.quantizer)
|
set_no_grad(self.quantizer)
|
||||||
set_no_grad(self.enc_p)
|
set_no_grad(self.enc_p)
|
||||||
|
|
||||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
|
def forward(
|
||||||
|
self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
|
||||||
|
): # ssl_lengths no need now
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
||||||
@ -1235,14 +1169,14 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
self.quantizer.eval()
|
self.quantizer.eval()
|
||||||
self.enc_p.eval()
|
self.enc_p.eval()
|
||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||||
ssl, layers=[0]
|
|
||||||
)
|
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||||
fea = self.bridge(x)
|
fea = self.bridge(x)
|
||||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||||
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
fea, y_mask_ = self.wns1(
|
||||||
|
fea, mel_lengths, ge
|
||||||
|
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
||||||
B = ssl.shape[0]
|
B = ssl.shape[0]
|
||||||
prompt_len_max = mel_lengths * 2 / 3
|
prompt_len_max = mel_lengths * 2 / 3
|
||||||
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
|
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
|
||||||
@ -1256,7 +1190,7 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
def decode_encp(self, codes, text, refer, ge=None, speed=1):
|
def decode_encp(self, codes, text, refer, ge=None, speed=1):
|
||||||
# print(2333333,refer.shape)
|
# print(2333333,refer.shape)
|
||||||
# ge=None
|
# ge=None
|
||||||
if(ge==None):
|
if ge == None:
|
||||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
@ -1269,7 +1203,7 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
|
|
||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == '25hz':
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
||||||
fea = self.bridge(x)
|
fea = self.bridge(x)
|
||||||
@ -1283,12 +1217,14 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
return codes.transpose(0, 1)
|
return codes.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrnV3b(nn.Module):
|
class SynthesizerTrnV3b(nn.Module):
|
||||||
"""
|
"""
|
||||||
Synthesizer for Training
|
Synthesizer for Training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
spec_channels,
|
spec_channels,
|
||||||
segment_size,
|
segment_size,
|
||||||
inter_channels,
|
inter_channels,
|
||||||
@ -1309,8 +1245,8 @@ class SynthesizerTrnV3b(nn.Module):
|
|||||||
use_sdp=True,
|
use_sdp=True,
|
||||||
semantic_frame_rate=None,
|
semantic_frame_rate=None,
|
||||||
freeze_quantizer=None,
|
freeze_quantizer=None,
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.spec_channels = spec_channels
|
self.spec_channels = spec_channels
|
||||||
self.inter_channels = inter_channels
|
self.inter_channels = inter_channels
|
||||||
@ -1332,40 +1268,45 @@ class SynthesizerTrnV3b(nn.Module):
|
|||||||
|
|
||||||
self.model_dim = 512
|
self.model_dim = 512
|
||||||
self.use_sdp = use_sdp
|
self.use_sdp = use_sdp
|
||||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
self.enc_p = TextEncoder(
|
||||||
|
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||||
|
)
|
||||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
self.dec = Generator(
|
||||||
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
inter_channels,
|
||||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
resblock,
|
||||||
gin_channels=gin_channels)
|
resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes,
|
||||||
|
upsample_rates,
|
||||||
|
upsample_initial_channel,
|
||||||
|
upsample_kernel_sizes,
|
||||||
|
gin_channels=gin_channels,
|
||||||
|
)
|
||||||
|
self.enc_q = PosteriorEncoder(
|
||||||
|
spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
|
||||||
|
)
|
||||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||||
|
|
||||||
|
|
||||||
ssl_dim = 768
|
ssl_dim = 768
|
||||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||||
self.semantic_frame_rate = semantic_frame_rate
|
self.semantic_frame_rate = semantic_frame_rate
|
||||||
if semantic_frame_rate == '25hz':
|
if semantic_frame_rate == "25hz":
|
||||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||||
else:
|
else:
|
||||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||||
|
|
||||||
self.quantizer = ResidualVectorQuantizer(
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||||
dimension=ssl_dim,
|
|
||||||
n_q=1,
|
|
||||||
bins=1024
|
|
||||||
)
|
|
||||||
self.freeze_quantizer = freeze_quantizer
|
self.freeze_quantizer = freeze_quantizer
|
||||||
|
|
||||||
inter_channels2 = 512
|
inter_channels2 = 512
|
||||||
self.bridge=nn.Sequential(
|
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
|
||||||
nn.LeakyReLU()
|
|
||||||
)
|
|
||||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
self.cfm = CFM(
|
||||||
|
100,
|
||||||
|
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||||
|
) # text_dim is condition feature dim
|
||||||
|
|
||||||
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
|
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
@ -1379,9 +1320,7 @@ class SynthesizerTrnV3b(nn.Module):
|
|||||||
self.ssl_proj.eval()
|
self.ssl_proj.eval()
|
||||||
self.quantizer.eval()
|
self.quantizer.eval()
|
||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
||||||
ssl, layers=[0]
|
|
||||||
)
|
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
||||||
@ -1399,13 +1338,23 @@ class SynthesizerTrnV3b(nn.Module):
|
|||||||
mel = mel[:, :, :minn]
|
mel = mel[:, :, :minn]
|
||||||
fea = fea[:, :, :minn]
|
fea = fea[:, :, :minn]
|
||||||
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
|
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
|
||||||
return commit_loss,cfm_loss,F.mse_loss(learned_mel, mel),o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized
|
return (
|
||||||
|
commit_loss,
|
||||||
|
cfm_loss,
|
||||||
|
F.mse_loss(learned_mel, mel),
|
||||||
|
o,
|
||||||
|
ids_slice,
|
||||||
|
y_mask,
|
||||||
|
y_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
quantized,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode_encp(self, codes, text, refer, ge=None):
|
def decode_encp(self, codes, text, refer, ge=None):
|
||||||
# print(2333333,refer.shape)
|
# print(2333333,refer.shape)
|
||||||
# ge=None
|
# ge=None
|
||||||
if(ge==None):
|
if ge == None:
|
||||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
@ -1414,7 +1363,7 @@ class SynthesizerTrnV3b(nn.Module):
|
|||||||
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
||||||
|
|
||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == '25hz':
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
||||||
fea = self.bridge(x)
|
fea = self.bridge(x)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
@ -11,14 +10,14 @@ from module import attentions_onnx as attentions
|
|||||||
|
|
||||||
from f5_tts.model import DiT
|
from f5_tts.model import DiT
|
||||||
|
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
from module.commons import init_weights, get_padding
|
from module.commons import init_weights, get_padding
|
||||||
from module.quantize import ResidualVectorQuantizer
|
from module.quantize import ResidualVectorQuantizer
|
||||||
|
|
||||||
# from text import symbols
|
# from text import symbols
|
||||||
from text import symbols as symbols_v1
|
from text import symbols as symbols_v1
|
||||||
from text import symbols2 as symbols_v2
|
from text import symbols2 as symbols_v2
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
|
|
||||||
|
|
||||||
class StochasticDurationPredictor(nn.Module):
|
class StochasticDurationPredictor(nn.Module):
|
||||||
@ -44,29 +43,21 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
self.flows = nn.ModuleList()
|
self.flows = nn.ModuleList()
|
||||||
self.flows.append(modules.ElementwiseAffine(2))
|
self.flows.append(modules.ElementwiseAffine(2))
|
||||||
for i in range(n_flows):
|
for i in range(n_flows):
|
||||||
self.flows.append(
|
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
|
||||||
)
|
|
||||||
self.flows.append(modules.Flip())
|
self.flows.append(modules.Flip())
|
||||||
|
|
||||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||||
self.post_convs = modules.DDSConv(
|
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
|
||||||
)
|
|
||||||
self.post_flows = nn.ModuleList()
|
self.post_flows = nn.ModuleList()
|
||||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
self.post_flows.append(
|
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
|
||||||
)
|
|
||||||
self.post_flows.append(modules.Flip())
|
self.post_flows.append(modules.Flip())
|
||||||
|
|
||||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||||
self.convs = modules.DDSConv(
|
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
|
||||||
)
|
|
||||||
if gin_channels != 0:
|
if gin_channels != 0:
|
||||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||||
|
|
||||||
@ -87,10 +78,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
h_w = self.post_pre(w)
|
h_w = self.post_pre(w)
|
||||||
h_w = self.post_convs(h_w, x_mask)
|
h_w = self.post_convs(h_w, x_mask)
|
||||||
h_w = self.post_proj(h_w) * x_mask
|
h_w = self.post_proj(h_w) * x_mask
|
||||||
e_q = (
|
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
|
||||||
* x_mask
|
|
||||||
)
|
|
||||||
z_q = e_q
|
z_q = e_q
|
||||||
for flow in self.post_flows:
|
for flow in self.post_flows:
|
||||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||||
@ -98,13 +86,8 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||||
u = torch.sigmoid(z_u) * x_mask
|
u = torch.sigmoid(z_u) * x_mask
|
||||||
z0 = (w - u) * x_mask
|
z0 = (w - u) * x_mask
|
||||||
logdet_tot_q += torch.sum(
|
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||||
)
|
|
||||||
logq = (
|
|
||||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
|
||||||
- logdet_tot_q
|
|
||||||
)
|
|
||||||
|
|
||||||
logdet_tot = 0
|
logdet_tot = 0
|
||||||
z0, logdet = self.log_flow(z0, x_mask)
|
z0, logdet = self.log_flow(z0, x_mask)
|
||||||
@ -113,18 +96,12 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
for flow in flows:
|
for flow in flows:
|
||||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
logdet_tot = logdet_tot + logdet
|
logdet_tot = logdet_tot + logdet
|
||||||
nll = (
|
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
|
||||||
- logdet_tot
|
|
||||||
)
|
|
||||||
return nll + logq # [b]
|
return nll + logq # [b]
|
||||||
else:
|
else:
|
||||||
flows = list(reversed(self.flows))
|
flows = list(reversed(self.flows))
|
||||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||||
z = (
|
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
|
||||||
* noise_scale
|
|
||||||
)
|
|
||||||
for flow in flows:
|
for flow in flows:
|
||||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||||
z0, z1 = torch.split(z, [1, 1], 1)
|
z0, z1 = torch.split(z, [1, 1], 1)
|
||||||
@ -133,9 +110,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DurationPredictor(nn.Module):
|
class DurationPredictor(nn.Module):
|
||||||
def __init__(
|
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -145,13 +120,9 @@ class DurationPredictor(nn.Module):
|
|||||||
self.gin_channels = gin_channels
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
self.drop = nn.Dropout(p_dropout)
|
self.drop = nn.Dropout(p_dropout)
|
||||||
self.conv_1 = nn.Conv1d(
|
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
|
||||||
)
|
|
||||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||||
self.conv_2 = nn.Conv1d(
|
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
|
||||||
)
|
|
||||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||||
|
|
||||||
@ -246,7 +217,7 @@ class TextEncoder(nn.Module):
|
|||||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||||
|
|
||||||
y = self.encoder2(y * y_mask, y_mask)
|
y = self.encoder2(y * y_mask, y_mask)
|
||||||
if(speed!=1):
|
if speed != 1:
|
||||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||||
|
|
||||||
@ -333,9 +304,7 @@ class PosteriorEncoder(nn.Module):
|
|||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
if g != None:
|
if g != None:
|
||||||
g = g.detach()
|
g = g.detach()
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
x.dtype
|
|
||||||
)
|
|
||||||
x = self.pre(x) * x_mask
|
x = self.pre(x) * x_mask
|
||||||
x = self.enc(x, x_mask, g=g)
|
x = self.enc(x, x_mask, g=g)
|
||||||
stats = self.proj(x) * x_mask
|
stats = self.proj(x) * x_mask
|
||||||
@ -345,14 +314,9 @@ class PosteriorEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
in_channels,
|
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||||
out_channels,
|
):
|
||||||
hidden_channels,
|
|
||||||
kernel_size,
|
|
||||||
dilation_rate,
|
|
||||||
n_layers,
|
|
||||||
gin_channels=0):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@ -367,7 +331,7 @@ class Encoder(nn.Module):
|
|||||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
if(g!=None):
|
if g != None:
|
||||||
g = g.detach()
|
g = g.detach()
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
x = self.pre(x) * x_mask
|
x = self.pre(x) * x_mask
|
||||||
@ -375,6 +339,7 @@ class Encoder(nn.Module):
|
|||||||
stats = self.proj(x) * x_mask
|
stats = self.proj(x) * x_mask
|
||||||
return stats, x_mask
|
return stats, x_mask
|
||||||
|
|
||||||
|
|
||||||
class WNEncoder(nn.Module):
|
class WNEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -407,9 +372,7 @@ class WNEncoder(nn.Module):
|
|||||||
self.norm = modules.LayerNorm(out_channels)
|
self.norm = modules.LayerNorm(out_channels)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
x.dtype
|
|
||||||
)
|
|
||||||
x = self.pre(x) * x_mask
|
x = self.pre(x) * x_mask
|
||||||
x = self.enc(x, x_mask, g=g)
|
x = self.enc(x, x_mask, g=g)
|
||||||
out = self.proj(x) * x_mask
|
out = self.proj(x) * x_mask
|
||||||
@ -432,9 +395,7 @@ class Generator(torch.nn.Module):
|
|||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
self.conv_pre = Conv1d(
|
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
|
||||||
)
|
|
||||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
@ -454,9 +415,7 @@ class Generator(torch.nn.Module):
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
for j, (k, d) in enumerate(
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
|
||||||
):
|
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
|
|
||||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||||
@ -609,9 +568,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
|||||||
periods = [2, 3, 5, 7, 11]
|
periods = [2, 3, 5, 7, 11]
|
||||||
|
|
||||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||||
discs = discs + [
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
|
||||||
]
|
|
||||||
self.discriminators = nn.ModuleList(discs)
|
self.discriminators = nn.ModuleList(discs)
|
||||||
|
|
||||||
def forward(self, y, y_hat):
|
def forward(self, y, y_hat):
|
||||||
@ -711,10 +668,7 @@ class Quantizer(torch.nn.Module):
|
|||||||
super(Quantizer, self).__init__()
|
super(Quantizer, self).__init__()
|
||||||
assert embed_dim % n_code_groups == 0
|
assert embed_dim % n_code_groups == 0
|
||||||
self.quantizer_modules = nn.ModuleList(
|
self.quantizer_modules = nn.ModuleList(
|
||||||
[
|
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
||||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
|
||||||
for _ in range(n_code_groups)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
self.n_code_groups = n_code_groups
|
self.n_code_groups = n_code_groups
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -732,9 +686,7 @@ class Quantizer(torch.nn.Module):
|
|||||||
z_q.append(_z_q)
|
z_q.append(_z_q)
|
||||||
min_indicies.append(_min_indicies) # B * T,
|
min_indicies.append(_min_indicies) # B * T,
|
||||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||||
(z_q - xin.detach()) ** 2
|
|
||||||
)
|
|
||||||
z_q = xin + (z_q - xin).detach()
|
z_q = xin + (z_q - xin).detach()
|
||||||
z_q = z_q.transpose(1, 2)
|
z_q = z_q.transpose(1, 2)
|
||||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||||
@ -774,13 +726,9 @@ class CodePredictor(nn.Module):
|
|||||||
self.p_dropout = p_dropout
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||||
self.ref_enc = modules.MelStyleEncoder(
|
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||||
ssl_dim, style_vector_dim=hidden_channels
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder = attentions.Encoder(
|
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||||
self.n_q = n_q
|
self.n_q = n_q
|
||||||
@ -793,9 +741,7 @@ class CodePredictor(nn.Module):
|
|||||||
x = x + g
|
x = x + g
|
||||||
x = self.encoder(x * x_mask, x_mask)
|
x = self.encoder(x * x_mask, x_mask)
|
||||||
x = self.out_proj(x * x_mask) * x_mask
|
x = self.out_proj(x * x_mask) * x_mask
|
||||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||||
2, 3
|
|
||||||
)
|
|
||||||
target = codes[1:].transpose(0, 1)
|
target = codes[1:].transpose(0, 1)
|
||||||
if not infer:
|
if not infer:
|
||||||
logits = logits.reshape(-1, self.dims)
|
logits = logits.reshape(-1, self.dims)
|
||||||
@ -844,7 +790,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
semantic_frame_rate=None,
|
semantic_frame_rate=None,
|
||||||
freeze_quantizer=None,
|
freeze_quantizer=None,
|
||||||
version="v2",
|
version="v2",
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.spec_channels = spec_channels
|
self.spec_channels = spec_channels
|
||||||
@ -896,9 +842,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
# 16,
|
# 16,
|
||||||
# gin_channels=gin_channels,
|
# gin_channels=gin_channels,
|
||||||
# )
|
# )
|
||||||
self.flow = ResidualCouplingBlock(
|
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
|
||||||
)
|
|
||||||
|
|
||||||
# self.version=os.environ.get("version","v1")
|
# self.version=os.environ.get("version","v1")
|
||||||
if self.version == "v1":
|
if self.version == "v1":
|
||||||
@ -925,7 +869,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
|
|
||||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||||
if (self.version == "v1"):
|
if self.version == "v1":
|
||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
else:
|
else:
|
||||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||||
@ -935,9 +879,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||||
|
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||||
quantized, text, ge, speed
|
|
||||||
)
|
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
@ -951,11 +893,9 @@ class SynthesizerTrn(nn.Module):
|
|||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
return codes.transpose(0, 1)
|
return codes.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
class CFM(torch.nn.Module):
|
class CFM(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(self, in_channels, dit):
|
||||||
self,
|
|
||||||
in_channels,dit
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# self.sigma_min = 1e-6
|
# self.sigma_min = 1e-6
|
||||||
|
|
||||||
@ -965,7 +905,14 @@ class CFM(torch.nn.Module):
|
|||||||
|
|
||||||
# self.criterion = torch.nn.MSELoss()
|
# self.criterion = torch.nn.MSELoss()
|
||||||
|
|
||||||
def forward(self, mu:torch.Tensor, x_lens:torch.LongTensor, prompt:torch.Tensor, n_timesteps:torch.LongTensor, temperature:float=1.0):
|
def forward(
|
||||||
|
self,
|
||||||
|
mu: torch.Tensor,
|
||||||
|
x_lens: torch.LongTensor,
|
||||||
|
prompt: torch.Tensor,
|
||||||
|
n_timesteps: torch.LongTensor,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
):
|
||||||
"""Forward diffusion"""
|
"""Forward diffusion"""
|
||||||
B, T = mu.size(0), mu.size(1)
|
B, T = mu.size(0), mu.size(1)
|
||||||
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
|
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
|
||||||
@ -999,22 +946,26 @@ def set_no_grad(net_g):
|
|||||||
for name, param in net_g.named_parameters():
|
for name, param in net_g.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script_if_tracing
|
@torch.jit.script_if_tracing
|
||||||
def compile_codes_length(codes):
|
def compile_codes_length(codes):
|
||||||
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
|
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
|
||||||
return y_lengths1 * 2.5 * 1.5
|
return y_lengths1 * 2.5 * 1.5
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script_if_tracing
|
@torch.jit.script_if_tracing
|
||||||
def compile_ref_length(refer):
|
def compile_ref_length(refer):
|
||||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||||
return refer_lengths
|
return refer_lengths
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrnV3(nn.Module):
|
class SynthesizerTrnV3(nn.Module):
|
||||||
"""
|
"""
|
||||||
Synthesizer for Training
|
Synthesizer for Training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
spec_channels,
|
spec_channels,
|
||||||
segment_size,
|
segment_size,
|
||||||
inter_channels,
|
inter_channels,
|
||||||
@ -1036,8 +987,8 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
semantic_frame_rate=None,
|
semantic_frame_rate=None,
|
||||||
freeze_quantizer=None,
|
freeze_quantizer=None,
|
||||||
version="v3",
|
version="v3",
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.spec_channels = spec_channels
|
self.spec_channels = spec_channels
|
||||||
self.inter_channels = inter_channels
|
self.inter_channels = inter_channels
|
||||||
@ -1060,7 +1011,9 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
|
|
||||||
self.model_dim = 512
|
self.model_dim = 512
|
||||||
self.use_sdp = use_sdp
|
self.use_sdp = use_sdp
|
||||||
self.enc_p = TextEncoder(inter_channels,hidden_channels,filter_channels,n_heads,n_layers,kernel_size,p_dropout)
|
self.enc_p = TextEncoder(
|
||||||
|
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||||
|
)
|
||||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||||
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||||
@ -1069,29 +1022,24 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
# gin_channels=gin_channels)
|
# gin_channels=gin_channels)
|
||||||
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||||
|
|
||||||
|
|
||||||
ssl_dim = 768
|
ssl_dim = 768
|
||||||
assert semantic_frame_rate in ['25hz', "50hz"]
|
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||||
self.semantic_frame_rate = semantic_frame_rate
|
self.semantic_frame_rate = semantic_frame_rate
|
||||||
if semantic_frame_rate == '25hz':
|
if semantic_frame_rate == "25hz":
|
||||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||||
else:
|
else:
|
||||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||||
|
|
||||||
self.quantizer = ResidualVectorQuantizer(
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||||
dimension=ssl_dim,
|
|
||||||
n_q=1,
|
|
||||||
bins=1024
|
|
||||||
)
|
|
||||||
freeze_quantizer
|
freeze_quantizer
|
||||||
inter_channels2 = 512
|
inter_channels2 = 512
|
||||||
self.bridge=nn.Sequential(
|
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||||
nn.Conv1d(inter_channels, inter_channels2, 1, stride=1),
|
|
||||||
nn.LeakyReLU()
|
|
||||||
)
|
|
||||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
self.cfm = CFM(
|
||||||
|
100,
|
||||||
|
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||||
|
) # text_dim is condition feature dim
|
||||||
if freeze_quantizer == True:
|
if freeze_quantizer == True:
|
||||||
set_no_grad(self.ssl_proj)
|
set_no_grad(self.ssl_proj)
|
||||||
set_no_grad(self.quantizer)
|
set_no_grad(self.quantizer)
|
||||||
@ -1104,11 +1052,10 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
return ge
|
return ge
|
||||||
|
|
||||||
def forward(self, codes, text, ge, speed=1):
|
def forward(self, codes, text, ge, speed=1):
|
||||||
|
|
||||||
y_lengths1 = compile_codes_length(codes)
|
y_lengths1 = compile_codes_length(codes)
|
||||||
|
|
||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == '25hz':
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||||
fea = self.bridge(x)
|
fea = self.bridge(x)
|
||||||
|
@ -52,11 +52,7 @@ class ConvReluNorm(nn.Module):
|
|||||||
|
|
||||||
self.conv_layers = nn.ModuleList()
|
self.conv_layers = nn.ModuleList()
|
||||||
self.norm_layers = nn.ModuleList()
|
self.norm_layers = nn.ModuleList()
|
||||||
self.conv_layers.append(
|
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||||
nn.Conv1d(
|
|
||||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||||
for _ in range(n_layers - 1):
|
for _ in range(n_layers - 1):
|
||||||
@ -156,9 +152,7 @@ class WN(torch.nn.Module):
|
|||||||
self.drop = nn.Dropout(p_dropout)
|
self.drop = nn.Dropout(p_dropout)
|
||||||
|
|
||||||
if gin_channels != 0:
|
if gin_channels != 0:
|
||||||
cond_layer = torch.nn.Conv1d(
|
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||||
gin_channels, 2 * hidden_channels * n_layers, 1
|
|
||||||
)
|
|
||||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||||
|
|
||||||
for i in range(n_layers):
|
for i in range(n_layers):
|
||||||
@ -479,9 +473,7 @@ class ConvFlow(nn.Module):
|
|||||||
|
|
||||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
||||||
self.proj = nn.Conv1d(
|
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||||
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
|
||||||
)
|
|
||||||
self.proj.weight.data.zero_()
|
self.proj.weight.data.zero_()
|
||||||
self.proj.bias.data.zero_()
|
self.proj.bias.data.zero_()
|
||||||
|
|
||||||
@ -495,9 +487,7 @@ class ConvFlow(nn.Module):
|
|||||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||||
|
|
||||||
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
||||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
|
||||||
self.filter_channels
|
|
||||||
)
|
|
||||||
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||||
|
|
||||||
x1, logabsdet = piecewise_rational_quadratic_transform(
|
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||||
@ -616,9 +606,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||||
|
|
||||||
self.attention = ScaledDotProductAttention(
|
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
|
||||||
temperature=np.power(d_model, 0.5), dropout=dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
@ -649,9 +637,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
output, attn = self.attention(q, k, v, mask=slf_mask)
|
output, attn = self.attention(q, k, v, mask=slf_mask)
|
||||||
|
|
||||||
output = output.view(n_head, sz_b, len_x, d_v)
|
output = output.view(n_head, sz_b, len_x, d_v)
|
||||||
output = (
|
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
|
||||||
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
|
|
||||||
) # b x lq x (n*dv)
|
|
||||||
|
|
||||||
output = self.fc(output)
|
output = self.fc(output)
|
||||||
|
|
||||||
@ -741,9 +727,7 @@ class MelStyleEncoder(nn.Module):
|
|||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = (mask.int() == 0).squeeze(1)
|
mask = (mask.int() == 0).squeeze(1)
|
||||||
max_len = x.shape[1]
|
max_len = x.shape[1]
|
||||||
slf_attn_mask = (
|
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
||||||
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# spectral
|
# spectral
|
||||||
x = self.spectral(x)
|
x = self.spectral(x)
|
||||||
@ -785,9 +769,7 @@ class MelStyleEncoderVAE(nn.Module):
|
|||||||
mu = self.fc1(enc_out)
|
mu = self.fc1(enc_out)
|
||||||
logvar = self.fc2(enc_out)
|
logvar = self.fc2(enc_out)
|
||||||
posterior = D.Normal(mu, torch.exp(logvar))
|
posterior = D.Normal(mu, torch.exp(logvar))
|
||||||
kl_divergence = D.kl_divergence(
|
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
|
||||||
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
|
|
||||||
)
|
|
||||||
loss_kl = kl_divergence.mean()
|
loss_kl = kl_divergence.mean()
|
||||||
|
|
||||||
z = posterior.rsample()
|
z = posterior.rsample()
|
||||||
@ -825,9 +807,7 @@ class ActNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
|
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
|
||||||
if x_mask is None:
|
if x_mask is None:
|
||||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
|
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||||
device=x.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
x_len = torch.sum(x_mask, [1, 2])
|
x_len = torch.sum(x_mask, [1, 2])
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
self.initialize(x, x_mask)
|
self.initialize(x, x_mask)
|
||||||
@ -856,9 +836,7 @@ class ActNorm(nn.Module):
|
|||||||
v = m_sq - (m**2)
|
v = m_sq - (m**2)
|
||||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||||
|
|
||||||
bias_init = (
|
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||||
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
|
||||||
)
|
|
||||||
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
||||||
|
|
||||||
self.bias.data.copy_(bias_init)
|
self.bias.data.copy_(bias_init)
|
||||||
@ -873,9 +851,7 @@ class InvConvNear(nn.Module):
|
|||||||
self.n_split = n_split
|
self.n_split = n_split
|
||||||
self.no_jacobian = no_jacobian
|
self.no_jacobian = no_jacobian
|
||||||
|
|
||||||
w_init = torch.linalg.qr(
|
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
|
||||||
torch.FloatTensor(self.n_split, self.n_split).normal_()
|
|
||||||
)[0]
|
|
||||||
if torch.det(w_init) < 0:
|
if torch.det(w_init) < 0:
|
||||||
w_init[:, 0] = -1 * w_init[:, 0]
|
w_init[:, 0] = -1 * w_init[:, 0]
|
||||||
self.weight = nn.Parameter(w_init)
|
self.weight = nn.Parameter(w_init)
|
||||||
@ -890,11 +866,7 @@ class InvConvNear(nn.Module):
|
|||||||
x_len = torch.sum(x_mask, [1, 2])
|
x_len = torch.sum(x_mask, [1, 2])
|
||||||
|
|
||||||
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
|
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
|
||||||
x = (
|
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
|
||||||
x.permute(0, 1, 3, 2, 4)
|
|
||||||
.contiguous()
|
|
||||||
.view(b, self.n_split, c // self.n_split, t)
|
|
||||||
)
|
|
||||||
|
|
||||||
if reverse:
|
if reverse:
|
||||||
if hasattr(self, "weight_inv"):
|
if hasattr(self, "weight_inv"):
|
||||||
|
@ -31,32 +31,15 @@ class MRTE(nn.Module):
|
|||||||
text_enc = self.text_pre(text * text_mask)
|
text_enc = self.text_pre(text * text_mask)
|
||||||
if test != None:
|
if test != None:
|
||||||
if test == 0:
|
if test == 0:
|
||||||
x = (
|
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||||
self.cross_attention(
|
|
||||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
|
||||||
)
|
|
||||||
+ ssl_enc
|
|
||||||
+ ge
|
|
||||||
)
|
|
||||||
elif test == 1:
|
elif test == 1:
|
||||||
x = ssl_enc + ge
|
x = ssl_enc + ge
|
||||||
elif test == 2:
|
elif test == 2:
|
||||||
x = (
|
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
|
||||||
self.cross_attention(
|
|
||||||
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
|
|
||||||
)
|
|
||||||
+ ge
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("test should be 0,1,2")
|
raise ValueError("test should be 0,1,2")
|
||||||
else:
|
else:
|
||||||
x = (
|
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||||
self.cross_attention(
|
|
||||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
|
||||||
)
|
|
||||||
+ ssl_enc
|
|
||||||
+ ge
|
|
||||||
)
|
|
||||||
x = self.c_post(x * ssl_mask)
|
x = self.c_post(x * ssl_mask)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
|
|||||||
model_embedding_size=256,
|
model_embedding_size=256,
|
||||||
):
|
):
|
||||||
super(SpeakerEncoder, self).__init__()
|
super(SpeakerEncoder, self).__init__()
|
||||||
self.lstm = nn.LSTM(
|
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||||
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
|
|
||||||
)
|
|
||||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
"""Residual vector quantizer implementation."""
|
"""Residual vector quantizer implementation."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import math
|
|
||||||
import typing as tp
|
import typing as tp
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -88,14 +87,10 @@ class ResidualVectorQuantizer(nn.Module):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
||||||
)
|
)
|
||||||
quantized, codes, commit_loss, quantized_list = self.vq(
|
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
|
||||||
x, n_q=n_q, layers=layers
|
|
||||||
)
|
|
||||||
return quantized, codes, torch.mean(commit_loss), quantized_list
|
return quantized, codes, torch.mean(commit_loss), quantized_list
|
||||||
|
|
||||||
def encode(
|
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||||
The RVQ encode method sets the appropriate number of quantizer to use
|
The RVQ encode method sets the appropriate number of quantizer to use
|
||||||
and returns indices for each quantizer.
|
and returns indices for each quantizer.
|
||||||
|
@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
|
|||||||
min_bin_width=min_bin_width,
|
min_bin_width=min_bin_width,
|
||||||
min_bin_height=min_bin_height,
|
min_bin_height=min_bin_height,
|
||||||
min_derivative=min_derivative,
|
min_derivative=min_derivative,
|
||||||
**spline_kwargs
|
**spline_kwargs,
|
||||||
)
|
)
|
||||||
return outputs, logabsdet
|
return outputs, logabsdet
|
||||||
|
|
||||||
@ -175,8 +175,7 @@ def rational_quadratic_spline(
|
|||||||
|
|
||||||
theta_one_minus_theta = root * (1 - root)
|
theta_one_minus_theta = root * (1 - root)
|
||||||
denominator = input_delta + (
|
denominator = input_delta + (
|
||||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||||
* theta_one_minus_theta
|
|
||||||
)
|
)
|
||||||
derivative_numerator = input_delta.pow(2) * (
|
derivative_numerator = input_delta.pow(2) * (
|
||||||
input_derivatives_plus_one * root.pow(2)
|
input_derivatives_plus_one * root.pow(2)
|
||||||
@ -190,12 +189,9 @@ def rational_quadratic_spline(
|
|||||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||||
theta_one_minus_theta = theta * (1 - theta)
|
theta_one_minus_theta = theta * (1 - theta)
|
||||||
|
|
||||||
numerator = input_heights * (
|
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
|
||||||
)
|
|
||||||
denominator = input_delta + (
|
denominator = input_delta + (
|
||||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||||
* theta_one_minus_theta
|
|
||||||
)
|
)
|
||||||
outputs = input_cumheights + numerator / denominator
|
outputs = input_cumheights + numerator / denominator
|
||||||
|
|
||||||
|
@ -1,23 +1,22 @@
|
|||||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
|
||||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch import nn
|
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
|
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
ssl_model = cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
from text import cleaned_text_to_sequence
|
|
||||||
import soundfile
|
|
||||||
from tools.my_utils import load_audio
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
from text import cleaned_text_to_sequence
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
hann_window = torch.hann_window(win_size).to(
|
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||||
dtype=y.dtype, device=y.device
|
|
||||||
)
|
|
||||||
y = torch.nn.functional.pad(
|
y = torch.nn.functional.pad(
|
||||||
y.unsqueeze(1),
|
y.unsqueeze(1),
|
||||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||||
@ -135,9 +134,7 @@ class T2SModel(nn.Module):
|
|||||||
if dynamo:
|
if dynamo:
|
||||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||||
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
||||||
self.onnx_encoder,
|
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
|
||||||
export_options=export_options
|
|
||||||
)
|
)
|
||||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||||
return
|
return
|
||||||
@ -155,7 +152,7 @@ class T2SModel(nn.Module):
|
|||||||
"text_bert": {0: "text_length"},
|
"text_bert": {0: "text_length"},
|
||||||
"ssl_content": {2: "ssl_length"},
|
"ssl_content": {2: "ssl_length"},
|
||||||
},
|
},
|
||||||
opset_version=16
|
opset_version=16,
|
||||||
)
|
)
|
||||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
@ -170,7 +167,7 @@ class T2SModel(nn.Module):
|
|||||||
"prompts": {1: "prompts_length"},
|
"prompts": {1: "prompts_length"},
|
||||||
},
|
},
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=16
|
opset_version=16,
|
||||||
)
|
)
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
|
||||||
@ -188,7 +185,7 @@ class T2SModel(nn.Module):
|
|||||||
"ix_example": {1: "ix_example_length"},
|
"ix_example": {1: "ix_example_length"},
|
||||||
},
|
},
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=16
|
opset_version=16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -197,7 +194,7 @@ class VitsModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
dict_s2 = torch.load(vits_path, map_location="cpu")
|
dict_s2 = torch.load(vits_path, map_location="cpu")
|
||||||
self.hps = dict_s2["config"]
|
self.hps = dict_s2["config"]
|
||||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||||
self.hps["model"]["version"] = "v1"
|
self.hps["model"]["version"] = "v1"
|
||||||
else:
|
else:
|
||||||
self.hps["model"]["version"] = "v2"
|
self.hps["model"]["version"] = "v2"
|
||||||
@ -208,7 +205,7 @@ class VitsModel(nn.Module):
|
|||||||
self.hps.data.filter_length // 2 + 1,
|
self.hps.data.filter_length // 2 + 1,
|
||||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||||
n_speakers=self.hps.data.n_speakers,
|
n_speakers=self.hps.data.n_speakers,
|
||||||
**self.hps.model
|
**self.hps.model,
|
||||||
)
|
)
|
||||||
self.vq_model.eval()
|
self.vq_model.eval()
|
||||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
@ -220,7 +217,7 @@ class VitsModel(nn.Module):
|
|||||||
self.hps.data.sampling_rate,
|
self.hps.data.sampling_rate,
|
||||||
self.hps.data.hop_length,
|
self.hps.data.hop_length,
|
||||||
self.hps.data.win_length,
|
self.hps.data.win_length,
|
||||||
center=False
|
center=False,
|
||||||
)
|
)
|
||||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||||
|
|
||||||
@ -236,12 +233,16 @@ class GptSoVits(nn.Module):
|
|||||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||||
if debug:
|
if debug:
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
|
||||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||||
audio1 = sess.run(None, {
|
audio1 = sess.run(
|
||||||
|
None,
|
||||||
|
{
|
||||||
"text_seq": text_seq.detach().cpu().numpy(),
|
"text_seq": text_seq.detach().cpu().numpy(),
|
||||||
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
||||||
"ref_audio" : ref_audio.detach().cpu().numpy()
|
"ref_audio": ref_audio.detach().cpu().numpy(),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
return audio, audio1
|
return audio, audio1
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@ -260,7 +261,7 @@ class GptSoVits(nn.Module):
|
|||||||
"ref_audio": {1: "audio_length"},
|
"ref_audio": {1: "audio_length"},
|
||||||
},
|
},
|
||||||
opset_version=17,
|
opset_version=17,
|
||||||
verbose=False
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -278,8 +279,61 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
|||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
ssl = SSLModel()
|
ssl = SSLModel()
|
||||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
ref_seq = torch.LongTensor(
|
||||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
[
|
||||||
|
cleaned_text_to_sequence(
|
||||||
|
[
|
||||||
|
"n",
|
||||||
|
"i2",
|
||||||
|
"h",
|
||||||
|
"ao3",
|
||||||
|
",",
|
||||||
|
"w",
|
||||||
|
"o3",
|
||||||
|
"sh",
|
||||||
|
"i4",
|
||||||
|
"b",
|
||||||
|
"ai2",
|
||||||
|
"y",
|
||||||
|
"e4",
|
||||||
|
],
|
||||||
|
version=vits_model,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
text_seq = torch.LongTensor(
|
||||||
|
[
|
||||||
|
cleaned_text_to_sequence(
|
||||||
|
[
|
||||||
|
"w",
|
||||||
|
"o3",
|
||||||
|
"sh",
|
||||||
|
"i4",
|
||||||
|
"b",
|
||||||
|
"ai2",
|
||||||
|
"y",
|
||||||
|
"e4",
|
||||||
|
"w",
|
||||||
|
"o3",
|
||||||
|
"sh",
|
||||||
|
"i4",
|
||||||
|
"b",
|
||||||
|
"ai2",
|
||||||
|
"y",
|
||||||
|
"e4",
|
||||||
|
"w",
|
||||||
|
"o3",
|
||||||
|
"sh",
|
||||||
|
"i4",
|
||||||
|
"b",
|
||||||
|
"ai2",
|
||||||
|
"y",
|
||||||
|
"e4",
|
||||||
|
],
|
||||||
|
version=vits_model,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||||
@ -326,7 +380,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
|||||||
}
|
}
|
||||||
|
|
||||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||||
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
||||||
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,15 +12,13 @@ if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
|||||||
opt_dir = os.environ.get("opt_dir")
|
opt_dir = os.environ.get("opt_dir")
|
||||||
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||||
version = os.environ.get('version', None)
|
version = os.environ.get("version", None)
|
||||||
import sys, numpy as np, traceback, pdb
|
import traceback
|
||||||
import os.path
|
import os.path
|
||||||
from glob import glob
|
|
||||||
from tqdm import tqdm
|
|
||||||
from text.cleaner import clean_text
|
from text.cleaner import clean_text
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
import numpy as np
|
|
||||||
from tools.my_utils import clean_path
|
from tools.my_utils import clean_path
|
||||||
|
|
||||||
# inp_text=sys.argv[1]
|
# inp_text=sys.argv[1]
|
||||||
@ -56,8 +54,10 @@ if os.path.exists(txt_path) == False:
|
|||||||
# device = "mps"
|
# device = "mps"
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
if os.path.exists(bert_pretrained_dir):...
|
if os.path.exists(bert_pretrained_dir):
|
||||||
else:raise FileNotFoundError(bert_pretrained_dir)
|
...
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(bert_pretrained_dir)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
||||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
@ -89,9 +89,7 @@ if os.path.exists(txt_path) == False:
|
|||||||
name = clean_path(name)
|
name = clean_path(name)
|
||||||
name = os.path.basename(name)
|
name = os.path.basename(name)
|
||||||
print(name)
|
print(name)
|
||||||
phones, word2ph, norm_text = clean_text(
|
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
|
||||||
text.replace("%", "-").replace("¥", ","), lan, version
|
|
||||||
)
|
|
||||||
path_bert = "%s/%s.pt" % (bert_dir, name)
|
path_bert = "%s/%s.pt" % (bert_dir, name)
|
||||||
if os.path.exists(path_bert) == False and lan == "zh":
|
if os.path.exists(path_bert) == False and lan == "zh":
|
||||||
bert_feature = get_bert_feature(norm_text, word2ph)
|
bert_feature = get_bert_feature(norm_text, word2ph)
|
||||||
@ -131,9 +129,7 @@ if os.path.exists(txt_path) == False:
|
|||||||
wav_name, spk_name, language, text = line.split("|")
|
wav_name, spk_name, language, text = line.split("|")
|
||||||
# todo.append([name,text,"zh"])
|
# todo.append([name,text,"zh"])
|
||||||
if language in language_v1_to_language_v2.keys():
|
if language in language_v1_to_language_v2.keys():
|
||||||
todo.append(
|
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
|
||||||
[wav_name, text, language_v1_to_language_v2.get(language, language)]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
||||||
except:
|
except:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import sys,os
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
inp_text = os.environ.get("inp_text")
|
inp_text = os.environ.get("inp_text")
|
||||||
inp_wav_dir = os.environ.get("inp_wav_dir")
|
inp_wav_dir = os.environ.get("inp_wav_dir")
|
||||||
exp_name = os.environ.get("exp_name")
|
exp_name = os.environ.get("exp_name")
|
||||||
@ -9,14 +11,18 @@ all_parts= os.environ.get("all_parts")
|
|||||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
|
|
||||||
opt_dir = os.environ.get("opt_dir")
|
opt_dir = os.environ.get("opt_dir")
|
||||||
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||||
|
|
||||||
import pdb,traceback,numpy as np,logging
|
import traceback
|
||||||
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
from tools.my_utils import load_audio, clean_path
|
from tools.my_utils import load_audio, clean_path
|
||||||
@ -34,6 +40,8 @@ from tools.my_utils import load_audio,clean_path
|
|||||||
|
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||||
dir = os.path.dirname(path)
|
dir = os.path.dirname(path)
|
||||||
name = os.path.basename(path)
|
name = os.path.basename(path)
|
||||||
@ -42,6 +50,7 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
|||||||
torch.save(fea, tmp_path)
|
torch.save(fea, tmp_path)
|
||||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||||
|
|
||||||
|
|
||||||
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
||||||
wav32dir = "%s/5-wav32k" % (opt_dir)
|
wav32dir = "%s/5-wav32k" % (opt_dir)
|
||||||
os.makedirs(opt_dir, exist_ok=True)
|
os.makedirs(opt_dir, exist_ok=True)
|
||||||
@ -58,15 +67,18 @@ else:
|
|||||||
device = "cpu"
|
device = "cpu"
|
||||||
model = cnhubert.get_model()
|
model = cnhubert.get_model()
|
||||||
# is_half=False
|
# is_half=False
|
||||||
if(is_half==True):
|
if is_half == True:
|
||||||
model = model.half().to(device)
|
model = model.half().to(device)
|
||||||
else:
|
else:
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
nan_fails = []
|
nan_fails = []
|
||||||
|
|
||||||
|
|
||||||
def name2go(wav_name, wav_path):
|
def name2go(wav_name, wav_path):
|
||||||
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
||||||
if(os.path.exists(hubert_path)):return
|
if os.path.exists(hubert_path):
|
||||||
|
return
|
||||||
tmp_audio = load_audio(wav_path, 32000)
|
tmp_audio = load_audio(wav_path, 32000)
|
||||||
tmp_max = np.abs(tmp_audio).max()
|
tmp_max = np.abs(tmp_audio).max()
|
||||||
if tmp_max > 2.2:
|
if tmp_max > 2.2:
|
||||||
@ -74,11 +86,9 @@ def name2go(wav_name,wav_path):
|
|||||||
return
|
return
|
||||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
||||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
||||||
tmp_audio = librosa.resample(
|
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
|
||||||
tmp_audio32b, orig_sr=32000, target_sr=16000
|
|
||||||
)#不是重采样问题
|
|
||||||
tensor_wav16 = torch.from_numpy(tmp_audio)
|
tensor_wav16 = torch.from_numpy(tmp_audio)
|
||||||
if (is_half == True):
|
if is_half == True:
|
||||||
tensor_wav16 = tensor_wav16.half().to(device)
|
tensor_wav16 = tensor_wav16.half().to(device)
|
||||||
else:
|
else:
|
||||||
tensor_wav16 = tensor_wav16.to(device)
|
tensor_wav16 = tensor_wav16.to(device)
|
||||||
@ -94,6 +104,7 @@ def name2go(wav_name,wav_path):
|
|||||||
)
|
)
|
||||||
my_save(ssl, hubert_path)
|
my_save(ssl, hubert_path)
|
||||||
|
|
||||||
|
|
||||||
with open(inp_text, "r", encoding="utf8") as f:
|
with open(inp_text, "r", encoding="utf8") as f:
|
||||||
lines = f.read().strip("\n").split("\n")
|
lines = f.read().strip("\n").split("\n")
|
||||||
|
|
||||||
@ -102,7 +113,7 @@ for line in lines[int(i_part)::int(all_parts)]:
|
|||||||
# wav_name,text=line.split("\t")
|
# wav_name,text=line.split("\t")
|
||||||
wav_name, spk_name, language, text = line.split("|")
|
wav_name, spk_name, language, text = line.split("|")
|
||||||
wav_name = clean_path(wav_name)
|
wav_name = clean_path(wav_name)
|
||||||
if (inp_wav_dir != "" and inp_wav_dir != None):
|
if inp_wav_dir != "" and inp_wav_dir != None:
|
||||||
wav_name = os.path.basename(wav_name)
|
wav_name = os.path.basename(wav_name)
|
||||||
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
||||||
|
|
||||||
@ -113,7 +124,7 @@ for line in lines[int(i_part)::int(all_parts)]:
|
|||||||
except:
|
except:
|
||||||
print(line, traceback.format_exc())
|
print(line, traceback.format_exc())
|
||||||
|
|
||||||
if(len(nan_fails)>0 and is_half==True):
|
if len(nan_fails) > 0 and is_half == True:
|
||||||
is_half = False
|
is_half = False
|
||||||
model = model.float()
|
model = model.float()
|
||||||
for wav in nan_fails:
|
for wav in nan_fails:
|
||||||
|
@ -10,8 +10,10 @@ opt_dir = os.environ.get("opt_dir")
|
|||||||
pretrained_s2G = os.environ.get("pretrained_s2G")
|
pretrained_s2G = os.environ.get("pretrained_s2G")
|
||||||
s2config_path = os.environ.get("s2config_path")
|
s2config_path = os.environ.get("s2config_path")
|
||||||
|
|
||||||
if os.path.exists(pretrained_s2G):...
|
if os.path.exists(pretrained_s2G):
|
||||||
else:raise FileNotFoundError(pretrained_s2G)
|
...
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(pretrained_s2G)
|
||||||
# version=os.environ.get("version","v2")
|
# version=os.environ.get("version","v2")
|
||||||
size = os.path.getsize(pretrained_s2G)
|
size = os.path.getsize(pretrained_s2G)
|
||||||
if size < 82978 * 1024:
|
if size < 82978 * 1024:
|
||||||
@ -25,23 +27,22 @@ elif size < 700 * 1024 * 1024:
|
|||||||
else:
|
else:
|
||||||
version = "v3"
|
version = "v3"
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||||
import math, traceback
|
import traceback
|
||||||
import multiprocessing
|
import sys
|
||||||
import sys, pdb
|
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
from random import shuffle
|
import logging
|
||||||
import torch.multiprocessing as mp
|
import utils
|
||||||
from glob import glob
|
|
||||||
from tqdm import tqdm
|
|
||||||
import logging, librosa, utils
|
|
||||||
if version != "v3":
|
if version != "v3":
|
||||||
from module.models import SynthesizerTrn
|
from module.models import SynthesizerTrn
|
||||||
else:
|
else:
|
||||||
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
||||||
from tools.my_utils import clean_path
|
from tools.my_utils import clean_path
|
||||||
|
|
||||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
# from config import pretrained_s2G
|
# from config import pretrained_s2G
|
||||||
|
|
||||||
@ -70,7 +71,7 @@ if os.path.exists(semantic_path) == False:
|
|||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
version=version,
|
version=version,
|
||||||
**hps.model
|
**hps.model,
|
||||||
)
|
)
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
vq_model = vq_model.half().to(device)
|
vq_model = vq_model.half().to(device)
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
import shutil,os
|
import shutil
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
|
|
||||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||||
dir = os.path.dirname(path)
|
dir = os.path.dirname(path)
|
||||||
name = os.path.basename(path)
|
name = os.path.basename(path)
|
||||||
@ -14,22 +16,27 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
|||||||
torch.save(fea, tmp_path)
|
torch.save(fea, tmp_path)
|
||||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||||
|
|
||||||
'''
|
|
||||||
|
"""
|
||||||
00:v1
|
00:v1
|
||||||
01:v2
|
01:v2
|
||||||
02:v3
|
02:v3
|
||||||
03:v3lora
|
03:v3lora
|
||||||
|
|
||||||
|
|
||||||
'''
|
"""
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
def my_save2(fea, path):
|
def my_save2(fea, path):
|
||||||
bio = BytesIO()
|
bio = BytesIO()
|
||||||
torch.save(fea, bio)
|
torch.save(fea, bio)
|
||||||
bio.seek(0)
|
bio.seek(0)
|
||||||
data = bio.getvalue()
|
data = bio.getvalue()
|
||||||
data = b'03' + data[2:]###temp for v3lora only, todo
|
data = b"03" + data[2:] ###temp for v3lora only, todo
|
||||||
with open(path, "wb") as f: f.write(data)
|
with open(path, "wb") as f:
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||||
try:
|
try:
|
||||||
@ -50,11 +57,12 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
|
|||||||
except:
|
except:
|
||||||
return traceback.format_exc()
|
return traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
head2version = {
|
head2version = {
|
||||||
b'00':["v1","v1",False],
|
b"00": ["v1", "v1", False],
|
||||||
b'01':["v2","v2",False],
|
b"01": ["v2", "v2", False],
|
||||||
b'02':["v2","v3",False],
|
b"02": ["v2", "v3", False],
|
||||||
b'03':["v2","v3",True],
|
b"03": ["v2", "v3", True],
|
||||||
}
|
}
|
||||||
hash_pretrained_dict = {
|
hash_pretrained_dict = {
|
||||||
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
||||||
@ -62,29 +70,35 @@ hash_pretrained_dict={
|
|||||||
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
||||||
}
|
}
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
def get_hash_from_file(sovits_path):
|
def get_hash_from_file(sovits_path):
|
||||||
with open(sovits_path,"rb")as f:data=f.read(8192)
|
with open(sovits_path, "rb") as f:
|
||||||
|
data = f.read(8192)
|
||||||
hash_md5 = hashlib.md5()
|
hash_md5 = hashlib.md5()
|
||||||
hash_md5.update(data)
|
hash_md5.update(data)
|
||||||
return hash_md5.hexdigest()
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_sovits_version_from_path_fast(sovits_path):
|
def get_sovits_version_from_path_fast(sovits_path):
|
||||||
###1-if it is pretrained sovits models, by hash
|
###1-if it is pretrained sovits models, by hash
|
||||||
hash = get_hash_from_file(sovits_path)
|
hash = get_hash_from_file(sovits_path)
|
||||||
if hash in hash_pretrained_dict:
|
if hash in hash_pretrained_dict:
|
||||||
return hash_pretrained_dict[hash]
|
return hash_pretrained_dict[hash]
|
||||||
###2-new weights or old weights, by head
|
###2-new weights or old weights, by head
|
||||||
with open(sovits_path,"rb")as f:version=f.read(2)
|
with open(sovits_path, "rb") as f:
|
||||||
|
version = f.read(2)
|
||||||
if version != b"PK":
|
if version != b"PK":
|
||||||
return head2version[version]
|
return head2version[version]
|
||||||
###3-old weights, by file size
|
###3-old weights, by file size
|
||||||
if_lora_v3 = False
|
if_lora_v3 = False
|
||||||
size = os.path.getsize(sovits_path)
|
size = os.path.getsize(sovits_path)
|
||||||
'''
|
"""
|
||||||
v1weights:about 82942KB
|
v1weights:about 82942KB
|
||||||
half thr:82978KB
|
half thr:82978KB
|
||||||
v2weights:about 83014KB
|
v2weights:about 83014KB
|
||||||
v3weights:about 750MB
|
v3weights:about 750MB
|
||||||
'''
|
"""
|
||||||
if size < 82978 * 1024:
|
if size < 82978 * 1024:
|
||||||
model_version = version = "v1"
|
model_version = version = "v1"
|
||||||
elif size < 700 * 1024 * 1024:
|
elif size < 700 * 1024 * 1024:
|
||||||
@ -94,11 +108,12 @@ def get_sovits_version_from_path_fast(sovits_path):
|
|||||||
model_version = "v3"
|
model_version = "v3"
|
||||||
return version, model_version, if_lora_v3
|
return version, model_version, if_lora_v3
|
||||||
|
|
||||||
|
|
||||||
def load_sovits_new(sovits_path):
|
def load_sovits_new(sovits_path):
|
||||||
f = open(sovits_path, "rb")
|
f = open(sovits_path, "rb")
|
||||||
meta = f.read(2)
|
meta = f.read(2)
|
||||||
if meta != "PK":
|
if meta != "PK":
|
||||||
data = b'PK' + f.read()
|
data = b"PK" + f.read()
|
||||||
bio = BytesIO()
|
bio = BytesIO()
|
||||||
bio.write(data)
|
bio.write(data)
|
||||||
bio.seek(0)
|
bio.seek(0)
|
||||||
|
@ -1,31 +1,28 @@
|
|||||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
||||||
import os
|
import os
|
||||||
import pdb
|
|
||||||
|
|
||||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import platform
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch, platform
|
import torch
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
from pytorch_lightning import Trainer
|
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
|
||||||
from pytorch_lightning.strategies import DDPStrategy
|
|
||||||
from AR.data.data_module import Text2SemanticDataModule
|
from AR.data.data_module import Text2SemanticDataModule
|
||||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
from AR.utils.io import load_yaml_config
|
from AR.utils.io import load_yaml_config
|
||||||
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
|
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||||
|
from pytorch_lightning.strategies import DDPStrategy
|
||||||
|
|
||||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
from AR.utils import get_newest_ckpt
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from time import time as ttime
|
|
||||||
import shutil
|
from AR.utils import get_newest_ckpt
|
||||||
from process_ckpt import my_save
|
from process_ckpt import my_save
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +34,7 @@ class my_model_ckpt(ModelCheckpoint):
|
|||||||
if_save_every_weights,
|
if_save_every_weights,
|
||||||
half_weights_save_dir,
|
half_weights_save_dir,
|
||||||
exp_name,
|
exp_name,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.if_save_latest = if_save_latest
|
self.if_save_latest = if_save_latest
|
||||||
@ -50,10 +47,7 @@ class my_model_ckpt(ModelCheckpoint):
|
|||||||
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
||||||
if self._should_save_on_train_epoch_end(trainer):
|
if self._should_save_on_train_epoch_end(trainer):
|
||||||
monitor_candidates = self._monitor_candidates(trainer)
|
monitor_candidates = self._monitor_candidates(trainer)
|
||||||
if (
|
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||||
self._every_n_epochs >= 1
|
|
||||||
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
|
|
||||||
):
|
|
||||||
if (
|
if (
|
||||||
self.if_save_latest == True
|
self.if_save_latest == True
|
||||||
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||||
@ -75,7 +69,7 @@ class my_model_ckpt(ModelCheckpoint):
|
|||||||
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||||
# torch.save(
|
# torch.save(
|
||||||
# print(os.environ)
|
# print(os.environ)
|
||||||
if(os.environ.get("LOCAL_RANK","0")=="0"):
|
if os.environ.get("LOCAL_RANK", "0") == "0":
|
||||||
my_save(
|
my_save(
|
||||||
to_save_od,
|
to_save_od,
|
||||||
"%s/%s-e%s.ckpt"
|
"%s/%s-e%s.ckpt"
|
||||||
@ -123,9 +117,9 @@ def main(args):
|
|||||||
devices=-1 if torch.cuda.is_available() else 1,
|
devices=-1 if torch.cuda.is_available() else 1,
|
||||||
benchmark=False,
|
benchmark=False,
|
||||||
fast_dev_run=False,
|
fast_dev_run=False,
|
||||||
strategy = DDPStrategy(
|
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
if torch.cuda.is_available()
|
||||||
) if torch.cuda.is_available() else "auto",
|
else "auto",
|
||||||
precision=config["train"]["precision"],
|
precision=config["train"]["precision"],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
num_sanity_val_steps=0,
|
num_sanity_val_steps=0,
|
||||||
@ -133,9 +127,7 @@ def main(args):
|
|||||||
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
|
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
|
||||||
)
|
)
|
||||||
|
|
||||||
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
|
model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
|
||||||
config, output_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
||||||
config,
|
config,
|
||||||
|
@ -1,36 +1,41 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import utils, os
|
import os
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
hps = utils.get_hparams(stage=2)
|
hps = utils.get_hparams(stage=2)
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.distributed as dist, traceback
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import logging, traceback
|
|
||||||
|
|
||||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||||
logging.getLogger("numba").setLevel(logging.INFO)
|
logging.getLogger("numba").setLevel(logging.INFO)
|
||||||
from random import randint
|
from random import randint
|
||||||
from module import commons
|
|
||||||
|
|
||||||
|
from module import commons
|
||||||
from module.data_utils import (
|
from module.data_utils import (
|
||||||
TextAudioSpeakerLoader,
|
|
||||||
TextAudioSpeakerCollate,
|
|
||||||
DistributedBucketSampler,
|
DistributedBucketSampler,
|
||||||
|
TextAudioSpeakerCollate,
|
||||||
|
TextAudioSpeakerLoader,
|
||||||
)
|
)
|
||||||
from module.models import (
|
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
||||||
SynthesizerTrn,
|
|
||||||
MultiPeriodDiscriminator,
|
|
||||||
)
|
|
||||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
|
||||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
||||||
|
from module.models import (
|
||||||
|
MultiPeriodDiscriminator,
|
||||||
|
SynthesizerTrn,
|
||||||
|
)
|
||||||
from process_ckpt import savee
|
from process_ckpt import savee
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
@ -46,7 +51,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.cuda.device_count()
|
||||||
else:
|
else:
|
||||||
@ -128,19 +132,27 @@ def run(rank, n_gpus, hps):
|
|||||||
# batch_size=1, pin_memory=True,
|
# batch_size=1, pin_memory=True,
|
||||||
# drop_last=False, collate_fn=collate_fn)
|
# drop_last=False, collate_fn=collate_fn)
|
||||||
|
|
||||||
net_g = SynthesizerTrn(
|
net_g = (
|
||||||
|
SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model,
|
**hps.model,
|
||||||
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
|
).cuda(rank)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model,
|
**hps.model,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
)
|
||||||
|
|
||||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
net_d = (
|
||||||
|
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||||
|
)
|
||||||
for name, param in net_g.named_parameters():
|
for name, param in net_g.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
print(name, "not requires_grad")
|
print(name, "not requires_grad")
|
||||||
@ -213,37 +225,55 @@ def run(rank, n_gpus, hps):
|
|||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
epoch_str = 1
|
epoch_str = 1
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
if (
|
||||||
|
hps.train.pretrained_s2G != ""
|
||||||
|
and hps.train.pretrained_s2G != None
|
||||||
|
and os.path.exists(hps.train.pretrained_s2G)
|
||||||
|
):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
print(
|
||||||
|
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||||
net_g.module.load_state_dict(
|
net_g.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
) if torch.cuda.is_available() else net_g.load_state_dict(
|
)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else net_g.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
)
|
),
|
||||||
) ##测试不加载优化器
|
) ##测试不加载优化器
|
||||||
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
if (
|
||||||
|
hps.train.pretrained_s2D != ""
|
||||||
|
and hps.train.pretrained_s2D != None
|
||||||
|
and os.path.exists(hps.train.pretrained_s2D)
|
||||||
|
):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
||||||
print("loaded pretrained %s" % hps.train.pretrained_s2D,
|
print(
|
||||||
|
"loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||||
net_d.module.load_state_dict(
|
net_d.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||||
) if torch.cuda.is_available() else net_d.load_state_dict(
|
|
||||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
|
||||||
)
|
)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else net_d.load_state_dict(
|
||||||
|
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
optim_g,
|
||||||
|
gamma=hps.train.lr_decay,
|
||||||
|
last_epoch=-1,
|
||||||
)
|
)
|
||||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||||
optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
optim_d,
|
||||||
|
gamma=hps.train.lr_decay,
|
||||||
|
last_epoch=-1,
|
||||||
)
|
)
|
||||||
for _ in range(epoch_str):
|
for _ in range(epoch_str):
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
@ -285,9 +315,7 @@ def run(rank, n_gpus, hps):
|
|||||||
print("training done")
|
print("training done")
|
||||||
|
|
||||||
|
|
||||||
def train_and_evaluate(
|
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
|
||||||
):
|
|
||||||
net_g, net_d = nets
|
net_g, net_d = nets
|
||||||
optim_g, optim_d = optims
|
optim_g, optim_d = optims
|
||||||
# scheduler_g, scheduler_d = schedulers
|
# scheduler_g, scheduler_d = schedulers
|
||||||
@ -311,17 +339,38 @@ def train_and_evaluate(
|
|||||||
text_lengths,
|
text_lengths,
|
||||||
) in enumerate(tqdm(train_loader)):
|
) in enumerate(tqdm(train_loader)):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
spec, spec_lengths = (
|
||||||
rank, non_blocking=True
|
spec.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
spec_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
|
y, y_lengths = (
|
||||||
rank, non_blocking=True
|
y.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
y_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
ssl = ssl.cuda(rank, non_blocking=True)
|
ssl = ssl.cuda(rank, non_blocking=True)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
text, text_lengths = (
|
||||||
rank, non_blocking=True
|
text.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
text_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
@ -350,9 +399,7 @@ def train_and_evaluate(
|
|||||||
hps.data.mel_fmin,
|
hps.data.mel_fmin,
|
||||||
hps.data.mel_fmax,
|
hps.data.mel_fmax,
|
||||||
)
|
)
|
||||||
y_mel = commons.slice_segments(
|
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
||||||
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
|
|
||||||
)
|
|
||||||
y_hat_mel = mel_spectrogram_torch(
|
y_hat_mel = mel_spectrogram_torch(
|
||||||
y_hat.squeeze(1),
|
y_hat.squeeze(1),
|
||||||
hps.data.filter_length,
|
hps.data.filter_length,
|
||||||
@ -364,15 +411,14 @@ def train_and_evaluate(
|
|||||||
hps.data.mel_fmax,
|
hps.data.mel_fmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
y = commons.slice_segments(
|
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
||||||
y, ids_slice * hps.data.hop_length, hps.train.segment_size
|
|
||||||
) # slice
|
|
||||||
|
|
||||||
# Discriminator
|
# Discriminator
|
||||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
||||||
y_d_hat_r, y_d_hat_g
|
y_d_hat_r,
|
||||||
|
y_d_hat_g,
|
||||||
)
|
)
|
||||||
loss_disc_all = loss_disc
|
loss_disc_all = loss_disc
|
||||||
optim_d.zero_grad()
|
optim_d.zero_grad()
|
||||||
@ -405,7 +451,8 @@ def train_and_evaluate(
|
|||||||
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
||||||
logger.info(
|
logger.info(
|
||||||
"Train Epoch: {} [{:.0f}%]".format(
|
"Train Epoch: {} [{:.0f}%]".format(
|
||||||
epoch, 100.0 * batch_idx / len(train_loader)
|
epoch,
|
||||||
|
100.0 * batch_idx / len(train_loader),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||||
@ -433,21 +480,33 @@ def train_and_evaluate(
|
|||||||
try: ###Some people installed the wrong version of matplotlib.
|
try: ###Some people installed the wrong version of matplotlib.
|
||||||
image_dict = {
|
image_dict = {
|
||||||
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
||||||
y_mel[0].data.cpu().numpy()
|
y_mel[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
||||||
y_hat_mel[0].data.cpu().numpy()
|
y_hat_mel[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
"all/mel": utils.plot_spectrogram_to_numpy(
|
"all/mel": utils.plot_spectrogram_to_numpy(
|
||||||
mel[0].data.cpu().numpy()
|
mel[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
||||||
stats_ssl[0].data.cpu().numpy()
|
stats_ssl[0].data.cpu().numpy(),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
except:pass
|
except:
|
||||||
if image_dict:utils.summarize(writer=writer,global_step=global_step,images=image_dict,scalars=scalar_dict,)
|
pass
|
||||||
else:utils.summarize(writer=writer,global_step=global_step,scalars=scalar_dict,)
|
if image_dict:
|
||||||
|
utils.summarize(
|
||||||
|
writer=writer,
|
||||||
|
global_step=global_step,
|
||||||
|
images=image_dict,
|
||||||
|
scalars=scalar_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
utils.summarize(
|
||||||
|
writer=writer,
|
||||||
|
global_step=global_step,
|
||||||
|
scalars=scalar_dict,
|
||||||
|
)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||||
if hps.train.if_save_latest == 0:
|
if hps.train.if_save_latest == 0:
|
||||||
@ -457,7 +516,8 @@ def train_and_evaluate(
|
|||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
|
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||||
|
"G_{}.pth".format(global_step),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
@ -466,7 +526,8 @@ def train_and_evaluate(
|
|||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
|
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||||
|
"D_{}.pth".format(global_step),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -476,7 +537,8 @@ def train_and_evaluate(
|
|||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
|
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||||
|
"G_{}.pth".format(233333333333),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
@ -485,7 +547,8 @@ def train_and_evaluate(
|
|||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
|
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||||
|
"D_{}.pth".format(233333333333),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||||
@ -540,10 +603,24 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
ssl = ssl.to(device)
|
ssl = ssl.to(device)
|
||||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||||
for test in [0, 1]:
|
for test in [0, 1]:
|
||||||
y_hat, mask, *_ = generator.module.infer(
|
y_hat, mask, *_ = (
|
||||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
generator.module.infer(
|
||||||
) if torch.cuda.is_available() else generator.infer(
|
ssl,
|
||||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
test=test,
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else generator.infer(
|
||||||
|
ssl,
|
||||||
|
spec,
|
||||||
|
spec_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
test=test,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
||||||
|
|
||||||
@ -568,19 +645,19 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
image_dict.update(
|
image_dict.update(
|
||||||
{
|
{
|
||||||
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
||||||
y_hat_mel[0].cpu().numpy()
|
y_hat_mel[0].cpu().numpy(),
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
audio_dict.update(
|
audio_dict.update(
|
||||||
{f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
|
{
|
||||||
|
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
image_dict.update(
|
image_dict.update(
|
||||||
{
|
{
|
||||||
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
|
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
|
||||||
mel[0].cpu().numpy()
|
},
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
||||||
|
|
||||||
|
@ -1,36 +1,41 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import utils, os
|
import os
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
hps = utils.get_hparams(stage=2)
|
hps = utils.get_hparams(stage=2)
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.distributed as dist, traceback
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import logging, traceback
|
|
||||||
|
|
||||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||||
logging.getLogger("numba").setLevel(logging.INFO)
|
logging.getLogger("numba").setLevel(logging.INFO)
|
||||||
from random import randint
|
from random import randint
|
||||||
from module import commons
|
|
||||||
|
|
||||||
|
from module import commons
|
||||||
|
from module.data_utils import (
|
||||||
|
DistributedBucketSampler,
|
||||||
|
)
|
||||||
|
from module.data_utils import (
|
||||||
|
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||||
|
)
|
||||||
from module.data_utils import (
|
from module.data_utils import (
|
||||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
|
||||||
DistributedBucketSampler,
|
|
||||||
)
|
)
|
||||||
from module.models import (
|
from module.models import (
|
||||||
SynthesizerTrnV3 as SynthesizerTrn,
|
SynthesizerTrnV3 as SynthesizerTrn,
|
||||||
MultiPeriodDiscriminator,
|
|
||||||
)
|
)
|
||||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
|
||||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
|
||||||
from process_ckpt import savee
|
from process_ckpt import savee
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
@ -46,7 +51,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.cuda.device_count()
|
||||||
else:
|
else:
|
||||||
@ -128,17 +132,21 @@ def run(rank, n_gpus, hps):
|
|||||||
# batch_size=1, pin_memory=True,
|
# batch_size=1, pin_memory=True,
|
||||||
# drop_last=False, collate_fn=collate_fn)
|
# drop_last=False, collate_fn=collate_fn)
|
||||||
|
|
||||||
net_g = SynthesizerTrn(
|
net_g = (
|
||||||
|
SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model,
|
**hps.model,
|
||||||
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
|
).cuda(rank)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model,
|
**hps.model,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
)
|
||||||
|
|
||||||
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
||||||
# for name, param in net_g.named_parameters():
|
# for name, param in net_g.named_parameters():
|
||||||
@ -186,17 +194,24 @@ def run(rank, n_gpus, hps):
|
|||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
epoch_str = 1
|
epoch_str = 1
|
||||||
global_step = 0
|
global_step = 0
|
||||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
if (
|
||||||
|
hps.train.pretrained_s2G != ""
|
||||||
|
and hps.train.pretrained_s2G != None
|
||||||
|
and os.path.exists(hps.train.pretrained_s2G)
|
||||||
|
):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
print(
|
||||||
|
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||||
net_g.module.load_state_dict(
|
net_g.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
) if torch.cuda.is_available() else net_g.load_state_dict(
|
)
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else net_g.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
)
|
),
|
||||||
) ##测试不加载优化器
|
) ##测试不加载优化器
|
||||||
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
@ -212,9 +227,7 @@ def run(rank, n_gpus, hps):
|
|||||||
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
|
||||||
)
|
|
||||||
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||||
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1
|
||||||
# )
|
# )
|
||||||
@ -260,7 +273,16 @@ def run(rank, n_gpus, hps):
|
|||||||
|
|
||||||
|
|
||||||
def train_and_evaluate(
|
def train_and_evaluate(
|
||||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
rank,
|
||||||
|
epoch,
|
||||||
|
hps,
|
||||||
|
nets,
|
||||||
|
optims,
|
||||||
|
schedulers,
|
||||||
|
scaler,
|
||||||
|
loaders,
|
||||||
|
logger,
|
||||||
|
writers,
|
||||||
):
|
):
|
||||||
net_g, net_d = nets
|
net_g, net_d = nets
|
||||||
optim_g, optim_d = optims
|
optim_g, optim_d = optims
|
||||||
@ -284,19 +306,33 @@ def train_and_evaluate(
|
|||||||
# text,
|
# text,
|
||||||
# text_lengths,
|
# text_lengths,
|
||||||
# ) in enumerate(tqdm(train_loader)):
|
# ) in enumerate(tqdm(train_loader)):
|
||||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
|
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||||
|
tqdm(train_loader)
|
||||||
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
spec, spec_lengths = (
|
||||||
rank, non_blocking=True
|
spec.cuda(
|
||||||
)
|
rank,
|
||||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
|
non_blocking=True,
|
||||||
rank, non_blocking=True
|
),
|
||||||
|
spec_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
|
||||||
ssl = ssl.cuda(rank, non_blocking=True)
|
ssl = ssl.cuda(rank, non_blocking=True)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
text, text_lengths = (
|
||||||
rank, non_blocking=True
|
text.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
text_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
@ -307,7 +343,17 @@ def train_and_evaluate(
|
|||||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
cfm_loss = net_g(
|
||||||
|
ssl,
|
||||||
|
spec,
|
||||||
|
mel,
|
||||||
|
ssl_lengths,
|
||||||
|
spec_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
mel_lengths,
|
||||||
|
use_grad_ckpt=hps.train.grad_ckpt,
|
||||||
|
)
|
||||||
loss_gen_all = cfm_loss
|
loss_gen_all = cfm_loss
|
||||||
optim_g.zero_grad()
|
optim_g.zero_grad()
|
||||||
scaler.scale(loss_gen_all).backward()
|
scaler.scale(loss_gen_all).backward()
|
||||||
@ -318,12 +364,15 @@ def train_and_evaluate(
|
|||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if global_step % hps.train.log_interval == 0:
|
if global_step % hps.train.log_interval == 0:
|
||||||
lr = optim_g.param_groups[0]['lr']
|
lr = optim_g.param_groups[0]["lr"]
|
||||||
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
|
||||||
losses = [cfm_loss]
|
losses = [cfm_loss]
|
||||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
logger.info(
|
||||||
|
"Train Epoch: {} [{:.0f}%]".format(
|
||||||
epoch,
|
epoch,
|
||||||
100. * batch_idx / len(train_loader)))
|
100.0 * batch_idx / len(train_loader),
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||||
|
|
||||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||||
@ -337,7 +386,8 @@ def train_and_evaluate(
|
|||||||
writer=writer,
|
writer=writer,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
# images=image_dict,
|
# images=image_dict,
|
||||||
scalars=scalar_dict)
|
scalars=scalar_dict,
|
||||||
|
)
|
||||||
|
|
||||||
# if global_step % hps.train.eval_interval == 0:
|
# if global_step % hps.train.eval_interval == 0:
|
||||||
# # evaluate(hps, net_g, eval_loader, writer_eval)
|
# # evaluate(hps, net_g, eval_loader, writer_eval)
|
||||||
@ -347,7 +397,6 @@ def train_and_evaluate(
|
|||||||
# # if keep_ckpts > 0:
|
# # if keep_ckpts > 0:
|
||||||
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
|
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
|
||||||
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||||
if hps.train.if_save_latest == 0:
|
if hps.train.if_save_latest == 0:
|
||||||
@ -357,7 +406,8 @@ def train_and_evaluate(
|
|||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
|
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||||
|
"G_{}.pth".format(global_step),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# utils.save_checkpoint(
|
# utils.save_checkpoint(
|
||||||
@ -376,7 +426,8 @@ def train_and_evaluate(
|
|||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
|
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
||||||
|
"G_{}.pth".format(233333333333),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# utils.save_checkpoint(
|
# utils.save_checkpoint(
|
||||||
|
@ -1,38 +1,45 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import utils, os
|
import os
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
hps = utils.get_hparams(stage=2)
|
hps = utils.get_hparams(stage=2)
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.distributed as dist, traceback
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import logging, traceback
|
|
||||||
|
|
||||||
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
||||||
logging.getLogger("h5py").setLevel(logging.INFO)
|
logging.getLogger("h5py").setLevel(logging.INFO)
|
||||||
logging.getLogger("numba").setLevel(logging.INFO)
|
logging.getLogger("numba").setLevel(logging.INFO)
|
||||||
|
from collections import OrderedDict as od
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
from module import commons
|
from module import commons
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from module.data_utils import (
|
||||||
|
DistributedBucketSampler,
|
||||||
|
)
|
||||||
|
from module.data_utils import (
|
||||||
|
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||||
|
)
|
||||||
from module.data_utils import (
|
from module.data_utils import (
|
||||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
|
||||||
DistributedBucketSampler,
|
|
||||||
)
|
)
|
||||||
from module.models import (
|
from module.models import (
|
||||||
SynthesizerTrnV3 as SynthesizerTrn,
|
SynthesizerTrnV3 as SynthesizerTrn,
|
||||||
MultiPeriodDiscriminator,
|
|
||||||
)
|
)
|
||||||
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
|
from peft import LoraConfig, get_peft_model
|
||||||
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
|
||||||
from process_ckpt import savee
|
from process_ckpt import savee
|
||||||
from collections import OrderedDict as od
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = False
|
torch.backends.cudnn.deterministic = False
|
||||||
###反正A100fp32更快,那试试tf32吧
|
###反正A100fp32更快,那试试tf32吧
|
||||||
@ -46,7 +53,6 @@ device = "cpu" # cuda以外的设备,等mps优化后加入
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.cuda.device_count()
|
||||||
else:
|
else:
|
||||||
@ -131,12 +137,15 @@ def run(rank, n_gpus, hps):
|
|||||||
lora_alpha=lora_rank,
|
lora_alpha=lora_rank,
|
||||||
init_lora_weights=True,
|
init_lora_weights=True,
|
||||||
)
|
)
|
||||||
def get_model(hps):return SynthesizerTrn(
|
|
||||||
|
def get_model(hps):
|
||||||
|
return SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model,
|
**hps.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_optim(net_g):
|
def get_optim(net_g):
|
||||||
return torch.optim.AdamW(
|
return torch.optim.AdamW(
|
||||||
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
|
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
|
||||||
@ -144,12 +153,14 @@ def run(rank, n_gpus, hps):
|
|||||||
betas=hps.train.betas,
|
betas=hps.train.betas,
|
||||||
eps=hps.train.eps,
|
eps=hps.train.eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def model2cuda(net_g, rank):
|
def model2cuda(net_g, rank):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
|
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
net_g = net_g.to(device)
|
net_g = net_g.to(device)
|
||||||
return net_g
|
return net_g
|
||||||
|
|
||||||
try: # 如果能加载自动resume
|
try: # 如果能加载自动resume
|
||||||
net_g = get_model(hps)
|
net_g = get_model(hps)
|
||||||
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
||||||
@ -168,14 +179,19 @@ def run(rank, n_gpus, hps):
|
|||||||
epoch_str = 1
|
epoch_str = 1
|
||||||
global_step = 0
|
global_step = 0
|
||||||
net_g = get_model(hps)
|
net_g = get_model(hps)
|
||||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
if (
|
||||||
|
hps.train.pretrained_s2G != ""
|
||||||
|
and hps.train.pretrained_s2G != None
|
||||||
|
and os.path.exists(hps.train.pretrained_s2G)
|
||||||
|
):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||||
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
print(
|
||||||
|
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||||
net_g.load_state_dict(
|
net_g.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
|
||||||
net_g = model2cuda(net_g, rank)
|
net_g = model2cuda(net_g, rank)
|
||||||
@ -189,9 +205,7 @@ def run(rank, n_gpus, hps):
|
|||||||
# print(no_grad_names)
|
# print(no_grad_names)
|
||||||
# os._exit(233333)
|
# os._exit(233333)
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
|
||||||
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
|
|
||||||
)
|
|
||||||
for _ in range(epoch_str):
|
for _ in range(epoch_str):
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
|
|
||||||
@ -230,9 +244,8 @@ def run(rank, n_gpus, hps):
|
|||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
print("training done")
|
print("training done")
|
||||||
|
|
||||||
def train_and_evaluate(
|
|
||||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
||||||
):
|
|
||||||
net_g, net_d = nets
|
net_g, net_d = nets
|
||||||
optim_g, optim_d = optims
|
optim_g, optim_d = optims
|
||||||
# scheduler_g, scheduler_d = schedulers
|
# scheduler_g, scheduler_d = schedulers
|
||||||
@ -244,18 +257,32 @@ def train_and_evaluate(
|
|||||||
global global_step
|
global global_step
|
||||||
|
|
||||||
net_g.train()
|
net_g.train()
|
||||||
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
|
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
|
||||||
|
tqdm(train_loader)
|
||||||
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
spec, spec_lengths = (
|
||||||
rank, non_blocking=True
|
spec.cuda(
|
||||||
)
|
rank,
|
||||||
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
|
non_blocking=True,
|
||||||
rank, non_blocking=True
|
),
|
||||||
|
spec_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
|
||||||
ssl = ssl.cuda(rank, non_blocking=True)
|
ssl = ssl.cuda(rank, non_blocking=True)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
|
text, text_lengths = (
|
||||||
rank, non_blocking=True
|
text.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
text_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
@ -265,7 +292,17 @@ def train_and_evaluate(
|
|||||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
cfm_loss = net_g(
|
||||||
|
ssl,
|
||||||
|
spec,
|
||||||
|
mel,
|
||||||
|
ssl_lengths,
|
||||||
|
spec_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
mel_lengths,
|
||||||
|
use_grad_ckpt=hps.train.grad_ckpt,
|
||||||
|
)
|
||||||
loss_gen_all = cfm_loss
|
loss_gen_all = cfm_loss
|
||||||
optim_g.zero_grad()
|
optim_g.zero_grad()
|
||||||
scaler.scale(loss_gen_all).backward()
|
scaler.scale(loss_gen_all).backward()
|
||||||
@ -276,18 +313,17 @@ def train_and_evaluate(
|
|||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if global_step % hps.train.log_interval == 0:
|
if global_step % hps.train.log_interval == 0:
|
||||||
lr = optim_g.param_groups[0]['lr']
|
lr = optim_g.param_groups[0]["lr"]
|
||||||
losses = [cfm_loss]
|
losses = [cfm_loss]
|
||||||
logger.info('Train Epoch: {} [{:.0f}%]'.format(
|
logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
|
||||||
epoch,
|
|
||||||
100. * batch_idx / len(train_loader)))
|
|
||||||
logger.info([x.item() for x in losses] + [global_step, lr])
|
logger.info([x.item() for x in losses] + [global_step, lr])
|
||||||
|
|
||||||
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
|
||||||
utils.summarize(
|
utils.summarize(
|
||||||
writer=writer,
|
writer=writer,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
scalars=scalar_dict)
|
scalars=scalar_dict,
|
||||||
|
)
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||||
@ -297,9 +333,7 @@ def train_and_evaluate(
|
|||||||
optim_g,
|
optim_g,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(save_root, "G_{}.pth".format(global_step)),
|
||||||
save_root, "G_{}.pth".format(global_step)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
utils.save_checkpoint(
|
utils.save_checkpoint(
|
||||||
@ -307,9 +341,7 @@ def train_and_evaluate(
|
|||||||
optim_g,
|
optim_g,
|
||||||
hps.train.learning_rate,
|
hps.train.learning_rate,
|
||||||
epoch,
|
epoch,
|
||||||
os.path.join(
|
os.path.join(save_root, "G_{}.pth".format(233333333333)),
|
||||||
save_root, "G_{}.pth".format(233333333333)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||||
if hasattr(net_g, "module"):
|
if hasattr(net_g, "module"):
|
||||||
@ -332,7 +364,8 @@ def train_and_evaluate(
|
|||||||
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
|
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
|
||||||
epoch,
|
epoch,
|
||||||
global_step,
|
global_step,
|
||||||
hps,lora_rank=lora_rank
|
hps,
|
||||||
|
lora_rank=lora_rank,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -3,19 +3,25 @@ import re
|
|||||||
|
|
||||||
# jieba静音
|
# jieba静音
|
||||||
import jieba
|
import jieba
|
||||||
|
|
||||||
jieba.setLogLevel(logging.CRITICAL)
|
jieba.setLogLevel(logging.CRITICAL)
|
||||||
|
|
||||||
# 更改fast_langdetect大模型位置
|
# 更改fast_langdetect大模型位置
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import fast_langdetect
|
import fast_langdetect
|
||||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
|
||||||
|
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
|
||||||
|
fast_langdetect.infer.LangDetectConfig(
|
||||||
|
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from split_lang import LangSplitter
|
from split_lang import LangSplitter
|
||||||
|
|
||||||
|
|
||||||
def full_en(text):
|
def full_en(text):
|
||||||
pattern = r'^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
|
pattern = r"^[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
|
||||||
return bool(re.match(pattern, text))
|
return bool(re.match(pattern, text))
|
||||||
|
|
||||||
|
|
||||||
@ -34,7 +40,7 @@ def full_cjk(text):
|
|||||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||||
]
|
]
|
||||||
|
|
||||||
pattern = r'[0-9、-〜。!?.!?… ]+$'
|
pattern = r"[0-9、-〜。!?.!?… ]+$"
|
||||||
|
|
||||||
cjk_text = ""
|
cjk_text = ""
|
||||||
for char in text:
|
for char in text:
|
||||||
@ -53,28 +59,28 @@ def split_jako(tag_lang,item):
|
|||||||
|
|
||||||
lang_list: list[dict] = []
|
lang_list: list[dict] = []
|
||||||
tag = 0
|
tag = 0
|
||||||
for match in re.finditer(pattern, item['text']):
|
for match in re.finditer(pattern, item["text"]):
|
||||||
if match.start() > tag:
|
if match.start() > tag:
|
||||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
|
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
|
||||||
|
|
||||||
tag = match.end()
|
tag = match.end()
|
||||||
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
|
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
|
||||||
|
|
||||||
if tag < len(item['text']):
|
if tag < len(item["text"]):
|
||||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
|
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
|
||||||
|
|
||||||
return lang_list
|
return lang_list
|
||||||
|
|
||||||
|
|
||||||
def merge_lang(lang_list, item):
|
def merge_lang(lang_list, item):
|
||||||
if lang_list and item['lang'] == lang_list[-1]['lang']:
|
if lang_list and item["lang"] == lang_list[-1]["lang"]:
|
||||||
lang_list[-1]['text'] += item['text']
|
lang_list[-1]["text"] += item["text"]
|
||||||
else:
|
else:
|
||||||
lang_list.append(item)
|
lang_list.append(item)
|
||||||
return lang_list
|
return lang_list
|
||||||
|
|
||||||
|
|
||||||
class LangSegmenter():
|
class LangSegmenter:
|
||||||
# 默认过滤器, 基于gsv目前四种语言
|
# 默认过滤器, 基于gsv目前四种语言
|
||||||
DEFAULT_LANG_MAP = {
|
DEFAULT_LANG_MAP = {
|
||||||
"zh": "zh",
|
"zh": "zh",
|
||||||
@ -87,7 +93,6 @@ class LangSegmenter():
|
|||||||
"en": "en",
|
"en": "en",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def getTexts(text):
|
def getTexts(text):
|
||||||
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
||||||
substr = lang_splitter.split_by_lang(text=text)
|
substr = lang_splitter.split_by_lang(text=text)
|
||||||
@ -95,18 +100,18 @@ class LangSegmenter():
|
|||||||
lang_list: list[dict] = []
|
lang_list: list[dict] = []
|
||||||
|
|
||||||
for _, item in enumerate(substr):
|
for _, item in enumerate(substr):
|
||||||
dict_item = {'lang':item.lang,'text':item.text}
|
dict_item = {"lang": item.lang, "text": item.text}
|
||||||
|
|
||||||
# 处理短英文被识别为其他语言的问题
|
# 处理短英文被识别为其他语言的问题
|
||||||
if full_en(dict_item['text']):
|
if full_en(dict_item["text"]):
|
||||||
dict_item['lang'] = 'en'
|
dict_item["lang"] = "en"
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 处理非日语夹日文的问题(不包含CJK)
|
# 处理非日语夹日文的问题(不包含CJK)
|
||||||
ja_list: list[dict] = []
|
ja_list: list[dict] = []
|
||||||
if dict_item['lang'] != 'ja':
|
if dict_item["lang"] != "ja":
|
||||||
ja_list = split_jako('ja',dict_item)
|
ja_list = split_jako("ja", dict_item)
|
||||||
|
|
||||||
if not ja_list:
|
if not ja_list:
|
||||||
ja_list.append(dict_item)
|
ja_list.append(dict_item)
|
||||||
@ -115,8 +120,8 @@ class LangSegmenter():
|
|||||||
ko_list: list[dict] = []
|
ko_list: list[dict] = []
|
||||||
temp_list: list[dict] = []
|
temp_list: list[dict] = []
|
||||||
for _, ko_item in enumerate(ja_list):
|
for _, ko_item in enumerate(ja_list):
|
||||||
if ko_item["lang"] != 'ko':
|
if ko_item["lang"] != "ko":
|
||||||
ko_list = split_jako('ko',ko_item)
|
ko_list = split_jako("ko", ko_item)
|
||||||
|
|
||||||
if ko_list:
|
if ko_list:
|
||||||
temp_list.extend(ko_list)
|
temp_list.extend(ko_list)
|
||||||
@ -126,10 +131,10 @@ class LangSegmenter():
|
|||||||
# 未存在非日韩文夹日韩文
|
# 未存在非日韩文夹日韩文
|
||||||
if len(temp_list) == 1:
|
if len(temp_list) == 1:
|
||||||
# 未知语言检查是否为CJK
|
# 未知语言检查是否为CJK
|
||||||
if dict_item['lang'] == 'x':
|
if dict_item["lang"] == "x":
|
||||||
cjk_text = full_cjk(dict_item['text'])
|
cjk_text = full_cjk(dict_item["text"])
|
||||||
if cjk_text:
|
if cjk_text:
|
||||||
dict_item = {'lang':'zh','text':cjk_text}
|
dict_item = {"lang": "zh", "text": cjk_text}
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -139,10 +144,10 @@ class LangSegmenter():
|
|||||||
# 存在非日韩文夹日韩文
|
# 存在非日韩文夹日韩文
|
||||||
for _, temp_item in enumerate(temp_list):
|
for _, temp_item in enumerate(temp_list):
|
||||||
# 未知语言检查是否为CJK
|
# 未知语言检查是否为CJK
|
||||||
if temp_item['lang'] == 'x':
|
if temp_item["lang"] == "x":
|
||||||
cjk_text = full_cjk(dict_item['text'])
|
cjk_text = full_cjk(dict_item["text"])
|
||||||
if cjk_text:
|
if cjk_text:
|
||||||
dict_item = {'lang':'zh','text':cjk_text}
|
dict_item = {"lang": "zh", "text": cjk_text}
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
else:
|
else:
|
||||||
lang_list = merge_lang(lang_list, temp_item)
|
lang_list = merge_lang(lang_list, temp_item)
|
||||||
@ -155,4 +160,3 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
|
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
|
||||||
print(LangSegmenter.getTexts(text))
|
print(LangSegmenter.getTexts(text))
|
||||||
|
|
||||||
|
@ -10,18 +10,19 @@ from text import symbols2 as symbols_v2
|
|||||||
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
_symbol_to_id_v1 = {s: i for i, s in enumerate(symbols_v1.symbols)}
|
||||||
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
_symbol_to_id_v2 = {s: i for i, s in enumerate(symbols_v2.symbols)}
|
||||||
|
|
||||||
|
|
||||||
def cleaned_text_to_sequence(cleaned_text, version=None):
|
def cleaned_text_to_sequence(cleaned_text, version=None):
|
||||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||||
Args:
|
Args:
|
||||||
text: string to convert to a sequence
|
text: string to convert to a sequence
|
||||||
Returns:
|
Returns:
|
||||||
List of integers corresponding to the symbols in the text
|
List of integers corresponding to the symbols in the text
|
||||||
'''
|
"""
|
||||||
if version is None:version=os.environ.get('version', 'v2')
|
if version is None:
|
||||||
|
version = os.environ.get("version", "v2")
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
|
phones = [_symbol_to_id_v1[symbol] for symbol in cleaned_text]
|
||||||
else:
|
else:
|
||||||
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
|
phones = [_symbol_to_id_v2[symbol] for symbol in cleaned_text]
|
||||||
|
|
||||||
return phones
|
return phones
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py
|
# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py
|
||||||
|
|
||||||
import sys
|
|
||||||
import re
|
import re
|
||||||
import cn2an
|
import cn2an
|
||||||
import ToJyutping
|
import ToJyutping
|
||||||
@ -99,9 +98,7 @@ def replace_punctuation(text):
|
|||||||
|
|
||||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||||
|
|
||||||
replaced_text = re.sub(
|
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
|
||||||
)
|
|
||||||
|
|
||||||
return replaced_text
|
return replaced_text
|
||||||
|
|
||||||
@ -116,6 +113,8 @@ def text_normalize(text):
|
|||||||
|
|
||||||
|
|
||||||
punctuation_set = set(punctuation)
|
punctuation_set = set(punctuation)
|
||||||
|
|
||||||
|
|
||||||
def jyuping_to_initials_finals_tones(jyuping_syllables):
|
def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||||
initials_finals = []
|
initials_finals = []
|
||||||
tones = []
|
tones = []
|
||||||
@ -162,10 +161,12 @@ def jyuping_to_initials_finals_tones(jyuping_syllables):
|
|||||||
###魔改为辅音+带音调的元音
|
###魔改为辅音+带音调的元音
|
||||||
phones = []
|
phones = []
|
||||||
for a, b in zip(initials_finals, tones):
|
for a, b in zip(initials_finals, tones):
|
||||||
if(b not in [-1,0]):###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||||
todo = "%s%s" % (a, b)
|
todo = "%s%s" % (a, b)
|
||||||
else:todo=a
|
else:
|
||||||
if(todo not in punctuation_set):todo="Y%s"%todo
|
todo = a
|
||||||
|
if todo not in punctuation_set:
|
||||||
|
todo = "Y%s" % todo
|
||||||
phones.append(todo)
|
phones.append(todo)
|
||||||
|
|
||||||
# return initials_finals, tones, word2ph
|
# return initials_finals, tones, word2ph
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import pdb
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import cn2an
|
import cn2an
|
||||||
@ -17,7 +16,9 @@ pinyin_to_symbol_map = {
|
|||||||
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
||||||
}
|
}
|
||||||
|
|
||||||
import jieba_fast, logging
|
import jieba_fast
|
||||||
|
import logging
|
||||||
|
|
||||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||||
import jieba_fast.posseg as psg
|
import jieba_fast.posseg as psg
|
||||||
|
|
||||||
@ -49,9 +50,7 @@ def replace_punctuation(text):
|
|||||||
|
|
||||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||||
|
|
||||||
replaced_text = re.sub(
|
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
|
||||||
)
|
|
||||||
|
|
||||||
return replaced_text
|
return replaced_text
|
||||||
|
|
||||||
@ -62,17 +61,15 @@ def replace_punctuation_with_en(text):
|
|||||||
|
|
||||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||||
|
|
||||||
replaced_text = re.sub(
|
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
|
||||||
)
|
|
||||||
|
|
||||||
return replaced_text
|
return replaced_text
|
||||||
|
|
||||||
|
|
||||||
def replace_consecutive_punctuation(text):
|
def replace_consecutive_punctuation(text):
|
||||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||||
result = re.sub(pattern, r'\1', text)
|
result = re.sub(pattern, r"\1", text)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -87,9 +84,7 @@ def _get_initials_finals(word):
|
|||||||
initials = []
|
initials = []
|
||||||
finals = []
|
finals = []
|
||||||
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||||
orig_finals = lazy_pinyin(
|
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
|
||||||
)
|
|
||||||
for c, v in zip(orig_initials, orig_finals):
|
for c, v in zip(orig_initials, orig_finals):
|
||||||
initials.append(c)
|
initials.append(c)
|
||||||
finals.append(v)
|
finals.append(v)
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import pdb
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import cn2an
|
import cn2an
|
||||||
from pypinyin import lazy_pinyin, Style
|
from pypinyin import lazy_pinyin, Style
|
||||||
from pypinyin.contrib.tone_convert import to_normal, to_finals_tone3, to_initials, to_finals
|
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
||||||
|
|
||||||
from text.symbols import punctuation
|
from text.symbols import punctuation
|
||||||
from text.tone_sandhi import ToneSandhi
|
from text.tone_sandhi import ToneSandhi
|
||||||
@ -18,7 +17,9 @@ pinyin_to_symbol_map = {
|
|||||||
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
||||||
}
|
}
|
||||||
|
|
||||||
import jieba_fast, logging
|
import jieba_fast
|
||||||
|
import logging
|
||||||
|
|
||||||
jieba_fast.setLogLevel(logging.CRITICAL)
|
jieba_fast.setLogLevel(logging.CRITICAL)
|
||||||
import jieba_fast.posseg as psg
|
import jieba_fast.posseg as psg
|
||||||
|
|
||||||
@ -28,8 +29,14 @@ is_g2pw = True#True if is_g2pw_str.lower() == 'true' else False
|
|||||||
if is_g2pw:
|
if is_g2pw:
|
||||||
# print("当前使用g2pw进行拼音推理")
|
# print("当前使用g2pw进行拼音推理")
|
||||||
from text.g2pw import G2PWPinyin, correct_pronunciation
|
from text.g2pw import G2PWPinyin, correct_pronunciation
|
||||||
|
|
||||||
parent_directory = os.path.dirname(current_file_path)
|
parent_directory = os.path.dirname(current_file_path)
|
||||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source=os.environ.get("bert_path","GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),v_to_u=False, neutral_tone_with_five=True)
|
g2pw = G2PWPinyin(
|
||||||
|
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||||
|
model_source=os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
|
||||||
|
v_to_u=False,
|
||||||
|
neutral_tone_with_five=True,
|
||||||
|
)
|
||||||
|
|
||||||
rep_map = {
|
rep_map = {
|
||||||
":": ",",
|
":": ",",
|
||||||
@ -58,9 +65,7 @@ def replace_punctuation(text):
|
|||||||
|
|
||||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||||
|
|
||||||
replaced_text = re.sub(
|
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||||
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
|
|
||||||
)
|
|
||||||
|
|
||||||
return replaced_text
|
return replaced_text
|
||||||
|
|
||||||
@ -77,9 +82,7 @@ def _get_initials_finals(word):
|
|||||||
finals = []
|
finals = []
|
||||||
|
|
||||||
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||||
orig_finals = lazy_pinyin(
|
orig_finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
|
|
||||||
)
|
|
||||||
|
|
||||||
for c, v in zip(orig_initials, orig_finals):
|
for c, v in zip(orig_initials, orig_finals):
|
||||||
initials.append(c)
|
initials.append(c)
|
||||||
@ -87,31 +90,66 @@ def _get_initials_finals(word):
|
|||||||
return initials, finals
|
return initials, finals
|
||||||
|
|
||||||
|
|
||||||
must_erhua = {
|
must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"}
|
||||||
"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿", "媳妇儿"
|
|
||||||
}
|
|
||||||
not_erhua = {
|
not_erhua = {
|
||||||
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
|
"虐儿",
|
||||||
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
|
"为儿",
|
||||||
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
|
"护儿",
|
||||||
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
|
"瞒儿",
|
||||||
"狗儿", "少儿"
|
"救儿",
|
||||||
|
"替儿",
|
||||||
|
"有儿",
|
||||||
|
"一儿",
|
||||||
|
"我儿",
|
||||||
|
"俺儿",
|
||||||
|
"妻儿",
|
||||||
|
"拐儿",
|
||||||
|
"聋儿",
|
||||||
|
"乞儿",
|
||||||
|
"患儿",
|
||||||
|
"幼儿",
|
||||||
|
"孤儿",
|
||||||
|
"婴儿",
|
||||||
|
"婴幼儿",
|
||||||
|
"连体儿",
|
||||||
|
"脑瘫儿",
|
||||||
|
"流浪儿",
|
||||||
|
"体弱儿",
|
||||||
|
"混血儿",
|
||||||
|
"蜜雪儿",
|
||||||
|
"舫儿",
|
||||||
|
"祖儿",
|
||||||
|
"美儿",
|
||||||
|
"应采儿",
|
||||||
|
"可儿",
|
||||||
|
"侄儿",
|
||||||
|
"孙儿",
|
||||||
|
"侄孙儿",
|
||||||
|
"女儿",
|
||||||
|
"男儿",
|
||||||
|
"红孩儿",
|
||||||
|
"花儿",
|
||||||
|
"虫儿",
|
||||||
|
"马儿",
|
||||||
|
"鸟儿",
|
||||||
|
"猪儿",
|
||||||
|
"猫儿",
|
||||||
|
"狗儿",
|
||||||
|
"少儿",
|
||||||
}
|
}
|
||||||
def _merge_erhua(initials: list[str],
|
|
||||||
finals: list[str],
|
|
||||||
word: str,
|
def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) -> list[list[str]]:
|
||||||
pos: str) -> list[list[str]]:
|
|
||||||
"""
|
"""
|
||||||
Do erhub.
|
Do erhub.
|
||||||
"""
|
"""
|
||||||
# fix er1
|
# fix er1
|
||||||
for i, phn in enumerate(finals):
|
for i, phn in enumerate(finals):
|
||||||
if i == len(finals) - 1 and word[i] == "儿" and phn == 'er1':
|
if i == len(finals) - 1 and word[i] == "儿" and phn == "er1":
|
||||||
finals[i] = 'er2'
|
finals[i] = "er2"
|
||||||
|
|
||||||
# 发音
|
# 发音
|
||||||
if word not in must_erhua and (word in not_erhua or
|
if word not in must_erhua and (word in not_erhua or pos in {"a", "j", "nr"}):
|
||||||
pos in {"a", "j", "nr"}):
|
|
||||||
return initials, finals
|
return initials, finals
|
||||||
|
|
||||||
# "……" 等情况直接返回
|
# "……" 等情况直接返回
|
||||||
@ -124,9 +162,13 @@ def _merge_erhua(initials: list[str],
|
|||||||
new_initials = []
|
new_initials = []
|
||||||
new_finals = []
|
new_finals = []
|
||||||
for i, phn in enumerate(finals):
|
for i, phn in enumerate(finals):
|
||||||
if i == len(finals) - 1 and word[i] == "儿" and phn in {
|
if (
|
||||||
"er2", "er5"
|
i == len(finals) - 1
|
||||||
} and word[-2:] not in not_erhua and new_finals:
|
and word[i] == "儿"
|
||||||
|
and phn in {"er2", "er5"}
|
||||||
|
and word[-2:] not in not_erhua
|
||||||
|
and new_finals
|
||||||
|
):
|
||||||
phn = "er" + new_finals[-1][-1]
|
phn = "er" + new_finals[-1][-1]
|
||||||
|
|
||||||
new_initials.append(initials[i])
|
new_initials.append(initials[i])
|
||||||
@ -171,7 +213,7 @@ def _g2p(segments):
|
|||||||
sub_finals = []
|
sub_finals = []
|
||||||
now_word_length = pre_word_length + len(word)
|
now_word_length = pre_word_length + len(word)
|
||||||
|
|
||||||
if pos == 'eng':
|
if pos == "eng":
|
||||||
pre_word_length = now_word_length
|
pre_word_length = now_word_length
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -259,18 +301,18 @@ def replace_punctuation_with_en(text):
|
|||||||
|
|
||||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||||
|
|
||||||
replaced_text = re.sub(
|
replaced_text = re.sub(r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||||
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
|
|
||||||
)
|
|
||||||
|
|
||||||
return replaced_text
|
return replaced_text
|
||||||
|
|
||||||
|
|
||||||
def replace_consecutive_punctuation(text):
|
def replace_consecutive_punctuation(text):
|
||||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||||
result = re.sub(pattern, r'\1', text)
|
result = re.sub(pattern, r"\1", text)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def text_normalize(text):
|
def text_normalize(text):
|
||||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||||
tx = TextNormalizer()
|
tx = TextNormalizer()
|
||||||
@ -283,6 +325,7 @@ def text_normalize(text):
|
|||||||
dest_text = replace_consecutive_punctuation(dest_text)
|
dest_text = replace_consecutive_punctuation(dest_text)
|
||||||
return dest_text
|
return dest_text
|
||||||
|
|
||||||
|
|
||||||
# 不排除英文的文本格式化
|
# 不排除英文的文本格式化
|
||||||
def mix_text_normalize(text):
|
def mix_text_normalize(text):
|
||||||
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
|
||||||
|
@ -19,7 +19,8 @@ special = [
|
|||||||
|
|
||||||
|
|
||||||
def clean_text(text, language, version=None):
|
def clean_text(text, language, version=None):
|
||||||
if version is None:version=os.environ.get('version', 'v2')
|
if version is None:
|
||||||
|
version = os.environ.get("version", "v2")
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
symbols = symbols_v1.symbols
|
symbols = symbols_v1.symbols
|
||||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||||
@ -27,7 +28,7 @@ def clean_text(text, language, version=None):
|
|||||||
symbols = symbols_v2.symbols
|
symbols = symbols_v2.symbols
|
||||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
|
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
|
||||||
|
|
||||||
if(language not in language_module_map):
|
if language not in language_module_map:
|
||||||
language = "en"
|
language = "en"
|
||||||
text = " "
|
text = " "
|
||||||
for special_s, special_l, target_symbol in special:
|
for special_s, special_l, target_symbol in special:
|
||||||
@ -45,17 +46,18 @@ def clean_text(text, language, version=None):
|
|||||||
elif language == "en":
|
elif language == "en":
|
||||||
phones = language_module.g2p(norm_text)
|
phones = language_module.g2p(norm_text)
|
||||||
if len(phones) < 4:
|
if len(phones) < 4:
|
||||||
phones = [','] + phones
|
phones = [","] + phones
|
||||||
word2ph = None
|
word2ph = None
|
||||||
else:
|
else:
|
||||||
phones = language_module.g2p(norm_text)
|
phones = language_module.g2p(norm_text)
|
||||||
word2ph = None
|
word2ph = None
|
||||||
phones = ['UNK' if ph not in symbols else ph for ph in phones]
|
phones = ["UNK" if ph not in symbols else ph for ph in phones]
|
||||||
return phones, word2ph, norm_text
|
return phones, word2ph, norm_text
|
||||||
|
|
||||||
|
|
||||||
def clean_special(text, language, special_s, target_symbol, version=None):
|
def clean_special(text, language, special_s, target_symbol, version=None):
|
||||||
if version is None:version=os.environ.get('version', 'v2')
|
if version is None:
|
||||||
|
version = os.environ.get("version", "v2")
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
symbols = symbols_v1.symbols
|
symbols = symbols_v1.symbols
|
||||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||||
@ -81,8 +83,9 @@ def clean_special(text, language, special_s, target_symbol, version=None):
|
|||||||
|
|
||||||
|
|
||||||
def text_to_sequence(text, language, version=None):
|
def text_to_sequence(text, language, version=None):
|
||||||
version = os.environ.get('version',version)
|
version = os.environ.get("version", version)
|
||||||
if version is None:version='v2'
|
if version is None:
|
||||||
|
version = "v2"
|
||||||
phones = clean_text(text)
|
phones = clean_text(text)
|
||||||
return cleaned_text_to_sequence(phones, version)
|
return cleaned_text_to_sequence(phones, version)
|
||||||
|
|
||||||
|
@ -9,17 +9,17 @@ import unicodedata
|
|||||||
# 后缀计量单位替换表
|
# 后缀计量单位替换表
|
||||||
measurement_map = {
|
measurement_map = {
|
||||||
"m": ["meter", "meters"],
|
"m": ["meter", "meters"],
|
||||||
'km': ["kilometer", "kilometers"],
|
"km": ["kilometer", "kilometers"],
|
||||||
"km/h": ["kilometer per hour", "kilometers per hour"],
|
"km/h": ["kilometer per hour", "kilometers per hour"],
|
||||||
"ft": ["feet", "feet"],
|
"ft": ["feet", "feet"],
|
||||||
"L": ["liter", "liters"],
|
"L": ["liter", "liters"],
|
||||||
"tbsp": ["tablespoon", "tablespoons"],
|
"tbsp": ["tablespoon", "tablespoons"],
|
||||||
'tsp': ["teaspoon", "teaspoons"],
|
"tsp": ["teaspoon", "teaspoons"],
|
||||||
"h": ["hour", "hours"],
|
"h": ["hour", "hours"],
|
||||||
"min": ["minute", "minutes"],
|
"min": ["minute", "minutes"],
|
||||||
"s": ["second", "seconds"],
|
"s": ["second", "seconds"],
|
||||||
"°C": ["degree celsius", "degrees celsius"],
|
"°C": ["degree celsius", "degrees celsius"],
|
||||||
"°F": ["degree fahrenheit", "degrees fahrenheit"]
|
"°F": ["degree fahrenheit", "degrees fahrenheit"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -27,37 +27,38 @@ measurement_map = {
|
|||||||
_inflect = inflect.engine()
|
_inflect = inflect.engine()
|
||||||
|
|
||||||
# 转化数字序数词
|
# 转化数字序数词
|
||||||
_ordinal_number_re = re.compile(r'\b([0-9]+)\. ')
|
_ordinal_number_re = re.compile(r"\b([0-9]+)\. ")
|
||||||
|
|
||||||
# 我听说好像对于数字正则识别其实用 \d 会好一点
|
# 我听说好像对于数字正则识别其实用 \d 会好一点
|
||||||
|
|
||||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||||
|
|
||||||
# 时间识别
|
# 时间识别
|
||||||
_time_re = re.compile(r'\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b')
|
_time_re = re.compile(r"\b([01]?[0-9]|2[0-3]):([0-5][0-9])\b")
|
||||||
|
|
||||||
# 后缀计量单位识别
|
# 后缀计量单位识别
|
||||||
_measurement_re = re.compile(r'\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b')
|
_measurement_re = re.compile(r"\b([0-9]+(\.[0-9]+)?(m|km|km/h|ft|L|tbsp|tsp|h|min|s|°C|°F))\b")
|
||||||
|
|
||||||
# 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ )
|
# 前后 £ 识别 ( 写了识别两边某一边的,但是不知道为什么失败了┭┮﹏┭┮ )
|
||||||
_pounds_re_start = re.compile(r'£([0-9\.\,]*[0-9]+)')
|
_pounds_re_start = re.compile(r"£([0-9\.\,]*[0-9]+)")
|
||||||
_pounds_re_end = re.compile(r'([0-9\.\,]*[0-9]+)£')
|
_pounds_re_end = re.compile(r"([0-9\.\,]*[0-9]+)£")
|
||||||
|
|
||||||
# 前后 $ 识别
|
# 前后 $ 识别
|
||||||
_dollars_re_start = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
_dollars_re_start = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||||
_dollars_re_end = re.compile(r'([(0-9\.\,]*[0-9]+)\$')
|
_dollars_re_end = re.compile(r"([(0-9\.\,]*[0-9]+)\$")
|
||||||
|
|
||||||
# 小数的识别
|
# 小数的识别
|
||||||
_decimal_number_re = re.compile(r'([0-9]+\.\s*[0-9]+)')
|
_decimal_number_re = re.compile(r"([0-9]+\.\s*[0-9]+)")
|
||||||
|
|
||||||
# 分数识别 (形式 "3/4" )
|
# 分数识别 (形式 "3/4" )
|
||||||
_fraction_re = re.compile(r'([0-9]+/[0-9]+)')
|
_fraction_re = re.compile(r"([0-9]+/[0-9]+)")
|
||||||
|
|
||||||
# 序数词识别
|
# 序数词识别
|
||||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||||
|
|
||||||
# 数字处理
|
# 数字处理
|
||||||
_number_re = re.compile(r'[0-9]+')
|
_number_re = re.compile(r"[0-9]+")
|
||||||
|
|
||||||
|
|
||||||
def _convert_ordinal(m):
|
def _convert_ordinal(m):
|
||||||
"""
|
"""
|
||||||
@ -70,8 +71,10 @@ def _convert_ordinal(m):
|
|||||||
ordinal = _inflect.ordinal(m.group(1))
|
ordinal = _inflect.ordinal(m.group(1))
|
||||||
return ordinal + ", "
|
return ordinal + ", "
|
||||||
|
|
||||||
|
|
||||||
def _remove_commas(m):
|
def _remove_commas(m):
|
||||||
return m.group(1).replace(',', '')
|
return m.group(1).replace(",", "")
|
||||||
|
|
||||||
|
|
||||||
def _expand_time(m):
|
def _expand_time(m):
|
||||||
"""
|
"""
|
||||||
@ -82,12 +85,12 @@ def _expand_time(m):
|
|||||||
output: "one o'clock p.m. / four o'clock am. / one thirty p.m."
|
output: "one o'clock p.m. / four o'clock am. / one thirty p.m."
|
||||||
"""
|
"""
|
||||||
hours, minutes = map(int, m.group(1, 2))
|
hours, minutes = map(int, m.group(1, 2))
|
||||||
period = 'a.m.' if hours < 12 else 'p.m.'
|
period = "a.m." if hours < 12 else "p.m."
|
||||||
if hours > 12:
|
if hours > 12:
|
||||||
hours -= 12
|
hours -= 12
|
||||||
|
|
||||||
hour_word = _inflect.number_to_words(hours)
|
hour_word = _inflect.number_to_words(hours)
|
||||||
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ''
|
minute_word = _inflect.number_to_words(minutes) if minutes != 0 else ""
|
||||||
|
|
||||||
if minutes == 0:
|
if minutes == 0:
|
||||||
return f"{hour_word} o'clock {period}"
|
return f"{hour_word} o'clock {period}"
|
||||||
@ -103,7 +106,7 @@ def _expand_measurement(m):
|
|||||||
sign = m.group(3)
|
sign = m.group(3)
|
||||||
ptr = 1
|
ptr = 1
|
||||||
# 想不到怎么方便的取数字,又懒得改正则,诶,1.2 反正也是复数读法,干脆直接去掉 "."
|
# 想不到怎么方便的取数字,又懒得改正则,诶,1.2 反正也是复数读法,干脆直接去掉 "."
|
||||||
num = int(m.group(1).replace(sign, '').replace(".",''))
|
num = int(m.group(1).replace(sign, "").replace(".", ""))
|
||||||
decimal_part = m.group(2)
|
decimal_part = m.group(2)
|
||||||
# 上面判断的漏洞,比如 0.1 的情况,在这里排除了
|
# 上面判断的漏洞,比如 0.1 的情况,在这里排除了
|
||||||
if decimal_part == None and num == 1:
|
if decimal_part == None and num == 1:
|
||||||
@ -116,23 +119,24 @@ def _expand_pounds(m):
|
|||||||
没找到特别规范的说明,和美元的处理一样,其实可以把两个合并在一起
|
没找到特别规范的说明,和美元的处理一样,其实可以把两个合并在一起
|
||||||
"""
|
"""
|
||||||
match = m.group(1)
|
match = m.group(1)
|
||||||
parts = match.split('.')
|
parts = match.split(".")
|
||||||
if len(parts) > 2:
|
if len(parts) > 2:
|
||||||
return match + ' pounds' # Unexpected format
|
return match + " pounds" # Unexpected format
|
||||||
pounds = int(parts[0]) if parts[0] else 0
|
pounds = int(parts[0]) if parts[0] else 0
|
||||||
pence = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
|
pence = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
|
||||||
if pounds and pence:
|
if pounds and pence:
|
||||||
pound_unit = 'pound' if pounds == 1 else 'pounds'
|
pound_unit = "pound" if pounds == 1 else "pounds"
|
||||||
penny_unit = 'penny' if pence == 1 else 'pence'
|
penny_unit = "penny" if pence == 1 else "pence"
|
||||||
return '%s %s and %s %s' % (pounds, pound_unit, pence, penny_unit)
|
return "%s %s and %s %s" % (pounds, pound_unit, pence, penny_unit)
|
||||||
elif pounds:
|
elif pounds:
|
||||||
pound_unit = 'pound' if pounds == 1 else 'pounds'
|
pound_unit = "pound" if pounds == 1 else "pounds"
|
||||||
return '%s %s' % (pounds, pound_unit)
|
return "%s %s" % (pounds, pound_unit)
|
||||||
elif pence:
|
elif pence:
|
||||||
penny_unit = 'penny' if pence == 1 else 'pence'
|
penny_unit = "penny" if pence == 1 else "pence"
|
||||||
return '%s %s' % (pence, penny_unit)
|
return "%s %s" % (pence, penny_unit)
|
||||||
else:
|
else:
|
||||||
return 'zero pounds'
|
return "zero pounds"
|
||||||
|
|
||||||
|
|
||||||
def _expand_dollars(m):
|
def _expand_dollars(m):
|
||||||
"""
|
"""
|
||||||
@ -142,23 +146,24 @@ def _expand_dollars(m):
|
|||||||
output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents"
|
output: "thirty-two dollars and thirty cents" / "six dollars and twenty-four cents"
|
||||||
"""
|
"""
|
||||||
match = m.group(1)
|
match = m.group(1)
|
||||||
parts = match.split('.')
|
parts = match.split(".")
|
||||||
if len(parts) > 2:
|
if len(parts) > 2:
|
||||||
return match + ' dollars' # Unexpected format
|
return match + " dollars" # Unexpected format
|
||||||
dollars = int(parts[0]) if parts[0] else 0
|
dollars = int(parts[0]) if parts[0] else 0
|
||||||
cents = int(parts[1].ljust(2, '0')) if len(parts) > 1 and parts[1] else 0
|
cents = int(parts[1].ljust(2, "0")) if len(parts) > 1 and parts[1] else 0
|
||||||
if dollars and cents:
|
if dollars and cents:
|
||||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
cent_unit = "cent" if cents == 1 else "cents"
|
||||||
return '%s %s and %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
return "%s %s and %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||||
elif dollars:
|
elif dollars:
|
||||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||||
return '%s %s' % (dollars, dollar_unit)
|
return "%s %s" % (dollars, dollar_unit)
|
||||||
elif cents:
|
elif cents:
|
||||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
cent_unit = "cent" if cents == 1 else "cents"
|
||||||
return '%s %s' % (cents, cent_unit)
|
return "%s %s" % (cents, cent_unit)
|
||||||
else:
|
else:
|
||||||
return 'zero dollars'
|
return "zero dollars"
|
||||||
|
|
||||||
|
|
||||||
# 小数的处理
|
# 小数的处理
|
||||||
def _expand_decimal_number(m):
|
def _expand_decimal_number(m):
|
||||||
@ -168,11 +173,11 @@ def _expand_decimal_number(m):
|
|||||||
output: "thirteen point two three four"
|
output: "thirteen point two three four"
|
||||||
"""
|
"""
|
||||||
match = m.group(1)
|
match = m.group(1)
|
||||||
parts = match.split('.')
|
parts = match.split(".")
|
||||||
words = []
|
words = []
|
||||||
# 遍历字符串中的每个字符
|
# 遍历字符串中的每个字符
|
||||||
for char in parts[1]:
|
for char in parts[1]:
|
||||||
if char == '.':
|
if char == ".":
|
||||||
words.append("point")
|
words.append("point")
|
||||||
else:
|
else:
|
||||||
words.append(char)
|
words.append(char)
|
||||||
@ -196,39 +201,41 @@ def _expend_fraction(m):
|
|||||||
| 3/2 | three halves |
|
| 3/2 | three halves |
|
||||||
"""
|
"""
|
||||||
match = m.group(0)
|
match = m.group(0)
|
||||||
numerator, denominator = map(int, match.split('/'))
|
numerator, denominator = map(int, match.split("/"))
|
||||||
|
|
||||||
numerator_part = _inflect.number_to_words(numerator)
|
numerator_part = _inflect.number_to_words(numerator)
|
||||||
if denominator == 2:
|
if denominator == 2:
|
||||||
if numerator == 1:
|
if numerator == 1:
|
||||||
denominator_part = 'half'
|
denominator_part = "half"
|
||||||
else:
|
else:
|
||||||
denominator_part = 'halves'
|
denominator_part = "halves"
|
||||||
elif denominator == 1:
|
elif denominator == 1:
|
||||||
return f'{numerator_part}'
|
return f"{numerator_part}"
|
||||||
else:
|
else:
|
||||||
denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator))
|
denominator_part = _inflect.ordinal(_inflect.number_to_words(denominator))
|
||||||
if numerator > 1:
|
if numerator > 1:
|
||||||
denominator_part += 's'
|
denominator_part += "s"
|
||||||
|
|
||||||
|
return f"{numerator_part} {denominator_part}"
|
||||||
|
|
||||||
return f'{numerator_part} {denominator_part}'
|
|
||||||
|
|
||||||
def _expand_ordinal(m):
|
def _expand_ordinal(m):
|
||||||
return _inflect.number_to_words(m.group(0))
|
return _inflect.number_to_words(m.group(0))
|
||||||
|
|
||||||
|
|
||||||
def _expand_number(m):
|
def _expand_number(m):
|
||||||
num = int(m.group(0))
|
num = int(m.group(0))
|
||||||
if num > 1000 and num < 3000:
|
if num > 1000 and num < 3000:
|
||||||
if num == 2000:
|
if num == 2000:
|
||||||
return 'two thousand'
|
return "two thousand"
|
||||||
elif num > 2000 and num < 2010:
|
elif num > 2000 and num < 2010:
|
||||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||||
elif num % 100 == 0:
|
elif num % 100 == 0:
|
||||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
return _inflect.number_to_words(num // 100) + " hundred"
|
||||||
else:
|
else:
|
||||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||||
else:
|
else:
|
||||||
return _inflect.number_to_words(num, andword='')
|
return _inflect.number_to_words(num, andword="")
|
||||||
|
|
||||||
|
|
||||||
def normalize(text):
|
def normalize(text):
|
||||||
@ -238,7 +245,7 @@ def normalize(text):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
|
text = re.sub(_ordinal_number_re, _convert_ordinal, text)
|
||||||
text = re.sub(r'(?<!\d)-|-(?!\d)', ' minus ', text)
|
text = re.sub(r"(?<!\d)-|-(?!\d)", " minus ", text)
|
||||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
text = re.sub(_time_re, _expand_time, text)
|
text = re.sub(_time_re, _expand_time, text)
|
||||||
text = re.sub(_measurement_re, _expand_measurement, text)
|
text = re.sub(_measurement_re, _expand_measurement, text)
|
||||||
@ -251,19 +258,20 @@ def normalize(text):
|
|||||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||||
text = re.sub(_number_re, _expand_number, text)
|
text = re.sub(_number_re, _expand_number, text)
|
||||||
|
|
||||||
text = ''.join(char for char in unicodedata.normalize('NFD', text)
|
text = "".join(
|
||||||
if unicodedata.category(char) != 'Mn') # Strip accents
|
char for char in unicodedata.normalize("NFD", text) if unicodedata.category(char) != "Mn"
|
||||||
|
) # Strip accents
|
||||||
|
|
||||||
text = re.sub("%", " percent", text)
|
text = re.sub("%", " percent", text)
|
||||||
text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
|
text = re.sub("[^ A-Za-z'.,?!\-]", "", text)
|
||||||
text = re.sub(r"(?i)i\.e\.", "that is", text)
|
text = re.sub(r"(?i)i\.e\.", "that is", text)
|
||||||
text = re.sub(r"(?i)e\.g\.", "for example", text)
|
text = re.sub(r"(?i)e\.g\.", "for example", text)
|
||||||
# 增加纯大写单词拆分
|
# 增加纯大写单词拆分
|
||||||
text = re.sub(r'(?<!^)(?<![\s])([A-Z])', r' \1', text)
|
text = re.sub(r"(?<!^)(?<![\s])([A-Z])", r" \1", text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# 我觉得其实可以把切分结果展示出来(只读,或者修改不影响传给TTS的实际text)
|
# 我觉得其实可以把切分结果展示出来(只读,或者修改不影响传给TTS的实际text)
|
||||||
# 然后让用户确认后再输入给 TTS,可以让用户检查自己有没有不标准的输入
|
# 然后让用户确认后再输入给 TTS,可以让用户检查自己有没有不标准的输入
|
||||||
print(normalize("1. test ordinal number 1st"))
|
print(normalize("1. test ordinal number 1st"))
|
||||||
|
@ -8,10 +8,10 @@ from text.symbols import punctuation
|
|||||||
|
|
||||||
from text.symbols2 import symbols
|
from text.symbols2 import symbols
|
||||||
|
|
||||||
import unicodedata
|
|
||||||
from builtins import str as unicode
|
from builtins import str as unicode
|
||||||
from text.en_normalization.expend import normalize
|
from text.en_normalization.expend import normalize
|
||||||
from nltk.tokenize import TweetTokenizer
|
from nltk.tokenize import TweetTokenizer
|
||||||
|
|
||||||
word_tokenize = TweetTokenizer().tokenize
|
word_tokenize = TweetTokenizer().tokenize
|
||||||
from nltk import pos_tag
|
from nltk import pos_tag
|
||||||
|
|
||||||
@ -122,9 +122,9 @@ def replace_phs(phs):
|
|||||||
|
|
||||||
|
|
||||||
def replace_consecutive_punctuation(text):
|
def replace_consecutive_punctuation(text):
|
||||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||||
pattern = f'([{punctuations}\s])([{punctuations}])+'
|
pattern = f"([{punctuations}\s])([{punctuations}])+"
|
||||||
result = re.sub(pattern, r'\1', text)
|
result = re.sub(pattern, r"\1", text)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -183,6 +183,7 @@ def read_dict_new():
|
|||||||
|
|
||||||
return g2p_dict
|
return g2p_dict
|
||||||
|
|
||||||
|
|
||||||
def hot_reload_hot(g2p_dict):
|
def hot_reload_hot(g2p_dict):
|
||||||
with open(CMU_DICT_HOT_PATH) as f:
|
with open(CMU_DICT_HOT_PATH) as f:
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
@ -259,9 +260,12 @@ class en_G2p(G2p):
|
|||||||
del self.cmu[word.lower()]
|
del self.cmu[word.lower()]
|
||||||
|
|
||||||
# 修正多音字
|
# 修正多音字
|
||||||
self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP')
|
self.homograph2features["read"] = (["R", "IY1", "D"], ["R", "EH1", "D"], "VBP")
|
||||||
self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ')
|
self.homograph2features["complex"] = (
|
||||||
|
["K", "AH0", "M", "P", "L", "EH1", "K", "S"],
|
||||||
|
["K", "AA1", "M", "P", "L", "EH0", "K", "S"],
|
||||||
|
"JJ",
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, text):
|
def __call__(self, text):
|
||||||
# tokenization
|
# tokenization
|
||||||
@ -280,7 +284,7 @@ class en_G2p(G2p):
|
|||||||
elif len(word) == 1:
|
elif len(word) == 1:
|
||||||
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写
|
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写
|
||||||
if o_word == "A":
|
if o_word == "A":
|
||||||
pron = ['EY1']
|
pron = ["EY1"]
|
||||||
else:
|
else:
|
||||||
pron = self.cmu[word][0]
|
pron = self.cmu[word][0]
|
||||||
# g2p_en 原版多音字处理
|
# g2p_en 原版多音字处理
|
||||||
@ -302,7 +306,6 @@ class en_G2p(G2p):
|
|||||||
|
|
||||||
return prons[:-1]
|
return prons[:-1]
|
||||||
|
|
||||||
|
|
||||||
def qryword(self, o_word):
|
def qryword(self, o_word):
|
||||||
word = o_word.lower()
|
word = o_word.lower()
|
||||||
|
|
||||||
@ -320,7 +323,7 @@ class en_G2p(G2p):
|
|||||||
for w in word:
|
for w in word:
|
||||||
# 单读 A 发音修正, 此处不存在大写的情况
|
# 单读 A 发音修正, 此处不存在大写的情况
|
||||||
if w == "a":
|
if w == "a":
|
||||||
phones.extend(['EY1'])
|
phones.extend(["EY1"])
|
||||||
elif not w.isalpha():
|
elif not w.isalpha():
|
||||||
phones.extend([w])
|
phones.extend([w])
|
||||||
else:
|
else:
|
||||||
@ -331,16 +334,16 @@ class en_G2p(G2p):
|
|||||||
if re.match(r"^([a-z]+)('s)$", word):
|
if re.match(r"^([a-z]+)('s)$", word):
|
||||||
phones = self.qryword(word[:-2])[:]
|
phones = self.qryword(word[:-2])[:]
|
||||||
# P T K F TH HH 无声辅音结尾 's 发 ['S']
|
# P T K F TH HH 无声辅音结尾 's 发 ['S']
|
||||||
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
|
if phones[-1] in ["P", "T", "K", "F", "TH", "HH"]:
|
||||||
phones.extend(['S'])
|
phones.extend(["S"])
|
||||||
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
|
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
|
||||||
elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']:
|
elif phones[-1] in ["S", "Z", "SH", "ZH", "CH", "JH"]:
|
||||||
phones.extend(['AH0', 'Z'])
|
phones.extend(["AH0", "Z"])
|
||||||
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
|
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
|
||||||
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
|
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
|
||||||
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
|
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
|
||||||
else:
|
else:
|
||||||
phones.extend(['Z'])
|
phones.extend(["Z"])
|
||||||
return phones
|
return phones
|
||||||
|
|
||||||
# 尝试进行分词,应对复合词
|
# 尝试进行分词,应对复合词
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
Credits
|
Credits
|
||||||
This code is modified from https://github.com/GitYCC/g2pW
|
This code is modified from https://github.com/GitYCC/g2pW
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
@ -23,10 +24,11 @@ import numpy as np
|
|||||||
|
|
||||||
from .utils import tokenize_and_map
|
from .utils import tokenize_and_map
|
||||||
|
|
||||||
ANCHOR_CHAR = '▁'
|
ANCHOR_CHAR = "▁"
|
||||||
|
|
||||||
|
|
||||||
def prepare_onnx_input(tokenizer,
|
def prepare_onnx_input(
|
||||||
|
tokenizer,
|
||||||
labels: List[str],
|
labels: List[str],
|
||||||
char2phonemes: Dict[str, List[int]],
|
char2phonemes: Dict[str, List[int]],
|
||||||
chars: List[str],
|
chars: List[str],
|
||||||
@ -34,10 +36,12 @@ def prepare_onnx_input(tokenizer,
|
|||||||
query_ids: List[int],
|
query_ids: List[int],
|
||||||
use_mask: bool = False,
|
use_mask: bool = False,
|
||||||
window_size: int = None,
|
window_size: int = None,
|
||||||
max_len: int=512) -> Dict[str, np.array]:
|
max_len: int = 512,
|
||||||
|
) -> Dict[str, np.array]:
|
||||||
if window_size is not None:
|
if window_size is not None:
|
||||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||||
window_size=window_size, texts=texts, query_ids=query_ids)
|
window_size=window_size, texts=texts, query_ids=query_ids
|
||||||
|
)
|
||||||
input_ids = []
|
input_ids = []
|
||||||
token_type_ids = []
|
token_type_ids = []
|
||||||
attention_masks = []
|
attention_masks = []
|
||||||
@ -50,33 +54,27 @@ def prepare_onnx_input(tokenizer,
|
|||||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokens, text2token, token2text = tokenize_and_map(
|
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||||
tokenizer=tokenizer, text=text)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f'warning: text "{text}" is invalid')
|
print(f'warning: text "{text}" is invalid')
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
text, query_id, tokens, text2token, token2text = _truncate(
|
text, query_id, tokens, text2token, token2text = _truncate(
|
||||||
max_len=max_len,
|
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
|
||||||
text=text,
|
)
|
||||||
query_id=query_id,
|
|
||||||
tokens=tokens,
|
|
||||||
text2token=text2token,
|
|
||||||
token2text=token2text)
|
|
||||||
|
|
||||||
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||||
|
|
||||||
input_id = list(
|
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||||
np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
|
||||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||||
|
|
||||||
query_char = text[query_id]
|
query_char = text[query_id]
|
||||||
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
|
phoneme_mask = (
|
||||||
if use_mask else [1] * len(labels)
|
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
|
||||||
|
)
|
||||||
char_id = chars.index(query_char)
|
char_id = chars.index(query_char)
|
||||||
position_id = text2token[
|
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||||
query_id] + 1 # [CLS] token locate at first place
|
|
||||||
|
|
||||||
input_ids.append(input_id)
|
input_ids.append(input_id)
|
||||||
token_type_ids.append(token_type_id)
|
token_type_ids.append(token_type_id)
|
||||||
@ -86,18 +84,17 @@ def prepare_onnx_input(tokenizer,
|
|||||||
position_ids.append(position_id)
|
position_ids.append(position_id)
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
'input_ids': np.array(input_ids).astype(np.int64),
|
"input_ids": np.array(input_ids).astype(np.int64),
|
||||||
'token_type_ids': np.array(token_type_ids).astype(np.int64),
|
"token_type_ids": np.array(token_type_ids).astype(np.int64),
|
||||||
'attention_masks': np.array(attention_masks).astype(np.int64),
|
"attention_masks": np.array(attention_masks).astype(np.int64),
|
||||||
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
|
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
||||||
'char_ids': np.array(char_ids).astype(np.int64),
|
"char_ids": np.array(char_ids).astype(np.int64),
|
||||||
'position_ids': np.array(position_ids).astype(np.int64),
|
"position_ids": np.array(position_ids).astype(np.int64),
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def _truncate_texts(window_size: int, texts: List[str],
|
def _truncate_texts(window_size: int, texts: List[str], query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
||||||
query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
|
||||||
truncated_texts = []
|
truncated_texts = []
|
||||||
truncated_query_ids = []
|
truncated_query_ids = []
|
||||||
for text, query_id in zip(texts, query_ids):
|
for text, query_id in zip(texts, query_ids):
|
||||||
@ -111,12 +108,9 @@ def _truncate_texts(window_size: int, texts: List[str],
|
|||||||
return truncated_texts, truncated_query_ids
|
return truncated_texts, truncated_query_ids
|
||||||
|
|
||||||
|
|
||||||
def _truncate(max_len: int,
|
def _truncate(
|
||||||
text: str,
|
max_len: int, text: str, query_id: int, tokens: List[str], text2token: List[int], token2text: List[Tuple[int]]
|
||||||
query_id: int,
|
):
|
||||||
tokens: List[str],
|
|
||||||
text2token: List[int],
|
|
||||||
token2text: List[Tuple[int]]):
|
|
||||||
truncate_len = max_len - 2
|
truncate_len = max_len - 2
|
||||||
if len(tokens) <= truncate_len:
|
if len(tokens) <= truncate_len:
|
||||||
return (text, query_id, tokens, text2token, token2text)
|
return (text, query_id, tokens, text2token, token2text)
|
||||||
@ -137,14 +131,16 @@ def _truncate(max_len: int,
|
|||||||
start = token2text[token_start][0]
|
start = token2text[token_start][0]
|
||||||
end = token2text[token_end - 1][1]
|
end = token2text[token_end - 1][1]
|
||||||
|
|
||||||
return (text[start:end], query_id - start, tokens[token_start:token_end], [
|
return (
|
||||||
i - token_start if i is not None else None
|
text[start:end],
|
||||||
for i in text2token[start:end]
|
query_id - start,
|
||||||
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
|
tokens[token_start:token_end],
|
||||||
|
[i - token_start if i is not None else None for i in text2token[start:end]],
|
||||||
|
[(s - start, e - start) for s, e in token2text[token_start:token_end]],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
def get_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
|
||||||
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
||||||
char2phonemes = {}
|
char2phonemes = {}
|
||||||
for char, phoneme in polyphonic_chars:
|
for char, phoneme in polyphonic_chars:
|
||||||
@ -154,13 +150,11 @@ def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
|||||||
return labels, char2phonemes
|
return labels, char2phonemes
|
||||||
|
|
||||||
|
|
||||||
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
|
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]) -> Tuple[List[str], Dict[str, List[int]]]:
|
||||||
) -> Tuple[List[str], Dict[str, List[int]]]:
|
labels = sorted(list(set([f"{char} {phoneme}" for char, phoneme in polyphonic_chars])))
|
||||||
labels = sorted(
|
|
||||||
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
|
||||||
char2phonemes = {}
|
char2phonemes = {}
|
||||||
for char, phoneme in polyphonic_chars:
|
for char, phoneme in polyphonic_chars:
|
||||||
if char not in char2phonemes:
|
if char not in char2phonemes:
|
||||||
char2phonemes[char] = []
|
char2phonemes[char] = []
|
||||||
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
|
char2phonemes[char].append(labels.index(f"{char} {phoneme}"))
|
||||||
return labels, char2phonemes
|
return labels, char2phonemes
|
||||||
|
@ -17,17 +17,25 @@ PP_FIX_DICT_PATH = os.path.join(current_file_path, "polyphonic-fix.rep")
|
|||||||
|
|
||||||
|
|
||||||
class G2PWPinyin(Pinyin):
|
class G2PWPinyin(Pinyin):
|
||||||
def __init__(self, model_dir='G2PWModel/', model_source=None,
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir="G2PWModel/",
|
||||||
|
model_source=None,
|
||||||
enable_non_tradional_chinese=True,
|
enable_non_tradional_chinese=True,
|
||||||
v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
v_to_u=False,
|
||||||
|
neutral_tone_with_five=False,
|
||||||
|
tone_sandhi=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self._g2pw = G2PWOnnxConverter(
|
self._g2pw = G2PWOnnxConverter(
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
style='pinyin',
|
style="pinyin",
|
||||||
model_source=model_source,
|
model_source=model_source,
|
||||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||||
)
|
)
|
||||||
self._converter = Converter(
|
self._converter = Converter(
|
||||||
self._g2pw, v_to_u=v_to_u,
|
self._g2pw,
|
||||||
|
v_to_u=v_to_u,
|
||||||
neutral_tone_with_five=neutral_tone_with_five,
|
neutral_tone_with_five=neutral_tone_with_five,
|
||||||
tone_sandhi=tone_sandhi,
|
tone_sandhi=tone_sandhi,
|
||||||
)
|
)
|
||||||
@ -37,31 +45,25 @@ class G2PWPinyin(Pinyin):
|
|||||||
|
|
||||||
|
|
||||||
class Converter(UltimateConverter):
|
class Converter(UltimateConverter):
|
||||||
def __init__(self, g2pw_instance, v_to_u=False,
|
def __init__(self, g2pw_instance, v_to_u=False, neutral_tone_with_five=False, tone_sandhi=False, **kwargs):
|
||||||
neutral_tone_with_five=False,
|
|
||||||
tone_sandhi=False, **kwargs):
|
|
||||||
super(Converter, self).__init__(
|
super(Converter, self).__init__(
|
||||||
v_to_u=v_to_u,
|
v_to_u=v_to_u, neutral_tone_with_five=neutral_tone_with_five, tone_sandhi=tone_sandhi, **kwargs
|
||||||
neutral_tone_with_five=neutral_tone_with_five,
|
)
|
||||||
tone_sandhi=tone_sandhi, **kwargs)
|
|
||||||
|
|
||||||
self._g2pw = g2pw_instance
|
self._g2pw = g2pw_instance
|
||||||
|
|
||||||
def convert(self, words, style, heteronym, errors, strict, **kwargs):
|
def convert(self, words, style, heteronym, errors, strict, **kwargs):
|
||||||
pys = []
|
pys = []
|
||||||
if RE_HANS.match(words):
|
if RE_HANS.match(words):
|
||||||
pys = self._to_pinyin(words, style=style, heteronym=heteronym,
|
pys = self._to_pinyin(words, style=style, heteronym=heteronym, errors=errors, strict=strict)
|
||||||
errors=errors, strict=strict)
|
|
||||||
post_data = self.post_pinyin(words, heteronym, pys)
|
post_data = self.post_pinyin(words, heteronym, pys)
|
||||||
if post_data is not None:
|
if post_data is not None:
|
||||||
pys = post_data
|
pys = post_data
|
||||||
|
|
||||||
pys = self.convert_styles(
|
pys = self.convert_styles(pys, words, style, heteronym, errors, strict)
|
||||||
pys, words, style, heteronym, errors, strict)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
py = self.handle_nopinyin(words, style=style, errors=errors,
|
py = self.handle_nopinyin(words, style=style, errors=errors, heteronym=heteronym, strict=strict)
|
||||||
heteronym=heteronym, strict=strict)
|
|
||||||
if py:
|
if py:
|
||||||
pys.extend(py)
|
pys.extend(py)
|
||||||
|
|
||||||
@ -73,13 +75,11 @@ class Converter(UltimateConverter):
|
|||||||
g2pw_pinyin = self._g2pw(han)
|
g2pw_pinyin = self._g2pw(han)
|
||||||
|
|
||||||
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
if not g2pw_pinyin: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||||
return super(Converter, self).convert(
|
return super(Converter, self).convert(han, Style.TONE, heteronym, errors, strict, **kwargs)
|
||||||
han, Style.TONE, heteronym, errors, strict, **kwargs)
|
|
||||||
|
|
||||||
for i, item in enumerate(g2pw_pinyin[0]):
|
for i, item in enumerate(g2pw_pinyin[0]):
|
||||||
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
if item is None: # g2pw 不支持的汉字改为使用 pypinyin 原有逻辑
|
||||||
py = super(Converter, self).convert(
|
py = super(Converter, self).convert(han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
||||||
han[i], Style.TONE, heteronym, errors, strict, **kwargs)
|
|
||||||
pinyins.extend(py)
|
pinyins.extend(py)
|
||||||
else:
|
else:
|
||||||
pinyins.append([to_tone(item)])
|
pinyins.append([to_tone(item)])
|
||||||
@ -104,7 +104,7 @@ def _remove_dup_and_empty(lst_list):
|
|||||||
if lst:
|
if lst:
|
||||||
new_lst_list.append(lst)
|
new_lst_list.append(lst)
|
||||||
else:
|
else:
|
||||||
new_lst_list.append([''])
|
new_lst_list.append([""])
|
||||||
|
|
||||||
return new_lst_list
|
return new_lst_list
|
||||||
|
|
||||||
@ -130,14 +130,14 @@ def read_dict():
|
|||||||
with open(PP_DICT_PATH, encoding="utf-8") as f:
|
with open(PP_DICT_PATH, encoding="utf-8") as f:
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
while line:
|
while line:
|
||||||
key, value_str = line.split(':')
|
key, value_str = line.split(":")
|
||||||
value = eval(value_str.strip())
|
value = eval(value_str.strip())
|
||||||
polyphonic_dict[key.strip()] = value
|
polyphonic_dict[key.strip()] = value
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
|
with open(PP_FIX_DICT_PATH, encoding="utf-8") as f:
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
while line:
|
while line:
|
||||||
key, value_str = line.split(':')
|
key, value_str = line.split(":")
|
||||||
value = eval(value_str.strip())
|
value = eval(value_str.strip())
|
||||||
polyphonic_dict[key.strip()] = value
|
polyphonic_dict[key.strip()] = value
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
|
@ -2,44 +2,43 @@
|
|||||||
# This code is modified from https://github.com/GitYCC/g2pW
|
# This code is modified from https://github.com/GitYCC/g2pW
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import zipfile,requests
|
import zipfile
|
||||||
from typing import Any
|
from typing import Any, Dict, List, Tuple
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
import requests
|
||||||
|
|
||||||
onnxruntime.set_default_logger_severity(3)
|
onnxruntime.set_default_logger_severity(3)
|
||||||
from opencc import OpenCC
|
from opencc import OpenCC
|
||||||
|
from pypinyin import Style, pinyin
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from pypinyin import pinyin
|
|
||||||
from pypinyin import Style
|
|
||||||
|
|
||||||
from .dataset import get_char_phoneme_labels
|
|
||||||
from .dataset import get_phoneme_labels
|
|
||||||
from .dataset import prepare_onnx_input
|
|
||||||
from .utils import load_config
|
|
||||||
from ..zh_normalization.char_convert import tranditional_to_simplified
|
from ..zh_normalization.char_convert import tranditional_to_simplified
|
||||||
|
from .dataset import get_char_phoneme_labels, get_phoneme_labels, prepare_onnx_input
|
||||||
|
from .utils import load_config
|
||||||
|
|
||||||
model_version = '1.1'
|
model_version = "1.1"
|
||||||
|
|
||||||
|
|
||||||
def predict(session, onnx_input: Dict[str, Any],
|
def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[List[str], List[float]]:
|
||||||
labels: List[str]) -> Tuple[List[str], List[float]]:
|
|
||||||
all_preds = []
|
all_preds = []
|
||||||
all_confidences = []
|
all_confidences = []
|
||||||
probs = session.run([], {
|
probs = session.run(
|
||||||
"input_ids": onnx_input['input_ids'],
|
[],
|
||||||
"token_type_ids": onnx_input['token_type_ids'],
|
{
|
||||||
"attention_mask": onnx_input['attention_masks'],
|
"input_ids": onnx_input["input_ids"],
|
||||||
"phoneme_mask": onnx_input['phoneme_masks'],
|
"token_type_ids": onnx_input["token_type_ids"],
|
||||||
"char_ids": onnx_input['char_ids'],
|
"attention_mask": onnx_input["attention_masks"],
|
||||||
"position_ids": onnx_input['position_ids']
|
"phoneme_mask": onnx_input["phoneme_masks"],
|
||||||
})[0]
|
"char_ids": onnx_input["char_ids"],
|
||||||
|
"position_ids": onnx_input["position_ids"],
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
preds = np.argmax(probs, axis=1).tolist()
|
preds = np.argmax(probs, axis=1).tolist()
|
||||||
max_probs = []
|
max_probs = []
|
||||||
@ -51,7 +50,7 @@ def predict(session, onnx_input: Dict[str, Any],
|
|||||||
return all_preds, all_confidences
|
return all_preds, all_confidences
|
||||||
|
|
||||||
|
|
||||||
def download_and_decompress(model_dir: str='G2PWModel/'):
|
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||||
if not os.path.exists(model_dir):
|
if not os.path.exists(model_dir):
|
||||||
parent_directory = os.path.dirname(model_dir)
|
parent_directory = os.path.dirname(model_dir)
|
||||||
zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
|
zip_dir = os.path.join(parent_directory, "G2PWModel_1.1.zip")
|
||||||
@ -61,7 +60,7 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
|
|||||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||||
with requests.get(modelscope_url, stream=True) as r:
|
with requests.get(modelscope_url, stream=True) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
with open(zip_dir, 'wb') as f:
|
with open(zip_dir, "wb") as f:
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
if chunk:
|
if chunk:
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
@ -74,12 +73,15 @@ def download_and_decompress(model_dir: str='G2PWModel/'):
|
|||||||
|
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
|
||||||
class G2PWOnnxConverter:
|
class G2PWOnnxConverter:
|
||||||
def __init__(self,
|
def __init__(
|
||||||
model_dir: str='G2PWModel/',
|
self,
|
||||||
style: str='bopomofo',
|
model_dir: str = "G2PWModel/",
|
||||||
|
style: str = "bopomofo",
|
||||||
model_source: str = None,
|
model_source: str = None,
|
||||||
enable_non_tradional_chinese: bool=False):
|
enable_non_tradional_chinese: bool = False,
|
||||||
|
):
|
||||||
uncompress_path = download_and_decompress(model_dir)
|
uncompress_path = download_and_decompress(model_dir)
|
||||||
|
|
||||||
sess_options = onnxruntime.SessionOptions()
|
sess_options = onnxruntime.SessionOptions()
|
||||||
@ -87,41 +89,59 @@ class G2PWOnnxConverter:
|
|||||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||||
sess_options.intra_op_num_threads = 2
|
sess_options.intra_op_num_threads = 2
|
||||||
try:
|
try:
|
||||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
self.session_g2pW = onnxruntime.InferenceSession(
|
||||||
|
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||||
|
sess_options=sess_options,
|
||||||
|
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
|
self.session_g2pW = onnxruntime.InferenceSession(
|
||||||
self.config = load_config(
|
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||||
config_path=os.path.join(uncompress_path, 'config.py'),
|
sess_options=sess_options,
|
||||||
use_default=True)
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
|
||||||
|
|
||||||
self.model_source = model_source if model_source else self.config.model_source
|
self.model_source = model_source if model_source else self.config.model_source
|
||||||
self.enable_opencc = enable_non_tradional_chinese
|
self.enable_opencc = enable_non_tradional_chinese
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||||
|
|
||||||
polyphonic_chars_path = os.path.join(uncompress_path,
|
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
|
||||||
'POLYPHONIC_CHARS.txt')
|
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
|
||||||
monophonic_chars_path = os.path.join(uncompress_path,
|
|
||||||
'MONOPHONIC_CHARS.txt')
|
|
||||||
self.polyphonic_chars = [
|
self.polyphonic_chars = [
|
||||||
line.split('\t')
|
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||||
for line in open(polyphonic_chars_path, encoding='utf-8').read()
|
|
||||||
.strip().split('\n')
|
|
||||||
]
|
]
|
||||||
self.non_polyphonic = {
|
self.non_polyphonic = {
|
||||||
'一', '不', '和', '咋', '嗲', '剖', '差', '攢', '倒', '難', '奔', '勁', '拗',
|
"一",
|
||||||
'肖', '瘙', '誒', '泊', '听', '噢'
|
"不",
|
||||||
|
"和",
|
||||||
|
"咋",
|
||||||
|
"嗲",
|
||||||
|
"剖",
|
||||||
|
"差",
|
||||||
|
"攢",
|
||||||
|
"倒",
|
||||||
|
"難",
|
||||||
|
"奔",
|
||||||
|
"勁",
|
||||||
|
"拗",
|
||||||
|
"肖",
|
||||||
|
"瘙",
|
||||||
|
"誒",
|
||||||
|
"泊",
|
||||||
|
"听",
|
||||||
|
"噢",
|
||||||
}
|
}
|
||||||
self.non_monophonic = {'似', '攢'}
|
self.non_monophonic = {"似", "攢"}
|
||||||
self.monophonic_chars = [
|
self.monophonic_chars = [
|
||||||
line.split('\t')
|
line.split("\t") for line in open(monophonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||||
for line in open(monophonic_chars_path, encoding='utf-8').read()
|
|
||||||
.strip().split('\n')
|
|
||||||
]
|
]
|
||||||
self.labels, self.char2phonemes = get_char_phoneme_labels(
|
self.labels, self.char2phonemes = (
|
||||||
polyphonic_chars=self.polyphonic_chars
|
get_char_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||||
) if self.config.use_char_phoneme else get_phoneme_labels(
|
if self.config.use_char_phoneme
|
||||||
polyphonic_chars=self.polyphonic_chars)
|
else get_phoneme_labels(polyphonic_chars=self.polyphonic_chars)
|
||||||
|
)
|
||||||
|
|
||||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||||
|
|
||||||
@ -130,41 +150,29 @@ class G2PWOnnxConverter:
|
|||||||
if char in self.polyphonic_chars_new:
|
if char in self.polyphonic_chars_new:
|
||||||
self.polyphonic_chars_new.remove(char)
|
self.polyphonic_chars_new.remove(char)
|
||||||
|
|
||||||
self.monophonic_chars_dict = {
|
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
|
||||||
char: phoneme
|
|
||||||
for char, phoneme in self.monophonic_chars
|
|
||||||
}
|
|
||||||
for char in self.non_monophonic:
|
for char in self.non_monophonic:
|
||||||
if char in self.monophonic_chars_dict:
|
if char in self.monophonic_chars_dict:
|
||||||
self.monophonic_chars_dict.pop(char)
|
self.monophonic_chars_dict.pop(char)
|
||||||
|
|
||||||
self.pos_tags = [
|
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
|
||||||
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
|
||||||
]
|
|
||||||
|
|
||||||
with open(
|
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
|
||||||
os.path.join(uncompress_path,
|
|
||||||
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
|
||||||
'r',
|
|
||||||
encoding='utf-8') as fr:
|
|
||||||
self.bopomofo_convert_dict = json.load(fr)
|
self.bopomofo_convert_dict = json.load(fr)
|
||||||
self.style_convert_func = {
|
self.style_convert_func = {
|
||||||
'bopomofo': lambda x: x,
|
"bopomofo": lambda x: x,
|
||||||
'pinyin': self._convert_bopomofo_to_pinyin,
|
"pinyin": self._convert_bopomofo_to_pinyin,
|
||||||
}[style]
|
}[style]
|
||||||
|
|
||||||
with open(
|
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
|
||||||
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
|
||||||
'r',
|
|
||||||
encoding='utf-8') as fr:
|
|
||||||
self.char_bopomofo_dict = json.load(fr)
|
self.char_bopomofo_dict = json.load(fr)
|
||||||
|
|
||||||
if self.enable_opencc:
|
if self.enable_opencc:
|
||||||
self.cc = OpenCC('s2tw')
|
self.cc = OpenCC("s2tw")
|
||||||
|
|
||||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||||
tone = bopomofo[-1]
|
tone = bopomofo[-1]
|
||||||
assert tone in '12345'
|
assert tone in "12345"
|
||||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||||
if component:
|
if component:
|
||||||
return component + tone
|
return component + tone
|
||||||
@ -184,8 +192,7 @@ class G2PWOnnxConverter:
|
|||||||
translated_sentences.append(translated_sent)
|
translated_sentences.append(translated_sent)
|
||||||
sentences = translated_sentences
|
sentences = translated_sentences
|
||||||
|
|
||||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(
|
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||||
sentences=sentences)
|
|
||||||
if len(texts) == 0:
|
if len(texts) == 0:
|
||||||
# sentences no polyphonic words
|
# sentences no polyphonic words
|
||||||
return partial_results
|
return partial_results
|
||||||
@ -198,14 +205,12 @@ class G2PWOnnxConverter:
|
|||||||
texts=texts,
|
texts=texts,
|
||||||
query_ids=query_ids,
|
query_ids=query_ids,
|
||||||
use_mask=self.config.use_mask,
|
use_mask=self.config.use_mask,
|
||||||
window_size=None)
|
window_size=None,
|
||||||
|
)
|
||||||
|
|
||||||
preds, confidences = predict(
|
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
|
||||||
session=self.session_g2pW,
|
|
||||||
onnx_input=onnx_input,
|
|
||||||
labels=self.labels)
|
|
||||||
if self.config.use_char_phoneme:
|
if self.config.use_char_phoneme:
|
||||||
preds = [pred.split(' ')[1] for pred in preds]
|
preds = [pred.split(" ")[1] for pred in preds]
|
||||||
|
|
||||||
results = partial_results
|
results = partial_results
|
||||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||||
@ -213,15 +218,12 @@ class G2PWOnnxConverter:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _prepare_data(
|
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||||
self, sentences: List[str]
|
|
||||||
) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
|
||||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||||
for sent_id, sent in enumerate(sentences):
|
for sent_id, sent in enumerate(sentences):
|
||||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||||
sent_s = tranditional_to_simplified(sent)
|
sent_s = tranditional_to_simplified(sent)
|
||||||
pypinyin_result = pinyin(
|
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||||
sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
|
||||||
partial_result = [None] * len(sent)
|
partial_result = [None] * len(sent)
|
||||||
for i, char in enumerate(sent):
|
for i, char in enumerate(sent):
|
||||||
if char in self.polyphonic_chars_new:
|
if char in self.polyphonic_chars_new:
|
||||||
@ -229,8 +231,7 @@ class G2PWOnnxConverter:
|
|||||||
query_ids.append(i)
|
query_ids.append(i)
|
||||||
sent_ids.append(sent_id)
|
sent_ids.append(sent_id)
|
||||||
elif char in self.monophonic_chars_dict:
|
elif char in self.monophonic_chars_dict:
|
||||||
partial_result[i] = self.style_convert_func(
|
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
|
||||||
self.monophonic_chars_dict[char])
|
|
||||||
elif char in self.char_bopomofo_dict:
|
elif char in self.char_bopomofo_dict:
|
||||||
partial_result[i] = pypinyin_result[i][0]
|
partial_result[i] = pypinyin_result[i][0]
|
||||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
Credits
|
Credits
|
||||||
This code is modified from https://github.com/GitYCC/g2pW
|
This code is modified from https://github.com/GitYCC/g2pW
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -24,14 +25,14 @@ def wordize_and_map(text: str):
|
|||||||
index_map_from_text_to_word = []
|
index_map_from_text_to_word = []
|
||||||
index_map_from_word_to_text = []
|
index_map_from_word_to_text = []
|
||||||
while len(text) > 0:
|
while len(text) > 0:
|
||||||
match_space = re.match(r'^ +', text)
|
match_space = re.match(r"^ +", text)
|
||||||
if match_space:
|
if match_space:
|
||||||
space_str = match_space.group(0)
|
space_str = match_space.group(0)
|
||||||
index_map_from_text_to_word += [None] * len(space_str)
|
index_map_from_text_to_word += [None] * len(space_str)
|
||||||
text = text[len(space_str) :]
|
text = text[len(space_str) :]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
match_en = re.match(r"^[a-zA-Z0-9]+", text)
|
||||||
if match_en:
|
if match_en:
|
||||||
en_word = match_en.group(0)
|
en_word = match_en.group(0)
|
||||||
|
|
||||||
@ -63,15 +64,14 @@ def tokenize_and_map(tokenizer, text: str):
|
|||||||
for word, (word_start, word_end) in zip(words, word2text):
|
for word, (word_start, word_end) in zip(words, word2text):
|
||||||
word_tokens = tokenizer.tokenize(word)
|
word_tokens = tokenizer.tokenize(word)
|
||||||
|
|
||||||
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
if len(word_tokens) == 0 or word_tokens == ["[UNK]"]:
|
||||||
index_map_from_token_to_text.append((word_start, word_end))
|
index_map_from_token_to_text.append((word_start, word_end))
|
||||||
tokens.append('[UNK]')
|
tokens.append("[UNK]")
|
||||||
else:
|
else:
|
||||||
current_word_start = word_start
|
current_word_start = word_start
|
||||||
for word_token in word_tokens:
|
for word_token in word_tokens:
|
||||||
word_token_len = len(re.sub(r'^##', '', word_token))
|
word_token_len = len(re.sub(r"^##", "", word_token))
|
||||||
index_map_from_token_to_text.append(
|
index_map_from_token_to_text.append((current_word_start, current_word_start + word_token_len))
|
||||||
(current_word_start, current_word_start + word_token_len))
|
|
||||||
current_word_start = current_word_start + word_token_len
|
current_word_start = current_word_start + word_token_len
|
||||||
tokens.append(word_token)
|
tokens.append(word_token)
|
||||||
|
|
||||||
@ -85,49 +85,47 @@ def tokenize_and_map(tokenizer, text: str):
|
|||||||
|
|
||||||
def _load_config(config_path: os.PathLike):
|
def _load_config(config_path: os.PathLike):
|
||||||
import importlib.util
|
import importlib.util
|
||||||
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
|
||||||
|
spec = importlib.util.spec_from_file_location("__init__", config_path)
|
||||||
config = importlib.util.module_from_spec(spec)
|
config = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(config)
|
spec.loader.exec_module(config)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
default_config_dict = {
|
default_config_dict = {
|
||||||
'manual_seed': 1313,
|
"manual_seed": 1313,
|
||||||
'model_source': 'bert-base-chinese',
|
"model_source": "bert-base-chinese",
|
||||||
'window_size': 32,
|
"window_size": 32,
|
||||||
'num_workers': 2,
|
"num_workers": 2,
|
||||||
'use_mask': True,
|
"use_mask": True,
|
||||||
'use_char_phoneme': False,
|
"use_char_phoneme": False,
|
||||||
'use_conditional': True,
|
"use_conditional": True,
|
||||||
'param_conditional': {
|
"param_conditional": {
|
||||||
'affect_location': 'softmax',
|
"affect_location": "softmax",
|
||||||
'bias': True,
|
"bias": True,
|
||||||
'char-linear': True,
|
"char-linear": True,
|
||||||
'pos-linear': False,
|
"pos-linear": False,
|
||||||
'char+pos-second': True,
|
"char+pos-second": True,
|
||||||
'char+pos-second_lowrank': False,
|
"char+pos-second_lowrank": False,
|
||||||
'lowrank_size': 0,
|
"lowrank_size": 0,
|
||||||
'char+pos-second_fm': False,
|
"char+pos-second_fm": False,
|
||||||
'fm_size': 0,
|
"fm_size": 0,
|
||||||
'fix_mode': None,
|
"fix_mode": None,
|
||||||
'count_json': 'train.count.json'
|
"count_json": "train.count.json",
|
||||||
},
|
},
|
||||||
'lr': 5e-5,
|
"lr": 5e-5,
|
||||||
'val_interval': 200,
|
"val_interval": 200,
|
||||||
'num_iter': 10000,
|
"num_iter": 10000,
|
||||||
'use_focal': False,
|
"use_focal": False,
|
||||||
'param_focal': {
|
"param_focal": {"alpha": 0.0, "gamma": 0.7},
|
||||||
'alpha': 0.0,
|
"use_pos": True,
|
||||||
'gamma': 0.7
|
"param_pos ": {
|
||||||
|
"weight": 0.1,
|
||||||
|
"pos_joint_training": True,
|
||||||
|
"train_pos_path": "train.pos",
|
||||||
|
"valid_pos_path": "dev.pos",
|
||||||
|
"test_pos_path": "test.pos",
|
||||||
},
|
},
|
||||||
'use_pos': True,
|
|
||||||
'param_pos ': {
|
|
||||||
'weight': 0.1,
|
|
||||||
'pos_joint_training': True,
|
|
||||||
'train_pos_path': 'train.pos',
|
|
||||||
'valid_pos_path': 'dev.pos',
|
|
||||||
'test_pos_path': 'test.pos'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,43 +2,51 @@
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pyopenjtalk
|
import pyopenjtalk
|
||||||
|
|
||||||
current_file_path = os.path.dirname(__file__)
|
current_file_path = os.path.dirname(__file__)
|
||||||
|
|
||||||
# 防止win下无法读取模型
|
# 防止win下无法读取模型
|
||||||
if os.name == 'nt':
|
if os.name == "nt":
|
||||||
python_dir = os.getcwd()
|
python_dir = os.getcwd()
|
||||||
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
|
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
|
||||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', OPEN_JTALK_DICT_DIR)):
|
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)):
|
||||||
if (OPEN_JTALK_DICT_DIR[:len(python_dir)].upper() == python_dir.upper()):
|
if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper():
|
||||||
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
|
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
|
||||||
else:
|
else:
|
||||||
import shutil
|
import shutil
|
||||||
if not os.path.exists('TEMP'):
|
|
||||||
os.mkdir('TEMP')
|
if not os.path.exists("TEMP"):
|
||||||
|
os.mkdir("TEMP")
|
||||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||||
os.mkdir(os.path.join("TEMP", "ja"))
|
os.mkdir(os.path.join("TEMP", "ja"))
|
||||||
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
|
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
|
||||||
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
|
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
|
||||||
shutil.copytree(pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"), os.path.join("TEMP", "ja", "open_jtalk_dic"), )
|
shutil.copytree(
|
||||||
|
pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"),
|
||||||
|
os.path.join("TEMP", "ja", "open_jtalk_dic"),
|
||||||
|
)
|
||||||
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
|
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
|
||||||
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
|
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
|
||||||
|
|
||||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', current_file_path)):
|
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)):
|
||||||
if (current_file_path[:len(python_dir)].upper() == python_dir.upper()):
|
if current_file_path[: len(python_dir)].upper() == python_dir.upper():
|
||||||
current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
|
current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
|
||||||
else:
|
else:
|
||||||
if not os.path.exists('TEMP'):
|
if not os.path.exists("TEMP"):
|
||||||
os.mkdir('TEMP')
|
os.mkdir("TEMP")
|
||||||
if not os.path.exists(os.path.join("TEMP", "ja")):
|
if not os.path.exists(os.path.join("TEMP", "ja")):
|
||||||
os.mkdir(os.path.join("TEMP", "ja"))
|
os.mkdir(os.path.join("TEMP", "ja"))
|
||||||
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
|
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
|
||||||
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
|
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
|
||||||
shutil.copyfile(os.path.join(current_file_path, "ja_userdic", "userdict.csv"),os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"))
|
shutil.copyfile(
|
||||||
|
os.path.join(current_file_path, "ja_userdic", "userdict.csv"),
|
||||||
|
os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"),
|
||||||
|
)
|
||||||
current_file_path = os.path.join("TEMP", "ja")
|
current_file_path = os.path.join("TEMP", "ja")
|
||||||
|
|
||||||
|
|
||||||
def get_hash(fp: str) -> str:
|
def get_hash(fp: str) -> str:
|
||||||
hash_md5 = hashlib.md5()
|
hash_md5 = hashlib.md5()
|
||||||
with open(fp, "rb") as f:
|
with open(fp, "rb") as f:
|
||||||
@ -51,21 +59,26 @@ try:
|
|||||||
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
|
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
|
||||||
# 如果没有用户词典,就生成一个;如果有,就检查md5,如果不一样,就重新生成
|
# 如果没有用户词典,就生成一个;如果有,就检查md5,如果不一样,就重新生成
|
||||||
if os.path.exists(USERDIC_CSV_PATH):
|
if os.path.exists(USERDIC_CSV_PATH):
|
||||||
if not os.path.exists(USERDIC_BIN_PATH) or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r",encoding='utf-8').read():
|
if (
|
||||||
|
not os.path.exists(USERDIC_BIN_PATH)
|
||||||
|
or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read()
|
||||||
|
):
|
||||||
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
|
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
|
||||||
with open(USERDIC_HASH_PATH, "w", encoding='utf-8') as f:
|
with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f:
|
||||||
f.write(get_hash(USERDIC_CSV_PATH))
|
f.write(get_hash(USERDIC_CSV_PATH))
|
||||||
|
|
||||||
if os.path.exists(USERDIC_BIN_PATH):
|
if os.path.exists(USERDIC_BIN_PATH):
|
||||||
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
|
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# print(e)
|
# print(e)
|
||||||
import pyopenjtalk
|
import pyopenjtalk
|
||||||
|
|
||||||
# failed to load user dictionary, ignore.
|
# failed to load user dictionary, ignore.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
from text.symbols import punctuation
|
from text.symbols import punctuation
|
||||||
|
|
||||||
# Regular expression matching Japanese without punctuation marks:
|
# Regular expression matching Japanese without punctuation marks:
|
||||||
_japanese_characters = re.compile(
|
_japanese_characters = re.compile(
|
||||||
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
||||||
@ -123,9 +136,9 @@ def post_replace_ph(ph):
|
|||||||
|
|
||||||
|
|
||||||
def replace_consecutive_punctuation(text):
|
def replace_consecutive_punctuation(text):
|
||||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||||
result = re.sub(pattern, r'\1', text)
|
result = re.sub(pattern, r"\1", text)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -165,6 +178,7 @@ def text_normalize(text):
|
|||||||
text = replace_consecutive_punctuation(text)
|
text = replace_consecutive_punctuation(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||||
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||||
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
|
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
|
||||||
@ -241,6 +255,7 @@ def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
|||||||
|
|
||||||
return phones
|
return phones
|
||||||
|
|
||||||
|
|
||||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||||
def _numeric_feature_by_regex(regex, s):
|
def _numeric_feature_by_regex(regex, s):
|
||||||
match = re.search(regex, s)
|
match = re.search(regex, s)
|
||||||
@ -248,6 +263,7 @@ def _numeric_feature_by_regex(regex, s):
|
|||||||
return -50
|
return -50
|
||||||
return int(match.group(1))
|
return int(match.group(1))
|
||||||
|
|
||||||
|
|
||||||
def g2p(norm_text, with_prosody=True):
|
def g2p(norm_text, with_prosody=True):
|
||||||
phones = preprocess_jap(norm_text, with_prosody)
|
phones = preprocess_jap(norm_text, with_prosody)
|
||||||
phones = [post_replace_ph(i) for i in phones]
|
phones = [post_replace_ph(i) for i in phones]
|
||||||
|
@ -9,39 +9,43 @@ import importlib
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# 防止win下无法读取模型
|
# 防止win下无法读取模型
|
||||||
if os.name == 'nt':
|
if os.name == "nt":
|
||||||
|
|
||||||
class win_G2p(G2p):
|
class win_G2p(G2p):
|
||||||
def check_mecab(self):
|
def check_mecab(self):
|
||||||
super().check_mecab()
|
super().check_mecab()
|
||||||
spam_spec = importlib.util.find_spec("eunjeon")
|
spam_spec = importlib.util.find_spec("eunjeon")
|
||||||
non_found = spam_spec is None
|
non_found = spam_spec is None
|
||||||
if non_found:
|
if non_found:
|
||||||
print(f'you have to install eunjeon. install it...')
|
print("you have to install eunjeon. install it...")
|
||||||
else:
|
else:
|
||||||
installpath = spam_spec.submodule_search_locations[0]
|
installpath = spam_spec.submodule_search_locations[0]
|
||||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from eunjeon import Mecab as _Mecab
|
from eunjeon import Mecab as _Mecab
|
||||||
|
|
||||||
class Mecab(_Mecab):
|
class Mecab(_Mecab):
|
||||||
def get_dicpath(installpath):
|
def get_dicpath(installpath):
|
||||||
if not (re.match(r'^[A-Za-z0-9_/\\:.\-]*$', installpath)):
|
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", installpath)):
|
||||||
import shutil
|
import shutil
|
||||||
python_dir = os.getcwd()
|
|
||||||
if (installpath[:len(python_dir)].upper() == python_dir.upper()):
|
|
||||||
dicpath = os.path.join(os.path.relpath(installpath,python_dir),'data','mecabrc')
|
|
||||||
else:
|
|
||||||
if not os.path.exists('TEMP'):
|
|
||||||
os.mkdir('TEMP')
|
|
||||||
if not os.path.exists(os.path.join('TEMP', 'ko')):
|
|
||||||
os.mkdir(os.path.join('TEMP', 'ko'))
|
|
||||||
if os.path.exists(os.path.join('TEMP', 'ko', 'ko_dict')):
|
|
||||||
shutil.rmtree(os.path.join('TEMP', 'ko', 'ko_dict'))
|
|
||||||
|
|
||||||
shutil.copytree(os.path.join(installpath, 'data'), os.path.join('TEMP', 'ko', 'ko_dict'))
|
python_dir = os.getcwd()
|
||||||
dicpath = os.path.join('TEMP', 'ko', 'ko_dict', 'mecabrc')
|
if installpath[: len(python_dir)].upper() == python_dir.upper():
|
||||||
|
dicpath = os.path.join(os.path.relpath(installpath, python_dir), "data", "mecabrc")
|
||||||
else:
|
else:
|
||||||
dicpath=os.path.abspath(os.path.join(installpath, 'data/mecabrc'))
|
if not os.path.exists("TEMP"):
|
||||||
|
os.mkdir("TEMP")
|
||||||
|
if not os.path.exists(os.path.join("TEMP", "ko")):
|
||||||
|
os.mkdir(os.path.join("TEMP", "ko"))
|
||||||
|
if os.path.exists(os.path.join("TEMP", "ko", "ko_dict")):
|
||||||
|
shutil.rmtree(os.path.join("TEMP", "ko", "ko_dict"))
|
||||||
|
|
||||||
|
shutil.copytree(
|
||||||
|
os.path.join(installpath, "data"), os.path.join("TEMP", "ko", "ko_dict")
|
||||||
|
)
|
||||||
|
dicpath = os.path.join("TEMP", "ko", "ko_dict", "mecabrc")
|
||||||
|
else:
|
||||||
|
dicpath = os.path.abspath(os.path.join(installpath, "data/mecabrc"))
|
||||||
return dicpath
|
return dicpath
|
||||||
|
|
||||||
def __init__(self, dicpath=get_dicpath(installpath)):
|
def __init__(self, dicpath=get_dicpath(installpath)):
|
||||||
@ -55,10 +59,14 @@ if os.name == 'nt':
|
|||||||
from text.symbols2 import symbols
|
from text.symbols2 import symbols
|
||||||
|
|
||||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||||
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
_korean_classifiers = (
|
||||||
|
"군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통"
|
||||||
|
)
|
||||||
|
|
||||||
# List of (hangul, hangul divided) pairs:
|
# List of (hangul, hangul divided) pairs:
|
||||||
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
_hangul_divided = [
|
||||||
|
(re.compile("%s" % x[0]), x[1])
|
||||||
|
for x in [
|
||||||
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
|
# ('ㄳ', 'ㄱㅅ'), # g2pk2, A Syllable-ending Rule
|
||||||
# ('ㄵ', 'ㄴㅈ'),
|
# ('ㄵ', 'ㄴㅈ'),
|
||||||
# ('ㄶ', 'ㄴㅎ'),
|
# ('ㄶ', 'ㄴㅎ'),
|
||||||
@ -70,79 +78,86 @@ _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
|||||||
# ('ㄿ', 'ㄹㅍ'),
|
# ('ㄿ', 'ㄹㅍ'),
|
||||||
# ('ㅀ', 'ㄹㅎ'),
|
# ('ㅀ', 'ㄹㅎ'),
|
||||||
# ('ㅄ', 'ㅂㅅ'),
|
# ('ㅄ', 'ㅂㅅ'),
|
||||||
('ㅘ', 'ㅗㅏ'),
|
("ㅘ", "ㅗㅏ"),
|
||||||
('ㅙ', 'ㅗㅐ'),
|
("ㅙ", "ㅗㅐ"),
|
||||||
('ㅚ', 'ㅗㅣ'),
|
("ㅚ", "ㅗㅣ"),
|
||||||
('ㅝ', 'ㅜㅓ'),
|
("ㅝ", "ㅜㅓ"),
|
||||||
('ㅞ', 'ㅜㅔ'),
|
("ㅞ", "ㅜㅔ"),
|
||||||
('ㅟ', 'ㅜㅣ'),
|
("ㅟ", "ㅜㅣ"),
|
||||||
('ㅢ', 'ㅡㅣ'),
|
("ㅢ", "ㅡㅣ"),
|
||||||
('ㅑ', 'ㅣㅏ'),
|
("ㅑ", "ㅣㅏ"),
|
||||||
('ㅒ', 'ㅣㅐ'),
|
("ㅒ", "ㅣㅐ"),
|
||||||
('ㅕ', 'ㅣㅓ'),
|
("ㅕ", "ㅣㅓ"),
|
||||||
('ㅖ', 'ㅣㅔ'),
|
("ㅖ", "ㅣㅔ"),
|
||||||
('ㅛ', 'ㅣㅗ'),
|
("ㅛ", "ㅣㅗ"),
|
||||||
('ㅠ', 'ㅣㅜ')
|
("ㅠ", "ㅣㅜ"),
|
||||||
]]
|
]
|
||||||
|
]
|
||||||
|
|
||||||
# List of (Latin alphabet, hangul) pairs:
|
# List of (Latin alphabet, hangul) pairs:
|
||||||
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
_latin_to_hangul = [
|
||||||
('a', '에이'),
|
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
||||||
('b', '비'),
|
for x in [
|
||||||
('c', '시'),
|
("a", "에이"),
|
||||||
('d', '디'),
|
("b", "비"),
|
||||||
('e', '이'),
|
("c", "시"),
|
||||||
('f', '에프'),
|
("d", "디"),
|
||||||
('g', '지'),
|
("e", "이"),
|
||||||
('h', '에이치'),
|
("f", "에프"),
|
||||||
('i', '아이'),
|
("g", "지"),
|
||||||
('j', '제이'),
|
("h", "에이치"),
|
||||||
('k', '케이'),
|
("i", "아이"),
|
||||||
('l', '엘'),
|
("j", "제이"),
|
||||||
('m', '엠'),
|
("k", "케이"),
|
||||||
('n', '엔'),
|
("l", "엘"),
|
||||||
('o', '오'),
|
("m", "엠"),
|
||||||
('p', '피'),
|
("n", "엔"),
|
||||||
('q', '큐'),
|
("o", "오"),
|
||||||
('r', '아르'),
|
("p", "피"),
|
||||||
('s', '에스'),
|
("q", "큐"),
|
||||||
('t', '티'),
|
("r", "아르"),
|
||||||
('u', '유'),
|
("s", "에스"),
|
||||||
('v', '브이'),
|
("t", "티"),
|
||||||
('w', '더블유'),
|
("u", "유"),
|
||||||
('x', '엑스'),
|
("v", "브이"),
|
||||||
('y', '와이'),
|
("w", "더블유"),
|
||||||
('z', '제트')
|
("x", "엑스"),
|
||||||
]]
|
("y", "와이"),
|
||||||
|
("z", "제트"),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
# List of (ipa, lazy ipa) pairs:
|
# List of (ipa, lazy ipa) pairs:
|
||||||
_ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
_ipa_to_lazy_ipa = [
|
||||||
('t͡ɕ','ʧ'),
|
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
||||||
('d͡ʑ','ʥ'),
|
for x in [
|
||||||
('ɲ','n^'),
|
("t͡ɕ", "ʧ"),
|
||||||
('ɕ','ʃ'),
|
("d͡ʑ", "ʥ"),
|
||||||
('ʷ','w'),
|
("ɲ", "n^"),
|
||||||
('ɭ','l`'),
|
("ɕ", "ʃ"),
|
||||||
('ʎ','ɾ'),
|
("ʷ", "w"),
|
||||||
('ɣ','ŋ'),
|
("ɭ", "l`"),
|
||||||
('ɰ','ɯ'),
|
("ʎ", "ɾ"),
|
||||||
('ʝ','j'),
|
("ɣ", "ŋ"),
|
||||||
('ʌ','ə'),
|
("ɰ", "ɯ"),
|
||||||
('ɡ','g'),
|
("ʝ", "j"),
|
||||||
('\u031a','#'),
|
("ʌ", "ə"),
|
||||||
('\u0348','='),
|
("ɡ", "g"),
|
||||||
('\u031e',''),
|
("\u031a", "#"),
|
||||||
('\u0320',''),
|
("\u0348", "="),
|
||||||
('\u0339','')
|
("\u031e", ""),
|
||||||
]]
|
("\u0320", ""),
|
||||||
|
("\u0339", ""),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def fix_g2pk2_error(text):
|
def fix_g2pk2_error(text):
|
||||||
new_text = ""
|
new_text = ""
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(text) - 4:
|
while i < len(text) - 4:
|
||||||
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ':
|
if (text[i : i + 3] == "ㅇㅡㄹ" or text[i : i + 3] == "ㄹㅡㄹ") and text[i + 3] == " " and text[i + 4] == "ㄹ":
|
||||||
new_text += text[i:i+3] + ' ' + 'ㄴ'
|
new_text += text[i : i + 3] + " " + "ㄴ"
|
||||||
i += 5
|
i += 5
|
||||||
else:
|
else:
|
||||||
new_text += text[i]
|
new_text += text[i]
|
||||||
@ -166,20 +181,20 @@ def divide_hangul(text):
|
|||||||
|
|
||||||
|
|
||||||
def hangul_number(num, sino=True):
|
def hangul_number(num, sino=True):
|
||||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
"""Reference https://github.com/Kyubyong/g2pK"""
|
||||||
num = re.sub(',', '', num)
|
num = re.sub(",", "", num)
|
||||||
|
|
||||||
if num == '0':
|
if num == "0":
|
||||||
return '영'
|
return "영"
|
||||||
if not sino and num == '20':
|
if not sino and num == "20":
|
||||||
return '스무'
|
return "스무"
|
||||||
|
|
||||||
digits = '123456789'
|
digits = "123456789"
|
||||||
names = '일이삼사오육칠팔구'
|
names = "일이삼사오육칠팔구"
|
||||||
digit2name = {d: n for d, n in zip(digits, names)}
|
digit2name = {d: n for d, n in zip(digits, names)}
|
||||||
|
|
||||||
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
|
modifiers = "한 두 세 네 다섯 여섯 일곱 여덟 아홉"
|
||||||
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
|
decimals = "열 스물 서른 마흔 쉰 예순 일흔 여든 아흔"
|
||||||
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
||||||
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
||||||
|
|
||||||
@ -188,75 +203,75 @@ def hangul_number(num, sino=True):
|
|||||||
i = len(num) - i - 1
|
i = len(num) - i - 1
|
||||||
if sino:
|
if sino:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
name = digit2name.get(digit, '')
|
name = digit2name.get(digit, "")
|
||||||
elif i == 1:
|
elif i == 1:
|
||||||
name = digit2name.get(digit, '') + '십'
|
name = digit2name.get(digit, "") + "십"
|
||||||
name = name.replace('일십', '십')
|
name = name.replace("일십", "십")
|
||||||
else:
|
else:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
name = digit2mod.get(digit, '')
|
name = digit2mod.get(digit, "")
|
||||||
elif i == 1:
|
elif i == 1:
|
||||||
name = digit2dec.get(digit, '')
|
name = digit2dec.get(digit, "")
|
||||||
if digit == '0':
|
if digit == "0":
|
||||||
if i % 4 == 0:
|
if i % 4 == 0:
|
||||||
last_three = spelledout[-min(3, len(spelledout)) :]
|
last_three = spelledout[-min(3, len(spelledout)) :]
|
||||||
if ''.join(last_three) == '':
|
if "".join(last_three) == "":
|
||||||
spelledout.append('')
|
spelledout.append("")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
spelledout.append('')
|
spelledout.append("")
|
||||||
continue
|
continue
|
||||||
if i == 2:
|
if i == 2:
|
||||||
name = digit2name.get(digit, '') + '백'
|
name = digit2name.get(digit, "") + "백"
|
||||||
name = name.replace('일백', '백')
|
name = name.replace("일백", "백")
|
||||||
elif i == 3:
|
elif i == 3:
|
||||||
name = digit2name.get(digit, '') + '천'
|
name = digit2name.get(digit, "") + "천"
|
||||||
name = name.replace('일천', '천')
|
name = name.replace("일천", "천")
|
||||||
elif i == 4:
|
elif i == 4:
|
||||||
name = digit2name.get(digit, '') + '만'
|
name = digit2name.get(digit, "") + "만"
|
||||||
name = name.replace('일만', '만')
|
name = name.replace("일만", "만")
|
||||||
elif i == 5:
|
elif i == 5:
|
||||||
name = digit2name.get(digit, '') + '십'
|
name = digit2name.get(digit, "") + "십"
|
||||||
name = name.replace('일십', '십')
|
name = name.replace("일십", "십")
|
||||||
elif i == 6:
|
elif i == 6:
|
||||||
name = digit2name.get(digit, '') + '백'
|
name = digit2name.get(digit, "") + "백"
|
||||||
name = name.replace('일백', '백')
|
name = name.replace("일백", "백")
|
||||||
elif i == 7:
|
elif i == 7:
|
||||||
name = digit2name.get(digit, '') + '천'
|
name = digit2name.get(digit, "") + "천"
|
||||||
name = name.replace('일천', '천')
|
name = name.replace("일천", "천")
|
||||||
elif i == 8:
|
elif i == 8:
|
||||||
name = digit2name.get(digit, '') + '억'
|
name = digit2name.get(digit, "") + "억"
|
||||||
elif i == 9:
|
elif i == 9:
|
||||||
name = digit2name.get(digit, '') + '십'
|
name = digit2name.get(digit, "") + "십"
|
||||||
elif i == 10:
|
elif i == 10:
|
||||||
name = digit2name.get(digit, '') + '백'
|
name = digit2name.get(digit, "") + "백"
|
||||||
elif i == 11:
|
elif i == 11:
|
||||||
name = digit2name.get(digit, '') + '천'
|
name = digit2name.get(digit, "") + "천"
|
||||||
elif i == 12:
|
elif i == 12:
|
||||||
name = digit2name.get(digit, '') + '조'
|
name = digit2name.get(digit, "") + "조"
|
||||||
elif i == 13:
|
elif i == 13:
|
||||||
name = digit2name.get(digit, '') + '십'
|
name = digit2name.get(digit, "") + "십"
|
||||||
elif i == 14:
|
elif i == 14:
|
||||||
name = digit2name.get(digit, '') + '백'
|
name = digit2name.get(digit, "") + "백"
|
||||||
elif i == 15:
|
elif i == 15:
|
||||||
name = digit2name.get(digit, '') + '천'
|
name = digit2name.get(digit, "") + "천"
|
||||||
spelledout.append(name)
|
spelledout.append(name)
|
||||||
return ''.join(elem for elem in spelledout)
|
return "".join(elem for elem in spelledout)
|
||||||
|
|
||||||
|
|
||||||
def number_to_hangul(text):
|
def number_to_hangul(text):
|
||||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
"""Reference https://github.com/Kyubyong/g2pK"""
|
||||||
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
|
tokens = set(re.findall(r"(\d[\d,]*)([\uac00-\ud71f]+)", text))
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
num, classifier = token
|
num, classifier = token
|
||||||
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
||||||
spelledout = hangul_number(num, sino=False)
|
spelledout = hangul_number(num, sino=False)
|
||||||
else:
|
else:
|
||||||
spelledout = hangul_number(num, sino=True)
|
spelledout = hangul_number(num, sino=True)
|
||||||
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
|
text = text.replace(f"{num}{classifier}", f"{spelledout}{classifier}")
|
||||||
# digit by digit for remaining digits
|
# digit by digit for remaining digits
|
||||||
digits = '0123456789'
|
digits = "0123456789"
|
||||||
names = '영일이삼사오육칠팔구'
|
names = "영일이삼사오육칠팔구"
|
||||||
for d, n in zip(digits, names):
|
for d, n in zip(digits, names):
|
||||||
text = text.replace(d, n)
|
text = text.replace(d, n)
|
||||||
return text
|
return text
|
||||||
@ -265,19 +280,23 @@ def number_to_hangul(text):
|
|||||||
def korean_to_lazy_ipa(text):
|
def korean_to_lazy_ipa(text):
|
||||||
text = latin_to_hangul(text)
|
text = latin_to_hangul(text)
|
||||||
text = number_to_hangul(text)
|
text = number_to_hangul(text)
|
||||||
text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
|
text = re.sub("[\uac00-\ud7af]+", lambda x: ko_pron.romanise(x.group(0), "ipa").split("] ~ [")[0], text)
|
||||||
for regex, replacement in _ipa_to_lazy_ipa:
|
for regex, replacement in _ipa_to_lazy_ipa:
|
||||||
text = re.sub(regex, replacement, text)
|
text = re.sub(regex, replacement, text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
_g2p = G2p()
|
_g2p = G2p()
|
||||||
|
|
||||||
|
|
||||||
def korean_to_ipa(text):
|
def korean_to_ipa(text):
|
||||||
text = latin_to_hangul(text)
|
text = latin_to_hangul(text)
|
||||||
text = number_to_hangul(text)
|
text = number_to_hangul(text)
|
||||||
text = _g2p(text)
|
text = _g2p(text)
|
||||||
text = fix_g2pk2_error(text)
|
text = fix_g2pk2_error(text)
|
||||||
text = korean_to_lazy_ipa(text)
|
text = korean_to_lazy_ipa(text)
|
||||||
return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
|
return text.replace("ʧ", "tʃ").replace("ʥ", "dʑ")
|
||||||
|
|
||||||
|
|
||||||
def post_replace_ph(ph):
|
def post_replace_ph(ph):
|
||||||
rep_map = {
|
rep_map = {
|
||||||
@ -301,12 +320,13 @@ def post_replace_ph(ph):
|
|||||||
ph = "停"
|
ph = "停"
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
|
|
||||||
def g2p(text):
|
def g2p(text):
|
||||||
text = latin_to_hangul(text)
|
text = latin_to_hangul(text)
|
||||||
text = _g2p(text)
|
text = _g2p(text)
|
||||||
text = divide_hangul(text)
|
text = divide_hangul(text)
|
||||||
text = fix_g2pk2_error(text)
|
text = fix_g2pk2_error(text)
|
||||||
text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
|
text = re.sub(r"([\u3131-\u3163])$", r"\1.", text)
|
||||||
# text = "".join([post_replace_ph(i) for i in text])
|
# text = "".join([post_replace_ph(i) for i in text])
|
||||||
text = [post_replace_ph(i) for i in text]
|
text = [post_replace_ph(i) for i in text]
|
||||||
return text
|
return text
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
||||||
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
||||||
punctuation.append("-")
|
punctuation.append("-")
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
# punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿
|
||||||
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
punctuation = ["!", "?", "…", ",", "."] # @是SP停顿
|
||||||
punctuation.append("-")
|
punctuation.append("-")
|
||||||
@ -396,10 +394,390 @@ arpa = {
|
|||||||
"SH",
|
"SH",
|
||||||
}
|
}
|
||||||
|
|
||||||
ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停'
|
ko_symbols = "ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ空停"
|
||||||
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
# ko_symbols='ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
||||||
|
|
||||||
yue_symbols={'Yeot3', 'Yip1', 'Yyu3', 'Yeng4', 'Yut5', 'Yaan5', 'Ym5', 'Yaan6', 'Yang1', 'Yun4', 'Yon2', 'Yui5', 'Yun2', 'Yat3', 'Ye', 'Yeot1', 'Yoeng5', 'Yoek2', 'Yam2', 'Yeon6', 'Yu6', 'Yiu3', 'Yaang6', 'Yp5', 'Yai4', 'Yoek4', 'Yit6', 'Yam5', 'Yoeng6', 'Yg1', 'Yk3', 'Yoe4', 'Yam3', 'Yc', 'Yyu4', 'Yyut1', 'Yiu4', 'Ying3', 'Yip3', 'Yaap3', 'Yau3', 'Yan4', 'Yau1', 'Yap4', 'Yk6', 'Yok3', 'Yai1', 'Yeot6', 'Yan2', 'Yoek6', 'Yt1', 'Yoi1', 'Yit5', 'Yn4', 'Yaau3', 'Yau4', 'Yuk6', 'Ys', 'Yuk', 'Yin6', 'Yung6', 'Ya', 'You', 'Yaai5', 'Yau5', 'Yoi3', 'Yaak3', 'Yaat3', 'Ying2', 'Yok5', 'Yeng2', 'Yyut3', 'Yam1', 'Yip5', 'You1', 'Yam6', 'Yaa5', 'Yi6', 'Yek4', 'Yyu2', 'Yuk5', 'Yaam1', 'Yang2', 'Yai', 'Yiu6', 'Yin4', 'Yok4', 'Yot3', 'Yui2', 'Yeoi5', 'Yyun6', 'Yyu5', 'Yoi5', 'Yeot2', 'Yim4', 'Yeoi2', 'Yaan1', 'Yang6', 'Yong1', 'Yaang4', 'Yung5', 'Yeon1', 'Yin2', 'Ya3', 'Yaang3', 'Yg', 'Yk2', 'Yaau5', 'Yut1', 'Yt5', 'Yip4', 'Yung4', 'Yj', 'Yong3', 'Ya1', 'Yg6', 'Yaau6', 'Yit3', 'Yun3', 'Ying1', 'Yn2', 'Yg4', 'Yl', 'Yp3', 'Yn3', 'Yak1', 'Yang5', 'Yoe6', 'You2', 'Yap2', 'Yak2', 'Yt3', 'Yot5', 'Yim2', 'Yi1', 'Yn6', 'Yaat5', 'Yaam3', 'Yoek5', 'Ye3', 'Yeon4', 'Yaa2', 'Yu3', 'Yim6', 'Ym', 'Yoe3', 'Yaai2', 'Ym2', 'Ya6', 'Yeng6', 'Yik4', 'Yot4', 'Yaai4', 'Yyun3', 'Yu1', 'Yoeng1', 'Yaap2', 'Yuk3', 'Yoek3', 'Yeng5', 'Yeoi1', 'Yiu2', 'Yok1', 'Yo1', 'Yoek1', 'Yoeng2', 'Yeon5', 'Yiu1', 'Yoeng4', 'Yuk2', 'Yat4', 'Yg5', 'Yut4', 'Yan6', 'Yin3', 'Yaa6', 'Yap1', 'Yg2', 'Yoe5', 'Yt4', 'Ya5', 'Yo4', 'Yyu1', 'Yak3', 'Yeon2', 'Yong4', 'Ym1', 'Ye2', 'Yaang5', 'Yoi2', 'Yeng3', 'Yn', 'Yyut4', 'Yau', 'Yaak2', 'Yaan4', 'Yek2', 'Yin1', 'Yi5', 'Yoe2', 'Yei5', 'Yaat6', 'Yak5', 'Yp6', 'Yok6', 'Yei2', 'Yaap1', 'Yyut5', 'Yi4', 'Yim1', 'Yk5', 'Ye4', 'Yok2', 'Yaam6', 'Yat2', 'Yon6', 'Yei3', 'Yyu6', 'Yeot5', 'Yk4', 'Yai6', 'Yd', 'Yg3', 'Yei6', 'Yau2', 'Yok', 'Yau6', 'Yung3', 'Yim5', 'Yut6', 'Yit1', 'Yon3', 'Yat1', 'Yaam2', 'Yyut2', 'Yui6', 'Yt2', 'Yek6', 'Yt', 'Ye6', 'Yang3', 'Ying6', 'Yaau1', 'Yeon3', 'Yng', 'Yh', 'Yang4', 'Ying5', 'Yaap6', 'Yoeng3', 'Yyun4', 'You3', 'Yan5', 'Yat5', 'Yot1', 'Yun1', 'Yi3', 'Yaa1', 'Yaap4', 'You6', 'Yaang2', 'Yaap5', 'Yaa3', 'Yaak6', 'Yeng1', 'Yaak1', 'Yo5', 'Yoi4', 'Yam4', 'Yik1', 'Ye1', 'Yai5', 'Yung1', 'Yp2', 'Yui4', 'Yaak4', 'Yung2', 'Yak4', 'Yaat4', 'Yeoi4', 'Yut2', 'Yin5', 'Yaau4', 'Yap6', 'Yb', 'Yaam4', 'Yw', 'Yut3', 'Yong2', 'Yt6', 'Yaai6', 'Yap5', 'Yik5', 'Yun6', 'Yaam5', 'Yun5', 'Yik3', 'Ya2', 'Yyut6', 'Yon4', 'Yk1', 'Yit4', 'Yak6', 'Yaan2', 'Yuk1', 'Yai2', 'Yik2', 'Yaat2', 'Yo3', 'Ykw', 'Yn5', 'Yaa', 'Ye5', 'Yu4', 'Yei1', 'Yai3', 'Yyun5', 'Yip2', 'Yaau2', 'Yiu5', 'Ym4', 'Yeoi6', 'Yk', 'Ym6', 'Yoe1', 'Yeoi3', 'Yon', 'Yuk4', 'Yaai3', 'Yaa4', 'Yot6', 'Yaang1', 'Yei4', 'Yek1', 'Yo', 'Yp', 'Yo6', 'Yp4', 'Yan3', 'Yoi', 'Yap3', 'Yek3', 'Yim3', 'Yz', 'Yot2', 'Yoi6', 'Yit2', 'Yu5', 'Yaan3', 'Yan1', 'Yon5', 'Yp1', 'Yong5', 'Ygw', 'Yak', 'Yat6', 'Ying4', 'Yu2', 'Yf', 'Ya4', 'Yon1', 'You4', 'Yik6', 'Yui1', 'Yaat1', 'Yeot4', 'Yi2', 'Yaai1', 'Yek5', 'Ym3', 'Yong6', 'You5', 'Yyun1', 'Yn1', 'Yo2', 'Yip6', 'Yui3', 'Yaak5', 'Yyun2'}
|
yue_symbols = {
|
||||||
|
"Yeot3",
|
||||||
|
"Yip1",
|
||||||
|
"Yyu3",
|
||||||
|
"Yeng4",
|
||||||
|
"Yut5",
|
||||||
|
"Yaan5",
|
||||||
|
"Ym5",
|
||||||
|
"Yaan6",
|
||||||
|
"Yang1",
|
||||||
|
"Yun4",
|
||||||
|
"Yon2",
|
||||||
|
"Yui5",
|
||||||
|
"Yun2",
|
||||||
|
"Yat3",
|
||||||
|
"Ye",
|
||||||
|
"Yeot1",
|
||||||
|
"Yoeng5",
|
||||||
|
"Yoek2",
|
||||||
|
"Yam2",
|
||||||
|
"Yeon6",
|
||||||
|
"Yu6",
|
||||||
|
"Yiu3",
|
||||||
|
"Yaang6",
|
||||||
|
"Yp5",
|
||||||
|
"Yai4",
|
||||||
|
"Yoek4",
|
||||||
|
"Yit6",
|
||||||
|
"Yam5",
|
||||||
|
"Yoeng6",
|
||||||
|
"Yg1",
|
||||||
|
"Yk3",
|
||||||
|
"Yoe4",
|
||||||
|
"Yam3",
|
||||||
|
"Yc",
|
||||||
|
"Yyu4",
|
||||||
|
"Yyut1",
|
||||||
|
"Yiu4",
|
||||||
|
"Ying3",
|
||||||
|
"Yip3",
|
||||||
|
"Yaap3",
|
||||||
|
"Yau3",
|
||||||
|
"Yan4",
|
||||||
|
"Yau1",
|
||||||
|
"Yap4",
|
||||||
|
"Yk6",
|
||||||
|
"Yok3",
|
||||||
|
"Yai1",
|
||||||
|
"Yeot6",
|
||||||
|
"Yan2",
|
||||||
|
"Yoek6",
|
||||||
|
"Yt1",
|
||||||
|
"Yoi1",
|
||||||
|
"Yit5",
|
||||||
|
"Yn4",
|
||||||
|
"Yaau3",
|
||||||
|
"Yau4",
|
||||||
|
"Yuk6",
|
||||||
|
"Ys",
|
||||||
|
"Yuk",
|
||||||
|
"Yin6",
|
||||||
|
"Yung6",
|
||||||
|
"Ya",
|
||||||
|
"You",
|
||||||
|
"Yaai5",
|
||||||
|
"Yau5",
|
||||||
|
"Yoi3",
|
||||||
|
"Yaak3",
|
||||||
|
"Yaat3",
|
||||||
|
"Ying2",
|
||||||
|
"Yok5",
|
||||||
|
"Yeng2",
|
||||||
|
"Yyut3",
|
||||||
|
"Yam1",
|
||||||
|
"Yip5",
|
||||||
|
"You1",
|
||||||
|
"Yam6",
|
||||||
|
"Yaa5",
|
||||||
|
"Yi6",
|
||||||
|
"Yek4",
|
||||||
|
"Yyu2",
|
||||||
|
"Yuk5",
|
||||||
|
"Yaam1",
|
||||||
|
"Yang2",
|
||||||
|
"Yai",
|
||||||
|
"Yiu6",
|
||||||
|
"Yin4",
|
||||||
|
"Yok4",
|
||||||
|
"Yot3",
|
||||||
|
"Yui2",
|
||||||
|
"Yeoi5",
|
||||||
|
"Yyun6",
|
||||||
|
"Yyu5",
|
||||||
|
"Yoi5",
|
||||||
|
"Yeot2",
|
||||||
|
"Yim4",
|
||||||
|
"Yeoi2",
|
||||||
|
"Yaan1",
|
||||||
|
"Yang6",
|
||||||
|
"Yong1",
|
||||||
|
"Yaang4",
|
||||||
|
"Yung5",
|
||||||
|
"Yeon1",
|
||||||
|
"Yin2",
|
||||||
|
"Ya3",
|
||||||
|
"Yaang3",
|
||||||
|
"Yg",
|
||||||
|
"Yk2",
|
||||||
|
"Yaau5",
|
||||||
|
"Yut1",
|
||||||
|
"Yt5",
|
||||||
|
"Yip4",
|
||||||
|
"Yung4",
|
||||||
|
"Yj",
|
||||||
|
"Yong3",
|
||||||
|
"Ya1",
|
||||||
|
"Yg6",
|
||||||
|
"Yaau6",
|
||||||
|
"Yit3",
|
||||||
|
"Yun3",
|
||||||
|
"Ying1",
|
||||||
|
"Yn2",
|
||||||
|
"Yg4",
|
||||||
|
"Yl",
|
||||||
|
"Yp3",
|
||||||
|
"Yn3",
|
||||||
|
"Yak1",
|
||||||
|
"Yang5",
|
||||||
|
"Yoe6",
|
||||||
|
"You2",
|
||||||
|
"Yap2",
|
||||||
|
"Yak2",
|
||||||
|
"Yt3",
|
||||||
|
"Yot5",
|
||||||
|
"Yim2",
|
||||||
|
"Yi1",
|
||||||
|
"Yn6",
|
||||||
|
"Yaat5",
|
||||||
|
"Yaam3",
|
||||||
|
"Yoek5",
|
||||||
|
"Ye3",
|
||||||
|
"Yeon4",
|
||||||
|
"Yaa2",
|
||||||
|
"Yu3",
|
||||||
|
"Yim6",
|
||||||
|
"Ym",
|
||||||
|
"Yoe3",
|
||||||
|
"Yaai2",
|
||||||
|
"Ym2",
|
||||||
|
"Ya6",
|
||||||
|
"Yeng6",
|
||||||
|
"Yik4",
|
||||||
|
"Yot4",
|
||||||
|
"Yaai4",
|
||||||
|
"Yyun3",
|
||||||
|
"Yu1",
|
||||||
|
"Yoeng1",
|
||||||
|
"Yaap2",
|
||||||
|
"Yuk3",
|
||||||
|
"Yoek3",
|
||||||
|
"Yeng5",
|
||||||
|
"Yeoi1",
|
||||||
|
"Yiu2",
|
||||||
|
"Yok1",
|
||||||
|
"Yo1",
|
||||||
|
"Yoek1",
|
||||||
|
"Yoeng2",
|
||||||
|
"Yeon5",
|
||||||
|
"Yiu1",
|
||||||
|
"Yoeng4",
|
||||||
|
"Yuk2",
|
||||||
|
"Yat4",
|
||||||
|
"Yg5",
|
||||||
|
"Yut4",
|
||||||
|
"Yan6",
|
||||||
|
"Yin3",
|
||||||
|
"Yaa6",
|
||||||
|
"Yap1",
|
||||||
|
"Yg2",
|
||||||
|
"Yoe5",
|
||||||
|
"Yt4",
|
||||||
|
"Ya5",
|
||||||
|
"Yo4",
|
||||||
|
"Yyu1",
|
||||||
|
"Yak3",
|
||||||
|
"Yeon2",
|
||||||
|
"Yong4",
|
||||||
|
"Ym1",
|
||||||
|
"Ye2",
|
||||||
|
"Yaang5",
|
||||||
|
"Yoi2",
|
||||||
|
"Yeng3",
|
||||||
|
"Yn",
|
||||||
|
"Yyut4",
|
||||||
|
"Yau",
|
||||||
|
"Yaak2",
|
||||||
|
"Yaan4",
|
||||||
|
"Yek2",
|
||||||
|
"Yin1",
|
||||||
|
"Yi5",
|
||||||
|
"Yoe2",
|
||||||
|
"Yei5",
|
||||||
|
"Yaat6",
|
||||||
|
"Yak5",
|
||||||
|
"Yp6",
|
||||||
|
"Yok6",
|
||||||
|
"Yei2",
|
||||||
|
"Yaap1",
|
||||||
|
"Yyut5",
|
||||||
|
"Yi4",
|
||||||
|
"Yim1",
|
||||||
|
"Yk5",
|
||||||
|
"Ye4",
|
||||||
|
"Yok2",
|
||||||
|
"Yaam6",
|
||||||
|
"Yat2",
|
||||||
|
"Yon6",
|
||||||
|
"Yei3",
|
||||||
|
"Yyu6",
|
||||||
|
"Yeot5",
|
||||||
|
"Yk4",
|
||||||
|
"Yai6",
|
||||||
|
"Yd",
|
||||||
|
"Yg3",
|
||||||
|
"Yei6",
|
||||||
|
"Yau2",
|
||||||
|
"Yok",
|
||||||
|
"Yau6",
|
||||||
|
"Yung3",
|
||||||
|
"Yim5",
|
||||||
|
"Yut6",
|
||||||
|
"Yit1",
|
||||||
|
"Yon3",
|
||||||
|
"Yat1",
|
||||||
|
"Yaam2",
|
||||||
|
"Yyut2",
|
||||||
|
"Yui6",
|
||||||
|
"Yt2",
|
||||||
|
"Yek6",
|
||||||
|
"Yt",
|
||||||
|
"Ye6",
|
||||||
|
"Yang3",
|
||||||
|
"Ying6",
|
||||||
|
"Yaau1",
|
||||||
|
"Yeon3",
|
||||||
|
"Yng",
|
||||||
|
"Yh",
|
||||||
|
"Yang4",
|
||||||
|
"Ying5",
|
||||||
|
"Yaap6",
|
||||||
|
"Yoeng3",
|
||||||
|
"Yyun4",
|
||||||
|
"You3",
|
||||||
|
"Yan5",
|
||||||
|
"Yat5",
|
||||||
|
"Yot1",
|
||||||
|
"Yun1",
|
||||||
|
"Yi3",
|
||||||
|
"Yaa1",
|
||||||
|
"Yaap4",
|
||||||
|
"You6",
|
||||||
|
"Yaang2",
|
||||||
|
"Yaap5",
|
||||||
|
"Yaa3",
|
||||||
|
"Yaak6",
|
||||||
|
"Yeng1",
|
||||||
|
"Yaak1",
|
||||||
|
"Yo5",
|
||||||
|
"Yoi4",
|
||||||
|
"Yam4",
|
||||||
|
"Yik1",
|
||||||
|
"Ye1",
|
||||||
|
"Yai5",
|
||||||
|
"Yung1",
|
||||||
|
"Yp2",
|
||||||
|
"Yui4",
|
||||||
|
"Yaak4",
|
||||||
|
"Yung2",
|
||||||
|
"Yak4",
|
||||||
|
"Yaat4",
|
||||||
|
"Yeoi4",
|
||||||
|
"Yut2",
|
||||||
|
"Yin5",
|
||||||
|
"Yaau4",
|
||||||
|
"Yap6",
|
||||||
|
"Yb",
|
||||||
|
"Yaam4",
|
||||||
|
"Yw",
|
||||||
|
"Yut3",
|
||||||
|
"Yong2",
|
||||||
|
"Yt6",
|
||||||
|
"Yaai6",
|
||||||
|
"Yap5",
|
||||||
|
"Yik5",
|
||||||
|
"Yun6",
|
||||||
|
"Yaam5",
|
||||||
|
"Yun5",
|
||||||
|
"Yik3",
|
||||||
|
"Ya2",
|
||||||
|
"Yyut6",
|
||||||
|
"Yon4",
|
||||||
|
"Yk1",
|
||||||
|
"Yit4",
|
||||||
|
"Yak6",
|
||||||
|
"Yaan2",
|
||||||
|
"Yuk1",
|
||||||
|
"Yai2",
|
||||||
|
"Yik2",
|
||||||
|
"Yaat2",
|
||||||
|
"Yo3",
|
||||||
|
"Ykw",
|
||||||
|
"Yn5",
|
||||||
|
"Yaa",
|
||||||
|
"Ye5",
|
||||||
|
"Yu4",
|
||||||
|
"Yei1",
|
||||||
|
"Yai3",
|
||||||
|
"Yyun5",
|
||||||
|
"Yip2",
|
||||||
|
"Yaau2",
|
||||||
|
"Yiu5",
|
||||||
|
"Ym4",
|
||||||
|
"Yeoi6",
|
||||||
|
"Yk",
|
||||||
|
"Ym6",
|
||||||
|
"Yoe1",
|
||||||
|
"Yeoi3",
|
||||||
|
"Yon",
|
||||||
|
"Yuk4",
|
||||||
|
"Yaai3",
|
||||||
|
"Yaa4",
|
||||||
|
"Yot6",
|
||||||
|
"Yaang1",
|
||||||
|
"Yei4",
|
||||||
|
"Yek1",
|
||||||
|
"Yo",
|
||||||
|
"Yp",
|
||||||
|
"Yo6",
|
||||||
|
"Yp4",
|
||||||
|
"Yan3",
|
||||||
|
"Yoi",
|
||||||
|
"Yap3",
|
||||||
|
"Yek3",
|
||||||
|
"Yim3",
|
||||||
|
"Yz",
|
||||||
|
"Yot2",
|
||||||
|
"Yoi6",
|
||||||
|
"Yit2",
|
||||||
|
"Yu5",
|
||||||
|
"Yaan3",
|
||||||
|
"Yan1",
|
||||||
|
"Yon5",
|
||||||
|
"Yp1",
|
||||||
|
"Yong5",
|
||||||
|
"Ygw",
|
||||||
|
"Yak",
|
||||||
|
"Yat6",
|
||||||
|
"Ying4",
|
||||||
|
"Yu2",
|
||||||
|
"Yf",
|
||||||
|
"Ya4",
|
||||||
|
"Yon1",
|
||||||
|
"You4",
|
||||||
|
"Yik6",
|
||||||
|
"Yui1",
|
||||||
|
"Yaat1",
|
||||||
|
"Yeot4",
|
||||||
|
"Yi2",
|
||||||
|
"Yaai1",
|
||||||
|
"Yek5",
|
||||||
|
"Ym3",
|
||||||
|
"Yong6",
|
||||||
|
"You5",
|
||||||
|
"Yyun1",
|
||||||
|
"Yn1",
|
||||||
|
"Yo2",
|
||||||
|
"Yip6",
|
||||||
|
"Yui3",
|
||||||
|
"Yaak5",
|
||||||
|
"Yyun2",
|
||||||
|
}
|
||||||
|
|
||||||
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
|
# symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)+list(ko_symbols)#+list(yue_symbols)###直接这么加yue顺序乱了
|
||||||
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
|
symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa)
|
||||||
@ -411,9 +789,9 @@ symbols+=sorted(list(yue_symbols))##新加的yue统一摆在后头#已查过开
|
|||||||
# print(len(symbols))
|
# print(len(symbols))
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(len(symbols))
|
print(len(symbols))
|
||||||
'''
|
"""
|
||||||
粤语:
|
粤语:
|
||||||
732-353=379
|
732-353=379
|
||||||
韩文+粤语:
|
韩文+粤语:
|
||||||
732-322=410
|
732-322=410
|
||||||
'''
|
"""
|
||||||
|
@ -510,12 +510,7 @@ class ToneSandhi:
|
|||||||
# e.g. 走了, 看着, 去过
|
# e.g. 走了, 看着, 去过
|
||||||
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
elif (
|
elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"} and word not in self.must_not_neural_tone_words:
|
||||||
len(word) > 1
|
|
||||||
and word[-1] in "们子"
|
|
||||||
and pos in {"r", "n"}
|
|
||||||
and word not in self.must_not_neural_tone_words
|
|
||||||
):
|
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
# e.g. 桌上, 地下, 家里
|
# e.g. 桌上, 地下, 家里
|
||||||
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
||||||
@ -525,25 +520,18 @@ class ToneSandhi:
|
|||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
# 个做量词
|
# 个做量词
|
||||||
elif (
|
elif (
|
||||||
ge_idx >= 1
|
ge_idx >= 1 and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
|
||||||
and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
|
|
||||||
) or word == "个":
|
) or word == "个":
|
||||||
finals[ge_idx] = finals[ge_idx][:-1] + "5"
|
finals[ge_idx] = finals[ge_idx][:-1] + "5"
|
||||||
else:
|
else:
|
||||||
if (
|
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||||
word in self.must_neural_tone_words
|
|
||||||
or word[-2:] in self.must_neural_tone_words
|
|
||||||
):
|
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
|
|
||||||
word_list = self._split_word(word)
|
word_list = self._split_word(word)
|
||||||
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
|
finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
|
||||||
for i, word in enumerate(word_list):
|
for i, word in enumerate(word_list):
|
||||||
# conventional neural in Chinese
|
# conventional neural in Chinese
|
||||||
if (
|
if word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||||
word in self.must_neural_tone_words
|
|
||||||
or word[-2:] in self.must_neural_tone_words
|
|
||||||
):
|
|
||||||
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
|
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
|
||||||
finals = sum(finals_list, [])
|
finals = sum(finals_list, [])
|
||||||
return finals
|
return finals
|
||||||
@ -561,9 +549,7 @@ class ToneSandhi:
|
|||||||
|
|
||||||
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||||
# "一" in number sequences, e.g. 一零零, 二一零
|
# "一" in number sequences, e.g. 一零零, 二一零
|
||||||
if word.find("一") != -1 and all(
|
if word.find("一") != -1 and all([item.isnumeric() for item in word if item != "一"]):
|
||||||
[item.isnumeric() for item in word if item != "一"]
|
|
||||||
):
|
|
||||||
return finals
|
return finals
|
||||||
# "一" between reduplication words shold be yi5, e.g. 看一看
|
# "一" between reduplication words shold be yi5, e.g. 看一看
|
||||||
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
||||||
@ -697,13 +683,10 @@ class ToneSandhi:
|
|||||||
return new_seg
|
return new_seg
|
||||||
|
|
||||||
# the first and the second words are all_tone_three
|
# the first and the second words are all_tone_three
|
||||||
def _merge_continuous_three_tones(
|
def _merge_continuous_three_tones(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
self, seg: List[Tuple[str, str]]
|
|
||||||
) -> List[Tuple[str, str]]:
|
|
||||||
new_seg = []
|
new_seg = []
|
||||||
sub_finals_list = [
|
sub_finals_list = [
|
||||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
|
||||||
for (word, pos) in seg
|
|
||||||
]
|
]
|
||||||
assert len(sub_finals_list) == len(seg)
|
assert len(sub_finals_list) == len(seg)
|
||||||
merge_last = [False] * len(seg)
|
merge_last = [False] * len(seg)
|
||||||
@ -715,10 +698,7 @@ class ToneSandhi:
|
|||||||
and not merge_last[i - 1]
|
and not merge_last[i - 1]
|
||||||
):
|
):
|
||||||
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||||
if (
|
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||||
not self._is_reduplication(seg[i - 1][0])
|
|
||||||
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
|
||||||
):
|
|
||||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||||
merge_last[i] = True
|
merge_last[i] = True
|
||||||
else:
|
else:
|
||||||
@ -732,13 +712,10 @@ class ToneSandhi:
|
|||||||
return len(word) == 2 and word[0] == word[1]
|
return len(word) == 2 and word[0] == word[1]
|
||||||
|
|
||||||
# the last char of first word and the first char of second word is tone_three
|
# the last char of first word and the first char of second word is tone_three
|
||||||
def _merge_continuous_three_tones_2(
|
def _merge_continuous_three_tones_2(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
self, seg: List[Tuple[str, str]]
|
|
||||||
) -> List[Tuple[str, str]]:
|
|
||||||
new_seg = []
|
new_seg = []
|
||||||
sub_finals_list = [
|
sub_finals_list = [
|
||||||
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) in seg
|
||||||
for (word, pos) in seg
|
|
||||||
]
|
]
|
||||||
assert len(sub_finals_list) == len(seg)
|
assert len(sub_finals_list) == len(seg)
|
||||||
merge_last = [False] * len(seg)
|
merge_last = [False] * len(seg)
|
||||||
@ -750,10 +727,7 @@ class ToneSandhi:
|
|||||||
and not merge_last[i - 1]
|
and not merge_last[i - 1]
|
||||||
):
|
):
|
||||||
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||||
if (
|
if not self._is_reduplication(seg[i - 1][0]) and len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||||
not self._is_reduplication(seg[i - 1][0])
|
|
||||||
and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
|
|
||||||
):
|
|
||||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||||
merge_last[i] = True
|
merge_last[i] = True
|
||||||
else:
|
else:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -21,25 +21,29 @@ from .num import verbalize_digit
|
|||||||
|
|
||||||
def _time_num2str(num_string: str) -> str:
|
def _time_num2str(num_string: str) -> str:
|
||||||
"""A special case for verbalizing number in time."""
|
"""A special case for verbalizing number in time."""
|
||||||
result = num2str(num_string.lstrip('0'))
|
result = num2str(num_string.lstrip("0"))
|
||||||
if num_string.startswith('0'):
|
if num_string.startswith("0"):
|
||||||
result = DIGITS['0'] + result
|
result = DIGITS["0"] + result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# 时刻表达式
|
# 时刻表达式
|
||||||
RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
RE_TIME = re.compile(
|
||||||
r':([0-5][0-9])'
|
r"([0-1]?[0-9]|2[0-3])"
|
||||||
r'(:([0-5][0-9]))?')
|
r":([0-5][0-9])"
|
||||||
|
r"(:([0-5][0-9]))?"
|
||||||
|
)
|
||||||
|
|
||||||
# 时间范围,如8:30-12:30
|
# 时间范围,如8:30-12:30
|
||||||
RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
RE_TIME_RANGE = re.compile(
|
||||||
r':([0-5][0-9])'
|
r"([0-1]?[0-9]|2[0-3])"
|
||||||
r'(:([0-5][0-9]))?'
|
r":([0-5][0-9])"
|
||||||
r'(~|-)'
|
r"(:([0-5][0-9]))?"
|
||||||
r'([0-1]?[0-9]|2[0-3])'
|
r"(~|-)"
|
||||||
r':([0-5][0-9])'
|
r"([0-1]?[0-9]|2[0-3])"
|
||||||
r'(:([0-5][0-9]))?')
|
r":([0-5][0-9])"
|
||||||
|
r"(:([0-5][0-9]))?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_time(match) -> str:
|
def replace_time(match) -> str:
|
||||||
@ -62,31 +66,33 @@ def replace_time(match) -> str:
|
|||||||
second_2 = match.group(9)
|
second_2 = match.group(9)
|
||||||
|
|
||||||
result = f"{num2str(hour)}点"
|
result = f"{num2str(hour)}点"
|
||||||
if minute.lstrip('0'):
|
if minute.lstrip("0"):
|
||||||
if int(minute) == 30:
|
if int(minute) == 30:
|
||||||
result += "半"
|
result += "半"
|
||||||
else:
|
else:
|
||||||
result += f"{_time_num2str(minute)}分"
|
result += f"{_time_num2str(minute)}分"
|
||||||
if second and second.lstrip('0'):
|
if second and second.lstrip("0"):
|
||||||
result += f"{_time_num2str(second)}秒"
|
result += f"{_time_num2str(second)}秒"
|
||||||
|
|
||||||
if is_range:
|
if is_range:
|
||||||
result += "至"
|
result += "至"
|
||||||
result += f"{num2str(hour_2)}点"
|
result += f"{num2str(hour_2)}点"
|
||||||
if minute_2.lstrip('0'):
|
if minute_2.lstrip("0"):
|
||||||
if int(minute) == 30:
|
if int(minute) == 30:
|
||||||
result += "半"
|
result += "半"
|
||||||
else:
|
else:
|
||||||
result += f"{_time_num2str(minute_2)}分"
|
result += f"{_time_num2str(minute_2)}分"
|
||||||
if second_2 and second_2.lstrip('0'):
|
if second_2 and second_2.lstrip("0"):
|
||||||
result += f"{_time_num2str(second_2)}秒"
|
result += f"{_time_num2str(second_2)}秒"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
RE_DATE = re.compile(r'(\d{4}|\d{2})年'
|
RE_DATE = re.compile(
|
||||||
r'((0?[1-9]|1[0-2])月)?'
|
r"(\d{4}|\d{2})年"
|
||||||
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
|
r"((0?[1-9]|1[0-2])月)?"
|
||||||
|
r"(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_date(match) -> str:
|
def replace_date(match) -> str:
|
||||||
@ -110,8 +116,7 @@ def replace_date(match) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
|
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
|
||||||
RE_DATE2 = re.compile(
|
RE_DATE2 = re.compile(r"(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])")
|
||||||
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
|
|
||||||
|
|
||||||
|
|
||||||
def replace_date2(match) -> str:
|
def replace_date2(match) -> str:
|
||||||
|
@ -18,10 +18,7 @@ from pypinyin.constants import SUPPORT_UCS4
|
|||||||
|
|
||||||
# 全角半角转换
|
# 全角半角转换
|
||||||
# 英文字符全角 -> 半角映射表 (num: 52)
|
# 英文字符全角 -> 半角映射表 (num: 52)
|
||||||
F2H_ASCII_LETTERS = {
|
F2H_ASCII_LETTERS = {ord(char) + 65248: ord(char) for char in string.ascii_letters}
|
||||||
ord(char) + 65248: ord(char)
|
|
||||||
for char in string.ascii_letters
|
|
||||||
}
|
|
||||||
|
|
||||||
# 英文字符半角 -> 全角映射表
|
# 英文字符半角 -> 全角映射表
|
||||||
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
|
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
|
||||||
@ -37,26 +34,29 @@ F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation}
|
|||||||
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
|
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
|
||||||
|
|
||||||
# 空格 (num: 1)
|
# 空格 (num: 1)
|
||||||
F2H_SPACE = {'\u3000': ' '}
|
F2H_SPACE = {"\u3000": " "}
|
||||||
H2F_SPACE = {' ': '\u3000'}
|
H2F_SPACE = {" ": "\u3000"}
|
||||||
|
|
||||||
# 非"有拼音的汉字"的字符串,可用于NSW提取
|
# 非"有拼音的汉字"的字符串,可用于NSW提取
|
||||||
if SUPPORT_UCS4:
|
if SUPPORT_UCS4:
|
||||||
RE_NSW = re.compile(r'(?:[^'
|
RE_NSW = re.compile(
|
||||||
r'\u3007' # 〇
|
r"(?:[^"
|
||||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
r"\u3007" # 〇
|
||||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
|
||||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
|
||||||
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
|
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
|
||||||
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
|
r"\U00020000-\U0002A6DF" # CJK扩展B:[20000-2A6DF]
|
||||||
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
|
r"\U0002A703-\U0002B73F" # CJK扩展C:[2A700-2B73F]
|
||||||
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
|
r"\U0002B740-\U0002B81D" # CJK扩展D:[2B740-2B81D]
|
||||||
r'])+')
|
r"\U0002F80A-\U0002FA1F" # CJK兼容扩展:[2F800-2FA1F]
|
||||||
|
r"])+"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
RE_NSW = re.compile( # pragma: no cover
|
RE_NSW = re.compile( # pragma: no cover
|
||||||
r'(?:[^'
|
r"(?:[^"
|
||||||
r'\u3007' # 〇
|
r"\u3007" # 〇
|
||||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
r"\u3400-\u4dbf" # CJK扩展A:[3400-4DBF]
|
||||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
r"\u4e00-\u9fff" # CJK基本:[4E00-9FFF]
|
||||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
r"\uf900-\ufaff" # CJK兼容:[F900-FAFF]
|
||||||
r'])+')
|
r"])+"
|
||||||
|
)
|
||||||
|
@ -15,23 +15,26 @@
|
|||||||
Rules to verbalize numbers into Chinese characters.
|
Rules to verbalize numbers into Chinese characters.
|
||||||
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")}
|
||||||
UNITS = OrderedDict({
|
UNITS = OrderedDict(
|
||||||
1: '十',
|
{
|
||||||
2: '百',
|
1: "十",
|
||||||
3: '千',
|
2: "百",
|
||||||
4: '万',
|
3: "千",
|
||||||
8: '亿',
|
4: "万",
|
||||||
})
|
8: "亿",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
COM_QUANTIFIERS = '(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
||||||
|
|
||||||
# 分数表达式
|
# 分数表达式
|
||||||
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)")
|
||||||
|
|
||||||
|
|
||||||
def replace_frac(match) -> str:
|
def replace_frac(match) -> str:
|
||||||
@ -52,7 +55,7 @@ def replace_frac(match) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# 百分数表达式
|
# 百分数表达式
|
||||||
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%")
|
||||||
|
|
||||||
|
|
||||||
def replace_percentage(match) -> str:
|
def replace_percentage(match) -> str:
|
||||||
@ -72,7 +75,7 @@ def replace_percentage(match) -> str:
|
|||||||
|
|
||||||
# 整数表达式
|
# 整数表达式
|
||||||
# 带负号的整数 -10
|
# 带负号的整数 -10
|
||||||
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
RE_INTEGER = re.compile(r"(-)" r"(\d+)")
|
||||||
|
|
||||||
|
|
||||||
def replace_negative_num(match) -> str:
|
def replace_negative_num(match) -> str:
|
||||||
@ -92,7 +95,7 @@ def replace_negative_num(match) -> str:
|
|||||||
|
|
||||||
# 编号-无符号整形
|
# 编号-无符号整形
|
||||||
# 00078
|
# 00078
|
||||||
RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
|
RE_DEFAULT_NUM = re.compile(r"\d{3}\d*")
|
||||||
|
|
||||||
|
|
||||||
def replace_default_num(match):
|
def replace_default_num(match):
|
||||||
@ -110,15 +113,11 @@ def replace_default_num(match):
|
|||||||
# RE_ASMD = re.compile(
|
# RE_ASMD = re.compile(
|
||||||
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
||||||
RE_ASMD = re.compile(
|
RE_ASMD = re.compile(
|
||||||
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
|
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
|
||||||
|
)
|
||||||
|
|
||||||
|
asmd_map = {"+": "加", "-": "减", "×": "乘", "÷": "除", "=": "等于"}
|
||||||
|
|
||||||
asmd_map = {
|
|
||||||
'+': '加',
|
|
||||||
'-': '减',
|
|
||||||
'×': '乘',
|
|
||||||
'÷': '除',
|
|
||||||
'=': '等于'
|
|
||||||
}
|
|
||||||
|
|
||||||
def replace_asmd(match) -> str:
|
def replace_asmd(match) -> str:
|
||||||
"""
|
"""
|
||||||
@ -132,24 +131,25 @@ def replace_asmd(match) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# 次方专项
|
# 次方专项
|
||||||
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
|
RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+")
|
||||||
|
|
||||||
power_map = {
|
power_map = {
|
||||||
'⁰': '0',
|
"⁰": "0",
|
||||||
'¹': '1',
|
"¹": "1",
|
||||||
'²': '2',
|
"²": "2",
|
||||||
'³': '3',
|
"³": "3",
|
||||||
'⁴': '4',
|
"⁴": "4",
|
||||||
'⁵': '5',
|
"⁵": "5",
|
||||||
'⁶': '6',
|
"⁶": "6",
|
||||||
'⁷': '7',
|
"⁷": "7",
|
||||||
'⁸': '8',
|
"⁸": "8",
|
||||||
'⁹': '9',
|
"⁹": "9",
|
||||||
'ˣ': 'x',
|
"ˣ": "x",
|
||||||
'ʸ': 'y',
|
"ʸ": "y",
|
||||||
'ⁿ': 'n'
|
"ⁿ": "n",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def replace_power(match) -> str:
|
def replace_power(match) -> str:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -166,10 +166,10 @@ def replace_power(match) -> str:
|
|||||||
|
|
||||||
# 数字表达式
|
# 数字表达式
|
||||||
# 纯小数
|
# 纯小数
|
||||||
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))")
|
||||||
# 正整数 + 量词
|
# 正整数 + 量词
|
||||||
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
||||||
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))")
|
||||||
|
|
||||||
|
|
||||||
def replace_positive_quantifier(match) -> str:
|
def replace_positive_quantifier(match) -> str:
|
||||||
@ -220,7 +220,9 @@ RE_RANGE = re.compile(
|
|||||||
[-~] # 匹配范围分隔符
|
[-~] # 匹配范围分隔符
|
||||||
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
||||||
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
||||||
""", re.VERBOSE)
|
""",
|
||||||
|
re.VERBOSE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_range(match) -> str:
|
def replace_range(match) -> str:
|
||||||
@ -239,7 +241,9 @@ def replace_range(match) -> str:
|
|||||||
|
|
||||||
# ~至表达式
|
# ~至表达式
|
||||||
RE_TO_RANGE = re.compile(
|
RE_TO_RANGE = re.compile(
|
||||||
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
|
r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_to_range(match) -> str:
|
def replace_to_range(match) -> str:
|
||||||
"""
|
"""
|
||||||
@ -248,71 +252,66 @@ def replace_to_range(match) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
str
|
str
|
||||||
"""
|
"""
|
||||||
result = match.group(0).replace('~', '至')
|
result = match.group(0).replace("~", "至")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
|
def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
|
||||||
stripped = value_string.lstrip('0')
|
stripped = value_string.lstrip("0")
|
||||||
if len(stripped) == 0:
|
if len(stripped) == 0:
|
||||||
return []
|
return []
|
||||||
elif len(stripped) == 1:
|
elif len(stripped) == 1:
|
||||||
if use_zero and len(stripped) < len(value_string):
|
if use_zero and len(stripped) < len(value_string):
|
||||||
return [DIGITS['0'], DIGITS[stripped]]
|
return [DIGITS["0"], DIGITS[stripped]]
|
||||||
else:
|
else:
|
||||||
return [DIGITS[stripped]]
|
return [DIGITS[stripped]]
|
||||||
else:
|
else:
|
||||||
largest_unit = next(
|
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||||
power for power in reversed(UNITS.keys()) if power < len(stripped))
|
|
||||||
first_part = value_string[:-largest_unit]
|
first_part = value_string[:-largest_unit]
|
||||||
second_part = value_string[-largest_unit:]
|
second_part = value_string[-largest_unit:]
|
||||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
|
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
|
||||||
second_part)
|
|
||||||
|
|
||||||
|
|
||||||
def verbalize_cardinal(value_string: str) -> str:
|
def verbalize_cardinal(value_string: str) -> str:
|
||||||
if not value_string:
|
if not value_string:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
# 000 -> '零' , 0 -> '零'
|
# 000 -> '零' , 0 -> '零'
|
||||||
value_string = value_string.lstrip('0')
|
value_string = value_string.lstrip("0")
|
||||||
if len(value_string) == 0:
|
if len(value_string) == 0:
|
||||||
return DIGITS['0']
|
return DIGITS["0"]
|
||||||
|
|
||||||
result_symbols = _get_value(value_string)
|
result_symbols = _get_value(value_string)
|
||||||
# verbalized number starting with '一十*' is abbreviated as `十*`
|
# verbalized number starting with '一十*' is abbreviated as `十*`
|
||||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
|
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS["1"] and result_symbols[1] == UNITS[1]:
|
||||||
'1'] and result_symbols[1] == UNITS[1]:
|
|
||||||
result_symbols = result_symbols[1:]
|
result_symbols = result_symbols[1:]
|
||||||
return ''.join(result_symbols)
|
return "".join(result_symbols)
|
||||||
|
|
||||||
|
|
||||||
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
||||||
result_symbols = [DIGITS[digit] for digit in value_string]
|
result_symbols = [DIGITS[digit] for digit in value_string]
|
||||||
result = ''.join(result_symbols)
|
result = "".join(result_symbols)
|
||||||
if alt_one:
|
if alt_one:
|
||||||
result = result.replace("一", "幺")
|
result = result.replace("一", "幺")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def num2str(value_string: str) -> str:
|
def num2str(value_string: str) -> str:
|
||||||
integer_decimal = value_string.split('.')
|
integer_decimal = value_string.split(".")
|
||||||
if len(integer_decimal) == 1:
|
if len(integer_decimal) == 1:
|
||||||
integer = integer_decimal[0]
|
integer = integer_decimal[0]
|
||||||
decimal = ''
|
decimal = ""
|
||||||
elif len(integer_decimal) == 2:
|
elif len(integer_decimal) == 2:
|
||||||
integer, decimal = integer_decimal
|
integer, decimal = integer_decimal
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
|
||||||
f"The value string: '${value_string}' has more than one point in it."
|
|
||||||
)
|
|
||||||
|
|
||||||
result = verbalize_cardinal(integer)
|
result = verbalize_cardinal(integer)
|
||||||
|
|
||||||
decimal = decimal.rstrip('0')
|
decimal = decimal.rstrip("0")
|
||||||
if decimal:
|
if decimal:
|
||||||
# '.22' is verbalized as '零点二二'
|
# '.22' is verbalized as '零点二二'
|
||||||
# '3.20' is verbalized as '三点二
|
# '3.20' is verbalized as '三点二
|
||||||
result = result if result else "零"
|
result = result if result else "零"
|
||||||
result += '点' + verbalize_digit(decimal)
|
result += "点" + verbalize_digit(decimal)
|
||||||
return result
|
return result
|
||||||
|
@ -21,10 +21,8 @@ from .num import verbalize_digit
|
|||||||
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
||||||
# 联通:130、131、132、156、155、186、185、176
|
# 联通:130、131、132、156、155、186、185、176
|
||||||
# 电信:133、153、189、180、181、177
|
# 电信:133、153、189、180、181、177
|
||||||
RE_MOBILE_PHONE = re.compile(
|
RE_MOBILE_PHONE = re.compile(r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||||
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
RE_TELEPHONE = re.compile(r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
||||||
RE_TELEPHONE = re.compile(
|
|
||||||
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
|
||||||
|
|
||||||
# 全国统一的号码400开头
|
# 全国统一的号码400开头
|
||||||
RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
||||||
@ -32,14 +30,12 @@ RE_NATIONAL_UNIFORM_NUMBER = re.compile(r"(400)(-)?\d{3}(-)?\d{4}")
|
|||||||
|
|
||||||
def phone2str(phone_string: str, mobile=True) -> str:
|
def phone2str(phone_string: str, mobile=True) -> str:
|
||||||
if mobile:
|
if mobile:
|
||||||
sp_parts = phone_string.strip('+').split()
|
sp_parts = phone_string.strip("+").split()
|
||||||
result = ','.join(
|
result = ",".join([verbalize_digit(part, alt_one=True) for part in sp_parts])
|
||||||
[verbalize_digit(part, alt_one=True) for part in sp_parts])
|
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
sil_parts = phone_string.split('-')
|
sil_parts = phone_string.split("-")
|
||||||
result = ','.join(
|
result = ",".join([verbalize_digit(part, alt_one=True) for part in sil_parts])
|
||||||
[verbalize_digit(part, alt_one=True) for part in sil_parts])
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from .num import num2str
|
|||||||
|
|
||||||
# 温度表达式,温度会影响负号的读法
|
# 温度表达式,温度会影响负号的读法
|
||||||
# -3°C 零下三度
|
# -3°C 零下三度
|
||||||
RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
|
RE_TEMPERATURE = re.compile(r"(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)")
|
||||||
measure_dict = {
|
measure_dict = {
|
||||||
"cm2": "平方厘米",
|
"cm2": "平方厘米",
|
||||||
"cm²": "平方厘米",
|
"cm²": "平方厘米",
|
||||||
@ -35,7 +35,7 @@ measure_dict = {
|
|||||||
"ml": "毫升",
|
"ml": "毫升",
|
||||||
"m": "米",
|
"m": "米",
|
||||||
"mm": "毫米",
|
"mm": "毫米",
|
||||||
"s": "秒"
|
"s": "秒",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,9 +56,9 @@ from .quantifier import replace_measure
|
|||||||
from .quantifier import replace_temperature
|
from .quantifier import replace_temperature
|
||||||
|
|
||||||
|
|
||||||
class TextNormalizer():
|
class TextNormalizer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)')
|
self.SENTENCE_SPLITOR = re.compile(r"([:、,;。?!,;?!][”’]?)")
|
||||||
|
|
||||||
def _split(self, text: str, lang="zh") -> List[str]:
|
def _split(self, text: str, lang="zh") -> List[str]:
|
||||||
"""Split long text into sentences with sentence-splitting punctuations.
|
"""Split long text into sentences with sentence-splitting punctuations.
|
||||||
@ -71,66 +71,64 @@ class TextNormalizer():
|
|||||||
if lang == "zh":
|
if lang == "zh":
|
||||||
text = text.replace(" ", "")
|
text = text.replace(" ", "")
|
||||||
# 过滤掉特殊字符
|
# 过滤掉特殊字符
|
||||||
text = re.sub(r'[——《》【】<>{}()()#&@“”^_|\\]', '', text)
|
text = re.sub(r"[——《》【】<>{}()()#&@“”^_|\\]", "", text)
|
||||||
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
|
text = self.SENTENCE_SPLITOR.sub(r"\1\n", text)
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
|
sentences = [sentence.strip() for sentence in re.split(r"\n+", text)]
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
def _post_replace(self, sentence: str) -> str:
|
def _post_replace(self, sentence: str) -> str:
|
||||||
sentence = sentence.replace('/', '每')
|
sentence = sentence.replace("/", "每")
|
||||||
# sentence = sentence.replace('~', '至')
|
# sentence = sentence.replace('~', '至')
|
||||||
# sentence = sentence.replace('~', '至')
|
# sentence = sentence.replace('~', '至')
|
||||||
sentence = sentence.replace('①', '一')
|
sentence = sentence.replace("①", "一")
|
||||||
sentence = sentence.replace('②', '二')
|
sentence = sentence.replace("②", "二")
|
||||||
sentence = sentence.replace('③', '三')
|
sentence = sentence.replace("③", "三")
|
||||||
sentence = sentence.replace('④', '四')
|
sentence = sentence.replace("④", "四")
|
||||||
sentence = sentence.replace('⑤', '五')
|
sentence = sentence.replace("⑤", "五")
|
||||||
sentence = sentence.replace('⑥', '六')
|
sentence = sentence.replace("⑥", "六")
|
||||||
sentence = sentence.replace('⑦', '七')
|
sentence = sentence.replace("⑦", "七")
|
||||||
sentence = sentence.replace('⑧', '八')
|
sentence = sentence.replace("⑧", "八")
|
||||||
sentence = sentence.replace('⑨', '九')
|
sentence = sentence.replace("⑨", "九")
|
||||||
sentence = sentence.replace('⑩', '十')
|
sentence = sentence.replace("⑩", "十")
|
||||||
sentence = sentence.replace('α', '阿尔法')
|
sentence = sentence.replace("α", "阿尔法")
|
||||||
sentence = sentence.replace('β', '贝塔')
|
sentence = sentence.replace("β", "贝塔")
|
||||||
sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛')
|
sentence = sentence.replace("γ", "伽玛").replace("Γ", "伽玛")
|
||||||
sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔')
|
sentence = sentence.replace("δ", "德尔塔").replace("Δ", "德尔塔")
|
||||||
sentence = sentence.replace('ε', '艾普西龙')
|
sentence = sentence.replace("ε", "艾普西龙")
|
||||||
sentence = sentence.replace('ζ', '捷塔')
|
sentence = sentence.replace("ζ", "捷塔")
|
||||||
sentence = sentence.replace('η', '依塔')
|
sentence = sentence.replace("η", "依塔")
|
||||||
sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔')
|
sentence = sentence.replace("θ", "西塔").replace("Θ", "西塔")
|
||||||
sentence = sentence.replace('ι', '艾欧塔')
|
sentence = sentence.replace("ι", "艾欧塔")
|
||||||
sentence = sentence.replace('κ', '喀帕')
|
sentence = sentence.replace("κ", "喀帕")
|
||||||
sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达')
|
sentence = sentence.replace("λ", "拉姆达").replace("Λ", "拉姆达")
|
||||||
sentence = sentence.replace('μ', '缪')
|
sentence = sentence.replace("μ", "缪")
|
||||||
sentence = sentence.replace('ν', '拗')
|
sentence = sentence.replace("ν", "拗")
|
||||||
sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西')
|
sentence = sentence.replace("ξ", "克西").replace("Ξ", "克西")
|
||||||
sentence = sentence.replace('ο', '欧米克伦')
|
sentence = sentence.replace("ο", "欧米克伦")
|
||||||
sentence = sentence.replace('π', '派').replace('Π', '派')
|
sentence = sentence.replace("π", "派").replace("Π", "派")
|
||||||
sentence = sentence.replace('ρ', '肉')
|
sentence = sentence.replace("ρ", "肉")
|
||||||
sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace(
|
sentence = sentence.replace("ς", "西格玛").replace("Σ", "西格玛").replace("σ", "西格玛")
|
||||||
'σ', '西格玛')
|
sentence = sentence.replace("τ", "套")
|
||||||
sentence = sentence.replace('τ', '套')
|
sentence = sentence.replace("υ", "宇普西龙")
|
||||||
sentence = sentence.replace('υ', '宇普西龙')
|
sentence = sentence.replace("φ", "服艾").replace("Φ", "服艾")
|
||||||
sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾')
|
sentence = sentence.replace("χ", "器")
|
||||||
sentence = sentence.replace('χ', '器')
|
sentence = sentence.replace("ψ", "普赛").replace("Ψ", "普赛")
|
||||||
sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛')
|
sentence = sentence.replace("ω", "欧米伽").replace("Ω", "欧米伽")
|
||||||
sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽')
|
|
||||||
# 兜底数学运算,顺便兼容懒人用语
|
# 兜底数学运算,顺便兼容懒人用语
|
||||||
sentence = sentence.replace('+', '加')
|
sentence = sentence.replace("+", "加")
|
||||||
sentence = sentence.replace('-', '减')
|
sentence = sentence.replace("-", "减")
|
||||||
sentence = sentence.replace('×', '乘')
|
sentence = sentence.replace("×", "乘")
|
||||||
sentence = sentence.replace('÷', '除')
|
sentence = sentence.replace("÷", "除")
|
||||||
sentence = sentence.replace('=', '等')
|
sentence = sentence.replace("=", "等")
|
||||||
# re filter special characters, have one more character "-" than line 68
|
# re filter special characters, have one more character "-" than line 68
|
||||||
sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|\\]', '', sentence)
|
sentence = re.sub(r"[-——《》【】<=>{}()()#&@“”^_|\\]", "", sentence)
|
||||||
return sentence
|
return sentence
|
||||||
|
|
||||||
def normalize_sentence(self, sentence: str) -> str:
|
def normalize_sentence(self, sentence: str) -> str:
|
||||||
# basic character conversions
|
# basic character conversions
|
||||||
sentence = tranditional_to_simplified(sentence)
|
sentence = tranditional_to_simplified(sentence)
|
||||||
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
|
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(F2H_DIGITS).translate(F2H_SPACE)
|
||||||
F2H_DIGITS).translate(F2H_SPACE)
|
|
||||||
|
|
||||||
# number related NSW verbalization
|
# number related NSW verbalization
|
||||||
sentence = RE_DATE.sub(replace_date, sentence)
|
sentence = RE_DATE.sub(replace_date, sentence)
|
||||||
@ -161,8 +159,7 @@ class TextNormalizer():
|
|||||||
|
|
||||||
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
||||||
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
||||||
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
|
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, sentence)
|
||||||
sentence)
|
|
||||||
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
|
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
|
||||||
sentence = RE_NUMBER.sub(replace_number, sentence)
|
sentence = RE_NUMBER.sub(replace_number, sentence)
|
||||||
sentence = self._post_replace(sentence)
|
sentence = self._post_replace(sentence)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user