import math from typing import Union, List import torch import os from datetime import datetime, timedelta import imageio import numpy as np import itertools import tempfile import PIL import safetensors.torch import tqdm import logging logger = logging.getLogger(__file__) def load_torch_file(ckpt, device=None, dtype=torch.float16): if device is None: device = torch.device("cpu") if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: if not 'weights_only' in torch.load.__code__.co_varnames: logger.warning( "Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") pl_sd = torch.load(ckpt, map_location=device, weights_only=True) if "global_step" in pl_sd: logger.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] elif 'params_ema' in pl_sd: sd = pl_sd['params_ema'] else: sd = pl_sd # Convert all tensors in the state_dict to the specified dtype sd = {k: v.to(dtype) for k, v in sd.items()} return sd def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): if filter_keys: out = {} else: out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: w = state_dict.pop(x[0]) out[x[1]] = w return out def module_size(module): module_mem = 0 sd = module.state_dict() for k in sd: t = sd[k] module_mem += t.nelement() * t.element_size() return module_mem def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): dims = len(tile) print(f"samples dtype:{samples.dtype}") output = torch.empty( [samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device) for b in range(samples.shape[0]): s = samples[b:b + 1] out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): s_in = s upscaled = [] for d in range(dims): pos = max(0, min(s.shape[d + 2] - overlap, it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) upscaled.append(round(pos * upscale_amount)) ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): for d in range(2, dims + 2): m = mask.narrow(d, t, 1) m *= ((1.0 / feather) * (t + 1)) m = mask.narrow(d, mask.shape[d] - 1 - t, 1) m *= ((1.0 / feather) * (t + 1)) o = out o_d = out_div for d in range(dims): o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o += ps * mask o_d += mask if pbar is not None: pbar.update(1) output[b:b + 1] = out / out_div return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar) def export_to_video_imageio( video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 ) -> str: """ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX) """ if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name if isinstance(video_frames[0], PIL.Image.Image): video_frames = [np.array(frame) for frame in video_frames] with imageio.get_writer(output_video_path, fps=fps) as writer: for frame in video_frames: writer.append_data(frame) return output_video_path def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") video_path = f"./output/{timestamp}.mp4" os.makedirs(os.path.dirname(video_path), exist_ok=True) export_to_video_imageio(tensor, video_path, fps=fps) return video_path class ProgressBar: def __init__(self, total, desc=None): self.total = total self.current = 0 self.b_unit = tqdm.tqdm( total=total, desc="ProgressBar context index: 0" if desc is None else desc ) def update(self, value): if value > self.total: value = self.total self.current = value if self.b_unit is not None: self.b_unit.set_description( "ProgressBar context index: {}".format(self.current) ) self.b_unit.refresh() # 更新进度 self.b_unit.update(self.current)