mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
117 lines
4.3 KiB
Python
117 lines
4.3 KiB
Python
# -*- encoding: utf-8 -*-
|
|
'''
|
|
@File : direct_sr.py
|
|
@Time : 2022/03/02 13:58:11
|
|
@Author : Ming Ding
|
|
@Contact : dm18@mails.tsinghua.edu.cn
|
|
'''
|
|
|
|
# here put the import lib
|
|
import os
|
|
import sys
|
|
import math
|
|
import random
|
|
import torch
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
'''
|
|
@File : inference_cogview2.py
|
|
@Time : 2021/10/10 16:31:34
|
|
@Author : Ming Ding
|
|
@Contact : dm18@mails.tsinghua.edu.cn
|
|
'''
|
|
|
|
# here put the import lib
|
|
import os
|
|
import sys
|
|
import math
|
|
import random
|
|
from PIL import ImageEnhance, Image
|
|
|
|
import torch
|
|
import argparse
|
|
from torchvision import transforms
|
|
|
|
from SwissArmyTransformer import get_args
|
|
from SwissArmyTransformer.training.model_io import load_checkpoint
|
|
from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy
|
|
from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
|
|
|
|
from .dsr_model import DsrModel
|
|
|
|
from icetk import icetk as tokenizer
|
|
|
|
class DirectSuperResolution:
|
|
def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False):
|
|
args.load = path
|
|
args.kernel_size = 5
|
|
args.kernel_size2 = 5
|
|
args.new_sequence_length = 4624
|
|
args.layout = [96,496,4096]
|
|
|
|
model = DsrModel(args)
|
|
if args.fp16:
|
|
model = model.half()
|
|
|
|
load_checkpoint(model, args) # on cpu
|
|
model.eval()
|
|
self.model = model
|
|
self.onCUDA = onCUDA
|
|
if onCUDA:
|
|
self.model = self.model.cuda()
|
|
|
|
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
|
|
|
self.strategy = IterativeEntfilterStrategy(invalid_slices,
|
|
temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!!
|
|
self.max_bz = max_bz
|
|
|
|
def __call__(self, text_tokens, image_tokens, enhance=False):
|
|
if len(text_tokens.shape) == 1:
|
|
text_tokens.unsqueeze_(0)
|
|
if len(image_tokens.shape) == 1:
|
|
image_tokens.unsqueeze_(0)
|
|
# ===================== Debug ======================== #
|
|
# new_image_tokens = []
|
|
# for small_img in image_tokens:
|
|
# decoded = tokenizer.decode(image_ids=small_img)
|
|
# decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0)
|
|
# ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
|
# image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
|
|
# small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1)
|
|
# new_image_tokens.append(small_img2)
|
|
# image_tokens = torch.stack(new_image_tokens)
|
|
# return image_tokens
|
|
# ===================== END OF BLOCK ======================= #
|
|
if enhance:
|
|
new_image_tokens = []
|
|
for small_img in image_tokens:
|
|
decoded = tokenizer.decode(image_ids=small_img).squeeze(0)
|
|
ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
|
image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr))
|
|
small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1)
|
|
new_image_tokens.append(small_img2)
|
|
image_tokens = torch.stack(new_image_tokens)
|
|
|
|
seq = torch.cat((text_tokens,image_tokens), dim=1)
|
|
seq1 = torch.tensor([tokenizer['<start_of_image>']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1)
|
|
if not self.onCUDA:
|
|
print('Converting Dsr model...')
|
|
model = self.model.cuda()
|
|
else:
|
|
model = self.model
|
|
print('Direct super-resolution...')
|
|
output_list = []
|
|
for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)):
|
|
output1 = filling_sequence_dsr(model,
|
|
seq[tim*self.max_bz:(tim+1)*self.max_bz],
|
|
seq1[tim*self.max_bz:(tim+1)*self.max_bz],
|
|
warmup_steps=1, block_hw=(1, 0),
|
|
strategy=self.strategy
|
|
)
|
|
output_list.extend(output1[1:])
|
|
if not self.onCUDA:
|
|
print('Moving back Dsr to cpu...')
|
|
model = model.cpu()
|
|
torch.cuda.empty_cache()
|
|
return torch.cat(output_list, dim=0) |