diff --git a/.gitignore b/.gitignore index 6afb1e50..2d5a9ccc 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ output logs reference SoVITS_weights -GPT_weights \ No newline at end of file +GPT_weights +TEMP \ No newline at end of file diff --git a/GPT_SoVITS/AR/data/data_module.py b/GPT_SoVITS/AR/data/data_module.py index 037484a9..6a217852 100644 --- a/GPT_SoVITS/AR/data/data_module.py +++ b/GPT_SoVITS/AR/data/data_module.py @@ -1,7 +1,7 @@ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py from pytorch_lightning import LightningDataModule from AR.data.bucket_sampler import DistributedBucketSampler -from AR.data.dataset import Text2SemanticDataset +from AR.data.dataset import Text2SemanticDataset, Text2SemanticSpeakerDataset from torch.utils.data import DataLoader @@ -26,12 +26,20 @@ class Text2SemanticDataModule(LightningDataModule): pass def setup(self, stage=None, output_logs=False): - self._train_dataset = Text2SemanticDataset( - phoneme_path=self.train_phoneme_path, - semantic_path=self.train_semantic_path, - max_sec=self.config["data"]["max_sec"], - pad_val=self.config["data"]["pad_val"], - ) + if 'train_dir' not in self.config: + self._train_dataset = Text2SemanticDataset( + phoneme_path=self.train_phoneme_path, + semantic_path=self.train_semantic_path, + max_sec=self.config["data"]["max_sec"], + pad_val=self.config["data"]["pad_val"], + ) + else: + self._train_dataset = Text2SemanticSpeakerDataset( + self.config['speaker_dict_path'], + self.config['train_dir'], + max_sec=self.config["data"]["max_sec"], + pad_val=self.config["data"]["pad_val"], + ) self._dev_dataset = self._train_dataset # self._dev_dataset = Text2SemanticDataset( # phoneme_path=self.dev_phoneme_path, diff --git a/GPT_SoVITS/AR/data/dataset.py b/GPT_SoVITS/AR/data/dataset.py index b1ea69e6..5836bb0d 100644 --- a/GPT_SoVITS/AR/data/dataset.py +++ b/GPT_SoVITS/AR/data/dataset.py @@ -293,6 +293,128 @@ class Text2SemanticDataset(Dataset): "bert_feature": bert_padded, } +class Text2SemanticSpeakerDataset(Dataset): + def __init__( + self, + speaker_dict_path: str, + train_dir: str, + max_sample: int = None, + max_sec: int = 100, + pad_val: int = 1024, + # min value of phoneme/sec + min_ps_ratio: int = 3, + # max value of phoneme/sec + max_ps_ratio: int = 25, + ): + super().__init__() + self.speaker_dict = dict() + if os.path.exists(speaker_dict_path): + with open(speaker_dict_path, "r", encoding="utf-8") as f: + try: + self.speaker_dict = json.load(f) + except: + pass + + self.speaker2dataset = {} + self.PAD = pad_val + for name in os.listdir(train_dir): + if name in ["logs_s1", os.path.basename(speaker_dict_path)]: + continue + opt_dir = os.path.join(train_dir, name) + speaker_id = self.speaker_dict.get(name, None) + if speaker_id is None: + speaker_id = len(self.speaker_dict) + self.speaker_dict[name] = speaker_id + self.speaker2dataset[speaker_id] = Text2SemanticDataset( + phoneme_path=os.path.join(opt_dir, "2-name2text.txt"), + semantic_path=os.path.join(opt_dir, "6-name2semantic.tsv"), + max_sample=max_sample, + max_sec=max_sec, + pad_val=pad_val, + min_ps_ratio=min_ps_ratio, + max_ps_ratio=max_ps_ratio, + ) + with open(speaker_dict_path, "w", encoding="utf-8") as f: + json.dump(self.speaker_dict, f, ensure_ascii=False, indent=4) + + def __len__(self): + result = 0 + for _, dataset in self.speaker2dataset.items(): + result += len(dataset) + return result + + def __getitem__(self, idx): + for speaker_id, dataset in self.speaker2dataset.items(): + if idx < len(dataset): + result = dataset[idx] + result["speaker_id"] = speaker_id + return result + else: + idx -= len(dataset) + raise IndexError("index out of range") + + def get_sample_length(self, idx: int): + for speaker_id, dataset in self.speaker2dataset.items(): + if idx < len(dataset): + semantic_ids = dataset.semantic_phoneme[idx][0] + sec = 1.0 * len(semantic_ids) / dataset.hz + return sec + else: + idx -= len(dataset) + raise IndexError("index out of range") + + def collate(self, examples: List[Dict]) -> Dict: + sample_index: List[int] = [] + phoneme_ids: List[torch.Tensor] = [] + phoneme_ids_lens: List[int] = [] + semantic_ids: List[torch.Tensor] = [] + semantic_ids_lens: List[int] = [] + speaker_ids: List[int] = [] + # return + + for item in examples: + sample_index.append(item["idx"]) + phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) + semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) + phoneme_ids_lens.append(item["phoneme_ids_len"]) + semantic_ids_lens.append(item["semantic_ids_len"]) + speaker_ids.append(item["speaker_id"]) + + # pad 0 + phoneme_ids = batch_sequences(phoneme_ids) + semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) + + # # convert each batch to torch.tensor + phoneme_ids = torch.tensor(phoneme_ids) + semantic_ids = torch.tensor(semantic_ids) + phoneme_ids_lens = torch.tensor(phoneme_ids_lens) + semantic_ids_lens = torch.tensor(semantic_ids_lens) + bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens)) + bert_padded.zero_() + speaker_ids = torch.tensor(speaker_ids, dtype=torch.long) + + for idx, item in enumerate(examples): + bert = item["bert_feature"] + if bert != None: + bert_padded[idx, :, : bert.shape[-1]] = bert + + return { + # List[int] + "ids": sample_index, + # torch.Tensor (B, max_phoneme_length) + "phoneme_ids": phoneme_ids, + # torch.Tensor (B) + "phoneme_ids_len": phoneme_ids_lens, + # torch.Tensor (B, max_semantic_ids_length) + "semantic_ids": semantic_ids, + # torch.Tensor (B) + "semantic_ids_len": semantic_ids_lens, + # torch.Tensor (B, 1024, max_phoneme_length) + "bert_feature": bert_padded, + "speaker_ids": speaker_ids, + } + + if __name__ == "__main__": root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/" diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index 39427695..8f8437b4 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -34,10 +34,11 @@ class Text2SemanticLightningModule(LightningModule): self.eval_dir.mkdir(parents=True, exist_ok=True) for param in self.model.parameters(): param.requires_grad = False - self.model.speaker_proj.weight.requires_grad = True - self.model.speaker_proj.bias.requires_grad = True - self.model.speaker_proj.train() - self.model.speaker_feat.requires_grad = True + self.model.speaker_proj.requires_grad_(True) + self.model.speaker_proj.weight.requires_grad_(True) + self.model.speaker_proj.bias.requires_grad_(True) + self.model.speaker_emb.requires_grad_(True) + self.model.speaker_emb.weight.requires_grad_(True) def training_step(self, batch: Dict, batch_idx: int): opt = self.optimizers() @@ -48,13 +49,13 @@ class Text2SemanticLightningModule(LightningModule): batch["semantic_ids"], batch["semantic_ids_len"], batch["bert_feature"], + speaker_ids=batch["speaker_ids"] if "speaker_ids" in batch else None, ) self.manual_backward(loss) if batch_idx > 0 and batch_idx % 4 == 0: opt.step() opt.zero_grad() scheduler.step() - torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt") self.log( "total_loss", diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index e35e2f4f..ef0959b1 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -52,12 +52,8 @@ class Text2SemanticDecoder(nn.Module): # should be same as num of kmeans bin # assert self.EOS == 1024 self.bert_proj = nn.Linear(1024, self.embedding_dim) - self.speaker_proj = nn.Linear(1024, self.embedding_dim) - self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt" - if not os.path.exists(self.path_speaker): - self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1) - else: - self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu")) + self.speaker_proj = nn.Linear(4096, self.embedding_dim) + self.speaker_emb = nn.Embedding(100, 4096) self.ar_text_embedding = TokenEmbedding( self.embedding_dim, self.phoneme_vocab_size, self.p_dropout ) @@ -95,7 +91,7 @@ class Text2SemanticDecoder(nn.Module): ignore_index=self.EOS, ) - def make_input_data(self, x, x_lens, y, y_lens, bert_feature): + def make_input_data(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) @@ -111,7 +107,8 @@ class Text2SemanticDecoder(nn.Module): x_len = x_lens.max() y_len = y_lens.max() y_emb = self.ar_audio_embedding(y) - y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1) + if speaker_ids is not None: + y_emb = y_emb + self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1) y_pos = self.ar_audio_position(y_emb) xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) @@ -149,7 +146,7 @@ class Text2SemanticDecoder(nn.Module): return xy_pos, xy_attn_mask, targets - def forward(self, x, x_lens, y, y_lens, bert_feature): + def forward(self, x, x_lens, y, y_lens, bert_feature, speaker_ids=None): """ x: phoneme_ids y: semantic_ids @@ -157,7 +154,7 @@ class Text2SemanticDecoder(nn.Module): reject_y, reject_y_lens = make_reject_y(y, y_lens) - xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature) + xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature, speaker_ids=speaker_ids) xy_dec, _ = self.h( (xy_pos, None), @@ -167,7 +164,7 @@ class Text2SemanticDecoder(nn.Module): logits = self.ar_predict_layer(xy_dec[:, x_len:]) ###### DPO ############# - reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature) + reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature, speaker_ids=speaker_ids) reject_xy_dec, _ = self.h( (reject_xy_pos, None), @@ -338,7 +335,7 @@ class Text2SemanticDecoder(nn.Module): top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, - use_speaker_feat=True, + speaker_ids=None, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -362,8 +359,8 @@ class Text2SemanticDecoder(nn.Module): "first_infer": 1, "stage": 0, } - if use_speaker_feat: - speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1) + if speaker_ids is not None: + speaker_feat = self.speaker_proj(self.speaker_emb(speaker_ids)).unsqueeze(1) else: speaker_feat = 0 for idx in tqdm(range(1500)): diff --git a/GPT_SoVITS/configs/s1longer.yaml b/GPT_SoVITS/configs/s1longer.yaml index 3f57abd2..9ed69a1d 100644 --- a/GPT_SoVITS/configs/s1longer.yaml +++ b/GPT_SoVITS/configs/s1longer.yaml @@ -3,7 +3,7 @@ train: epochs: 20 batch_size: 8 save_every_n_epoch: 1 - precision: 16-mixed + precision: "32" gradient_clip: 1.0 optimizer: lr: 0.01 diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 9c5197a7..16254ffc 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -365,7 +365,10 @@ def merge_short_text_in_array(texts, threshold): result[len(result) - 1] += text return result -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6,speaker_id="-1"): + speaker_id = int(speaker_id) + if speaker_id == -1: + speaker_id = None t0 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] @@ -381,7 +384,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, ) with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) - if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): + if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 0): raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) wav16k = torch.from_numpy(wav16k) zero_wav_torch = torch.from_numpy(zero_wav) @@ -448,6 +451,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_p=top_p, temperature=temperature, early_stop_num=hz * max_sec, + speaker_ids=torch.LongTensor([speaker_id]).to(device) if speaker_id is not None else None, ) t3 = ttime() # print(pred_semantic.shape,idx) @@ -601,6 +605,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True) SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True) refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary") + speaker_id = gr.Dropdown(label="说话人idx", choices=[str(i) for i in range(-1, 100)], value="-1") refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], []) GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], []) @@ -632,7 +637,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: inference_button.click( get_tts_wav, - [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature], + [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut,top_k,top_p,temperature,speaker_id], [output], ) diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py index b8355dd4..45b68dab 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py @@ -98,7 +98,7 @@ for line in lines[int(i_part)::int(all_parts)]: try: # wav_name,text=line.split("\t") wav_name, spk_name, language, text = line.split("|") - if (inp_wav_dir !=None): + if (inp_wav_dir !=None and len(inp_wav_dir) > 0): wav_name = os.path.basename(wav_name) wav_path = "%s/%s"%(inp_wav_dir, wav_name) diff --git a/tools/slice_audio.py b/tools/slice_audio.py index 46ee408a..98d34f0d 100644 --- a/tools/slice_audio.py +++ b/tools/slice_audio.py @@ -24,24 +24,36 @@ def slice(inp,opt_root,threshold,min_length,min_interval,hop_size,max_sil_kept,_ ) _max=float(_max) alpha=float(alpha) - for inp_path in input[int(i_part)::int(all_part)]: - # print(inp_path) + + def slice_audio(root_dir, inp_wav_path): try: - name = os.path.basename(inp_path) - audio = load_audio(inp_path, 32000) + name = os.path.basename(inp_wav_path) + audio = load_audio(inp_wav_path, 32000) # print(audio.shape) for chunk, start, end in slicer.slice(audio): # start和end是帧数 tmp_max = np.abs(chunk).max() if(tmp_max>1):chunk/=tmp_max chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk wavfile.write( - "%s/%s_%010d_%010d.wav" % (opt_root, name, start, end), + "%s/%s_%010d_%010d.wav" % (root_dir, name, start, end), 32000, # chunk.astype(np.float32), (chunk * 32767).astype(np.int16), ) except: - print(inp_path,"->fail->",traceback.format_exc()) + print(inp_wav_path,"->fail->",traceback.format_exc()) + + for inp_path in input[int(i_part)::int(all_part)]: + # print(inp_path) + if os.path.isdir(inp_path): + inp_dir_name = os.path.basename(inp_path) + os.makedirs(os.path.join(opt_root, inp_dir_name), exist_ok=True) + for inp_name in os.listdir(inp_path): + inp_wav_path = os.path.join(inp_path, inp_name) + slice_audio(os.path.join(opt_root, inp_dir_name), inp_wav_path) + else: + inp_wav_path = inp_path + slice_audio(opt_root, inp_wav_path) return "执行完毕,请检查输出文件" print(slice(*sys.argv[1:])) diff --git a/webui.py b/webui.py index fa40c3af..f85787b5 100644 --- a/webui.py +++ b/webui.py @@ -195,20 +195,42 @@ def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path from tools.asr.config import asr_dict def open_asr(asr_inp_dir, asr_opt_dir, asr_model, asr_model_size, asr_lang): global p_asr - if(p_asr==None): - asr_inp_dir=my_utils.clean_path(asr_inp_dir) + def run(asr_inp_dir, asr_opt_dir): cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}' cmd += f' -i "{asr_inp_dir}"' cmd += f' -o "{asr_opt_dir}"' cmd += f' -s {asr_model_size}' cmd += f' -l {asr_lang}' cmd += " -p %s"%("float16"if is_half==True else "float32") - - yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True} print(cmd) p_asr = Popen(cmd, shell=True) p_asr.wait() p_asr=None + if(p_asr==None): + # yield "ASR任务开启:%s"%cmd,{"__type__":"update","visible":False},{"__type__":"update","visible":True} + flag = False + for name in os.listdir(asr_inp_dir): + if os.path.isdir(os.path.join(asr_inp_dir, name)): + os.makedirs(os.path.join(asr_opt_dir, name), exist_ok=True) + run(os.path.join(asr_inp_dir, name), os.path.join(asr_opt_dir, name)) + else: + break + else: + flag = True + if flag: + run(asr_inp_dir, asr_opt_dir) + # asr_inp_dir=my_utils.clean_path(asr_inp_dir) + # cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}' + # cmd += f' -i "{asr_inp_dir}"' + # cmd += f' -o "{asr_opt_dir}"' + # cmd += f' -s {asr_model_size}' + # cmd += f' -l {asr_lang}' + # cmd += " -p %s"%("float16"if is_half==True else "float32") + + # print(cmd) + # p_asr = Popen(cmd, shell=True) + # p_asr.wait() + # p_asr=None yield f"ASR任务完成, 查看终端进行下一步",{"__type__":"update","visible":True},{"__type__":"update","visible":False} else: yield "已有正在进行的ASR任务,需先终止才能开启下一次任务",{"__type__":"update","visible":False},{"__type__":"update","visible":True} @@ -273,7 +295,7 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights data=f.read() data=yaml.load(data, Loader=yaml.FullLoader) s1_dir="%s/%s"%(exp_root,exp_name) - os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True) + # os.makedirs("%s/logs_s1"%(s1_dir),exist_ok=True) if(is_half==False): data["train"]["precision"]="32" batch_size = max(1, batch_size // 2) @@ -287,6 +309,8 @@ def open1Bb(batch_size,total_epoch,exp_name,if_save_latest,if_save_every_weights data["train"]["exp_name"]=exp_name data["train_semantic_path"]="%s/6-name2semantic.tsv"%s1_dir data["train_phoneme_path"]="%s/2-name2text.txt"%s1_dir + data["train_dir"] = s1_dir + data["speaker_dict_path"] = "%s/speaker_dict.json"%s1_dir data["output_dir"]="%s/logs_s1"%s1_dir os.environ["_CUDA_VISIBLE_DEVICES"]=gpu_numbers.replace("-",",") @@ -352,6 +376,13 @@ def close_slice(): ps1a=[] def open1a(inp_text,inp_wav_dir,exp_name,gpu_numbers,bert_pretrained_dir): global ps1a + if not inp_text.endswith(".list"): + for name in os.listdir(inp_text): + p = os.path.join(inp_text, name) + l = os.path.join(p, name + ".list") + assert os.path.exists(l) + yield from open1a(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,bert_pretrained_dir) + return inp_text = my_utils.clean_path(inp_text) inp_wav_dir = my_utils.clean_path(inp_wav_dir) if (ps1a == []): @@ -410,6 +441,13 @@ def close1a(): ps1b=[] def open1b(inp_text,inp_wav_dir,exp_name,gpu_numbers,ssl_pretrained_dir): global ps1b + if not inp_text.endswith(".list"): + for name in os.listdir(inp_text): + p = os.path.join(inp_text, name) + l = os.path.join(p, name + ".list") + assert os.path.exists(l) + yield from open1b(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers,ssl_pretrained_dir) + return inp_text = my_utils.clean_path(inp_text) inp_wav_dir = my_utils.clean_path(inp_wav_dir) if (ps1b == []): @@ -458,6 +496,13 @@ def close1b(): ps1c=[] def open1c(inp_text,exp_name,gpu_numbers,pretrained_s2G_path): global ps1c + if not inp_text.endswith(".list"): + for name in os.listdir(inp_text): + p = os.path.join(inp_text, name) + l = os.path.join(p, name + ".list") + assert os.path.exists(l) + yield from open1c(l,os.path.join(exp_name, name),gpu_numbers,pretrained_s2G_path) + return inp_text = my_utils.clean_path(inp_text) if (ps1c == []): opt_dir="%s/%s"%(exp_root,exp_name) @@ -515,6 +560,13 @@ def close1c(): ps1abc=[] def open1abc(inp_text,inp_wav_dir,exp_name,gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path): global ps1abc + if not inp_text.endswith(".list"): + for name in os.listdir(inp_text): + p = os.path.join(inp_text, name) + l = os.path.join(p, name + ".list") + assert os.path.exists(l) + yield from open1abc(l,inp_wav_dir,os.path.join(exp_name, name),gpu_numbers1a,gpu_numbers1Ba,gpu_numbers1c,bert_pretrained_dir,ssl_pretrained_dir,pretrained_s2G_path) + return inp_text = my_utils.clean_path(inp_text) inp_wav_dir = my_utils.clean_path(inp_wav_dir) if (ps1abc == []):