XXXXRT666 d5e479dad6
Introduce Docker and Windows CI Workflow, Pre-commit Formatting, and Language Resource Auto-Download (#2351)
* Docker Auto-Build Workflow

* Rename

* Update

* Fix Bugs

* Disable Progress Bar When workflows triggered

* Fix Wget

* Fix Bugs

* Fix Bugs

* Update Wget

* Update Workflows

* Accelerate Docker Image Building

* Fix Install.sh

* Add Skip-Check For Action Runner

* Fix Dockerfile

* .

* .

* .

* .

* Delete File in Runner

* Add Sort

* Delete More Files

* Delete More

* .

* .

* .

* Add Pre-Commit Hook
Update Docker

* Add Code Spell Check

* [pre-commit.ci] trigger

* [pre-commit.ci] trigger

* [pre-commit.ci] trigger

* Fix Bugs

* .

* Disable Progress Bar and Logs while using GitHub Actions

* .

* .

* Fix Bugs

* update conda

* fix bugs

* Fix Bugs

* fix bugs

* .

* .

* Quiet Installation

* fix bugs

* .

* fix bug

* .

* Fix pre-commit.ci and Docker

* fix bugs

* .

* Update Docker & Pre-Commit

* fix  bugs

* Update Req

* Update Req

* Update OpenCC

* update precommit

* .

* Update .pre-commit-config.yaml

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update Docs and fix bugs

* Fix \

* Fix MacOS

* .

* test

* .

* Add Tag Alias

* .

* fix bugs

* fix bugs

* make image smaller

* update pre-commit config

* .

* .

* fix bugs

* use miniconda

* Fix Wrong Path

* .

* debug

* debug

* revert

* Fix Bugs

* Update Docs, Add Dict Auto Download in install.sh

* update docker_build

* Update Docs for Install.sh

* update docker docs about architecture

* Add Xcode-Commandline-Tool Installation

* Update Docs

1. Add Missing VC17
2. Modufied the Order of FFmpeg Installation and Requirements Installation
3. Remove Duplicate FFmpeg

* Fix Wrong Cuda Version

* Update TESTED ENV

* Add PYTHONNOUSERSITE(-s)

* Fix Wrapper

* Update install.sh For Robustness

* Ignore .git

* Preload CUDNN For Ctranslate2

* Remove Gradio Warnings

* Update Colab

* Fix OpenCC Problems

* Update Win DLL Strategy

* Fix Onnxruntime-gpu NVRTC Error

* Fix Path Problems

* Add Windows Packages Workflow

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* .

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* Fix Path

* Fix Path

* Enable Logging

* Set 7-Zip compression level to maximum (-mx=9)

* Use Multithread in ONNX Session

* Fix Tag Bugs

* Add Time

* Add Time

* Add Time

* Compress More

* Copy DLL to Solve VC Runtime DLL Missing Issues

* Expose FFmpeg Errors, Copy Only Part of Visual C++ Runtime

* Update build_windows_packages.ps1

* Update build_windows_packages.ps1

* Update build_windows_packages.ps1

* Update build_windows_packages.ps1

* WIP

* WIP

* WIP

* Update build_windows_packages.ps1

* Update install.sh

* Update build_windows_packages.ps1

* Update docker-publish.yaml

* Update install.sh

* Update Dockerfile

* Update docker_build.sh

* Update miniconda_install.sh

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update Colab-WebUI.ipynb

* Update Colab-Inference.ipynb

* Update docker-compose.yaml

* 更新 build_windows_packages.ps1

* Update install.sh

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-05-26 10:45:14 +08:00

362 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import glob
import json
import logging
import os
import subprocess
import sys
import traceback
import librosa
import numpy as np
import torch
logging.getLogger("numba").setLevel(logging.ERROR)
logging.getLogger("matplotlib").setLevel(logging.ERROR)
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict["iteration"]
learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and not skip_optimizer and checkpoint_dict["optimizer"] is not None:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
# assert "quantizer" not in k
# print("load", k)
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (
saved_state_dict[k].shape,
v.shape,
)
except:
traceback.print_exc()
print("error, %s is not in the checkpoint" % k) # shape不对也会比如text_embedding当cleaner修改时
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
print("load ")
logger.info(
"Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path,
iteration,
)
)
return model, optimizer, learning_rate, iteration
import shutil
from time import time as ttime
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path))
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
# torch.save(
my_save(
{
"model": state_dict,
"iteration": iteration,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
print(x)
return x
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(),
aspect="auto",
origin="lower",
interpolation="none",
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
data, sampling_rate = librosa.load(full_path, sr=None)
return torch.FloatTensor(data), sampling_rate
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True, stage=1):
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
type=str,
default="./configs/s2.json",
help="JSON file for configuration",
)
parser.add_argument("-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir")
parser.add_argument(
"-rs",
"--resume_step",
type=int,
required=False,
default=None,
help="resume step",
)
# parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
# parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
# parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
args = parser.parse_args()
config_path = args.config
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.pretrain = args.pretrain
hparams.resume_step = args.resume_step
# hparams.data.exp_dir = args.exp_dir
if stage == 1:
model_dir = hparams.s1_ckpt_dir
else:
model_dir = hparams.s2_ckpt_dir
config_save_path = os.path.join(model_dir, "config.json")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
with open(config_save_path, "w") as f:
f.write(data)
return hparams
def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
"""Freeing up space by deleting saved ckpts
Arguments:
path_to_models -- Path to the model directory
n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
import re
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn(
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir,
)
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn(
"git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8],
cur_hash[:8],
)
)
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.INFO)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()
if __name__ == "__main__":
print(
load_wav_to_torch(
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac",
)
)