tqdm progress bar

This commit is contained in:
XXXXRT666 2024-08-26 02:06:35 +08:00
parent fc2161b484
commit be200d182a
4 changed files with 16 additions and 15 deletions

View File

@ -84,11 +84,10 @@ if os.path.exists(txt_path) == False:
return phone_level_feature.T
def process(data, res):
for name, text, lan in data:
for name, text, lan in tqdm(data,position=int(i_part),delay=0.5):
try:
name=clean_path(name)
name = os.path.basename(name)
print(name)
phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan, version
)
@ -102,7 +101,7 @@ if os.path.exists(txt_path) == False:
# res.append([name,phones])
res.append([name, phones, word2ph, norm_text])
except:
print(name, text, traceback.format_exc())
tqdm.write(name, text, traceback.format_exc())
todo = []
res = []
@ -135,9 +134,9 @@ if os.path.exists(txt_path) == False:
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
tqdm.write(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except:
print(line, traceback.format_exc())
tqdm.write(line, traceback.format_exc())
process(todo, res)
opt = []

View File

@ -20,6 +20,7 @@ import librosa
now_dir = os.getcwd()
sys.path.append(now_dir)
from tools.my_utils import load_audio,clean_path
from tqdm import tqdm
# from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path
@ -70,7 +71,7 @@ def name2go(wav_name,wav_path):
tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2:
print("%s-filtered,%s" % (wav_name, tmp_max))
tqdm.write("%s-filtered,%s" % (wav_name, tmp_max))
return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
@ -85,7 +86,7 @@ def name2go(wav_name,wav_path):
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append((wav_name,wav_path))
print("nan filtered:%s"%wav_name)
tqdm.write("nan filtered:%s"%wav_name)
return
wavfile.write(
"%s/%s"%(wav32dir,wav_name),
@ -97,7 +98,7 @@ def name2go(wav_name,wav_path):
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]:
for line in tqdm(lines[int(i_part)::int(all_parts)],position=int(i_part)):
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
@ -111,13 +112,13 @@ for line in lines[int(i_part)::int(all_parts)]:
wav_name = os.path.basename(wav_name)
name2go(wav_name,wav_path)
except:
print(line,traceback.format_exc())
tqdm.write(line,traceback.format_exc())
if(len(nan_fails)>0 and is_half==True):
is_half=False
model=model.float()
for wav in nan_fails:
for wav in tqdm(nan_fails,position=int(i_part)):
try:
name2go(wav[0],wav[1])
except:
print(wav_name,traceback.format_exc())
tqdm.write(wav_name,traceback.format_exc())

View File

@ -25,6 +25,7 @@ from tqdm import tqdm
import logging, librosa, utils
from module.models import SynthesizerTrn
from tools.my_utils import clean_path
from tqdm import tqdm
logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G
@ -87,8 +88,8 @@ if os.path.exists(semantic_path) == False:
lines = f.read().strip("\n").split("\n")
lines1 = []
for line in lines[int(i_part) :: int(all_parts)]:
# print(line)
for line in tqdm(lines[int(i_part) :: int(all_parts)],position=int(i_part)):
# tqdm.write(line)
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
@ -97,6 +98,6 @@ if os.path.exists(semantic_path) == False:
# name2go(name,lines1)
name2go(wav_name, lines1)
except:
print(line, traceback.format_exc())
tqdm.write(line, traceback.format_exc())
with open(semantic_path, "w", encoding="utf8") as f:
f.write("\n".join(lines1))

View File

@ -306,7 +306,7 @@ def train_and_evaluate(
y_lengths,
text,
text_lengths,
) in enumerate(tqdm(train_loader)):
) in enumerate(tqdm(train_loader,position=rank+1,leave=(epoch==hps.train.epochs),postfix=f'epoch:{epoch}')):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True