Add files via upload

This commit is contained in:
RVC-Boss 2024-01-21 22:47:51 +08:00 committed by GitHub
parent ea62d6e0cf
commit 7b89c9ed56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,8 @@
import time, logging import time
import logging
import os import os
import random, traceback import random
import traceback
import numpy as np import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
@ -12,15 +14,12 @@ from text import cleaned_text_to_sequence
from utils import load_wav_to_torch, load_filepaths_and_text from utils import load_wav_to_torch, load_filepaths_and_text
import torch.nn.functional as F import torch.nn.functional as F
from functools import lru_cache from functools import lru_cache
import torch
import requests import requests
from scipy.io import wavfile from scipy.io import wavfile
from io import BytesIO from io import BytesIO
# from config import exp_dir
from my_utils import load_audio from my_utils import load_audio
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
class TextAudioSpeakerLoader(torch.utils.data.Dataset): class TextAudioSpeakerLoader(torch.utils.data.Dataset):
""" """
1) loads audio, speaker_id, text pairs 1) loads audio, speaker_id, text pairs
@ -44,7 +43,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for line in lines: for line in lines:
tmp = line.split("\t") tmp = line.split("\t")
if len(tmp) != 4: if (len(tmp) != 4):
continue continue
self.phoneme_data[tmp[0]] = [tmp[1]] self.phoneme_data[tmp[0]] = [tmp[1]]
@ -52,7 +51,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text tmp = self.audiopaths_sid_text
leng = len(tmp) leng = len(tmp)
min_num = 100 min_num = 100
if leng < min_num: if (leng < min_num):
self.audiopaths_sid_text = [] self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))): for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp self.audiopaths_sid_text += tmp
@ -77,20 +76,28 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text): for audiopath in tqdm(self.audiopaths_sid_text):
try: try:
phoneme = self.phoneme_data[audiopath][0] phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(" ") phoneme = phoneme.split(' ')
phoneme_ids = cleaned_text_to_sequence(phoneme) phoneme_ids = cleaned_text_to_sequence(phoneme)
except Exception: except Exception:
print(f"{audiopath} not in self.phoneme_data !") print(f"{audiopath} not in self.phoneme_data !")
skipped_phone += 1 skipped_phone += 1
continue continue
size = os.path.getsize("%s/%s" % (self.path5, audiopath)) size = os.path.getsize("%s/%s" % (self.path5, audiopath))
duration = size / self.sampling_rate / 2 duration = size / self.sampling_rate / 2
if duration == 0:
print(f"Zero duration for {audiopath}, skipping...")
skipped_dur += 1
continue
if 54 > duration > 0.6 or self.val: if 54 > duration > 0.6 or self.val:
audiopaths_sid_text_new.append([audiopath, phoneme_ids]) audiopaths_sid_text_new.append([audiopath, phoneme_ids])
lengths.append(size // (2 * self.hop_length)) lengths.append(size // (2 * self.hop_length))
else: else:
skipped_dur += 1 skipped_dur += 1
continue continue
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur) print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
print("total left: ", len(audiopaths_sid_text_new)) print("total left: ", len(audiopaths_sid_text_new))
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
@ -103,10 +110,8 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
try: try:
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath)) spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad(): with torch.no_grad():
ssl = torch.load( ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
"%s/%s.pt" % (self.path4, audiopath), map_location="cpu" if (ssl.shape[-1] != spec.shape[-1]):
)
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False ssl.requires_grad = False
@ -117,25 +122,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
ssl = torch.zeros(1, 768, 100) ssl = torch.zeros(1, 768, 100)
text = text[-1:] text = text[-1:]
print("load audio or ssl error!!!!!!", audiopath) print("load audio or ssl error!!!!!!", audiopath)
# print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
return (ssl, spec, wav, text) return (ssl, spec, wav, text)
def get_audio(self, filename): def get_audio(self, filename):
audio_array = load_audio( audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
filename, self.sampling_rate
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
# print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
audio = torch.FloatTensor(audio_array) # /32768 audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio audio_norm = audio
audio_norm = audio_norm.unsqueeze(0) audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch( spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
audio_norm, center=False)
self.filter_length,
self.sampling_rate,
self.hop_length,
self.win_length,
center=False,
)
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
return spec, audio_norm return spec, audio_norm
@ -152,10 +147,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
def random_slice(self, ssl, wav, mel): def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first", "first", ssl.shape, wav.shape)
ssl.shape,
wav.shape,
)
len_mel = mel.shape[1] len_mel = mel.shape[1]
if self.val: if self.val:
@ -176,20 +168,13 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
mel = mel[:, :sep_point] mel = mel[:, :sep_point]
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, ( assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape, ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate: class TextAudioSpeakerCollate():
"""Zero-pads model inputs and targets""" """ Zero-pads model inputs and targets
"""
def __init__(self, return_ids=False): def __init__(self, return_ids=False):
self.return_ids = return_ids self.return_ids = return_ids
@ -202,8 +187,8 @@ class TextAudioSpeakerCollate:
""" """
# Right zero-pad all one-hot text sequences to max input length # Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort( _, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True torch.LongTensor([x[1].size(1) for x in batch]),
) dim=0, descending=True)
max_ssl_len = max([x[0].size(2) for x in batch]) max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1)) max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@ -246,16 +231,7 @@ class TextAudioSpeakerCollate:
text_padded[i, :text.size(0)] = text text_padded[i, :text.size(0)] = text
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
return ( return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
)
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
@ -268,18 +244,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
""" """
def __init__( def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
self,
dataset,
batch_size,
boundaries,
num_replicas=None,
rank=None,
shuffle=True,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths self.lengths = dataset.lengths
# print(233333333333333,self.lengths,dir(dataset))
self.batch_size = batch_size self.batch_size = batch_size
self.boundaries = boundaries self.boundaries = boundaries
@ -295,24 +262,22 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
if idx_bucket != -1: if idx_bucket != -1:
buckets[idx_bucket].append(i) buckets[idx_bucket].append(i)
for i in range(len(buckets) - 1, 0, -1): i = len(buckets) - 1
# for i in range(len(buckets) - 1, -1, -1): while i >= 0:
if len(buckets[i]) == 0: if len(buckets[i]) == 0:
buckets.pop(i) buckets.pop(i)
self.boundaries.pop(i + 1) self.boundaries.pop(i + 1)
i -= 1
num_samples_per_bucket = [] num_samples_per_bucket = []
for i in range(len(buckets)): for i in range(len(buckets)):
len_bucket = len(buckets[i]) len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size total_batch_size = self.num_replicas * self.batch_size
rem = ( rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
total_batch_size - (len_bucket % total_batch_size)
) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem) num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket return buckets, num_samples_per_bucket
def __iter__(self): def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator() g = torch.Generator()
g.manual_seed(self.epoch) g.manual_seed(self.epoch)
@ -331,25 +296,13 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
ids_bucket = indices[i] ids_bucket = indices[i]
num_samples_bucket = self.num_samples_per_bucket[i] num_samples_bucket = self.num_samples_per_bucket[i]
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket rem = num_samples_bucket - len_bucket
ids_bucket = ( ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
ids_bucket
+ ids_bucket * (rem // len_bucket)
+ ids_bucket[: (rem % len_bucket)]
)
# subsample
ids_bucket = ids_bucket[self.rank::self.num_replicas] ids_bucket = ids_bucket[self.rank::self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size): for j in range(len(ids_bucket) // self.batch_size):
batch = [ batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
bucket[idx]
for idx in ids_bucket[
j * self.batch_size : (j + 1) * self.batch_size
]
]
batches.append(batch) batches.append(batch)
if self.shuffle: if self.shuffle: