feat: Add lib allows tensor saving

This commit is contained in:
__kaning123__ 2026-02-23 09:51:55 +08:00 committed by GitHub
parent 2d9193b0d3
commit 6ef7c0b70f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 223 additions and 0 deletions

View File

@ -0,0 +1,140 @@
import numpy as np
import torch
import zipfile
import file_lib as fl
import time_lib as tl
import info_lib as il
import os
from typing import Union
POOL:set = set()
def get_unique_name(name,MySet:set=set()):
_id = 1
if name not in POOL and name not in MySet:
POOL.add(name)
return name
while name in POOL or name in MySet:
_id += 1
name = f'{name}_{_id}'
POOL.add(name)
return name
TEMP_DIR = fl.merge_dir_txt2(fl.get_my_dir(), "Temp")
TEMP_ZIP_DIR = fl.merge_dir_txt2(TEMP_DIR, "ZipTemp")
def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
cloned = tensor.clone().detach()
np_array = cloned.cpu().numpy()
return np_array
def save_np(path: str, np_array: np.ndarray) -> None:
np.save(path, np_array)
class ZIP_File:
def __init__(self, path: str,name:str,MySet:set=set()):
self.path = path
if not os.path.exists(self.path):
with zipfile.ZipFile(self.path, 'w') as zipf:
pass
self.name = get_unique_name(name,MySet=MySet)#MySet用于补充命名集合防止文件夹混淆
self.temp_write = fl.merge_dir_txt2(TEMP_ZIP_DIR, self.name)
if not os.path.exists(self.temp_write):
os.makedirs(self.temp_write)
def release(self):
'''relaese the zip file, extract it to temp dir'''
if os.path.exists(self.temp_write):
fl.delete_dir(self.temp_write)
fl.create_dir(self.temp_write)
with zipfile.ZipFile(self.path, 'r') as zipf:
zipf.extractall(self.temp_write)
#fl.delete_file(self.path)
def create_dir(self, dir_:str):
dir_path = fl.merge_dir_txt2(self.temp_write, dir_)
if not os.path.exists(dir_path):
os.makedirs(dir_path,exist_ok=True)
def create_file(self, file_name:str,location:str=''):
if location == '':
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
else:
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
if not os.path.exists(location):
os.makedirs(location,exist_ok=True)
with open(file_path, 'x') as f:
pass
def get_file_path(self, file_name:str,location:str=''):
if location == '':
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
else:
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"File {file_path} does not exist.")
return file_path
def get_file_obj(self, file_name:str,location:str,mode:str='r'):
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"File {file_path} does not exist.")
return open(file_path, mode)
def save_file(self, obj):
obj.close()
def save_zip(self):
with zipfile.ZipFile(self.path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(self.temp_write):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, self.temp_write)
zipf.write(file_path, arcname)
#fl.delete_dir(self.temp_write)
def close(self):
self.save_zip()
fl.delete_dir(self.temp_write)
POOL.remove(self.name)
def save_tensor(path: str, tensors: Union[torch.Tensor, list],name:str,MySet:set=set(),file_names:Union[str,list,None]=None,**info_save) -> None:
if isinstance(tensors, torch.Tensor):
tensors = [tensors]
if not file_names:
return
if isinstance(file_names, str):
files = [file_names]
else:
files = file_names
if len(tensors) != len(files):
raise ValueError("The number of tensors and files must be the same.")
np_arrays = []
for tensor in tensors:
np_array = _tensor_to_numpy(tensor)
np_arrays.append(np_array)
zf = ZIP_File(path, name, MySet=MySet)
zf.create_file("voice.json")
info = {'name': name}
info.update(info_save)
il.save_info(str(zf.get_file_path("voice.json")), info)
for i in range(len(files)):
file_name = files[i]
np_array = np_arrays[i]
zf.create_file(file_name)
save_np(str(zf.get_file_path(file_name)), np_array)
zf.close()
del zf
def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list:
zf = ZIP_File(path, name, MySet=MySet)
zf.release()
voice_path = find_func(zf,il)
tensors = []
for i in range(len(voice_path)):
v = voice_path[i]
np_array = np.load(v)
tensor = torch.from_numpy(np_array)
tensors.append(tensor)
zf.close()
del zf
return tensors

View File

@ -0,0 +1,35 @@
import os
import shutil
from pathlib import Path
def get_my_dir():
return os.path.dirname(os.path.abspath(__file__))
def get_parent_dir(dir_path,depth=1):
parent_path = Path(dir_path)
for _ in range(depth):
parent_path = parent_path.parent
return parent_path
def merge_dir_txt(a,b):
c=os.path.join(a,b)
return c
def merge_dir_txt2(*TXT):
return Path(os.path.join(*TXT))
def create_dir(path: Path, overwrite=False) -> bool:
if overwrite and path.exists():
shutil.rmtree(path)
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
return path.exists()
def get_dir_children_dirs(path: Path):
return [item for item in path.iterdir() if item.is_dir()]
def get_dir_children_files(path: Path):
return [item for item in path.iterdir() if item.is_file()]
def delete_dir(path: Path):
return shutil.rmtree(path)
def delete_file(path: Path):
return os.remove(path)
def file_exists(path: Path):
path = Path(path)
return path.exists()

View File

@ -0,0 +1,10 @@
import json
def load_info(info_path):
with open(info_path, 'r', encoding='utf-8') as f:
info = json.load(f)
return info
def save_info(info, info_path):
with open(info_path, 'w', encoding='utf-8') as f:
json.dump(info, f, ensure_ascii=False, indent=4)

View File

@ -0,0 +1,38 @@
import time
#time styles
STYLE_Y = "%Y"
STYLE_M = "%m"
STYLE_D = "%d"
STYLE_H = "%H"
STYLE_MIN = "%M"
STYLE_S = "%S"
STYLE_FULL = "%Y-%m-%d_%H.%M.%S"
#quick calls
def get_time_y(STYLE = STYLE_Y):
return time.strftime(STYLE, time.localtime())
def get_time_m(STYLE = STYLE_M):
return time.strftime(STYLE, time.localtime())
def get_time_d(STYLE = STYLE_D):
return time.strftime(STYLE, time.localtime())
def get_time_h(STYLE = STYLE_H):
return time.strftime(STYLE, time.localtime())
def get_time_min(STYLE = STYLE_MIN):
return time.strftime(STYLE, time.localtime())
def get_time_s(STYLE = STYLE_S):
return time.strftime(STYLE, time.localtime())
def get_time_full(STYLE = STYLE_FULL):
return time.strftime(STYLE, time.localtime())
def s(t:float):
time.sleep(t)
return
###
if __name__ == '__main__':
print(get_time_y())
print(get_time_m())
print(get_time_d())
print(get_time_h())
print(get_time_min())
print(get_time_s())
print(get_time_full())