Merge 5553ebe932f5df6b5bd79e1545958aa1628e43a9 into 5dfce9a3f0def7f1ee1e075df569b0b2d41df9e3

This commit is contained in:
Rylynn 2024-08-21 13:42:25 -07:00 committed by GitHub
commit d1fd1d1ad5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,8 @@ import logging
import os import os
import random import random
import traceback import traceback
from datetime import datetime
import numpy as np import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
@ -19,6 +21,7 @@ from scipy.io import wavfile
from io import BytesIO from io import BytesIO
from my_utils import load_audio from my_utils import load_audio
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79) # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
class TextAudioSpeakerLoader(torch.utils.data.Dataset): class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
@ -28,10 +31,18 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
def __init__(self, hparams, val=False): def __init__(self, hparams, val=False):
exp_dir = hparams.exp_dir self.exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir if not os.path.exists("%s/2-name2text.txt" % self.exp_dir): # 如果文件夹为空,寻址最新创建项目
self.path5 = "%s/5-wav32k" % exp_dir parent_dir = "/".join(os.getcwd().split("/")[:-1])
self.find_newest_folder(parent_dir)
self.path2 = "%s/2-name2text.txt" % self.exp_dir
self.path4 = "%s/4-cnhubert" % self.exp_dir
self.path5 = "%s/5-wav32k" % self.exp_dir
f"""-loading self.path2: {self.path2} correctly-"""
assert os.path.exists(self.path2) assert os.path.exists(self.path2)
assert os.path.exists(self.path4) assert os.path.exists(self.path4)
assert os.path.exists(self.path5) assert os.path.exists(self.path5)
@ -104,6 +115,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
self.audiopaths_sid_text = audiopaths_sid_text_new self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths self.lengths = lengths
def find_newest_folder(self, directory):
if os.path.exists(directory):
try:
for root, dirs, files in os.walk(directory):
# 获取当前目录下所有子文件夹的修改时间
folder_times = [datetime.fromtimestamp(os.path.getmtime(root + '/' + d)) for d in dirs]
if len(folder_times) > 0:
# 根据修改时间选择最新的子文件夹
self.exp_dir = max(zip(folder_times, dirs), key=lambda x: x[0])[-1]
except:
f"""locating newest folder fail in {directory}"""
return
def get_audio_text_speaker_pair(self, audiopath_sid_text): def get_audio_text_speaker_pair(self, audiopath_sid_text):
audiopath, phoneme_ids = audiopath_sid_text audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids) text = torch.FloatTensor(phoneme_ids)
@ -130,7 +155,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
center=False) center=False)
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
return spec, audio_norm return spec, audio_norm
@ -147,7 +172,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
def random_slice(self, ssl, wav, mel): def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first", ssl.shape, wav.shape) "first", ssl.shape, wav.shape)
len_mel = mel.shape[1] len_mel = mel.shape[1]
if self.val: if self.val:
@ -168,7 +193,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
mel = mel[:, :sep_point] mel = mel[:, :sep_point]
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir) ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
return reference_mel, ssl, wav2, mel return reference_mel, ssl, wav2, mel
@ -329,4 +354,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
return -1 return -1
def __len__(self): def __len__(self):
return self.num_samples // self.batch_size return self.num_samples // self.batch_size