添加说话人特征注入以及相关处理

This commit is contained in:
bwnotfound 2024-02-14 10:13:18 +08:00
parent 7303e108e1
commit 0341f13b92
10 changed files with 241 additions and 43 deletions

3
.gitignore vendored
View File

@ -9,4 +9,5 @@ output
logs
reference
SoVITS_weights
GPT_weights
GPT_weights
TEMP

View File

@ -1,7 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
from pytorch_lightning import LightningDataModule
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
from AR.data.dataset import Text2SemanticDataset, Text2SemanticSpeakerDataset
from torch.utils.data import DataLoader
@ -26,12 +26,20 @@ class Text2SemanticDataModule(LightningDataModule):
pass
def setup(self, stage=None, output_logs=False):
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
if 'train_dir' not in self.config:
self._train_dataset = Text2SemanticDataset(
phoneme_path=self.train_phoneme_path,
semantic_path=self.train_semantic_path,
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
else:
self._train_dataset = Text2SemanticSpeakerDataset(
self.config['speaker_dict_path'],
self.config['train_dir'],
max_sec=self.config["data"]["max_sec"],
pad_val=self.config["data"]["pad_val"],
)
self._dev_dataset = self._train_dataset
# self._dev_dataset = Text2SemanticDataset(
# phoneme_path=self.dev_phoneme_path,

View File

@ -293,6 +293,128 @@ class Text2SemanticDataset(Dataset):
"bert_feature": bert_padded,
}
class Text2SemanticSpeakerDataset(Dataset):
def __init__(
self,
speaker_dict_path: str,
train_dir: str,
max_sample: int = None,
max_sec: int = 100,
pad_val: int = 1024,
# min value of phoneme/sec
min_ps_ratio: int = 3,
# max value of phoneme/sec
max_ps_ratio: int = 25,
):
super().__init__()
self.speaker_dict = dict()
if os.path.exists(speaker_dict_path):
with open(speaker_dict_path, "r", encoding="utf-8") as f:
try:
self.speaker_dict = json.load(f)
except:
pass
self.speaker2dataset = {}
self.PAD = pad_val
for name in os.listdir(train_dir):
if name in ["logs_s1", os.path.basename(speaker_dict_path)]:
continue
opt_dir = os.path.join(train_dir, name)
speaker_id = self.speaker_dict.get(name, None)
if speaker_id is None:
speaker_id = len(self.speaker_dict)
self.speaker_dict[name] = speaker_id
self.speaker2dataset[speaker_id] = Text2SemanticDataset(
phoneme_path=os.path.join(opt_dir, "2-name2text.txt"),
semantic_path=os.path.join(opt_dir, "6-name2semantic.tsv"),
max_sample=max_sample,
max_sec=max_sec,
pad_val=pad_val,
min_ps_ratio=min_ps_ratio,
max_ps_ratio=max_ps_ratio,
)
with open(speaker_dict_path, "w", encoding="utf-8") as f:
json.dump(self.speaker_dict, f, ensure_ascii=False, indent=4)
def __len__(self):
result = 0
for _, dataset in self.speaker2dataset.items():
result += len(dataset)
return result
def __getitem__(self, idx):
for speaker_id, dataset in self.speaker2dataset.items():
if idx < len(dataset):
result = dataset[idx]
result["speaker_id"] = speaker_id
return result
else:
idx -= len(dataset)
raise IndexError("index out of range")
def get_sample_length(self, idx: int):
for speaker_id, dataset in self.speaker2dataset.items():
if idx < len(dataset):
semantic_ids = dataset.semantic_phoneme[idx][0]
sec = 1.0 * len(semantic_ids) / dataset.hz
return sec
else:
idx -= len(dataset)
raise IndexError("index out of range")
def collate(self, examples: List[Dict]) -> Dict:
sample_index: List[int] = []
phoneme_ids: List[torch.Tensor] = []
phoneme_ids_lens: List[int] = []
semantic_ids: List[torch.Tensor] = []
semantic_ids_lens: List[int] = []
speaker_ids: List[int] = []
# return
for item in examples:
sample_index.append(item["idx"])
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
phoneme_ids_lens.append(item["phoneme_ids_len"])
semantic_ids_lens.append(item["semantic_ids_len"])
speaker_ids.append(item["speaker_id"])
# pad 0
phoneme_ids = batch_sequences(phoneme_ids)
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
# # convert each batch to torch.tensor
phoneme_ids = torch.tensor(phoneme_ids)
semantic_ids = torch.tensor(semantic_ids)
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
semantic_ids_lens = torch.tensor(semantic_ids_lens)
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
bert_padded.zero_()
speaker_ids = torch.tensor(speaker_ids, dtype=torch.long)
for idx, item in enumerate(examples):
bert = item["bert_feature"]
if bert != None:
bert_padded[idx, :, : bert.shape[-1]] = bert
return {
# List[int]
"ids": sample_index,
# torch.Tensor (B, max_phoneme_length)
"phoneme_ids": phoneme_ids,
# torch.Tensor (B)
"phoneme_ids_len": phoneme_ids_lens,
# torch.Tensor (B, max_semantic_ids_length)
"semantic_ids": semantic_ids,
# torch.Tensor (B)
"semantic_ids_len": semantic_ids_lens,
# torch.Tensor (B, 1024, max_phoneme_length)
"bert_feature": bert_padded,
"speaker_ids": speaker_ids,
}
if __name__ == "__main__":
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"

