From 9d50702674533080ce8cc25d3d5bc399f4ae049d Mon Sep 17 00:00:00 2001 From: Turing's Cat Date: Mon, 5 Feb 2024 09:45:16 +0800 Subject: [PATCH 1/2] fix default dir miss assertion error --- GPT_SoVITS/module/data_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py index ff4c4f4..33158d3 100644 --- a/GPT_SoVITS/module/data_utils.py +++ b/GPT_SoVITS/module/data_utils.py @@ -29,9 +29,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): def __init__(self, hparams, val=False): exp_dir = hparams.exp_dir + if not os.path.exists("%s/2-name2text.txt" % exp_dir): + exp_dir = "logs/xxx" + self.path2 = "%s/2-name2text.txt" % exp_dir self.path4 = "%s/4-cnhubert" % exp_dir self.path5 = "%s/5-wav32k" % exp_dir + + f"""-loading self.path2: {self.path2} correctly-""" + assert os.path.exists(self.path2) assert os.path.exists(self.path4) assert os.path.exists(self.path5) From 5553ebe932f5df6b5bd79e1545958aa1628e43a9 Mon Sep 17 00:00:00 2001 From: Turing's Cat Date: Mon, 5 Feb 2024 10:08:41 +0800 Subject: [PATCH 2/2] fix default dir miss assertion error --- GPT_SoVITS/module/data_utils.py | 39 ++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py index 33158d3..340cdc4 100644 --- a/GPT_SoVITS/module/data_utils.py +++ b/GPT_SoVITS/module/data_utils.py @@ -3,6 +3,8 @@ import logging import os import random import traceback +from datetime import datetime + import numpy as np import torch import torch.utils.data @@ -19,6 +21,7 @@ from scipy.io import wavfile from io import BytesIO from my_utils import load_audio + # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79) class TextAudioSpeakerLoader(torch.utils.data.Dataset): """ @@ -28,13 +31,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): """ def __init__(self, hparams, val=False): - exp_dir = hparams.exp_dir - if not os.path.exists("%s/2-name2text.txt" % exp_dir): - exp_dir = "logs/xxx" + self.exp_dir = hparams.exp_dir - self.path2 = "%s/2-name2text.txt" % exp_dir - self.path4 = "%s/4-cnhubert" % exp_dir - self.path5 = "%s/5-wav32k" % exp_dir + if not os.path.exists("%s/2-name2text.txt" % self.exp_dir): # 如果文件夹为空,寻址最新创建项目 + parent_dir = "/".join(os.getcwd().split("/")[:-1]) + self.find_newest_folder(parent_dir) + + self.path2 = "%s/2-name2text.txt" % self.exp_dir + self.path4 = "%s/4-cnhubert" % self.exp_dir + self.path5 = "%s/5-wav32k" % self.exp_dir f"""-loading self.path2: {self.path2} correctly-""" @@ -110,6 +115,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): self.audiopaths_sid_text = audiopaths_sid_text_new self.lengths = lengths + def find_newest_folder(self, directory): + if os.path.exists(directory): + try: + for root, dirs, files in os.walk(directory): + # 获取当前目录下所有子文件夹的修改时间 + folder_times = [datetime.fromtimestamp(os.path.getmtime(root + '/' + d)) for d in dirs] + + if len(folder_times) > 0: + # 根据修改时间选择最新的子文件夹 + self.exp_dir = max(zip(folder_times, dirs), key=lambda x: x[0])[-1] + except: + f"""locating newest folder fail in {directory}""" + return + def get_audio_text_speaker_pair(self, audiopath_sid_text): audiopath, phoneme_ids = audiopath_sid_text text = torch.FloatTensor(phoneme_ids) @@ -136,7 +155,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): audio_norm = audio audio_norm = audio_norm.unsqueeze(0) spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, - center=False) + center=False) spec = torch.squeeze(spec, 0) return spec, audio_norm @@ -153,7 +172,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): def random_slice(self, ssl, wav, mel): assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ( - "first", ssl.shape, wav.shape) + "first", ssl.shape, wav.shape) len_mel = mel.shape[1] if self.val: @@ -174,7 +193,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): mel = mel[:, :sep_point] assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, ( - ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir) + ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir) return reference_mel, ssl, wav2, mel @@ -335,4 +354,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): return -1 def __len__(self): - return self.num_samples // self.batch_size \ No newline at end of file + return self.num_samples // self.batch_size