mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Merge 5553ebe932f5df6b5bd79e1545958aa1628e43a9 into 5dfce9a3f0def7f1ee1e075df569b0b2d41df9e3
This commit is contained in:
commit
d1fd1d1ad5
@ -3,6 +3,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
@ -19,6 +21,7 @@ from scipy.io import wavfile
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from my_utils import load_audio
|
from my_utils import load_audio
|
||||||
|
|
||||||
|
|
||||||
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
||||||
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
@ -28,10 +31,18 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hparams, val=False):
|
def __init__(self, hparams, val=False):
|
||||||
exp_dir = hparams.exp_dir
|
self.exp_dir = hparams.exp_dir
|
||||||
self.path2 = "%s/2-name2text.txt" % exp_dir
|
|
||||||
self.path4 = "%s/4-cnhubert" % exp_dir
|
if not os.path.exists("%s/2-name2text.txt" % self.exp_dir): # 如果文件夹为空,寻址最新创建项目
|
||||||
self.path5 = "%s/5-wav32k" % exp_dir
|
parent_dir = "/".join(os.getcwd().split("/")[:-1])
|
||||||
|
self.find_newest_folder(parent_dir)
|
||||||
|
|
||||||
|
self.path2 = "%s/2-name2text.txt" % self.exp_dir
|
||||||
|
self.path4 = "%s/4-cnhubert" % self.exp_dir
|
||||||
|
self.path5 = "%s/5-wav32k" % self.exp_dir
|
||||||
|
|
||||||
|
f"""-loading self.path2: {self.path2} correctly-"""
|
||||||
|
|
||||||
assert os.path.exists(self.path2)
|
assert os.path.exists(self.path2)
|
||||||
assert os.path.exists(self.path4)
|
assert os.path.exists(self.path4)
|
||||||
assert os.path.exists(self.path5)
|
assert os.path.exists(self.path5)
|
||||||
@ -104,6 +115,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||||
self.lengths = lengths
|
self.lengths = lengths
|
||||||
|
|
||||||
|
def find_newest_folder(self, directory):
|
||||||
|
if os.path.exists(directory):
|
||||||
|
try:
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
# 获取当前目录下所有子文件夹的修改时间
|
||||||
|
folder_times = [datetime.fromtimestamp(os.path.getmtime(root + '/' + d)) for d in dirs]
|
||||||
|
|
||||||
|
if len(folder_times) > 0:
|
||||||
|
# 根据修改时间选择最新的子文件夹
|
||||||
|
self.exp_dir = max(zip(folder_times, dirs), key=lambda x: x[0])[-1]
|
||||||
|
except:
|
||||||
|
f"""locating newest folder fail in {directory}"""
|
||||||
|
return
|
||||||
|
|
||||||
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
||||||
audiopath, phoneme_ids = audiopath_sid_text
|
audiopath, phoneme_ids = audiopath_sid_text
|
||||||
text = torch.FloatTensor(phoneme_ids)
|
text = torch.FloatTensor(phoneme_ids)
|
||||||
@ -130,7 +155,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
audio_norm = audio
|
audio_norm = audio
|
||||||
audio_norm = audio_norm.unsqueeze(0)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
||||||
center=False)
|
center=False)
|
||||||
spec = torch.squeeze(spec, 0)
|
spec = torch.squeeze(spec, 0)
|
||||||
return spec, audio_norm
|
return spec, audio_norm
|
||||||
|
|
||||||
@ -147,7 +172,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def random_slice(self, ssl, wav, mel):
|
def random_slice(self, ssl, wav, mel):
|
||||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
||||||
"first", ssl.shape, wav.shape)
|
"first", ssl.shape, wav.shape)
|
||||||
|
|
||||||
len_mel = mel.shape[1]
|
len_mel = mel.shape[1]
|
||||||
if self.val:
|
if self.val:
|
||||||
@ -168,7 +193,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|||||||
mel = mel[:, :sep_point]
|
mel = mel[:, :sep_point]
|
||||||
|
|
||||||
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
||||||
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
||||||
return reference_mel, ssl, wav2, mel
|
return reference_mel, ssl, wav2, mel
|
||||||
|
|
||||||
|
|
||||||
@ -329,4 +354,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
|||||||
return -1
|
return -1
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples // self.batch_size
|
return self.num_samples // self.batch_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user