View File

@ -34,10 +34,11 @@ class Text2SemanticLightningModule(LightningModule):
self.eval_dir.mkdir(parents=True, exist_ok=True)
for param in self.model.parameters():
param.requires_grad = False
self.model.speaker_proj.weight.requires_grad = True
self.model.speaker_proj.bias.requires_grad = True
self.model.speaker_proj.train()
self.model.speaker_feat.requires_grad = True
self.model.speaker_proj.requires_grad_(True)
self.model.speaker_proj.weight.requires_grad_(True)
self.model.speaker_proj.bias.requires_grad_(True)
self.model.speaker_emb.requires_grad_(True)
self.model.speaker_emb.weight.requires_grad_(True)
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
@ -48,13 +49,13 @@ class Text2SemanticLightningModule(LightningModule):
batch["semantic_ids"],
batch["semantic_ids_len"],
batch["bert_feature"],
speaker_ids=batch["speaker_ids"] if "speaker_ids" in batch else None,
)
self.manual_backward(loss)
if batch_idx > 0 and batch_idx % 4 == 0:
opt.step()
opt.zero_grad()
scheduler.step()
torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt")
self.log(
"total_loss",

View File

@ -52,12 +52,8 @@ class Text2SemanticDecoder(nn.Module):
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.speaker_proj = nn.Linear(1024, self.embedding_dim)
self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt"
if not os.path.exists(self.path_speaker):
self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1)
else:
self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu"))
self.speaker_proj = nn.Linear(4096, self.embedding_dim)
self.speaker_emb = nn.Embedding(100, 4096)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
@ -95,7 +91,7 @@ class Text2SemanticDecoder(nn.Module):
ignore_index=self.EOS,
)
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
def make_input_data(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x)
@ -111,7 +107,8 @@ class Text2SemanticDecoder(nn.Module):
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1)
if speaker_ids is not None:
y_emb = y_emb + self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
@ -149,7 +146,7 @@ class Text2SemanticDecoder(nn.Module):
return xy_pos, xy_attn_mask, targets
def forward(self, x, x_lens, y, y_lens, bert_feature):
def forward(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None):
"""
x: phoneme_ids
y: semantic_ids
@ -157,7 +154,7 @@ class Text2SemanticDecoder(nn.Module):
reject_y, reject_y_lens = make_reject_y(y, y_lens)
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature, speaker_ids=speaker_ids)
xy_dec, _ = self.h(
(xy_pos, None),
@ -167,7 +164,7 @@ class Text2SemanticDecoder(nn.Module):
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature, speaker_ids=speaker_ids)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
@ -338,7 +335,7 @@ class Text2SemanticDecoder(nn.Module):
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
use_speaker_feat=True,
speaker_ids=None,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -362,8 +359,8 @@ class Text2SemanticDecoder(nn.Module):
"first_infer": 1,
"stage": 0,
}
if use_speaker_feat:
speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1)
if speaker_ids is not None:
speaker_feat = self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1)
else:
speaker_feat = 0
for idx in tqdm(range(1500)):

View File

@ -3,7 +3,7 @@ train:
epochs: 20
batch_size: 8
save_every_n_epoch: 1
precision: 16-mixed
precision: "32"
gradient_clip: 1.0
optimizer:
lr: 0.01

View File

@ -365,7 +365,10 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6,speaker_id="-1"):
speaker_id = int(speaker_id)
if speaker_id == -1:
speaker_id = None
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
@ -381,7 +384,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 0):
raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
@ -448,6 +451,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
top_p=top_p,
temperature=temperature,
early_stop_num=hz * max_sec,
speaker_ids=torch.LongTensor([speaker_id]).to(device) if speaker_id is not None else None,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
@ -601,6 +605,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
speaker_id = gr.Dropdown(label="说话人idx", choices=[str(i) for i in range(-1, 100)], value="-1")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
@ -632,7 +637,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature],
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature,speaker_id],
[output],
)

View File

@ -98,7 +98,7 @@ for line in lines[int(i_part)::int(all_parts)]:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
if (inp_wav_dir !=None):
if (inp_wav_dir !=None and len(inp_wav_dir) > 0):
wav_name = os.path.basename(wav_name)
wav_path = "%s/%s"%(inp_wav_dir, wav_name)

View File

@ -24,24 +24,36 @@ def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_
)
_max=float(_max)
alpha=float(alpha)
for inp_path in input[int(i_part)::int(all_part)]:
# print(inp_path)
def slice_audio(root_dir, inp_wav_path):
try:
name = os.path.basename(inp_path)
audio = load_audio(inp_path, 32000)
name = os.path.basename(inp_wav_path)
audio = load_audio(inp_wav_path, 32000)
# print(audio.shape)
for chunk, start, end in slicer.slice(audio): # start和end是帧数
tmp_max = np.abs(chunk).max()
if(tmp_max>1):chunk/=tmp_max
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
wavfile.write(
"%s/%s_%010d_%010d.wav" % (opt_root, name, start, end),
"%s/%s_%010d_%010d.wav" % (root_dir, name, start, end),
32000,
# chunk.astype(np.float32),
(chunk * 32767).astype(np.int16),
)
except:
print(inp_path,"->fail->",traceback.format_exc())
print(inp_wav_path,"->fail->",traceback.format_exc())
for inp_path in input[int(i_part)::int(all_part)]:
# print(inp_path)
if os.path.isdir(inp_path):
inp_dir_name = os.path.basename(inp_path)
os.makedirs(os.path.join(opt_root, inp_dir_name), exist_ok=True)
for inp_name in os.listdir(inp_path):
inp_wav_path = os.path.join(inp_path, inp_name)
slice_audio(os.path.join(opt_root, inp_dir_name), inp_wav_path)
else:
inp_wav_path = inp_path
slice_audio(opt_root, inp_wav_path)
return "执行完毕,请检查输出文件"
print(slice(*sys.argv[1:]))

View File

@ -195,20 +195,42 @@ def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path
from tools.asr.config import asr_dict
def open_asr(asr_inp_dir, asr_opt_dir, asr_model, asr_model_size, asr_lang):
global p_asr
if(p_asr==None):
asr_inp_dir=my_utils.clean_path(asr_inp_dir)
def run(asr_inp_dir, asr_opt_dir):
cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
cmd += f' -i "{asr_inp_dir}"'
cmd += f' -o "{asr_opt_dir}"'
cmd += f' -s {asr_model_size}'
cmd += f' -l {asr_lang}'
cmd += " -p %s"%("float16"if is_half==True else "float32")
yield "ASR任务开启%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
print(cmd)
p_asr = Popen(cmd, shell=True)
p_asr.wait()
p_asr=None
if(p_asr==None):
# yield "ASR任务开启%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
flag = False
for name in os.listdir(asr_inp_dir):
if os.path.isdir(os.path.join(asr_inp_dir, name)):
os.makedirs(os.path.join(asr_opt_dir, name), exist_ok=True)
run(os.path.join(asr_inp_dir, name), os.path.join(asr_opt_dir, name))
else:
break
else:
flag = True
if flag:
run(asr_inp_dir, asr_opt_dir)
# asr_inp_dir=my_utils.clean_path(asr_inp_dir)
# cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
# cmd += f' -i "{asr_inp_dir}"'
# cmd += f' -o "{asr_opt_dir}"'
# cmd += f' -s {asr_model_size}'
# cmd += f' -l {asr_lang}'
# cmd += " -p %s"%("float16"if is_half==True else "float32")
# print(cmd)
# p_asr = Popen(cmd, shell=True)
# p_asr.wait()
# p_asr=None
yield f"ASR任务完成, 查看终端进行下一步",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
else:
yield "已有正在进行的ASR任务需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
@ -273,7 +295,7 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights
data=f.read()
data=yaml.load(data, Loader=yaml.FullLoader)
s1_dir="%s/%s"%(exp_root,exp_name)
os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
# os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
if(is_half==False):
data["train"]["precision"]="32"
batch_size = max(1, batch_size // 2)
@ -287,6 +309,8 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights
data["train"]["exp_name"]=exp_name
data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
data["train_dir"] = s1_dir
data["speaker_dict_path"] = "%s/speaker_dict.json"%s1_dir
data["output_dir"]="%s/logs_s1"%s1_dir
os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",")
@ -352,6 +376,13 @@ def close_slice():
ps1a=[]
def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
global ps1a
if not inp_text.endswith(".list"):
for name in os.listdir(inp_text):
p = os.path.join(inp_text, name)
l = os.path.join(p, name + ".list")
assert os.path.exists(l)
yield from open1a(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,bert_pretrained_dir)
return
inp_text = my_utils.clean_path(inp_text)
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if (ps1a == []):
@ -410,6 +441,13 @@ def close1a():
ps1b=[]
def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
global ps1b
if not inp_text.endswith(".list"):
for name in os.listdir(inp_text):
p = os.path.join(inp_text, name)
l = os.path.join(p, name + ".list")
assert os.path.exists(l)
yield from open1b(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,ssl_pretrained_dir)
return
inp_text = my_utils.clean_path(inp_text)
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if (ps1b == []):
@ -458,6 +496,13 @@ def close1b():
ps1c=[]
def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
global ps1c
if not inp_text.endswith(".list"):
for name in os.listdir(inp_text):
p = os.path.join(inp_text, name)
l = os.path.join(p, name + ".list")
assert os.path.exists(l)
yield from open1c(l,os.path.join(exp_name, name),gpu_numbers,pretrained_s2G_path)
return
inp_text = my_utils.clean_path(inp_text)
if (ps1c == []):
opt_dir="%s/%s"%(exp_root,exp_name)
@ -515,6 +560,13 @@ def close1c():
ps1abc=[]
def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
global ps1abc
if not inp_text.endswith(".list"):
for name in os.listdir(inp_text):
p = os.path.join(inp_text, name)
l = os.path.join(p, name + ".list")
assert os.path.exists(l)
yield from open1abc(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path)
return
inp_text = my_utils.clean_path(inp_text)
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if (ps1abc == []):