fix default dir miss assertion error

This commit is contained in:
Turing's Cat 2024-02-05 10:08:41 +08:00
parent 9d50702674
commit 5553ebe932

View File

@ -3,6 +3,8 @@ import logging
import os import os
import random import random
import traceback import traceback
from datetime import datetime
import numpy as np import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
@ -19,6 +21,7 @@ from scipy.io import wavfile
from io import BytesIO from io import BytesIO
from my_utils import load_audio from my_utils import load_audio
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79) # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
class TextAudioSpeakerLoader(torch.utils.data.Dataset): class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
@ -28,13 +31,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
def __init__(self, hparams, val=False): def __init__(self, hparams, val=False):
exp_dir = hparams.exp_dir self.exp_dir = hparams.exp_dir
if not os.path.exists("%s/2-name2text.txt" % exp_dir):
exp_dir = "logs/xxx"
self.path2 = "%s/2-name2text.txt" % exp_dir if not os.path.exists("%s/2-name2text.txt" % self.exp_dir): # 如果文件夹为空,寻址最新创建项目
self.path4 = "%s/4-cnhubert" % exp_dir parent_dir = "/".join(os.getcwd().split("/")[:-1])
self.path5 = "%s/5-wav32k" % exp_dir self.find_newest_folder(parent_dir)
self.path2 = "%s/2-name2text.txt" % self.exp_dir
self.path4 = "%s/4-cnhubert" % self.exp_dir
self.path5 = "%s/5-wav32k" % self.exp_dir
f"""-loading self.path2: {self.path2} correctly-""" f"""-loading self.path2: {self.path2} correctly-"""
@ -110,6 +115,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
self.audiopaths_sid_text = audiopaths_sid_text_new self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths self.lengths = lengths
def find_newest_folder(self, directory):
if os.path.exists(directory):
try:
for root, dirs, files in os.walk(directory):
# 获取当前目录下所有子文件夹的修改时间
folder_times = [datetime.fromtimestamp(os.path.getmtime(root + '/' + d)) for d in dirs]
if len(folder_times) > 0:
# 根据修改时间选择最新的子文件夹
self.exp_dir = max(zip(folder_times, dirs), key=lambda x: x[0])[-1]
except:
f"""locating newest folder fail in {directory}"""
return
def get_audio_text_speaker_pair(self, audiopath_sid_text): def get_audio_text_speaker_pair(self, audiopath_sid_text):
audiopath, phoneme_ids = audiopath_sid_text audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids) text = torch.FloatTensor(phoneme_ids)
@ -136,7 +155,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
center=False) center=False)
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
return spec, audio_norm return spec, audio_norm
@ -153,7 +172,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
def random_slice(self, ssl, wav, mel): def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first", ssl.shape, wav.shape) "first", ssl.shape, wav.shape)
len_mel = mel.shape[1] len_mel = mel.shape[1]
if self.val: if self.val:
@ -174,7 +193,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
mel = mel[:, :sep_point] mel = mel[:, :sep_point]
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir) ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
return reference_mel, ssl, wav2, mel return reference_mel, ssl, wav2, mel
@ -335,4 +354,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
return -1 return -1
def __len__(self): def __len__(self):
return self.num_samples // self.batch_size return self.num_samples // self.batch_size