diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2cce19b --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +output/ +*__pycache__/ +samples*/ +runs/ +checkpoints/ +master_ip +logs/ +*.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 1a47803..d382ea0 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,30 @@ This is the official repo for the paper: [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](http://arxiv.org/abs/2205.15868). +**News!** The [demo](https://wudao.aminer.cn/cogvideo/) for CogVideo is available! + +**News!** The code and model for text-to-video generation is now available! Currently we only supports *simplified Chinese input*. https://user-images.githubusercontent.com/48993524/170857367-2033c514-3c9f-4297-876f-2468592a254b.mp4 +* **Read** our paper [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) on ArXiv for a formal introduction. +* **Try** our demo at [https://wudao.aminer.cn/cogvideo/](https://wudao.aminer.cn/cogvideo/) +* **Run** our pretrained models for text-to-video generation. Please use A100 GPU. +* **Cite** our paper if you find our work helpful + +``` +@article{hong2022cogvideo, + title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers}, + author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie}, + journal={arXiv preprint arXiv:2205.15868}, + year={2022} +} +``` + +## Web Demo + +The demo for CogVideo is at [https://wudao.aminer.cn/cogvideo/](https://wudao.aminer.cn/cogvideo/), where you can get hands-on practice on text-to-video generation. *The original input is in Chinese.* + ## Generated Samples @@ -20,3 +41,40 @@ https://user-images.githubusercontent.com/48993524/170857367-2033c514-3c9f-4297- A 4-second clip of 32 frames is shown below. ![High-frame-rate sample](assets/appendix-sample-highframerate.png) + +## Getting Started + +### Setup + +* Hardware: Linux servers with Nvidia A100s are recommended, but it is also okay to run the pretrained models with smaller `--max-inference-batch-size` and `--batch-size` or training smaller models on less powerful GPUs. +* Environment: install dependencies via `pip install -r requirements.txt`. +* LocalAttention: Make sure you have CUDA installed and compile the local attention kernel. + +```shell +git clone https://github.com/Sleepychord/Image-Local-Attention +cd Image-Local-Attention && python setup.py install +``` + +### Download + +Our code will automatically download or detect the models into the path defined by environment variable `SAT_HOME`. You can also manually download [CogVideo-Stage1](https://lfs.aminer.cn/misc/cogvideo/cogvideo-stage1.zip) and [CogVideo-Stage2](https://lfs.aminer.cn/misc/cogvideo/cogvideo-stage2.zip) and place them under SAT_HOME (with folders named `cogvideo-stage1` and `cogvideo-stage2`) + +### Text-to-Video Generation + +``` +./script/inference_cogvideo_pipeline.sh +``` + +Arguments useful in inference are mainly: + +* `--input-source [path or "interactive"]`. The path of the input file with one query per line. A CLI would be launched when using "interactive". +* `--output-path [path]`. The folder containing the results. +* `--batch-size [int]`. The number of samples will be generated per query. +* `--max-inference-batch-size [int]`. Maximum batch size per forward. Reduce it if OOM. +* `--stage1-max-inference-batch-size [int]` Maximum batch size per forward in Stage 1. Reduce it if OOM. +* `--both-stages`. Run both stage1 and stage2 sequentially. +* `--use-guidance-stage1` Use classifier-free guidance in stage1, which is strongly suggested to get better results. + +You'd better specify an environment variable `SAT_HOME` to specify the path to store the downloaded model. + +*Currently only Chinese input is supported.* diff --git a/cluster_label2.npy b/cluster_label2.npy new file mode 100644 index 0000000..dff3170 Binary files /dev/null and b/cluster_label2.npy differ diff --git a/coglm_strategy.py b/coglm_strategy.py new file mode 100644 index 0000000..d485715 --- /dev/null +++ b/coglm_strategy.py @@ -0,0 +1,101 @@ +# -*- encoding: utf-8 -*- +''' +@File : coglm_strategy.py +@Time : 2021/10/08 22:22:42 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import numpy as np +import torch.nn.functional as F + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + + return logits + + +class CoglmStrategy: + def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.temperature2 = temperature2 + self.topk = top_k + self.top_p = top_p + self.eps = eps + if end_tokens is None: + end_tokens = [] + self.end_tokens = end_tokens + self._is_done = False + self.outlier_count_down = torch.zeros(16) + self.vis_list = [[]for i in range(16)] + self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) + self.start_pos = -1 + self.white_cluster = [] + # self.fout = open('tmp.txt', 'w') + + @property + def is_done(self) -> bool: + return self._is_done + + def forward(self, logits, tokens, mems, temperature=None, temperature2=None): + if temperature is None: + temperature = self.temperature + if temperature2 is None: + temperature2 = self.temperature2 + logits = logits / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -65504 + + rprobs = F.softmax(logits.float(), dim=-1) + c = self.cluster_labels.expand(*rprobs.shape) + cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + # self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n') + # self.fout.flush() + best_scores, best_clusters = cprobs.topk(self.topk) + bz = logits.shape[0] + for i in range(bz): + selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] + logits[i, self.cluster_labels != selected_cluster] = -65504 + + # logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch + pred = torch.multinomial(probs, num_samples=1) + + if pred.numel() == 1 and pred.item() in self.end_tokens: + self._is_done = True + tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) + return tokens, mems + + def finalize(self, tokens, mems): + self._is_done = False + return tokens, mems \ No newline at end of file diff --git a/cogvideo_pipeline.py b/cogvideo_pipeline.py new file mode 100644 index 0000000..0efb161 --- /dev/null +++ b/cogvideo_pipeline.py @@ -0,0 +1,793 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_pipeline.py +@Time : 2022/07/15 11:24:56 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +import os +import sys +import torch +import argparse +import time +from torchvision.utils import save_image +import stat +from icetk import icetk as tokenizer +import logging, sys + +import torch.distributed as dist +tokenizer.add_special_tokens(['', '', '']) + + +from SwissArmyTransformer import get_args +from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders +from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually +from SwissArmyTransformer.resources import auto_create + +from models.cogvideo_cache_model import CogVideoCacheModel +from coglm_strategy import CoglmStrategy + + +def get_masks_and_position_ids_stage1(data, textlen, framelen): + # Extract batch size and sequence length. + tokens = data + seq_length = len(data[0]) + # Attention mask (lower triangular). + attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device) + attention_mask[:, :textlen, textlen:] = 0 + attention_mask[:, textlen:, textlen:].tril_() + attention_mask.unsqueeze_(1) + # Unaligned version + position_ids = torch.zeros(seq_length, dtype=torch.long, + device=data.device) + torch.arange(textlen, out=position_ids[:textlen], + dtype=torch.long, device=data.device) + torch.arange(512, 512+seq_length-textlen, out=position_ids[textlen:], + dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + +def get_masks_and_position_ids_stage2(data, textlen, framelen): + # Extract batch size and sequence length. + tokens = data + seq_length = len(data[0]) + + # Attention mask (lower triangular). + attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device) + attention_mask[:, :textlen, textlen:] = 0 + attention_mask[:, textlen:, textlen:].tril_() + attention_mask.unsqueeze_(1) + + # Unaligned version + position_ids = torch.zeros(seq_length, dtype=torch.long, + device=data.device) + torch.arange(textlen, out=position_ids[:textlen], + dtype=torch.long, device=data.device) + frame_num = (seq_length-textlen)//framelen + assert frame_num == 5 + torch.arange(512, 512+framelen, out=position_ids[textlen:textlen+framelen], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*2, 512+framelen*3, out=position_ids[textlen+framelen:textlen+framelen*2], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*(frame_num-1), 512+framelen*frame_num, out=position_ids[textlen+framelen*2:textlen+framelen*3], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*1, 512+framelen*2, out=position_ids[textlen+framelen*3:textlen+framelen*4], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*3, 512+framelen*4, out=position_ids[textlen+framelen*4:textlen+framelen*5], + dtype=torch.long, device=data.device) + + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + +def my_update_mems(hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len): + if hiddens is None: + return None, mems_indexs + mem_num = len(hiddens) + ret_mem = [] + with torch.no_grad(): + for id in range(mem_num): + if hiddens[id][0] is None: + ret_mem.append(None) + else: + if id == 0 and limited_spatial_channel_mem and mems_indexs[id]+hiddens[0][0].shape[1] >= text_len+frame_len: + if mems_indexs[id] == 0: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][layer, :, :text_len] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, :text_len] + new_mem_len_part2 = (mems_indexs[id]+hiddens[0][0].shape[1]-text_len)%frame_len + if new_mem_len_part2 > 0: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][layer, :, text_len:text_len+new_mem_len_part2] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, -new_mem_len_part2:] + mems_indexs[id] = text_len+new_mem_len_part2 + else: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][layer, :, mems_indexs[id]:mems_indexs[id]+hidden.shape[1]] = hidden.expand(mems_buffers[id].shape[1], -1, -1) + mems_indexs[id] += hidden.shape[1] + ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]]) + return ret_mem, mems_indexs + + +def my_save_multiple_images(imgs, path, subdir, debug=True): + # imgs: list of tensor images + if debug: + imgs = torch.cat(imgs, dim=0) + print("\nSave to: ", path, flush=True) + save_image(imgs, path, normalize=True) + else: + print("\nSave to: ", path, flush=True) + single_frame_path = os.path.join(path, subdir) + os.makedirs(single_frame_path, exist_ok=True) + for i in range(len(imgs)): + save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True) + os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) + save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True) + os.chmod(os.path.join(single_frame_path,f'frame_concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) + +def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len): + # The fisrt token's position id of the frame that the next token belongs to; + if total_len < text_len: + return None + return (total_len-text_len)//frame_len * frame_len + text_len + +def my_filling_sequence( + model, + args, + seq, + batch_size, + get_masks_and_position_ids, + text_len, + frame_len, + strategy=BaseStrategy(), + strategy2=BaseStrategy(), + mems=None, + log_text_attention_weights=0, # default to 0: no artificial change + mode_stage1=True, + enforce_no_swin=False, + guider_seq=None, + guider_text_len=0, + guidance_alpha=1, + limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内 + **kw_args + ): + ''' + seq: [2, 3, 5, ..., -1(to be generated), -1, ...] + mems: [num_layers, batch_size, len_mems(index), mem_hidden_size] + cache, should be first mems.shape[1] parts of context_tokens. + mems are the first-level citizens here, but we don't assume what is memorized. + input mems are used when multi-phase generation. + ''' + if guider_seq is not None: + logging.debug("Using Guidance In Inference") + if limited_spatial_channel_mem: + logging.debug("Limit spatial-channel's mem to current frame") + assert len(seq.shape) == 2 + + # building the initial tokens, attention_mask, and position_ids + actual_context_length = 0 + + while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens + actual_context_length += 1 # [0, context_length-1] are given + assert actual_context_length > 0 + current_frame_num = (actual_context_length-text_len) // frame_len + assert current_frame_num >= 0 + context_length = text_len + current_frame_num * frame_len + + tokens, attention_mask, position_ids = get_masks_and_position_ids(seq, text_len, frame_len) + tokens = tokens[..., :context_length] + input_tokens = tokens.clone() + + if guider_seq is not None: + guider_index_delta = text_len - guider_text_len + guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(guider_seq, guider_text_len, frame_len) + guider_tokens = guider_tokens[..., :context_length-guider_index_delta] + guider_input_tokens = guider_tokens.clone() + + for fid in range(current_frame_num): + input_tokens[:, text_len+400*fid] = tokenizer[''] + if guider_seq is not None: + guider_input_tokens[:, guider_text_len+400*fid] = tokenizer[''] + + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 + # initialize generation + counter = context_length - 1 # Last fixed index is ``counter'' + index = 0 # Next forward starting index, also the length of cache. + mems_buffers_on_GPU = False + mems_indexs = [0, 0] + mems_len = [(400+74) if limited_spatial_channel_mem else 5*400+74, 5*400+74] + mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype) + for mem_len in mems_len] + + + if guider_seq is not None: + guider_attention_mask = guider_attention_mask.type_as(next(model.parameters())) # if fp16 + guider_mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype) + for mem_len in mems_len] + guider_mems_indexs = [0, 0] + guider_mems = None + + torch.cuda.empty_cache() + # step-by-step generation + while counter < len(seq[0]) - 1: + # we have generated counter+1 tokens + # Now, we want to generate seq[counter + 1], + # token[:, index: counter+1] needs forwarding. + if index == 0: + group_size = 2 if (input_tokens.shape[0] == batch_size and not mode_stage1) else batch_size + + logits_all = None + for batch_idx in range(0, input_tokens.shape[0], group_size): + logits, *output_per_layers = model( + input_tokens[batch_idx:batch_idx+group_size, index:], + position_ids[..., index: counter+1], + attention_mask, # TODO memlen + mems=mems, + text_len=text_len, + frame_len=frame_len, + counter=counter, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + **kw_args + ) + logits_all = torch.cat((logits_all, logits), dim=0) if logits_all is not None else logits + mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]] + next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(text_len, frame_len, mem_kv01[0][0].shape[1]) + for id, mem_kv in enumerate(mem_kv01): + for layer, mem_kv_perlayer in enumerate(mem_kv): + if limited_spatial_channel_mem and id == 0: + mems_buffers[id][layer, batch_idx:batch_idx+group_size, :text_len] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :text_len] + mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\ + mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:] + else: + mems_buffers[id][layer, batch_idx:batch_idx+group_size, :mem_kv_perlayer.shape[1]] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1) + mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[1], mem_kv01[1][0].shape[1] + if limited_spatial_channel_mem: + mems_indexs[0] -= (next_tokens_frame_begin_id - text_len) + + mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)] + logits = logits_all + + # Guider + if guider_seq is not None: + guider_logits_all = None + for batch_idx in range(0, guider_input_tokens.shape[0], group_size): + guider_logits, *guider_output_per_layers = model( + guider_input_tokens[batch_idx:batch_idx+group_size, max(index-guider_index_delta, 0):], + guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta], + guider_attention_mask, + mems=guider_mems, + text_len=guider_text_len, + frame_len=frame_len, + counter=counter-guider_index_delta, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + **kw_args + ) + guider_logits_all = torch.cat((guider_logits_all, guider_logits), dim=0) if guider_logits_all is not None else guider_logits + guider_mem_kv01 = [[o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]] + for id, guider_mem_kv in enumerate(guider_mem_kv01): + for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv): + if limited_spatial_channel_mem and id == 0: + guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_text_len] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :guider_text_len] + guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(guider_text_len, frame_len, guider_mem_kv_perlayer.shape[1]) + guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\ + guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:] + else: + guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_mem_kv_perlayer.shape[1]] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1) + guider_mems_indexs[0], guider_mems_indexs[1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[1][0].shape[1] + if limited_spatial_channel_mem: + guider_mems_indexs[0] -= (guider_next_tokens_frame_begin_id-guider_text_len) + guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)] + guider_logits = guider_logits_all + else: + if not mems_buffers_on_GPU: + if not mode_stage1: + torch.cuda.empty_cache() + for idx, mem in enumerate(mems): + mems[idx] = mem.to(next(model.parameters()).device) + if guider_seq is not None: + for idx, mem in enumerate(guider_mems): + guider_mems[idx] = mem.to(next(model.parameters()).device) + else: + torch.cuda.empty_cache() + for idx, mem_buffer in enumerate(mems_buffers): + mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device) + mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)] + if guider_seq is not None: + for idx, guider_mem_buffer in enumerate(guider_mems_buffers): + guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device) + guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)] + mems_buffers_on_GPU = True + + logits, *output_per_layers = model( + input_tokens[:, index:], + position_ids[..., index: counter+1], + attention_mask, # TODO memlen + mems=mems, + text_len=text_len, + frame_len=frame_len, + counter=counter, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + limited_spatial_channel_mem=limited_spatial_channel_mem, + **kw_args + ) + mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers] + + if guider_seq is not None: + guider_logits, *guider_output_per_layers = model( + guider_input_tokens[:, max(index-guider_index_delta, 0):], + guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta], + guider_attention_mask, + mems=guider_mems, + text_len=guider_text_len, + frame_len=frame_len, + counter=counter-guider_index_delta, + log_text_attention_weights=0, + enforce_no_swin=enforce_no_swin, + limited_spatial_channel_mem=limited_spatial_channel_mem, + **kw_args + ) + guider_mem_kv0, guider_mem_kv1 = [o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers] + + if not mems_buffers_on_GPU: + torch.cuda.empty_cache() + for idx, mem_buffer in enumerate(mems_buffers): + mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device) + if guider_seq is not None: + for idx, guider_mem_buffer in enumerate(guider_mems_buffers): + guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device) + mems_buffers_on_GPU = True + + mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1], mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len) + if guider_seq is not None: + guider_mems, guider_mems_indexs = my_update_mems([guider_mem_kv0, guider_mem_kv1], guider_mems_buffers, guider_mems_indexs, limited_spatial_channel_mem, guider_text_len, frame_len) + + + counter += 1 + index = counter + + logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size] + tokens = tokens.expand(batch_size, -1) + if guider_seq is not None: + guider_logits = guider_logits[:, -1].expand(batch_size, -1) + guider_tokens = guider_tokens.expand(batch_size, -1) + + if seq[-1][counter].item() < 0: + # sampling + guided_logits = guider_logits+(logits-guider_logits)*guidance_alpha if guider_seq is not None else logits + if mode_stage1 and counter < text_len + 400: + tokens, mems = strategy.forward(guided_logits, tokens, mems) + else: + tokens, mems = strategy2.forward(guided_logits, tokens, mems) + if guider_seq is not None: + guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1) + + if seq[0][counter].item() >= 0: + for si in range(seq.shape[0]): + if seq[si][counter].item() >= 0: + tokens[si, -1] = seq[si, counter] + if guider_seq is not None: + guider_tokens[si, -1] = guider_seq[si, counter-guider_index_delta] + + else: + tokens = torch.cat((tokens, seq[:, counter:counter+1].clone().expand(tokens.shape[0], 1).to(device=tokens.device, dtype=tokens.dtype)), dim=1) + if guider_seq is not None: + guider_tokens = torch.cat((guider_tokens, + guider_seq[:, counter-guider_index_delta:counter+1-guider_index_delta] + .clone().expand(guider_tokens.shape[0], 1).to(device=guider_tokens.device, dtype=guider_tokens.dtype)), dim=1) + + input_tokens = tokens.clone() + if guider_seq is not None: + guider_input_tokens = guider_tokens.clone() + if (index-text_len-1)//400 < (input_tokens.shape[-1]-text_len-1)//400: + boi_idx = ((index-text_len-1)//400 +1)*400+text_len + while boi_idx < input_tokens.shape[-1]: + input_tokens[:, boi_idx] = tokenizer[''] + if guider_seq is not None: + guider_input_tokens[:, boi_idx-guider_index_delta] = tokenizer[''] + boi_idx += 400 + + if strategy.is_done: + break + return strategy.finalize(tokens, mems) + +class InferenceModel_Sequential(CogVideoCacheModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=-1, cogvideo_stage=1) + # TODO: check it + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) + return logits_parallel + +class InferenceModel_Interpolate(CogVideoCacheModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=10, cogvideo_stage=2) + # TODO: check it + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) + return logits_parallel + +def main(args): + assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1 + rank_id = args.device % args.parallel_size + generate_frame_num = args.generate_frame_num + + if args.stage_1 or args.both_stages: + model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1') + model_stage1.eval() + if args.both_stages: + model_stage1 = model_stage1.cpu() + + if args.stage_2 or args.both_stages: + model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2') + model_stage2.eval() + if args.both_stages: + model_stage2 = model_stage2.cpu() + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + strategy_cogview2 = CoglmStrategy(invalid_slices, + temperature=1.0, top_k=16) + strategy_cogvideo = CoglmStrategy(invalid_slices, + temperature=args.temperature, top_k=args.top_k, + temperature2=args.coglm_temperature2) + if not args.stage_1: + from sr_pipeline import DirectSuperResolution + dsr_path = auto_create('cogview2-dsr', path=None) # path=os.getenv('SAT_HOME', '~/.sat_models') + dsr = DirectSuperResolution(args, dsr_path, + max_bz=12, onCUDA=False) + + def process_stage2(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", parent_given_tokens=None, conddir=None, outputdir=None, gpu_rank=0, gpu_parallel_size=1): + stage2_starttime = time.time() + use_guidance = args.use_guidance_stage2 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage-2 model to cuda") + model = model.cuda() + logging.debug("moving in stage-2 model takes time: {:.2f}".format(time.time()-move_start_time)) + + try: + if parent_given_tokens is None: + assert conddir is not None + parent_given_tokens = torch.load(os.path.join(conddir, 'frame_tokens.pt'), map_location='cpu') + sample_num_allgpu = parent_given_tokens.shape[0] + sample_num = sample_num_allgpu // gpu_parallel_size + assert sample_num * gpu_parallel_size == sample_num_allgpu + parent_given_tokens = parent_given_tokens[gpu_rank*sample_num:(gpu_rank+1)*sample_num] + except: + logging.critical("No frame_tokens found in interpolation, skip") + return False + + # CogVideo Stage2 Generation + while duration >= 0.5: # TODO: You can change the boundary to change the frame rate + parent_given_tokens_num = parent_given_tokens.shape[1] + generate_batchsize_persample = (parent_given_tokens_num-1)//2 + generate_batchsize_total = generate_batchsize_persample * sample_num + total_frames = generate_frame_num + frame_len = 400 + enc_text = tokenizer.encode(seq_text) + enc_duration = tokenizer.encode(str(float(duration))+"秒") + seq = enc_duration + [tokenizer['']] + enc_text + [tokenizer['']] + [-1]*400*generate_frame_num + text_len = len(seq) - frame_len*generate_frame_num - 1 + + logging.info("[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(int(4/duration), tokenizer.decode(enc_text))) + + # generation + seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i] + seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1] + seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2] + + if use_guidance: + guider_seq = enc_duration + [tokenizer['']] + tokenizer.encode(video_guidance_text) + [tokenizer['']] + [-1]*400*generate_frame_num + guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1 + guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + guider_seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i] + guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1] + guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2] + video_log_text_attention_weights = 0 + else: + guider_seq=None + guider_text_len=0 + video_log_text_attention_weights = 1.4 + + mbz = args.max_inference_batch_size + + assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0 + output_list = [] + start_time = time.time() + for tim in range(max(generate_batchsize_total // mbz, 1)): + input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone() + guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None + output_list.append( + my_filling_sequence(model, args, input_seq, + batch_size=min(generate_batchsize_total, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage2, + text_len=text_len, frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + mode_stage1=False, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + )[0] + ) + logging.info("Duration {:.2f}, Taken time {:.2f}\n".format(duration, time.time() - start_time)) + + output_tokens = torch.cat(output_list, dim=0) + output_tokens = output_tokens[:, text_len+1:text_len+1+(total_frames)*400].reshape(sample_num, -1, 400*total_frames) + output_tokens_merge = torch.cat((output_tokens[:, :, :1*400], + output_tokens[:, :, 400*3:4*400], + output_tokens[:, :, 400*1:2*400], + output_tokens[:, :, 400*4:(total_frames)*400]), dim=2).reshape(sample_num, -1, 400) + + output_tokens_merge = torch.cat((output_tokens_merge, output_tokens[:, -1:, 400*2:3*400]), dim=1) + duration /= 2 + parent_given_tokens = output_tokens_merge + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 2 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug("moving out model2 takes time: {:.2f}".format(time.time()-move_start_time)) + + logging.info("CogVideo Stage2 completed. Taken time {:.2f}\n".format(time.time() - stage2_starttime)) + + # decoding + # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge] + # os.makedirs(output_dir_full_path, exist_ok=True) + # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False) + # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt')) + # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2") + + # direct super-resolution by CogView2 + logging.info("[Direct super-resolution]") + dsr_starttime = time.time() + enc_text = tokenizer.encode(seq_text) + frame_num_per_sample = parent_given_tokens.shape[1] + parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400) + text_seq = torch.cuda.LongTensor(enc_text, device=args.device).unsqueeze(0).repeat(parent_given_tokens_2d.shape[0], 1) + sred_tokens = dsr(text_seq, parent_given_tokens_2d) + decoded_sr_videos = [] + + for sample_i in range(sample_num): + decoded_sr_imgs = [] + for frame_i in range(frame_num_per_sample): + decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:]) + decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480))) + decoded_sr_videos.append(decoded_sr_imgs) + + for sample_i in range(sample_num): + my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) + os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125") + + logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime)) + + return True + + + def process_stage1(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1): + process_start_time = time.time() + use_guide = args.use_guidance_stage1 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cuda") + model = model.cuda() + logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time)) + + if video_raw_text is None: + video_raw_text = seq_text + mbz = args.stage1_max_inference_batch_size if args.stage1_max_inference_batch_size > 0 else args.max_inference_batch_size + assert batch_size < mbz or batch_size % mbz == 0 + frame_len = 400 + + # generate the first frame: + enc_text = tokenizer.encode(seq_text+image_text_suffix) + seq_1st = enc_text + [tokenizer['']] + [-1]*400 # IV!! # test local!!! # test randboi!!! + logging.info("[Generating First Frame with CogView2]Raw text: {:s}".format(tokenizer.decode(enc_text))) + text_len_1st = len(seq_1st) - frame_len*1 - 1 + + seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0) + output_list_1st = [] + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + output_list_1st.append( + my_filling_sequence(model, args,seq_1st.clone(), + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len_1st, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=1.4, + enforce_no_swin=True, + mode_stage1=True, + )[0] + ) + logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)) + output_tokens_1st = torch.cat(output_list_1st, dim=0) + given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400] + + # generate subsequent frames: + total_frames = generate_frame_num + enc_duration = tokenizer.encode(str(float(duration))+"秒") + if use_guide: + video_raw_text = video_raw_text + " 视频" + enc_text_video = tokenizer.encode(video_raw_text) + seq = enc_duration + [tokenizer['']] + enc_text_video + [tokenizer['']] + [-1]*400*generate_frame_num + guider_seq = enc_duration + [tokenizer['']] + tokenizer.encode(video_guidance_text) + [tokenizer['']] + [-1]*400*generate_frame_num + logging.info("[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(4/duration, tokenizer.decode(enc_text_video))) + + text_len = len(seq) - frame_len*generate_frame_num - 1 + guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1 + seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(batch_size, 1) + guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(batch_size, 1) + + for given_frame_id in range(given_tokens.shape[1]): + seq[:, text_len+1+given_frame_id*400: text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id] + guider_seq[:, guider_text_len+1+given_frame_id*400:guider_text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id] + output_list = [] + + if use_guide: + video_log_text_attention_weights = 0 + else: + guider_seq = None + video_log_text_attention_weights = 1.4 + + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone() + guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None + output_list.append( + my_filling_sequence(model, args,input_seq, + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len, frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + mode_stage1=True, + )[0] + ) + + output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:] + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time)) + + # decoding + imgs, sred_imgs, txts = [], [], [] + for seq in output_tokens: + decoded_imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()[i*400: (i+1)*400]), size=(480, 480)) for i in range(total_frames)] + imgs.append(decoded_imgs) # only the last image (target) + + assert len(imgs) == batch_size + save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu() + if outputdir is not None: + for clip_i in range(len(imgs)): + # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True) + my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False) + os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25") + torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt')) + + logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time)) + + return save_tokens + + # ====================================================================================================== + + if args.stage_1 or args.both_stages: + if args.input_source != "interactive": + with open(args.input_source, 'r') as fin: + promptlist = fin.readlines() + promptlist = [p.strip() for p in promptlist] + else: + promptlist = None + + now_qi = -1 + while True: + now_qi += 1 + + if promptlist is not None: # with input-source + if args.multi_gpu: + if now_qi % dist.get_world_size() != dist.get_rank(): + continue + rk = dist.get_rank() + else: + rk = 0 + raw_text = promptlist[now_qi] + raw_text = raw_text.strip() + print(f'Working on Line No. {now_qi} on {rk}... [{raw_text}]') + else: # interactive + raw_text = input("\nPlease Input Query (stop to exit) >>> ") + raw_text = raw_text.strip() + if not raw_text: + print('Query should not be empty!') + continue + if raw_text == "stop": + return + + try: + path = os.path.join(args.output_path, f"{now_qi}_{raw_text}") + parent_given_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频", + image_text_suffix=" 高清摄影", + outputdir=path if args.stage_1 else None, batch_size=args.batch_size) + if args.both_stages: + process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频", + video_guidance_text="视频", parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, gpu_parallel_size=1) # TODO: 修改 + except (ValueError, FileNotFoundError) as e: + print(e) + continue + + elif args.stage_2: + sample_dirs = os.listdir(args.output_path) + for sample in sample_dirs: + raw_text = sample.split('_')[-1] + path = os.path.join(args.output_path, sample, 'Interp') + parent_given_tokens = torch.load(os.path.join(args.output_path, sample, "frame_tokens.pt")) + + process_stage2(raw_text, duration=2.0, video_raw_text=raw_text+" 视频", + video_guidance_text="视频", parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, gpu_parallel_size=1) # TODO: 修改 + + else: + assert False + + +if __name__ == "__main__": + logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) + + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--generate-frame-num', type=int, default=5) + py_parser.add_argument('--coglm-temperature2', type=float, default=0.89) + # py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧 + # py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间 + py_parser.add_argument('--use-guidance-stage1', action='store_true') + py_parser.add_argument('--use-guidance-stage2', action='store_true') + py_parser.add_argument('--guidance-alpha', type=float, default=3.0) + py_parser.add_argument('--stage-1', action='store_true') # stage 1: sequential generation + py_parser.add_argument('--stage-2', action='store_true') # stage 2: interp + dsr + py_parser.add_argument('--both-stages', action='store_true') # stage 1&2: sequential generation; interp + dsr + py_parser.add_argument('--parallel-size', type=int, default=1) + py_parser.add_argument('--stage1-max-inference-batch-size', type=int, default=-1) # -1: use max-inference-batch-size + py_parser.add_argument('--multi-gpu', action='store_true') + + CogVideoCacheModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + args.layout = [int(x) for x in args.layout.split(',')] + args.do_train = False + + torch.cuda.set_device(args.device) + + with torch.no_grad(): + main(args) \ No newline at end of file diff --git a/models/cogvideo_cache_model.py b/models/cogvideo_cache_model.py new file mode 100644 index 0000000..ca39184 --- /dev/null +++ b/models/cogvideo_cache_model.py @@ -0,0 +1,695 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_cache_model.py +@Time : 2022/07/15 11:22:19 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +from multiprocessing import context +from tkinter import E +import torch +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim +from SwissArmyTransformer.model.transformer import unscaled_init_method +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +import torch.nn.functional as F +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +import math + + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 912), + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + + +def window_partition(x, window_size): + """ + Args: + x: (B, framenum, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, frame_num, window_size, window_size, C) + """ + B, framenum, H, W, C = x.shape + x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, frame_num, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, frame_num, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + framenum = windows.shape[1] + x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1) + return x + +class WindowAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + window_size, + shift_size, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + time_dim_attend_length=0 + ): + super(WindowAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense") + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.window_size = window_size + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + self.time_dim_attend_length = time_dim_attend_length + assert frame_resolution % window_size == 0 + assert 0 < shift_size < window_size + nW = (self.frame_resolution // self.window_size) ** 2 + ws_squre = self.window_size * self.window_size + + # odd non-shift, even shift + img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1)) + h_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, :, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size] + sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00)) + attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num) + attn_mask = attn_mask.tril() + + causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num) + causal_mask = causal_mask.tril() + + self.shift_sizes = [0, shift_size] + self.attn_mask = attn_mask + self.causal_mask = causal_mask + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + if not self.mask_initialized: + self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + if stage == 2: + assert frame_num == 3 + assert frame_num*frame_len == s1 + wind_square = self.window_size * self.window_size + nW = frame_len // wind_square + bswin = b0 * nW + + if memkv_text is not None: + s0 = memkv_text.shape[-2] + k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + + # shift + frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0) + if self.shift_sizes[layer_id%2] > 0: + frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3)) + # window partition + frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h] + q, k, v = qkv[0], qkv[1], qkv[2] + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + + if stage == 1: + if self.shift_sizes[layer_id%2] > 0: + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), + self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\ + - 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0)) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + else: + attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\ + - 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0)) + + if memkv_text is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + else: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2)) + attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0) + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_swin = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\ + .reshape(bswin, self.n_head, frame_num*wind_square, h))\ + .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + + context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution) + + # reverse cycle shift + if self.shift_sizes[layer_id%2] > 0: + context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + ret_context = context_swin.reshape(b0, s1, h0) + + # for mem + memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution) + memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution) + if self.shift_sizes[layer_id%2] > 0: + memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0) + + ret_mem = torch.cat((memk, memv), dim=-1) + return ret_context, ret_mem + + def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead] + # memkv [batchsize, pos, hidden_size*2] (include frames only) + # if memkv_text is not None: will attend to text + # pos: token's pos + b0, sin, h0 = frame_hidden_state.shape + h = h0 // self.n_head + assert sin == 1 + this_qkv = self.query_key_value[layer_id](frame_hidden_state) + thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:] + s1 = memkv.shape[1] if memkv is not None else 0 + frame_len = self.frame_resolution * self.frame_resolution + frame_num_before = s1 // frame_len + + + if memkv is not None: + pos_inframe = pos - frame_num_before * frame_len + + xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos + ypos = pos_inframe % self.frame_resolution + # [start, end) + if self.shift_sizes[layer_id%2] > 0: + xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2] + ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2] + xend = xstart + self.window_size + yend = ystart + self.window_size + xstart, ystart = max(0, xstart), max(0, ystart) + xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution) + else: + xstart = (xpos // self.window_size) * self.window_size + ystart = (ypos // self.window_size) * self.window_size + xend, yend = xstart + self.window_size, ystart+self.window_size + + # select index + selected_index = list() + if frame_num_before > 0: + # frames before + frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0 + for x in range(xstart, xend): + for y in range(ystart, yend): + selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start) + cnt_per_frame = len(selected_index) + for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame): + selected_index.append(selected_index[-cnt_per_frame]+frame_len) + + # the last frame + for x in range(xstart, xend): + for y in range(ystart, yend): + tmppos = x*self.frame_resolution+y + frame_num_before * frame_len + if tmppos < pos: + selected_index.append(tmppos) + else: + break + cnt_all = len(selected_index)+1 + selected_index = torch.tensor(selected_index, device=memkv.device) + used_memkv = torch.index_select(memkv, 1, selected_index) + used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:] + used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2) + used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2) + if memkv_text is not None: + cnt_all += memkv_text.shape[-2] + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3) + else: + used_k = thisk + used_v = thisv + + if memkv_text is not None: + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3) + else: + used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) + + thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h] + attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2)) + if memkv_text is not None: + attn[..., :memkv_text.shape[-2]] += log_text_attention_weights + attn = F.softmax(attn, dim=-1) + context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0) + + return context_swin, this_qkv[..., h0:] + +class FullAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + **kwargs, + ): + super(FullAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense") + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + + def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + assert stage == 1 + + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + + if memkv_text is not None: + s0 = memkv_text.shape[-2] + k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h] + q, k, v = qkv[0], qkv[1], qkv[2] + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril()) + + if memkv_text is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0) + else: + attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0] + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\ + .permute(0, 2, 1, 3).reshape(b0, s1, h0) + + # for mem + memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0) + memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0) + ret_mem = torch.cat((memk, memv), dim=-1) + + return context_swin, ret_mem + + def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1): + # pos: current token's pos + b0, sin, h0 = frame_hidden_state.shape + h = h0 // self.n_head + assert sin == 1 + assert stage == 1 + + this_qkv = self.query_key_value[layer_id](frame_hidden_state) + thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:] + + if memkv is not None: + used_k, used_v = memkv[..., :h0], memkv[..., h0:] + used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2) + used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2) + else: + used_k, used_v = thisk, thisv + + if memkv_text is not None: + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + + used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3) + thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h] + attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2)) + if memkv_text is not None: + attn[..., :memkv_text.shape[-2]] += log_text_attention_weights + attn = F.softmax(attn, dim=-1) + + context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0) + + return context_swin, this_qkv[..., h0:] + + +def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask, + n_head, text_len, frame_len, frame_num, + attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs): + b, s0, h0 = q0.shape + s1 = s0 - text_len + h = h0 // n_head + assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num + # attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len] + if stage == 2: + assert frame_num == 3 + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.transpose(-1, -2) + + score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_any2text += log_text_attention_weights + score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \ + - 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len]) + # context for text + attention_probs_text = F.softmax(score_any2text_part1, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_text = attention_dropout(attention_probs_text) + context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :]) + context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0) + + if frame_num > 0: + score_any2text_part2 = score_any2text[..., text_len:, :] + + # score: frame local + q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2) + score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame) + if stage == 1: + score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \ + - 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1)) + + # context for frame + score_frame_all = torch.cat((score_any2text_part2, + score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_frame = attention_dropout(attention_probs_frame) + context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\ + view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h) + + context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0) + else: + context_frame = None + + return context_text2text, context_frame + +def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num, + attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs): + # limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame} + b, s0, h0 = k0.shape + frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1 + h = h0 // n_head + assert q0.shape[1] == 1 + assert v0.shape[1] == k0.shape[1] + + q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + if limited_spatial_channel_mem: + assert frame_num_before == 0 + assert stage == 1 # not implemented for stage-2 yet + score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + score[..., :text_len] += log_text_attention_weights + attention_probs_frame = F.softmax(score, dim=-1) + context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0) + + else: + score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_token2text += log_text_attention_weights + score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:]) + score_frame_all = torch.cat((score_token2text, + score_frame_local0), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + + context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \ + v0[:, :, text_len+frame_num_before*frame_len:, :]) + context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0) + + return context_frame + + +class CogVideoCacheModel(BaseModel): + def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) + self.layout = args.layout # [64, 64+1024, 64+6*1024] + self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2 + self.n_head = args.num_attention_heads + self.window_size = window_size if window_size is not None else args.window_size + + frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0])) + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + args.additional_seqlen, args.hidden_size + )) + + if self.stage == 1: + self.add_mixin('attention_plus', FullAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + n_head=args.num_attention_heads, + frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]), + )) + else: + self.add_mixin('attention_plus', WindowAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + window_size=self.window_size, + shift_size=self.window_size//2, + n_head=args.num_attention_heads, + frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]), + )) + + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations') + group.add_argument("--layout", type=str, default='64, 464, 2064') + group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后 + group.add_argument("--additional-seqlen", type=int, default=2000) + group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后 + return parser + + def disable_untrainable_params(self): + pass + + def position_embedding_forward(self, position_ids, **kw_args): + if position_ids.shape[-1] > 1: + if self.stage == 1: + if position_ids[0,-1] >= (512+400): + frame_num = position_ids.shape[-1] // 400 + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]), + self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400)) + ), + dim=-2 + ) + else: + position_embeddings = self.transformer.position_embeddings(position_ids) + else: + # given 3, interpolate 2 + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position_ids[..., :-800]), + self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400)) + ), + dim=-2 + ) + else: + if position_ids[0, 0] >= (512+400): + position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400)) + else: + position_embeddings = self.transformer.position_embeddings(position_ids) + return position_embeddings + + def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + hidden_size = hidden_states.shape[-1] + + # base model qkv + if mems is None: + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + assert (q0.shape[1]-text_len) % frame_len == 0 + memkv0 = torch.cat((k0, v0), dim=-1) + context_text, context_frame_local_text = attention_localframe_and_text_NAR( + q0, k0, v0, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=text_len, + frame_len=frame_len, + frame_num=(q0.shape[1]-text_len)//frame_len, + log_text_attention_weights=log_text_attention_weights, + stage=self.stage + ) + + # change: self.swin_attend_to_text默认为True: + memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:] + output_text = attn_module.dense(context_text) + + if (q0.shape[1]-text_len)//frame_len > 0: + assert (q0.shape[1]-text_len) % frame_len == 0 + context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference( + hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage) + if not enforce_no_swin: + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + else: + output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :]) + output = torch.cat((output_text, output_frame), dim=-2) + memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame + else: + output = output_text + memkv1 = memkv1_text + kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1) + + + else: + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + new_memkv0 = torch.cat((k0, v0), dim=-1) + old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:] + + context_frame_local_text = attention_localframe_and_text_AR( + q0, + torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2), + torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2), + n_head=attn_module.num_attention_heads_per_partition, + text_len=text_len, + frame_len=frame_len, + frame_num=None, + log_text_attention_weights=log_text_attention_weights, + layer_id=layer_id, + limited_spatial_channel_mem=limited_spatial_channel_mem, + ) + + old_memkv1 = mems[1][layer_id] if mems[1] is not None else None + + context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states, + old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None, + counter-text_len, + layer_id, + memkv_text=old_memkv1[..., :text_len, :], + log_text_attention_weights=log_text_attention_weights) + if not enforce_no_swin: + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + else: + output = attn_module.dense(context_frame_local_text) + + kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1) + + return output \ No newline at end of file diff --git a/models/cogvideo_model.py b/models/cogvideo_model.py new file mode 100644 index 0000000..dfbc136 --- /dev/null +++ b/models/cogvideo_model.py @@ -0,0 +1,543 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_model.py +@Time : 2022/07/11 16:12:05 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +import torch +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim +from SwissArmyTransformer.model.transformer import unscaled_init_method +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +import torch.nn.functional as F +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +import math + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 912), + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + +def window_partition(x, window_size): + """ + Args: + x: (B, framenum, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, frame_num, window_size, window_size, C) + """ + B, framenum, H, W, C = x.shape + x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, frame_num, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, frame_num, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + framenum = windows.shape[1] + x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1) + return x + +class WindowAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + window_size, + shift_size, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + ): + super(WindowAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense", + ) + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.window_size = window_size + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + assert frame_resolution % window_size == 0 + assert 0 < shift_size < window_size + nW = (self.frame_resolution // self.window_size) ** 2 + ws_squre = self.window_size * self.window_size + + # odd non-shift, even shift + img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1)) + h_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, :, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size] + sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00)) + attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num) + + self.attn_mask_sequential = attn_mask.clone().tril() + self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril() + + self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num) + self.attn_mask_interp = attn_mask.clone() + + # bi-dir + for bi_idx in range(0, frame_num, 2): + for uni_idx in range(1, frame_num, 2): + self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0 + self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0 + # uni-dir + for uni_idx in range(1, frame_num, 2): + self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_() + self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_() + for uni_idx2 in range(uni_idx+2, frame_num, 2): + self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0 + self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0 + + # expand dim + self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None] + self.attn_mask_interp = self.attn_mask_interp[None, None, :, None] + self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None] + self.causal_mask_interp = self.causal_mask_interp[None, None, :, None] + + self.shift_sizes = [0, shift_size] + # self.register_buffer("attn_mask", attn_mask) + # self.register_buffer("causal_mask", causal_mask) + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None, + text_attn_mask=None, mode_sequential=True): + # pb relax + swin_pb_relax = True + alpha = 16 + + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + if not self.mask_initialized: + self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + wind_square = self.window_size * self.window_size + nW = frame_len // wind_square + bswin = b0 * nW + + causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp + attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp + if text_hidden_state is not None: + s0 = text_hidden_state.shape[1] + qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h] + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2] + + # shift + frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0) + if self.shift_sizes[layer_id%2] > 0: + frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3)) + # window partition + frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h] + q, k, v = qkv[0], qkv[1], qkv[2] + + # pb-relax + if swin_pb_relax: + attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2)) + else: + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + + if self.shift_sizes[layer_id%2] > 0: + # attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0) + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\ + - 10000.0 * (1.0 - attn_mask) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + else: + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\ + - 10000.0 * (1.0 - causal_mask) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + if swin_pb_relax: + swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1) + attn = (attn - swin_pb_relax_const)*alpha + + if text_hidden_state is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + else: + assert text_attn_mask is not None + text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2) + # pb-relax + if swin_pb_relax: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2)) + attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha + else: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2)) + + attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask) + attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0) + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_swin = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\ + .reshape(bswin, self.n_head, frame_num*wind_square, h))\ + .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + + context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution) + # reverse cycle shift + if self.shift_sizes[layer_id%2] > 0: + context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + context_swin = context_swin.reshape(b0, s1, h0) + + return context_swin + + +class FullAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + ): + super(FullAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense",) + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril() + + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + base_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data) + + def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None, + text_attn_mask=None, mode_sequential=False): + # pb relax + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + assert mode_sequential == True # only + swin_pb_relax = True + alpha = 16 + + if not self.mask_initialized: + self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h] + q, k, v = qkv[0], qkv[1], qkv[2] + + # frames-to-frames + if swin_pb_relax: + attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2)) + else: + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask) + if swin_pb_relax: + swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1) + attn = (attn - swin_pb_relax_const)*alpha + + if text_hidden_state is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0) + else: + # frame-to-text + assert text_attn_mask is not None + s0 = text_hidden_state.shape[1] + qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h] + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2] + text_attn_mask = text_attn_mask.unsqueeze(2) + if swin_pb_relax: + attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2)) + attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha + else: + attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2)) + attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask) + attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0) + + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_frame = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\ + .permute(0, 2, 1, 3).reshape(b0, s1, h0) + + return context_frame + + +def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local, + n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs): + b, s0, h0 = q0.shape + s1 = s0 - text_len + h = h0 // n_head + assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num + # attention_mask_totxt [b, 1, 1, text_len] + # attention_mask_local [1, 1, frame_num, frame_len, frame_len] + # attention_mask: [1, 1, text_len+frame_len, text_len+frame_len] + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.transpose(-1, -2) + + # score: any2text + score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \ + - 10000.0 * (1.0 - attention_mask_totxt) + score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \ + 10000.0 * (1.0 - attention_mask_totxt) + + # score: frame local + q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2) + score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame) + score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \ + - 10000.0 * (1.0 - attention_mask_local) + + # context for frame + score_frame_all = torch.cat((score_any2text_part2, + score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_frame = attention_dropout(attention_probs_frame) + + context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\ + view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h) + context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0) + + # context for text + attention_probs_text = F.softmax(score_any2text_part1, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_text = attention_dropout(attention_probs_text) + context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :]) + context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0) + + return context_text2text, context_frame + + +class CogVideoModel(BaseModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) + self.stage = args.cogvideo_stage # 1 or 2 + self.mode_sequential = True if self.stage==1 else False + self.layout = args.layout # [64, 64+400, 64+5*400] + self.n_head = args.num_attention_heads + frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0])) + frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]) + frame_len = self.layout[1]-self.layout[0] + + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + args.additional_seqlen, args.hidden_size + )) + + if args.window_size == -1: + # full attention + assert self.stage == 1 + self.add_mixin('attention_plus', FullAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + n_head=args.num_attention_heads, + frame_num=frame_num, + )) + else: + self.add_mixin('attention_plus', WindowAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + window_size=args.window_size, + shift_size=args.window_size//2, + n_head=args.num_attention_heads, + frame_num=frame_num, + )) + # attention_mask_local + self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0) + self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len) + + for idx in range(1, frame_num, 2): + self.attention_mask_local_interp[:, :, idx:idx+1].tril_() + self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0) + self.mask_initialized = False + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations') + group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num') + group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention") + group.add_argument("--additional-seqlen", type=int, default=2000) + group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) + return parser + + def disable_untrainable_params(self): + self.transformer.requires_grad_(False) + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :(64+400)] + position_plus = position_ids[..., (64+400):] + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400)) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, layer_id, **kw_args): + # mask.shape=[bs, 1, 1, 64] + if not self.mask_initialized: + self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype) + self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype) + self.mask_initialized = True + + attn_module = self.transformer.layers[layer_id].attention + hidden_size = hidden_states.shape[-1] + bs = hidden_states.shape[0] + + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None + + attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp + context_text, context_frame_local_text = attention_localframe_and_text( + q0, k0, v0, + attention_mask_totxt=mask, + attention_mask_local=attention_mask_local, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + frame_len=self.layout[1]-self.layout[0], + frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]), + attention_dropout=dropout_fn, + layer_id=layer_id, + ) + + context_frame_swin = self.get_mixin('attention_plus').attention_extra( + hidden_states[:, self.layout[0]:], layer_id, dropout_fn, + text_hidden_state=hidden_states[:, :self.layout[0]], + text_attn_mask=mask[..., 0, :], + mode_sequential=self.mode_sequential) + + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + + output_text = attn_module.dense(context_text) + output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + output = torch.cat((output_text, output_frame), dim=-2) + + return output \ No newline at end of file diff --git a/pretrain_cogvideo.py b/pretrain_cogvideo.py new file mode 100644 index 0000000..defd906 --- /dev/null +++ b/pretrain_cogvideo.py @@ -0,0 +1,184 @@ +# -*- encoding: utf-8 -*- +''' +@File : pretrain_cogvideo.py +@Time : 2021/10/06 00:58:32 +@Author : Wenyi Hong +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse +import numpy as np +from icetk import icetk as tokenizer +tokenizer.add_special_tokens(['', '', '']) + +from models.cogvideo_model import CogVideoModel +from SwissArmyTransformer import mpu, get_args +from SwissArmyTransformer.training.deepspeed_training import training_main +from SwissArmyTransformer.data_utils import BinaryDataset + +def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + assert attention_mask_totxt is not None + layout = args.layout + assert seq_length == layout[-1] + n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long() + frame_len = layout[1]-layout[0] + position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long, + device=data.device) + for i in range(batch_size): + torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]], + dtype=torch.long, device=data.device) + torch.arange(512, 512+layout[2]-layout[0], + out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device) + return position_ids + + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['text', 'loss_mask', 'attention_mask_totxt'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens_ = data_b['text'].long() + loss_mask = data_b['loss_mask'].float() + attention_mask_totxt = data_b['attention_mask_totxt'].float() + + labels = tokens_[:, 1:].clone().contiguous() + loss_mask = loss_mask[:, 1:].contiguous() + tokens = tokens_[:, :-1].clone().contiguous() + + for idx in range(args.layout[0], args.layout[2], 400): + tokens[:, idx] = tokenizer[''] + # Get the masks and postition ids. + position_ids = get_masks_and_position_ids_video( + tokens, + attention_mask_totxt=attention_mask_totxt, + args=args + ) + attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1) + # Convert + if args.fp16: + attention_mask_totxt = attention_mask_totxt.half() + return tokens, labels, loss_mask, attention_mask_totxt, position_ids + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + + # Forward model. + logits, *mems = model(tokens, position_ids, attention_mask_totxt) + # ======= hyper params =======# + perframe_len = 400 + text_len=64 + frame_num = 5 + logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous() + losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:]) + # scaling loss mask + loss_mask = loss_mask[:, text_len:].reshape(-1) + + losses_1d = losses.reshape(-1) * loss_mask + loss = torch.sum(losses_1d) / loss_mask.sum() + # ===================== Log partial losses ======================== # + log_loss_dict = {} + bs = losses.shape[0] + + if args.cogvideo_stage == 1: + for i in range(frame_num): + log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) + else: + for i in range(1, frame_num-1): + log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) + + # ===================== END OF BLOCK ======================= # + return loss, log_loss_dict + + +def create_dataset_function(path, args): + dataset_layout = [64, 464, 2064] + input_layout = [64, 464, 2064] + # frame_num = 6 + # frame_interval = 2 # DEBUG!!! + def process_fn(row): + row = row.astype(np.int64) + text = row[:dataset_layout[0]] + frames = row[dataset_layout[0]:] + + if text[0] == tokenizer['']: + text = text[1:] # due to our way of data processing + if args.cogvideo_stage == 1: + text, loss_mask, frames = make_text_video_generation(text, frames) + else: + text, loss_mask, frames = mask_video_frame_interpolation(text, frames) + + n_pad = input_layout[0] - len(text) + parts = [ + np.array([tokenizer['']] * n_pad, dtype=np.int64), + text, + np.array([tokenizer['']], dtype=np.int64), + frames, + ] + ret = np.concatenate(parts, axis=0) + + attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad)) + return {'text': ret, + 'loss_mask': loss_mask, + 'attention_mask_totxt': attention_mask_totxt, + } + return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1]) + +def make_text_video_generation(text, frames): + input_layout = [64, 464, 2064] + text = text[text!= tokenizer['']][:input_layout[0]] # dataset format: 1.0秒{text} ... + loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位 + return text, loss_mask, frames + +def mask_video_frame_interpolation(text, frames): + input_layout = [64, 464, 2064] + frame_len = input_layout[1]-input_layout[0] + # text format: 1.0秒 {text} + text = text[text!= tokenizer['']][:input_layout[0]] + loss_mask = np.array([0] * (input_layout[1]+1) + + [1] * (input_layout[1]-input_layout[0]) + + [0] * (input_layout[1]-input_layout[0]) + + [1] * (input_layout[1]-input_layout[0]) + + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位 + + return text, loss_mask, frames + + + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--txt-loss-scale', type=float, default=1) + CogVideoModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.layout = [int(x) for x in args.layout.split(',')] + + training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5cf5885 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +SwissArmyTransformer>=0.2 +icetk +gifmaker +torchvision \ No newline at end of file diff --git a/scripts/ds_brain_pretrain_cogvideo_stage1.sh b/scripts/ds_brain_pretrain_cogvideo_stage1.sh new file mode 100644 index 0000000..03c1b18 --- /dev/null +++ b/scripts/ds_brain_pretrain_cogvideo_stage1.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +# HOST_FILE_PATH="hostfile_single" + +video_data_test="" # TODO +CHECKPOINT_PATH="" # TODO: CogView2 ckpt + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name pretrain-cogvideo-stage1 \ + --tokenizer-type fake \ + --vocab-size 150010 \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --num-workers 0 \ + --num-layers 48 \ + --hidden-size 3072 \ + --num-attention-heads 48 \ + --layout 64,464,2064 \ + --window-size -1 \ + --cogvideo-stage 1 \ + --additional-seqlen 2000 \ + --train-iters 500000 \ + --resume-dataloader \ + --train-data ${video_data_test} \ + --train-data-weights 1 \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .001 \ + --checkpoint-activations \ + --max-sequence-length 1024 \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 500 \ + --eval-iters 15 \ + --log-interval 50 \ + --save $main_dir/checkpoints \ + --sandwich-ln \ + --load $CHECKPOINT_PATH \ +" + # --load $CHECKPOINT_PATH \ + # \ --sandwich-ln + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +#!/bin/bash + +# Distribute Example +#export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_DISABLE=0 +export NCCL_NET_GDR_LEVEL=2 +#export NCCL_IB_CUDA_SUPPORT=1 +#export NCCL_IB_GID_INDEX=3 +#export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null) +export NCCL_DEBUG=info +export OMP_NUM_THREADS=4 + +if [ $RLAUNCH_REPLICA == "0" ]; then + ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip +fi + +function finish { + rm -rf master_ip +} + +trap finish EXIT INT TERM + +while [ ! -f master_ip ]; do + echo "wait master_ip..." + ls > /dev/null && sleep 1; +done + +export MASTER_ADDR=$(cat master_ip) +echo "master_ip: $MASTER_ADDR" + +MP_SIZE=1 +task_set=$2 +source $1 +DATESTR=$(date +"%m-%d-%H-%M") + +mkdir logs +run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \ + --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt" + + +# run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/ds_brain_pretrain_cogvideo_stage2.sh b/scripts/ds_brain_pretrain_cogvideo_stage2.sh new file mode 100644 index 0000000..5b89b0a --- /dev/null +++ b/scripts/ds_brain_pretrain_cogvideo_stage2.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +# HOST_FILE_PATH="hostfile_single" + +video_data_test="" # TODO +CHECKPOINT_PATH="" # TODO: CogView2 ckpt + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name pretrain-cogvideo-stage2 \ + --tokenizer-type fake \ + --vocab-size 150010 \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --num-workers 0 \ + --num-layers 48 \ + --hidden-size 3072 \ + --num-attention-heads 48 \ + --layout 64,464,2064 \ + --window-size 10 \ + --cogvideo-stage 2 \ + --additional-seqlen 2000 \ + --train-iters 500000 \ + --resume-dataloader \ + --train-data ${video_data_test} \ + --train-data-weights 1 \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .001 \ + --checkpoint-activations \ + --max-sequence-length 1024 \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 500 \ + --eval-iters 15 \ + --log-interval 50 \ + --save $main_dir/checkpoints \ + --sandwich-ln \ + --load $CHECKPOINT_PATH \ +" + # --load $CHECKPOINT_PATH \ + # \ --sandwich-ln + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +#!/bin/bash + +# Distribute Example +#export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_DISABLE=0 +export NCCL_NET_GDR_LEVEL=2 +#export NCCL_IB_CUDA_SUPPORT=1 +#export NCCL_IB_GID_INDEX=3 +#export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null) +export NCCL_DEBUG=info +export OMP_NUM_THREADS=4 + +if [ $RLAUNCH_REPLICA == "0" ]; then + ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip +fi + +function finish { + rm -rf master_ip +} + +trap finish EXIT INT TERM + +while [ ! -f master_ip ]; do + echo "wait master_ip..." + ls > /dev/null && sleep 1; +done + +export MASTER_ADDR=$(cat master_ip) +echo "master_ip: $MASTER_ADDR" + +MP_SIZE=1 +task_set=$2 +source $1 +DATESTR=$(date +"%m-%d-%H-%M") + +mkdir logs +run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \ + --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt" + + +# run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json new file mode 100644 index 0000000..a9f7ad1 --- /dev/null +++ b/scripts/ds_config_zero.json @@ -0,0 +1,42 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "gradient_clipping": 0.1, + "zero_optimization": { + "stage": 2, + "cpu_offload": true, + "contiguous_gradients": false, + "overlap_comm": true, + "reduce_scatter": false, + "reduce_bucket_size": 100000000, + "allgather_bucket_size": 1000000000, + "load_from_fp32_weights": false + }, + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 400, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [ + 0.9, + 0.95 + ], + "eps": 1e-8, + "weight_decay": 1e-4 + } + }, + "activation_checkpointing": { + "partition_activations": false, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false + } + \ No newline at end of file diff --git a/scripts/inference_cogvideo_pipeline.sh b/scripts/inference_cogvideo_pipeline.sh new file mode 100644 index 0000000..ccbc543 --- /dev/null +++ b/scripts/inference_cogvideo_pipeline.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +NLAYERS=48 +NHIDDEN=3072 +NATT=48 +MAXSEQLEN=1024 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +MPSIZE=1 + +#SAMPLING ARGS +TEMP=1.05 +TOPK=12 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +MASTER_PORT=${MASTER_PORT} SAT_HOME=/sharefs/cogview-new python cogvideo_pipeline.py \ + --input-source interactive \ + --output-path ./output \ + --parallel-size 1 \ + --both-stages \ + --use-guidance-stage1 \ + --guidance-alpha 3.0 \ + --generate-frame-num 5 \ + --tokenizer-type fake \ + --mode inference \ + --distributed-backend nccl \ + --fp16 \ + --model-parallel-size $MPSIZE \ + --temperature $TEMP \ + --coglm-temperature2 0.89 \ + --top_k $TOPK \ + --sandwich-ln \ + --seed 1234 \ + --num-workers 0 \ + --batch-size 4 \ + --max-inference-batch-size 8 \ + $@ diff --git a/sr_pipeline/__init__.py b/sr_pipeline/__init__.py new file mode 100644 index 0000000..736cde4 --- /dev/null +++ b/sr_pipeline/__init__.py @@ -0,0 +1,17 @@ +# -*- encoding: utf-8 -*- +''' +@File : __init__.py +@Time : 2022/03/02 13:57:09 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +from .direct_sr import DirectSuperResolution +from .iterative_sr import IterativeSuperResolution +from .sr_group import SRGroup \ No newline at end of file diff --git a/sr_pipeline/direct_sr.py b/sr_pipeline/direct_sr.py new file mode 100644 index 0000000..fe32a3a --- /dev/null +++ b/sr_pipeline/direct_sr.py @@ -0,0 +1,117 @@ +# -*- encoding: utf-8 -*- +''' +@File : direct_sr.py +@Time : 2022/03/02 13:58:11 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch + +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview2.py +@Time : 2021/10/10 16:31:34 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from PIL import ImageEnhance, Image + +import torch +import argparse +from torchvision import transforms + +from SwissArmyTransformer import get_args +from SwissArmyTransformer.training.model_io import load_checkpoint +from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually + +from .dsr_model import DsrModel + +from icetk import icetk as tokenizer + +class DirectSuperResolution: + def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False): + args.load = path + args.kernel_size = 5 + args.kernel_size2 = 5 + args.new_sequence_length = 4624 + args.layout = [96,496,4096] + + model = DsrModel(args) + if args.fp16: + model = model.half() + + load_checkpoint(model, args) # on cpu + model.eval() + self.model = model + self.onCUDA = onCUDA + if onCUDA: + self.model = self.model.cuda() + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + + self.strategy = IterativeEntfilterStrategy(invalid_slices, + temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!! + self.max_bz = max_bz + + def __call__(self, text_tokens, image_tokens, enhance=False): + if len(text_tokens.shape) == 1: + text_tokens.unsqueeze_(0) + if len(image_tokens.shape) == 1: + image_tokens.unsqueeze_(0) + # ===================== Debug ======================== # + # new_image_tokens = [] + # for small_img in image_tokens: + # decoded = tokenizer.decode(image_ids=small_img) + # decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0) + # ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + # image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + # small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1) + # new_image_tokens.append(small_img2) + # image_tokens = torch.stack(new_image_tokens) + # return image_tokens + # ===================== END OF BLOCK ======================= # + if enhance: + new_image_tokens = [] + for small_img in image_tokens: + decoded = tokenizer.decode(image_ids=small_img).squeeze(0) + ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1) + new_image_tokens.append(small_img2) + image_tokens = torch.stack(new_image_tokens) + + seq = torch.cat((text_tokens,image_tokens), dim=1) + seq1 = torch.tensor([tokenizer['']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1) + if not self.onCUDA: + print('Converting Dsr model...') + model = self.model.cuda() + else: + model = self.model + print('Direct super-resolution...') + output_list = [] + for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)): + output1 = filling_sequence_dsr(model, + seq[tim*self.max_bz:(tim+1)*self.max_bz], + seq1[tim*self.max_bz:(tim+1)*self.max_bz], + warmup_steps=1, block_hw=(1, 0), + strategy=self.strategy + ) + output_list.extend(output1[1:]) + if not self.onCUDA: + print('Moving back Dsr to cpu...') + model = model.cpu() + torch.cuda.empty_cache() + return torch.cat(output_list, dim=0) \ No newline at end of file diff --git a/sr_pipeline/dsr_model.py b/sr_pipeline/dsr_model.py new file mode 100644 index 0000000..d918d18 --- /dev/null +++ b/sr_pipeline/dsr_model.py @@ -0,0 +1,225 @@ +# -*- encoding: utf-8 -*- +''' +@File : cuda2d_model.py +@Time : 2021/10/02 01:36:32 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + + +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method +from SwissArmyTransformer.mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 512+400) + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2]) + assert new_edge % old_edge == 0 + self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size)) + # self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + + +class AttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02) + ): + super(AttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3, + gather_output=False, init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear(hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + for layer_id in range(num_layers) + ]) + + def reinit(self, parent_model=None): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data) + self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data) + +class DsrModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.original_sequence_length = args.max_sequence_length + additional_seqlen = args.new_sequence_length - args.max_sequence_length + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + additional_seqlen, args.hidden_size + )) + self.add_mixin('attention_plus', AttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size + )) + self.layout = args.layout + # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]} + self.kernel_size = args.kernel_size + self.kernel_size2 = args.kernel_size2 + self.log_attention_weights = None + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :self.layout[1]] + position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, + layer_id=None, log_attention_weights=None, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + # attention_plus on all layers + query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id] + dense_plus = self.get_mixin('attention_plus').dense[layer_id] + # split two parts + hidden_states_plus = hidden_states[:, self.layout[1]:] + hidden_states = hidden_states[:, :self.layout[1]] + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + # cuda2d model qkv + mixed_raw_layer = query_key_value_plus(hidden_states_plus) + q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3) + + dropout_fn = attn_module.attention_dropout if self.training else None + + # cuda2d attention + context_layer0, context_layer1 = sparse_attention_2d_light( + q0, k0, v0, + q1, k1, v1, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + kernel_size=self.kernel_size, + kernel_size2=self.kernel_size2, + attention_dropout=dropout_fn, + log_attention_weights=log_attention_weights, + add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0) + ) + + output_0 = attn_module.dense(context_layer0) + output_1 = dense_plus(context_layer1) + output = torch.cat((output_0, output_1), dim=1) + + return output + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) + # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]) + return logits_parallel + + def disable_untrainable_params(self): + self.transformer.requires_grad_(False) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations') + group.add_argument("--kernel-size", type=int, default=5) + group.add_argument("--kernel-size2", type=int, default=5) + group.add_argument("--layout", type=str, default='96,496,4096') + group.add_argument("--new-sequence-length", type=int, default=4096) + return parser + +def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs): + ''' + q0, k0, v0: [batch_size, 1088, hidden_size] + q1, k1, v1: [batch_size, 4096, h2] + n_head: int + attention_mask: [batch_size, 1088, 1088] + ''' + from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting + + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1) + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + if log_attention_weights is not None: + attention_scores += log_attention_weights + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + # scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + # cross attention + k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous() + scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field] + scores_1 = torch.cat( + ( + scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar, + scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3]) + ), + dim=-1) + attention_probs1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + # with get_cuda_rng_tracker().fork(): + attention_probs0 = attention_dropout(attention_probs0) + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) + # context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True) + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head * h, l1**2) + # weighting for cross attention + probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0) + v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0) + context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False) + context1_to_0 = context1_to_0.view(b, n_head * h, l1**2) + context1 = context1 + context1_to_0 + return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2) \ No newline at end of file diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py new file mode 100644 index 0000000..5b8dded --- /dev/null +++ b/sr_pipeline/dsr_sampling.py @@ -0,0 +1,159 @@ +# -*- encoding: utf-8 -*- +''' +@File : cuda2d_sampling.py +@Time : 2021/10/09 00:46:04 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from cv2 import reduce +import torch + +import torch +import torch.nn.functional as F +import numpy as np + +def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + +class IterativeEntfilterStrategy: + def __init__(self, invalid_slices=[], temperature=1., topk=6): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) + + + def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): + # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] + if temperature is None: + temperature = self.temperature + + logits = logits_.float() / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float('Inf') + logits = logits.view(-1, logits.shape[-1]) + + rprobs = F.softmax(logits.float(), dim=-1) + c = self.cluster_labels.expand(*rprobs.shape) + cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + + best_scores, best_clusters = cprobs.topk(self.topk) + bz = logits.shape[0] + best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True) + sampled_ids = torch.multinomial(best_scores, num_samples=1) + selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids) + selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500) + logits[selected_mask] = -65504 + # for i in range(bz): + # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] + # logits[i, self.cluster_labels != selected_cluster] = -65504 + + # logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax(logits.float()/0.6, dim=-1) # float is essetial, due to a bug in Pytorch + pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2]) + + assert tokens.shape[1] == pred.shape[1] + 1 + tokens = torch.cat((tokens[:, :1], pred), dim=1) + return tokens + +def filling_sequence_dsr( + model, + seq0, + seq1, + warmup_steps=3, + block_hw=(4, 4), + strategy=IterativeEntfilterStrategy(topk=10), + ): + ''' + seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] + 4095 {layout[2]} final_token. + Attention: + The sampling temperature are changing, temporally we hard code them here. + The temperature in the strategy is not used. + ''' + assert hasattr(model, 'layout') + layout = model.layout + assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \ + and seq0.shape[0] == seq1.shape[0] + assert len(layout) == 3 + assert seq1.shape[1] == layout[-1] - layout[-2] + 1 + assert (seq1 >= 0).all() and (seq0 >= 0).all() + device = seq0.device + # concat and pad sequences + batch_size = seq0.shape[0] + n_pad = layout[1] - seq0.shape[1] + assert n_pad > 0, "You should truncate long input before filling." + seq = torch.cat(( + torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype) + .unsqueeze(0).expand(batch_size, n_pad), + seq0, seq1), dim=1) # [b, layout[-1]+1] + assert seq.shape[1] == layout[-1] + 1 + + # build initial tokens, attention_mask, and position_ids + tokens = seq.clone() + attention_mask = torch.ones(layout[1], layout[1]).to(device) + attention_mask[:layout[0], layout[0]:] = 0 + attention_mask[n_pad:, :n_pad] = 0 + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 + position_ids = torch.cat(( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, layout[0] - n_pad), + torch.arange(513, 513 + layout[1] - layout[0]), + torch.arange(1024, 1024+layout[2]-layout[1]))).to(device) + log_attention_weights = torch.zeros(layout[1], layout[1], + device=device).type_as(next(model.parameters())) + log_attention_weights[layout[0]:, n_pad:layout[0]] = 0. + + # prepare for interation + unfixed = (tokens < 0) # just init an all-False tensor + unfixed[:, -layout[-1] + layout[-2]:] = True + + ll, rr = block_hw + edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) + num_steps = warmup_steps + ll - 1 + rr + # interative refining + + # unfixed[..., -(layout[-1] - layout[-2]):].view( + # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False + + + ret = [] + ret.append(tokens[:, layout[-2]+1:].clone()) + for step_cnt in range(1, num_steps+1): + if step_cnt <= warmup_steps: + logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights) + real_temp = 1. + new_tokens = strategy.forward(logits, tokens, real_temp) + tokens[unfixed] = new_tokens[unfixed] + else: + logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights) + real_temp = 1. + new_tokens = strategy.forward( + logits, tokens, real_temp, + entfilter=1.3, + filter_topk=5, + temperature2=0.6 + ) + # tokens[unfixed] = new_tokens[unfixed] + # fixed tokens (update unfixed) + unfixed2 = (tokens > 10000000) + for x in range(min(ll, step_cnt - warmup_steps)): + y = step_cnt - warmup_steps - x - 1 + if y < rr: + unfixed[..., -(layout[-1] - layout[-2]):].view( + batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False + unfixed2[..., -(layout[-1] - layout[-2]):].view( + batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = True + tokens[unfixed2] = new_tokens[unfixed2] + + ret.append(tokens[:, layout[-2]+1:].clone()) + + return ret diff --git a/sr_pipeline/iterative_sr.py b/sr_pipeline/iterative_sr.py new file mode 100644 index 0000000..a55a6b5 --- /dev/null +++ b/sr_pipeline/iterative_sr.py @@ -0,0 +1,118 @@ +# -*- encoding: utf-8 -*- +''' +@File : iterative_sr.py +@Time : 2022/03/02 15:57:45 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +# here put the import lib +import os +import sys +import math +import random +from PIL import ImageEnhance, Image + +import torch +import argparse +from torchvision import transforms + +from SwissArmyTransformer.training.model_io import load_checkpoint +from SwissArmyTransformer import get_args +from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually + +from .itersr_model import ItersrModel + +from icetk import icetk as tokenizer + +class IterativeSuperResolution: + def __init__(self, args, path, max_bz=4, shared_transformer=None): + args.load = path + args.kernel_size = 5 + args.kernel_size2 = 5 + args.new_sequence_length = 4624 + args.layout = [16,3616] + + model = ItersrModel(args, transformer=shared_transformer) + if args.fp16: + model = model.half() + + load_checkpoint(model, args) # on cpu + model.eval() + self.model = model.cuda() + + # save cpu weights + self.saved_weights = dict((k,v.cpu()) + for k, v in model.named_parameters() + if 'transformer' in k + ) + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + + self.strategy = IterativeEntfilterStrategy(invalid_slices, + temperature=args.temp_all_itersr, topk=args.topk_itersr) + self.max_bz = max_bz + + def _restore_transformer_from_cpu(self, non_blocking=False): + for k, v in self.model.named_parameters(): + if k in self.saved_weights: + v.copy_(self.saved_weights[k]) + + def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None): + if len(text_tokens.shape) == 1: + text_tokens.unsqueeze_(0) + text_tokens = text_tokens.clone()[..., :16] + if len(image_tokens.shape) == 1: + image_tokens.unsqueeze_(0) + if enhance: + new_image_tokens = [] + for big_img in image_tokens: + decoded = tokenizer.decode(image_ids=big_img).squeeze(0) + ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1) + new_image_tokens.append(big_img2) + image_tokens = torch.stack(new_image_tokens) + print('Converting Itersr model...') + self._restore_transformer_from_cpu() + model = self.model + print('iterative super-resolution...') + output_list = [] + for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)): + big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz] + text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz] + mask_raw = torch.tensor( + [ + -1, 0, 1, 2, 3, 4, + 0, -1, 2, -1, -2, 5, + 1, -2, 3, 4, 5, 6, + 2, 3, 4, 5, -1, 1, + 3, -1, -2, 0, -1, 2, + 4, 5, 6, 1, 3, -2 + ] + ).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous() + + topks = [60, 40, 40, 40, 20, 20, 10] + + for mask_ratio in range(1, 7): + self.strategy.topk = topks[mask_ratio] + mask = (mask_raw.to(big_img.device) >= mask_ratio) + if input_mask is not None: + mask = mask & input_mask + big_img.masked_fill_(mask, tokenizer['']) + seq1 = big_img + output1 = filling_sequence_itersr(model, text_seq, seq1, + warmup_steps=1, block_hw=(1, 0), + strategy=self.strategy + ) + big_img = output1 + print(f'Iter {mask_ratio} times.') + output_list.append(output1.clone()) + return torch.cat(output_list, dim=0) \ No newline at end of file diff --git a/sr_pipeline/itersr_model.py b/sr_pipeline/itersr_model.py new file mode 100644 index 0000000..40981bc --- /dev/null +++ b/sr_pipeline/itersr_model.py @@ -0,0 +1,232 @@ +# -*- encoding: utf-8 -*- +''' +@File : itersr_model.py +@Time : 2021/10/02 01:36:32 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + + +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 512+400) + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2]) + assert new_edge % old_edge == 0 + self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size)) + +class ItersrModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.original_sequence_length = args.max_sequence_length + additional_seqlen = args.new_sequence_length - args.max_sequence_length + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + additional_seqlen, args.hidden_size + )) + # self.add_mixin('attention_plus', AttentionMixin( + # num_layers=args.num_layers, + # hidden_size=args.hidden_size + # )) + self.layout = args.layout + # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]} + self.kernel_size = args.kernel_size + self.kernel_size2 = args.kernel_size2 + self.log_attention_weights = None + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :self.layout[0]] + position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, + layer_id=None, log_attention_weights=None, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3) + # cuda2d model qkv + q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3) + + dropout_fn = attn_module.attention_dropout if self.training else None + + # cuda2d attention + context_layer = sparse_attention_2d_text( + q0, k0, v0, + q1, k1, v1, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + kernel_size=self.kernel_size, + attention_dropout=dropout_fn, + log_attention_weights=log_attention_weights, + ) + + output = attn_module.dense(context_layer) + + return output + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float() + # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]) + return logits_parallel + + # def disable_untrainable_params(self): + # self.transformer.requires_grad_(False) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations') + group.add_argument("--kernel-size", type=int, default=5) + group.add_argument("--kernel-size2", type=int, default=5) + group.add_argument("--layout", type=str, default='16,3616') + group.add_argument("--new-sequence-length", type=int, default=4096) + return parser + +def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs): + ''' + q0, k0, v0: [batch_size, 16, hidden_size] + q1, k1, v1: [batch_size, 3600, hidden_size] + n_head: int + attention_mask: [batch_size, 16] + ''' + from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l1 = h0 // n_head, sqrt(s1) + assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}" + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + # cross attention + scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T) + if log_attention_weights is not None: + scores_1_to_0 += log_attention_weights + scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + scores_1 = torch.cat( + ( + scores_1_to_0.view(b*n_head, s1, s0), + scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3]) + ), + dim=-1) + attention_probs1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head, h, l1**2) + # weighting for cross attention + probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3]) + + context1_to_0 = torch.matmul(probs_1_to_0, v0) + context1 = context1.transpose(-1, -2) + context1_to_0 + + output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0) + + return output + +def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs): + ''' + q0, k0, v0: [batch_size, 16, hidden_size] + q1, k1, v1: [batch_size, 3600, hidden_size] + n_head: int + attention_mask: [batch_size, 16] + ''' + from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l1 = h0 // n_head, sqrt(s1) + assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}" + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + attention_probs1 = F.softmax(scores_1_to_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1 + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head, h, l1**2) + # weighting for cross attention + context1 = context1.transpose(-1, -2) + + output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0) + + return output \ No newline at end of file diff --git a/sr_pipeline/itersr_sampling.py b/sr_pipeline/itersr_sampling.py new file mode 100644 index 0000000..df22a00 --- /dev/null +++ b/sr_pipeline/itersr_sampling.py @@ -0,0 +1,168 @@ +# -*- encoding: utf-8 -*- +''' +@File : itersr_sampling.py +@Time : 2022/03/03 14:24:28 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import numpy as np + +import torch +import torch.nn.functional as F +from icetk import icetk as tokenizer + +def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + +# class IterativeEntfilterStrategy: +# def __init__(self, invalid_slices=[], temperature=1., topk=10): +# self.invalid_slices = invalid_slices +# self.temperature = temperature +# self.topk = topk +# self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long) + + +# def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): +# # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] +# if temperature is None: +# temperature = self.temperature + +# logits = logits_.float() / temperature +# for invalid_slice in self.invalid_slices: +# logits[..., invalid_slice] = -float('Inf') +# logits = logits.view(-1, logits.shape[-1]) + +# rprobs = F.softmax(logits.float(), dim=-1) +# c = self.cluster_labels.expand(*rprobs.shape) +# cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + +# best_scores, best_clusters = cprobs.topk(self.topk) +# bz = logits.shape[0] +# best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True) +# sampled_ids = torch.multinomial(best_scores, num_samples=1) +# selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids) +# selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500) +# logits[selected_mask] = -65504 +# # for i in range(bz): +# # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] +# # logits[i, self.cluster_labels != selected_cluster] = -65504 + +# # logits = top_k_logits(logits, self.topk, self.top_p) +# probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch +# pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2]) + +# assert tokens.shape[1] == pred.shape[1] +# tokens = pred +# return tokens + +class IterativeEntfilterStrategy: + def __init__(self, invalid_slices=[], temperature=1., topk=10): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + + def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): + # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] + if temperature is None: + temperature = self.temperature + # check entropy filter + # if entfilter is not None: + # assert temperature2 is not None + # topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1) + # ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length] + # temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2 + + logits = logits.float() / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float('Inf') + + # debiased topk + # probs = F.softmax(logits, dim=-1) + # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1) + # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + # edge_idx = tk_idx[:, :, -1:] + # edge_value = tk_value[:, :, -1:] + # edge_mask = probs.gather(dim=-1, index=pred) < edge_value + # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token + # pred.squeeze_(-1) # [batch_size, seq_length] + + top_k_logits_(logits, self.topk) + probs = F.softmax(logits, dim=-1) + pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + pred.squeeze_(-1) + + assert tokens.shape[1] == pred.shape[1] + tokens = pred + return tokens + +def filling_sequence_itersr( + model, + seq0, + seq1, + warmup_steps=3, + block_hw=(4, 4), + strategy=IterativeEntfilterStrategy(topk=10), + ): + ''' + seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] + 4095 {layout[2]} final_token. + Attention: + The sampling temperature are changing, temporally we hard code them here. + The temperature in the strategy is not used. + ''' + assert hasattr(model, 'layout') + layout = model.layout + + device = seq0.device + # concat and pad sequences + batch_size = seq0.shape[0] + n_pad = layout[0] - seq0.shape[1] + assert n_pad >= 0, "You should truncate long input before filling." + seq = torch.cat(( + torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype) + .unsqueeze(0).expand(batch_size, n_pad), + seq0, seq1), dim=1) # [b, layout[-1]+1] + assert seq.shape[1] == layout[-1] + + # build initial tokens, attention_mask, and position_ids + tokens = seq.clone() + attention_mask = torch.ones(layout[0]).to(device) + attention_mask[:n_pad] = 0 + attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16 + position_ids = torch.cat(( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, layout[0] - n_pad), + torch.arange(1024, 1024+layout[1]-layout[0]))).to(device) + log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters())) + log_attention_weights[n_pad:layout[0]] = 0. + log_attention_weights = log_attention_weights.unsqueeze(0) + + # prepare for interation + unfixed = (tokens == tokenizer['']) + ll, rr = block_hw + edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) + num_steps = 1 + # interative refining + + # unfixed[..., -(layout[-1] - layout[-2]):].view( + # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False + + + ret = [] + # ret.append(tokens[:, layout[-2]:-1].clone()) + for step_cnt in range(1, num_steps+1): + logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights) + real_temp = 1. + new_tokens = strategy.forward(logits, tokens, real_temp) + tokens[unfixed] = new_tokens[unfixed] + + ret.append(tokens[:, layout[-2]:].clone()) + return torch.cat(ret, dim=0) \ No newline at end of file diff --git a/sr_pipeline/sr_group.py b/sr_pipeline/sr_group.py new file mode 100644 index 0000000..1ec51b6 --- /dev/null +++ b/sr_pipeline/sr_group.py @@ -0,0 +1,49 @@ +# -*- encoding: utf-8 -*- +''' +@File : sr_group.py +@Time : 2022/04/02 01:17:21 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +import numpy as np +import torch +import torch.nn.functional as F +from SwissArmyTransformer.resources import auto_create +from .direct_sr import DirectSuperResolution +from .iterative_sr import IterativeSuperResolution + +class SRGroup: + def __init__(self, args, home_path=None,): + dsr_path = auto_create('cogview2-dsr', path=home_path) + itersr_path = auto_create('cogview2-itersr', path=home_path) + dsr = DirectSuperResolution(args, dsr_path) + itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer) + self.dsr = dsr + self.itersr = itersr + + def sr_base(self, img_tokens, txt_tokens): + assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2 + batch_size = img_tokens.shape[0] + txt_len = txt_tokens.shape[-1] + if len(txt_tokens.shape) == 1: + txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) + sred_tokens = self.dsr(txt_tokens, img_tokens) + iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone()) + return iter_tokens[-batch_size:] + + # def sr_patch(self, img_tokens, txt_tokens): + # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2 + # batch_size = img_tokens.shape[0] * 9 + # txt_len = txt_tokens.shape[-1] + # if len(txt_tokens.shape) == 1: + # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) + # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400) + # iter_tokens = self.sr_base(img_tokens, txt_tokens) + # return iter_tokens \ No newline at end of file