import math from typing import Union, List import torch import os from datetime import datetime import imageio import numpy as np import itertools import tempfile import PIL.Image import safetensors.torch import tqdm import logging logger = logging.getLogger(__file__) def load_torch_file(ckpt, device=None, dtype=torch.float16): if device is None: device = torch.device("cpu") if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: if not 'weights_only' in torch.load.__code__.co_varnames: logger.warning( "Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") pl_sd = torch.load(ckpt, map_location=device, weights_only=True) if "global_step" in pl_sd: logger.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] elif 'params_ema' in pl_sd: sd = pl_sd['params_ema'] else: sd = pl_sd # Convert all tensors in the state_dict to the specified dtype sd = {k: v.to(dtype) for k, v in sd.items()} return sd def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): if filter_keys: out = {} else: out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: w = state_dict.pop(x[0]) out[x[1]] = w return out def module_size(module): module_mem = 0 sd = module.state_dict() for k in sd: t = sd[k] module_mem += t.nelement() * t.element_size() return module_mem def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): dims = len(tile) print(f"samples dtype:{samples.dtype}") output = torch.empty( [samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device) for b in range(samples.shape[0]): s = samples[b:b + 1] out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): s_in = s upscaled = [] for d in range(dims): pos = max(0, min(s.shape[d + 2] - overlap, it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) upscaled.append(round(pos * upscale_amount)) ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): for d in range(2, dims + 2): m = mask.narrow(d, t, 1) m *= ((1.0 / feather) * (t + 1)) m = mask.narrow(d, mask.shape[d] - 1 - t, 1) m *= ((1.0 / feather) * (t + 1)) o = out o_d = out_div for d in range(dims): o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o += ps * mask o_d += mask if pbar is not None: pbar.update(1) output[b:b + 1] = out / out_div return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar) def load_sd_upscale(ckpt, inf_device): from spandrel import ModelLoader # Simulate a step in loading pbar = ProgressBar(1, "Loading Upscale Model") sd = load_torch_file(ckpt, device=inf_device) if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: sd = state_dict_prefix_replace(sd, {"module.": ""}) out = ModelLoader().load_from_state_dict(sd).half() pbar.update(1) # Update progress by 1 return out def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu") -> torch.Tensor: memory_required = module_size(upscale_model.model) memory_required += ((512 * 512 * 3) * tensor.element_size() * max(upscale_model.scale, 1.0) * 384.0 ) #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate memory_required += tensor.nelement() * tensor.element_size() print(f"UPScaleMemory required: {memory_required / 1024 / 1024 / 1024} GB") upscale_model.to(inf_device) tile = 512 overlap = 32 steps = tensor.shape[0] * get_tiled_scale_steps(tensor.shape[3], tensor.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = ProgressBar(steps, desc="Tiling and Upscaling") s = tiled_scale(tensor, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.to(output_device) return s def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor: # 初始化一个空列表来存储每个批次的放大结果 upscaled_latents = [] # 遍历第一个维度 (批次) for i in range(latents.size(0)): # 取出第 i 个批次数据 (形状为 [49, 3, 512, 480]) latent = latents[i] # 调用放大模型对该批次数据进行放大 upscaled_latent = upscale(upscale_model, latent, inf_device, output_device) # 将放大的结果存储到列表中 upscaled_latents.append(upscaled_latent) return torch.stack(upscaled_latents) def export_to_video_imageio( video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 ) -> str: """ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX) """ if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name if isinstance(video_frames[0], PIL.Image.Image): video_frames = [np.array(frame) for frame in video_frames] with imageio.get_writer(output_video_path, fps=fps) as writer: for frame in video_frames: writer.append_data(frame) return output_video_path def save_video(tensor: Union[List[np.ndarray], List[PIL.Image.Image]], fps: int = 8): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") video_path = f"./output/{timestamp}.mp4" os.makedirs(os.path.dirname(video_path), exist_ok=True) export_to_video_imageio(tensor, video_path, fps=fps) return video_path class ProgressBar: def __init__(self, total, desc=None): self.total = total self.current = 0 self.b_unit = tqdm.tqdm( total=total, desc="ProgressBar context index: 0" if desc is None else desc ) def update(self, value): if value > self.total: value = self.total self.current = value if self.b_unit is not None: self.b_unit.set_description( "ProgressBar context index: {}".format(self.current) ) self.b_unit.refresh() # 更新进度 self.b_unit.update(self.current)