mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
# -*- encoding: utf-8 -*-
|
|
'''
|
|
@File : sr_group.py
|
|
@Time : 2022/04/02 01:17:21
|
|
@Author : Ming Ding
|
|
@Contact : dm18@mails.tsinghua.edu.cn
|
|
'''
|
|
|
|
# here put the import lib
|
|
import os
|
|
import sys
|
|
import math
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from SwissArmyTransformer.resources import auto_create
|
|
from .direct_sr import DirectSuperResolution
|
|
from .iterative_sr import IterativeSuperResolution
|
|
|
|
class SRGroup:
|
|
def __init__(self, args, home_path=None,):
|
|
dsr_path = auto_create('cogview2-dsr', path=home_path)
|
|
itersr_path = auto_create('cogview2-itersr', path=home_path)
|
|
dsr = DirectSuperResolution(args, dsr_path)
|
|
itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
|
|
self.dsr = dsr
|
|
self.itersr = itersr
|
|
|
|
def sr_base(self, img_tokens, txt_tokens):
|
|
assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
|
|
batch_size = img_tokens.shape[0]
|
|
txt_len = txt_tokens.shape[-1]
|
|
if len(txt_tokens.shape) == 1:
|
|
txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
|
|
sred_tokens = self.dsr(txt_tokens, img_tokens)
|
|
iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
|
|
return iter_tokens[-batch_size:]
|
|
|
|
# def sr_patch(self, img_tokens, txt_tokens):
|
|
# assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
|
|
# batch_size = img_tokens.shape[0] * 9
|
|
# txt_len = txt_tokens.shape[-1]
|
|
# if len(txt_tokens.shape) == 1:
|
|
# txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
|
|
# img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
|
|
# iter_tokens = self.sr_base(img_tokens, txt_tokens)
|
|
# return iter_tokens |