Add progress bar

This commit is contained in:
KamioRinn 2024-07-16 00:22:56 +08:00
parent 1295bf22fb
commit f043bfebbb

View File

@ -60,6 +60,9 @@ class BsRoformer_Loader:
length_init = mix.shape[-1] length_init = mix.shape[-1]
progress_bar = tqdm(total=(length_init//step)+3)
progress_bar.set_description("Processing")
# Do pad from the beginning and end to account floating window results better # Do pad from the beginning and end to account floating window results better
if length_init > 2 * border and (border > 0): if length_init > 2 * border and (border > 0):
mix = nn.functional.pad(mix, (border, border), mode='reflect') mix = nn.functional.pad(mix, (border, border), mode='reflect')
@ -96,6 +99,7 @@ class BsRoformer_Loader:
batch_data.append(part) batch_data.append(part)
batch_locations.append((i, length)) batch_locations.append((i, length))
i += step i += step
progress_bar.update(1)
if len(batch_data) >= batch_size or (i >= mix.shape[1]): if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0) arr = torch.stack(batch_data, dim=0)
@ -123,11 +127,13 @@ class BsRoformer_Loader:
# Remove pad # Remove pad
estimated_sources = estimated_sources[..., border:-border] estimated_sources = estimated_sources[..., border:-border]
progress_bar.close()
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)} return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}
def run_folder(self,input, vocal_root, others_root, format): def run_folder(self,input, vocal_root, others_root, format):
start_time = time.time() # start_time = time.time()
self.model.eval() self.model.eval()
path = input path = input
@ -185,7 +191,7 @@ class BsRoformer_Loader:
except: except:
pass pass
print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) # print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def __init__(self, model_path, device): def __init__(self, model_path, device):