mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 22:50:00 +08:00
添加说话人特征注入以及相关处理
This commit is contained in:
parent
7303e108e1
commit
0341f13b92
3
.gitignore
vendored
3
.gitignore
vendored
@ -9,4 +9,5 @@ output
|
||||
logs
|
||||
reference
|
||||
SoVITS_weights
|
||||
GPT_weights
|
||||
GPT_weights
|
||||
TEMP
|
@ -1,7 +1,7 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from AR.data.dataset import Text2SemanticDataset
|
||||
from AR.data.dataset import Text2SemanticDataset, Text2SemanticSpeakerDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
@ -26,12 +26,20 @@ class Text2SemanticDataModule(LightningDataModule):
|
||||
pass
|
||||
|
||||
def setup(self, stage=None, output_logs=False):
|
||||
self._train_dataset = Text2SemanticDataset(
|
||||
phoneme_path=self.train_phoneme_path,
|
||||
semantic_path=self.train_semantic_path,
|
||||
max_sec=self.config["data"]["max_sec"],
|
||||
pad_val=self.config["data"]["pad_val"],
|
||||
)
|
||||
if 'train_dir' not in self.config:
|
||||
self._train_dataset = Text2SemanticDataset(
|
||||
phoneme_path=self.train_phoneme_path,
|
||||
semantic_path=self.train_semantic_path,
|
||||
max_sec=self.config["data"]["max_sec"],
|
||||
pad_val=self.config["data"]["pad_val"],
|
||||
)
|
||||
else:
|
||||
self._train_dataset = Text2SemanticSpeakerDataset(
|
||||
self.config['speaker_dict_path'],
|
||||
self.config['train_dir'],
|
||||
max_sec=self.config["data"]["max_sec"],
|
||||
pad_val=self.config["data"]["pad_val"],
|
||||
)
|
||||
self._dev_dataset = self._train_dataset
|
||||
# self._dev_dataset = Text2SemanticDataset(
|
||||
# phoneme_path=self.dev_phoneme_path,
|
||||
|
@ -293,6 +293,128 @@ class Text2SemanticDataset(Dataset):
|
||||
"bert_feature": bert_padded,
|
||||
}
|
||||
|
||||
class Text2SemanticSpeakerDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
speaker_dict_path: str,
|
||||
train_dir: str,
|
||||
max_sample: int = None,
|
||||
max_sec: int = 100,
|
||||
pad_val: int = 1024,
|
||||
# min value of phoneme/sec
|
||||
min_ps_ratio: int = 3,
|
||||
# max value of phoneme/sec
|
||||
max_ps_ratio: int = 25,
|
||||
):
|
||||
super().__init__()
|
||||
self.speaker_dict = dict()
|
||||
if os.path.exists(speaker_dict_path):
|
||||
with open(speaker_dict_path, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
self.speaker_dict = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.speaker2dataset = {}
|
||||
self.PAD = pad_val
|
||||
for name in os.listdir(train_dir):
|
||||
if name in ["logs_s1", os.path.basename(speaker_dict_path)]:
|
||||
continue
|
||||
opt_dir = os.path.join(train_dir, name)
|
||||
speaker_id = self.speaker_dict.get(name, None)
|
||||
if speaker_id is None:
|
||||
speaker_id = len(self.speaker_dict)
|
||||
self.speaker_dict[name] = speaker_id
|
||||
self.speaker2dataset[speaker_id] = Text2SemanticDataset(
|
||||
phoneme_path=os.path.join(opt_dir, "2-name2text.txt"),
|
||||
semantic_path=os.path.join(opt_dir, "6-name2semantic.tsv"),
|
||||
max_sample=max_sample,
|
||||
max_sec=max_sec,
|
||||
pad_val=pad_val,
|
||||
min_ps_ratio=min_ps_ratio,
|
||||
max_ps_ratio=max_ps_ratio,
|
||||
)
|
||||
with open(speaker_dict_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.speaker_dict, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def __len__(self):
|
||||
result = 0
|
||||
for _, dataset in self.speaker2dataset.items():
|
||||
result += len(dataset)
|
||||
return result
|
||||
|
||||
def __getitem__(self, idx):
|
||||
for speaker_id, dataset in self.speaker2dataset.items():
|
||||
if idx < len(dataset):
|
||||
result = dataset[idx]
|
||||
result["speaker_id"] = speaker_id
|
||||
return result
|
||||
else:
|
||||
idx -= len(dataset)
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def get_sample_length(self, idx: int):
|
||||
for speaker_id, dataset in self.speaker2dataset.items():
|
||||
if idx < len(dataset):
|
||||
semantic_ids = dataset.semantic_phoneme[idx][0]
|
||||
sec = 1.0 * len(semantic_ids) / dataset.hz
|
||||
return sec
|
||||
else:
|
||||
idx -= len(dataset)
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def collate(self, examples: List[Dict]) -> Dict:
|
||||
sample_index: List[int] = []
|
||||
phoneme_ids: List[torch.Tensor] = []
|
||||
phoneme_ids_lens: List[int] = []
|
||||
semantic_ids: List[torch.Tensor] = []
|
||||
semantic_ids_lens: List[int] = []
|
||||
speaker_ids: List[int] = []
|
||||
# return
|
||||
|
||||
for item in examples:
|
||||
sample_index.append(item["idx"])
|
||||
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
|
||||
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
|
||||
phoneme_ids_lens.append(item["phoneme_ids_len"])
|
||||
semantic_ids_lens.append(item["semantic_ids_len"])
|
||||
speaker_ids.append(item["speaker_id"])
|
||||
|
||||
# pad 0
|
||||
phoneme_ids = batch_sequences(phoneme_ids)
|
||||
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
|
||||
|
||||
# # convert each batch to torch.tensor
|
||||
phoneme_ids = torch.tensor(phoneme_ids)
|
||||
semantic_ids = torch.tensor(semantic_ids)
|
||||
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
|
||||
semantic_ids_lens = torch.tensor(semantic_ids_lens)
|
||||
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
|
||||
bert_padded.zero_()
|
||||
speaker_ids = torch.tensor(speaker_ids, dtype=torch.long)
|
||||
|
||||
for idx, item in enumerate(examples):
|
||||
bert = item["bert_feature"]
|
||||
if bert != None:
|
||||
bert_padded[idx, :, : bert.shape[-1]] = bert
|
||||
|
||||
return {
|
||||
# List[int]
|
||||
"ids": sample_index,
|
||||
# torch.Tensor (B, max_phoneme_length)
|
||||
"phoneme_ids": phoneme_ids,
|
||||
# torch.Tensor (B)
|
||||
"phoneme_ids_len": phoneme_ids_lens,
|
||||
# torch.Tensor (B, max_semantic_ids_length)
|
||||
"semantic_ids": semantic_ids,
|
||||
# torch.Tensor (B)
|
||||
"semantic_ids_len": semantic_ids_lens,
|
||||
# torch.Tensor (B, 1024, max_phoneme_length)
|
||||
"bert_feature": bert_padded,
|
||||
"speaker_ids": speaker_ids,
|
||||
}
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
|
||||
|
@ -34,10 +34,11 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
self.model.speaker_proj.weight.requires_grad = True
|
||||
self.model.speaker_proj.bias.requires_grad = True
|
||||
self.model.speaker_proj.train()
|
||||
self.model.speaker_feat.requires_grad = True
|
||||
self.model.speaker_proj.requires_grad_(True)
|
||||
self.model.speaker_proj.weight.requires_grad_(True)
|
||||
self.model.speaker_proj.bias.requires_grad_(True)
|
||||
self.model.speaker_emb.requires_grad_(True)
|
||||
self.model.speaker_emb.weight.requires_grad_(True)
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
opt = self.optimizers()
|
||||
@ -48,13 +49,13 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
batch["semantic_ids"],
|
||||
batch["semantic_ids_len"],
|
||||
batch["bert_feature"],
|
||||
speaker_ids=batch["speaker_ids"] if "speaker_ids" in batch else None,
|
||||
)
|
||||
self.manual_backward(loss)
|
||||
if batch_idx > 0 and batch_idx % 4 == 0:
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
scheduler.step()
|
||||
torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt")
|
||||
|
||||
self.log(
|
||||
"total_loss",
|
||||
|
@ -52,12 +52,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# should be same as num of kmeans bin
|
||||
# assert self.EOS == 1024
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.speaker_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt"
|
||||
if not os.path.exists(self.path_speaker):
|
||||
self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1)
|
||||
else:
|
||||
self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu"))
|
||||
self.speaker_proj = nn.Linear(4096, self.embedding_dim)
|
||||
self.speaker_emb = nn.Embedding(100, 4096)
|
||||
self.ar_text_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
)
|
||||
@ -95,7 +91,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
ignore_index=self.EOS,
|
||||
)
|
||||
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_position(x)
|
||||
@ -111,7 +107,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_len = x_lens.max()
|
||||
y_len = y_lens.max()
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
||||
if speaker_ids is not None:
|
||||
y_emb = y_emb + self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||
@ -149,7 +146,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
return xy_pos, xy_attn_mask, targets
|
||||
|
||||
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
||||
def forward(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None):
|
||||
"""
|
||||
x: phoneme_ids
|
||||
y: semantic_ids
|
||||
@ -157,7 +154,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
reject_y, reject_y_lens = make_reject_y(y, y_lens)
|
||||
|
||||
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
|
||||
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature, speaker_ids=speaker_ids)
|
||||
|
||||
xy_dec, _ = self.h(
|
||||
(xy_pos, None),
|
||||
@ -167,7 +164,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||
|
||||
###### DPO #############
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature, speaker_ids=speaker_ids)
|
||||
|
||||
reject_xy_dec, _ = self.h(
|
||||
(reject_xy_pos, None),
|
||||
@ -338,7 +335,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
use_speaker_feat=True,
|
||||
speaker_ids=None,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@ -362,8 +359,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
"first_infer": 1,
|
||||
"stage": 0,
|
||||
}
|
||||
if use_speaker_feat:
|
||||
speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
||||
if speaker_ids is not None:
|
||||
speaker_feat = self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1)
|
||||
else:
|
||||
speaker_feat = 0
|
||||
for idx in tqdm(range(1500)):
|
||||
|
@ -3,7 +3,7 @@ train:
|
||||
epochs: 20
|
||||
batch_size: 8
|
||||
save_every_n_epoch: 1
|
||||
precision: 16-mixed
|
||||
precision: "32"
|
||||
gradient_clip: 1.0
|
||||
optimizer:
|
||||
lr: 0.01
|
||||
|
@ -365,7 +365,10 @@ def merge_short_text_in_array(texts, threshold):
|
||||
result[len(result) - 1] += text
|
||||
return result
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6,speaker_id="-1"):
|
||||
speaker_id = int(speaker_id)
|
||||
if speaker_id == -1:
|
||||
speaker_id = None
|
||||
t0 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
@ -381,7 +384,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
|
||||
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 0):
|
||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
@ -448,6 +451,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=hz * max_sec,
|
||||
speaker_ids=torch.LongTensor([speaker_id]).to(device) if speaker_id is not None else None,
|
||||
)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
@ -601,6 +605,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
speaker_id = gr.Dropdown(label="说话人idx", choices=[str(i) for i in range(-1, 100)], value="-1")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
|
||||
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
|
||||
@ -632,7 +637,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
|
||||
inference_button.click(
|
||||
get_tts_wav,
|
||||
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature],
|
||||
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature,speaker_id],
|
||||
[output],
|
||||
)
|
||||
|
||||
|
@ -98,7 +98,7 @@ for line in lines[int(i_part)::int(all_parts)]:
|
||||
try:
|
||||
# wav_name,text=line.split("\t")
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
if (inp_wav_dir !=None):
|
||||
if (inp_wav_dir !=None and len(inp_wav_dir) > 0):
|
||||
wav_name = os.path.basename(wav_name)
|
||||
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
|
||||
|
||||
|
@ -24,24 +24,36 @@ def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_
|
||||
)
|
||||
_max=float(_max)
|
||||
alpha=float(alpha)
|
||||
for inp_path in input[int(i_part)::int(all_part)]:
|
||||
# print(inp_path)
|
||||
|
||||
def slice_audio(root_dir, inp_wav_path):
|
||||
try:
|
||||
name = os.path.basename(inp_path)
|
||||
audio = load_audio(inp_path, 32000)
|
||||
name = os.path.basename(inp_wav_path)
|
||||
audio = load_audio(inp_wav_path, 32000)
|
||||
# print(audio.shape)
|
||||
for chunk, start, end in slicer.slice(audio): # start和end是帧数
|
||||
tmp_max = np.abs(chunk).max()
|
||||
if(tmp_max>1):chunk/=tmp_max
|
||||
chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
|
||||
wavfile.write(
|
||||
"%s/%s_%010d_%010d.wav" % (opt_root, name, start, end),
|
||||
"%s/%s_%010d_%010d.wav" % (root_dir, name, start, end),
|
||||
32000,
|
||||
# chunk.astype(np.float32),
|
||||
(chunk * 32767).astype(np.int16),
|
||||
)
|
||||
except:
|
||||
print(inp_path,"->fail->",traceback.format_exc())
|
||||
print(inp_wav_path,"->fail->",traceback.format_exc())
|
||||
|
||||
for inp_path in input[int(i_part)::int(all_part)]:
|
||||
# print(inp_path)
|
||||
if os.path.isdir(inp_path):
|
||||
inp_dir_name = os.path.basename(inp_path)
|
||||
os.makedirs(os.path.join(opt_root, inp_dir_name), exist_ok=True)
|
||||
for inp_name in os.listdir(inp_path):
|
||||
inp_wav_path = os.path.join(inp_path, inp_name)
|
||||
slice_audio(os.path.join(opt_root, inp_dir_name), inp_wav_path)
|
||||
else:
|
||||
inp_wav_path = inp_path
|
||||
slice_audio(opt_root, inp_wav_path)
|
||||
return "执行完毕,请检查输出文件"
|
||||
|
||||
print(slice(*sys.argv[1:]))
|
||||
|
62
webui.py
62
webui.py
@ -195,20 +195,42 @@ def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path
|
||||
from tools.asr.config import asr_dict
|
||||
def open_asr(asr_inp_dir, asr_opt_dir, asr_model, asr_model_size, asr_lang):
|
||||
global p_asr
|
||||
if(p_asr==None):
|
||||
asr_inp_dir=my_utils.clean_path(asr_inp_dir)
|
||||
def run(asr_inp_dir, asr_opt_dir):
|
||||
cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
|
||||
cmd += f' -i "{asr_inp_dir}"'
|
||||
cmd += f' -o "{asr_opt_dir}"'
|
||||
cmd += f' -s {asr_model_size}'
|
||||
cmd += f' -l {asr_lang}'
|
||||
cmd += " -p %s"%("float16"if is_half==True else "float32")
|
||||
|
||||
yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
||||
print(cmd)
|
||||
p_asr = Popen(cmd, shell=True)
|
||||
p_asr.wait()
|
||||
p_asr=None
|
||||
if(p_asr==None):
|
||||
# yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
||||
flag = False
|
||||
for name in os.listdir(asr_inp_dir):
|
||||
if os.path.isdir(os.path.join(asr_inp_dir, name)):
|
||||
os.makedirs(os.path.join(asr_opt_dir, name), exist_ok=True)
|
||||
run(os.path.join(asr_inp_dir, name), os.path.join(asr_opt_dir, name))
|
||||
else:
|
||||
break
|
||||
else:
|
||||
flag = True
|
||||
if flag:
|
||||
run(asr_inp_dir, asr_opt_dir)
|
||||
# asr_inp_dir=my_utils.clean_path(asr_inp_dir)
|
||||
# cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}'
|
||||
# cmd += f' -i "{asr_inp_dir}"'
|
||||
# cmd += f' -o "{asr_opt_dir}"'
|
||||
# cmd += f' -s {asr_model_size}'
|
||||
# cmd += f' -l {asr_lang}'
|
||||
# cmd += " -p %s"%("float16"if is_half==True else "float32")
|
||||
|
||||
# print(cmd)
|
||||
# p_asr = Popen(cmd, shell=True)
|
||||
# p_asr.wait()
|
||||
# p_asr=None
|
||||
yield f"ASR任务完成, 查看终端进行下一步",{"__type__":"update","visible":True},{"__type__":"update","visible":False}
|
||||
else:
|
||||
yield "已有正在进行的ASR任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True}
|
||||
@ -273,7 +295,7 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights
|
||||
data=f.read()
|
||||
data=yaml.load(data, Loader=yaml.FullLoader)
|
||||
s1_dir="%s/%s"%(exp_root,exp_name)
|
||||
os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
|
||||
# os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True)
|
||||
if(is_half==False):
|
||||
data["train"]["precision"]="32"
|
||||
batch_size = max(1, batch_size // 2)
|
||||
@ -287,6 +309,8 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights
|
||||
data["train"]["exp_name"]=exp_name
|
||||
data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir
|
||||
data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir
|
||||
data["train_dir"] = s1_dir
|
||||
data["speaker_dict_path"] = "%s/speaker_dict.json"%s1_dir
|
||||
data["output_dir"]="%s/logs_s1"%s1_dir
|
||||
|
||||
os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",")
|
||||
@ -352,6 +376,13 @@ def close_slice():
|
||||
ps1a=[]
|
||||
def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir):
|
||||
global ps1a
|
||||
if not inp_text.endswith(".list"):
|
||||
for name in os.listdir(inp_text):
|
||||
p = os.path.join(inp_text, name)
|
||||
l = os.path.join(p, name + ".list")
|
||||
assert os.path.exists(l)
|
||||
yield from open1a(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,bert_pretrained_dir)
|
||||
return
|
||||
inp_text = my_utils.clean_path(inp_text)
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if (ps1a == []):
|
||||
@ -410,6 +441,13 @@ def close1a():
|
||||
ps1b=[]
|
||||
def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir):
|
||||
global ps1b
|
||||
if not inp_text.endswith(".list"):
|
||||
for name in os.listdir(inp_text):
|
||||
p = os.path.join(inp_text, name)
|
||||
l = os.path.join(p, name + ".list")
|
||||
assert os.path.exists(l)
|
||||
yield from open1b(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,ssl_pretrained_dir)
|
||||
return
|
||||
inp_text = my_utils.clean_path(inp_text)
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if (ps1b == []):
|
||||
@ -458,6 +496,13 @@ def close1b():
|
||||
ps1c=[]
|
||||
def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path):
|
||||
global ps1c
|
||||
if not inp_text.endswith(".list"):
|
||||
for name in os.listdir(inp_text):
|
||||
p = os.path.join(inp_text, name)
|
||||
l = os.path.join(p, name + ".list")
|
||||
assert os.path.exists(l)
|
||||
yield from open1c(l,os.path.join(exp_name, name),gpu_numbers,pretrained_s2G_path)
|
||||
return
|
||||
inp_text = my_utils.clean_path(inp_text)
|
||||
if (ps1c == []):
|
||||
opt_dir="%s/%s"%(exp_root,exp_name)
|
||||
@ -515,6 +560,13 @@ def close1c():
|
||||
ps1abc=[]
|
||||
def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path):
|
||||
global ps1abc
|
||||
if not inp_text.endswith(".list"):
|
||||
for name in os.listdir(inp_text):
|
||||
p = os.path.join(inp_text, name)
|
||||
l = os.path.join(p, name + ".list")
|
||||
assert os.path.exists(l)
|
||||
yield from open1abc(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path)
|
||||
return
|
||||
inp_text = my_utils.clean_path(inp_text)
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if (ps1abc == []):
|
||||
|
Loading…
x
Reference in New Issue
Block a user