feat: 添加数据集的错误处理提示 (#2758)

Co-authored-by: moomushroom <107208254+moomushroom@users.noreply.github.com>
This commit is contained in:
Mushroomcowisheggs 2026-04-18 17:13:30 +08:00 committed by GitHub
parent 14191901cd
commit 00ce973412
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,8 +67,10 @@ class Text2SemanticDataset(Dataset):
) )
) # "%s/3-bert"%exp_dir#bert_dir ) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2) if not os.path.exists(self.path2):
assert os.path.exists(self.path6) raise FileNotFoundError(f"Phoneme data file not found: {self.path2}")
if not os.path.exists(self.path6):
raise FileNotFoundError(f"Semantic data file not found: {self.path6}")
self.phoneme_data = {} self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f: with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")
@ -131,7 +133,7 @@ class Text2SemanticDataset(Dataset):
phoneme, word2ph, text = self.phoneme_data[item_name] phoneme, word2ph, text = self.phoneme_data[item_name]
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
# print(f"{item_name} not in self.phoneme_data !") print(f"Warning: File \"{item_name}\" not in self.phoneme_data! Skipped. ")
num_not_in += 1 num_not_in += 1
continue continue
@ -152,7 +154,7 @@ class Text2SemanticDataset(Dataset):
phoneme_ids = cleaned_text_to_sequence(phoneme, version) phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except: except:
traceback.print_exc() traceback.print_exc()
# print(f"{item_name} not in self.phoneme_data !") print(f"Warning: Failed to convert phonemes to sequence for file \"{item_name}\"! Skipped. ")
num_not_in += 1 num_not_in += 1
continue continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行 # if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
@ -228,7 +230,11 @@ class Text2SemanticDataset(Dataset):
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32) # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature = None bert_feature = None
else: else:
assert bert_feature.shape[-1] == len(phoneme_ids) try:
assert bert_feature.shape[-1] == len(phoneme_ids)
except AssertionError:
print(f"AssertionError: The BERT feature dimension ({bert_feature.shape[-1]}) of the file '{item_name}' does not match the length of the phoneme sequence ({len(phoneme_ids)}).")
raise
return { return {
"idx": idx, "idx": idx,
"phoneme_ids": phoneme_ids, "phoneme_ids": phoneme_ids,