mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-04-29 21:00:42 +08:00
feat: Add lib allows tensor saving
This commit is contained in:
parent
2d9193b0d3
commit
6ef7c0b70f
140
GPT_SoVITS/VoiceSave/__init__.py
Normal file
140
GPT_SoVITS/VoiceSave/__init__.py
Normal file
@ -0,0 +1,140 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import zipfile
|
||||
import file_lib as fl
|
||||
import time_lib as tl
|
||||
import info_lib as il
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
POOL:set = set()
|
||||
def get_unique_name(name,MySet:set=set()):
|
||||
_id = 1
|
||||
if name not in POOL and name not in MySet:
|
||||
POOL.add(name)
|
||||
return name
|
||||
while name in POOL or name in MySet:
|
||||
_id += 1
|
||||
name = f'{name}_{_id}'
|
||||
POOL.add(name)
|
||||
return name
|
||||
|
||||
TEMP_DIR = fl.merge_dir_txt2(fl.get_my_dir(), "Temp")
|
||||
TEMP_ZIP_DIR = fl.merge_dir_txt2(TEMP_DIR, "ZipTemp")
|
||||
def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
|
||||
cloned = tensor.clone().detach()
|
||||
np_array = cloned.cpu().numpy()
|
||||
return np_array
|
||||
|
||||
def save_np(path: str, np_array: np.ndarray) -> None:
|
||||
np.save(path, np_array)
|
||||
|
||||
class ZIP_File:
|
||||
def __init__(self, path: str,name:str,MySet:set=set()):
|
||||
self.path = path
|
||||
if not os.path.exists(self.path):
|
||||
with zipfile.ZipFile(self.path, 'w') as zipf:
|
||||
pass
|
||||
self.name = get_unique_name(name,MySet=MySet)#MySet用于补充命名集合,防止文件夹混淆
|
||||
self.temp_write = fl.merge_dir_txt2(TEMP_ZIP_DIR, self.name)
|
||||
|
||||
if not os.path.exists(self.temp_write):
|
||||
os.makedirs(self.temp_write)
|
||||
|
||||
def release(self):
|
||||
'''relaese the zip file, extract it to temp dir'''
|
||||
if os.path.exists(self.temp_write):
|
||||
fl.delete_dir(self.temp_write)
|
||||
fl.create_dir(self.temp_write)
|
||||
with zipfile.ZipFile(self.path, 'r') as zipf:
|
||||
zipf.extractall(self.temp_write)
|
||||
#fl.delete_file(self.path)
|
||||
def create_dir(self, dir_:str):
|
||||
dir_path = fl.merge_dir_txt2(self.temp_write, dir_)
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path,exist_ok=True)
|
||||
|
||||
def create_file(self, file_name:str,location:str=''):
|
||||
if location == '':
|
||||
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
|
||||
else:
|
||||
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
|
||||
if not os.path.exists(location):
|
||||
os.makedirs(location,exist_ok=True)
|
||||
with open(file_path, 'x') as f:
|
||||
pass
|
||||
|
||||
def get_file_path(self, file_name:str,location:str=''):
|
||||
if location == '':
|
||||
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
|
||||
else:
|
||||
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File {file_path} does not exist.")
|
||||
return file_path
|
||||
|
||||
def get_file_obj(self, file_name:str,location:str,mode:str='r'):
|
||||
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File {file_path} does not exist.")
|
||||
return open(file_path, mode)
|
||||
|
||||
def save_file(self, obj):
|
||||
obj.close()
|
||||
|
||||
def save_zip(self):
|
||||
with zipfile.ZipFile(self.path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(self.temp_write):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, self.temp_write)
|
||||
zipf.write(file_path, arcname)
|
||||
#fl.delete_dir(self.temp_write)
|
||||
|
||||
def close(self):
|
||||
self.save_zip()
|
||||
fl.delete_dir(self.temp_write)
|
||||
POOL.remove(self.name)
|
||||
|
||||
def save_tensor(path: str, tensors: Union[torch.Tensor, list],name:str,MySet:set=set(),file_names:Union[str,list,None]=None,**info_save) -> None:
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
tensors = [tensors]
|
||||
if not file_names:
|
||||
return
|
||||
if isinstance(file_names, str):
|
||||
files = [file_names]
|
||||
else:
|
||||
files = file_names
|
||||
|
||||
if len(tensors) != len(files):
|
||||
raise ValueError("The number of tensors and files must be the same.")
|
||||
np_arrays = []
|
||||
for tensor in tensors:
|
||||
np_array = _tensor_to_numpy(tensor)
|
||||
np_arrays.append(np_array)
|
||||
zf = ZIP_File(path, name, MySet=MySet)
|
||||
zf.create_file("voice.json")
|
||||
info = {'name': name}
|
||||
info.update(info_save)
|
||||
il.save_info(str(zf.get_file_path("voice.json")), info)
|
||||
for i in range(len(files)):
|
||||
file_name = files[i]
|
||||
np_array = np_arrays[i]
|
||||
zf.create_file(file_name)
|
||||
save_np(str(zf.get_file_path(file_name)), np_array)
|
||||
zf.close()
|
||||
del zf
|
||||
|
||||
def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list:
|
||||
zf = ZIP_File(path, name, MySet=MySet)
|
||||
zf.release()
|
||||
voice_path = find_func(zf,il)
|
||||
tensors = []
|
||||
for i in range(len(voice_path)):
|
||||
v = voice_path[i]
|
||||
np_array = np.load(v)
|
||||
tensor = torch.from_numpy(np_array)
|
||||
tensors.append(tensor)
|
||||
zf.close()
|
||||
del zf
|
||||
return tensors
|
||||
35
GPT_SoVITS/VoiceSave/file_lib.py
Normal file
35
GPT_SoVITS/VoiceSave/file_lib.py
Normal file
@ -0,0 +1,35 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
def get_my_dir():
|
||||
return os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def get_parent_dir(dir_path,depth=1):
|
||||
parent_path = Path(dir_path)
|
||||
for _ in range(depth):
|
||||
parent_path = parent_path.parent
|
||||
return parent_path
|
||||
|
||||
def merge_dir_txt(a,b):
|
||||
c=os.path.join(a,b)
|
||||
return c
|
||||
def merge_dir_txt2(*TXT):
|
||||
return Path(os.path.join(*TXT))
|
||||
def create_dir(path: Path, overwrite=False) -> bool:
|
||||
if overwrite and path.exists():
|
||||
shutil.rmtree(path)
|
||||
path = Path(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path.exists()
|
||||
def get_dir_children_dirs(path: Path):
|
||||
return [item for item in path.iterdir() if item.is_dir()]
|
||||
def get_dir_children_files(path: Path):
|
||||
return [item for item in path.iterdir() if item.is_file()]
|
||||
def delete_dir(path: Path):
|
||||
return shutil.rmtree(path)
|
||||
def delete_file(path: Path):
|
||||
return os.remove(path)
|
||||
def file_exists(path: Path):
|
||||
path = Path(path)
|
||||
return path.exists()
|
||||
10
GPT_SoVITS/VoiceSave/info_lib.py
Normal file
10
GPT_SoVITS/VoiceSave/info_lib.py
Normal file
@ -0,0 +1,10 @@
|
||||
import json
|
||||
|
||||
def load_info(info_path):
|
||||
with open(info_path, 'r', encoding='utf-8') as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
|
||||
def save_info(info, info_path):
|
||||
with open(info_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(info, f, ensure_ascii=False, indent=4)
|
||||
38
GPT_SoVITS/VoiceSave/time_lib.py
Normal file
38
GPT_SoVITS/VoiceSave/time_lib.py
Normal file
@ -0,0 +1,38 @@
|
||||
import time
|
||||
#time styles
|
||||
STYLE_Y = "%Y"
|
||||
STYLE_M = "%m"
|
||||
STYLE_D = "%d"
|
||||
STYLE_H = "%H"
|
||||
STYLE_MIN = "%M"
|
||||
STYLE_S = "%S"
|
||||
STYLE_FULL = "%Y-%m-%d_%H.%M.%S"
|
||||
#quick calls
|
||||
def get_time_y(STYLE = STYLE_Y):
|
||||
return time.strftime(STYLE, time.localtime())
|
||||
def get_time_m(STYLE = STYLE_M):
|
||||
return time.strftime(STYLE, time.localtime())
|
||||
def get_time_d(STYLE = STYLE_D):
|
||||
return time.strftime(STYLE, time.localtime())
|
||||
def get_time_h(STYLE = STYLE_H):
|
||||
return time.strftime(STYLE, time.localtime())
|
||||
def get_time_min(STYLE = STYLE_MIN):
|
||||
return time.strftime(STYLE, time.localtime())
|
||||
def get_time_s(STYLE = STYLE_S):
|
||||
return time.strftime(STYLE, time.localtime())
|
||||
def get_time_full(STYLE = STYLE_FULL):
|
||||
return time.strftime(STYLE, time.localtime())
|
||||
|
||||
def s(t:float):
|
||||
time.sleep(t)
|
||||
return
|
||||
###
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(get_time_y())
|
||||
print(get_time_m())
|
||||
print(get_time_d())
|
||||
print(get_time_h())
|
||||
print(get_time_min())
|
||||
print(get_time_s())
|
||||
print(get_time_full())
|
||||
Loading…
x
Reference in New Issue
Block a user