From 51c42e97bd8e15168698f0cdd921e3f99723e702 Mon Sep 17 00:00:00 2001 From: wenyihong Date: Sat, 16 Jul 2022 21:52:20 +0800 Subject: [PATCH] Release code --- .gitignore | 8 + README.md | 58 ++ cluster_label2.npy | Bin 0 -> 160128 bytes coglm_strategy.py | 101 +++ cogvideo_pipeline.py | 793 +++++++++++++++++++ models/cogvideo_cache_model.py | 695 ++++++++++++++++ models/cogvideo_model.py | 543 +++++++++++++ pretrain_cogvideo.py | 184 +++++ requirements.txt | 4 + scripts/ds_brain_pretrain_cogvideo_stage1.sh | 108 +++ scripts/ds_brain_pretrain_cogvideo_stage2.sh | 108 +++ scripts/ds_config_zero.json | 42 + scripts/inference_cogvideo_pipeline.sh | 38 + sr_pipeline/__init__.py | 17 + sr_pipeline/direct_sr.py | 117 +++ sr_pipeline/dsr_model.py | 225 ++++++ sr_pipeline/dsr_sampling.py | 159 ++++ sr_pipeline/iterative_sr.py | 118 +++ sr_pipeline/itersr_model.py | 232 ++++++ sr_pipeline/itersr_sampling.py | 168 ++++ sr_pipeline/sr_group.py | 49 ++ 21 files changed, 3767 insertions(+) create mode 100644 .gitignore create mode 100644 cluster_label2.npy create mode 100644 coglm_strategy.py create mode 100644 cogvideo_pipeline.py create mode 100644 models/cogvideo_cache_model.py create mode 100644 models/cogvideo_model.py create mode 100644 pretrain_cogvideo.py create mode 100644 requirements.txt create mode 100644 scripts/ds_brain_pretrain_cogvideo_stage1.sh create mode 100644 scripts/ds_brain_pretrain_cogvideo_stage2.sh create mode 100644 scripts/ds_config_zero.json create mode 100644 scripts/inference_cogvideo_pipeline.sh create mode 100644 sr_pipeline/__init__.py create mode 100644 sr_pipeline/direct_sr.py create mode 100644 sr_pipeline/dsr_model.py create mode 100644 sr_pipeline/dsr_sampling.py create mode 100644 sr_pipeline/iterative_sr.py create mode 100644 sr_pipeline/itersr_model.py create mode 100644 sr_pipeline/itersr_sampling.py create mode 100644 sr_pipeline/sr_group.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2cce19b --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +output/ +*__pycache__/ +samples*/ +runs/ +checkpoints/ +master_ip +logs/ +*.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 1a47803..d382ea0 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,30 @@ This is the official repo for the paper: [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](http://arxiv.org/abs/2205.15868). +**News!** The [demo](https://wudao.aminer.cn/cogvideo/) for CogVideo is available! + +**News!** The code and model for text-to-video generation is now available! Currently we only supports *simplified Chinese input*. https://user-images.githubusercontent.com/48993524/170857367-2033c514-3c9f-4297-876f-2468592a254b.mp4 +* **Read** our paper [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) on ArXiv for a formal introduction. +* **Try** our demo at [https://wudao.aminer.cn/cogvideo/](https://wudao.aminer.cn/cogvideo/) +* **Run** our pretrained models for text-to-video generation. Please use A100 GPU. +* **Cite** our paper if you find our work helpful + +``` +@article{hong2022cogvideo, + title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers}, + author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie}, + journal={arXiv preprint arXiv:2205.15868}, + year={2022} +} +``` + +## Web Demo + +The demo for CogVideo is at [https://wudao.aminer.cn/cogvideo/](https://wudao.aminer.cn/cogvideo/), where you can get hands-on practice on text-to-video generation. *The original input is in Chinese.* + ## Generated Samples @@ -20,3 +41,40 @@ https://user-images.githubusercontent.com/48993524/170857367-2033c514-3c9f-4297- A 4-second clip of 32 frames is shown below. ![High-frame-rate sample](assets/appendix-sample-highframerate.png) + +## Getting Started + +### Setup + +* Hardware: Linux servers with Nvidia A100s are recommended, but it is also okay to run the pretrained models with smaller `--max-inference-batch-size` and `--batch-size` or training smaller models on less powerful GPUs. +* Environment: install dependencies via `pip install -r requirements.txt`. +* LocalAttention: Make sure you have CUDA installed and compile the local attention kernel. + +```shell +git clone https://github.com/Sleepychord/Image-Local-Attention +cd Image-Local-Attention && python setup.py install +``` + +### Download + +Our code will automatically download or detect the models into the path defined by environment variable `SAT_HOME`. You can also manually download [CogVideo-Stage1](https://lfs.aminer.cn/misc/cogvideo/cogvideo-stage1.zip) and [CogVideo-Stage2](https://lfs.aminer.cn/misc/cogvideo/cogvideo-stage2.zip) and place them under SAT_HOME (with folders named `cogvideo-stage1` and `cogvideo-stage2`) + +### Text-to-Video Generation + +``` +./script/inference_cogvideo_pipeline.sh +``` + +Arguments useful in inference are mainly: + +* `--input-source [path or "interactive"]`. The path of the input file with one query per line. A CLI would be launched when using "interactive". +* `--output-path [path]`. The folder containing the results. +* `--batch-size [int]`. The number of samples will be generated per query. +* `--max-inference-batch-size [int]`. Maximum batch size per forward. Reduce it if OOM. +* `--stage1-max-inference-batch-size [int]` Maximum batch size per forward in Stage 1. Reduce it if OOM. +* `--both-stages`. Run both stage1 and stage2 sequentially. +* `--use-guidance-stage1` Use classifier-free guidance in stage1, which is strongly suggested to get better results. + +You'd better specify an environment variable `SAT_HOME` to specify the path to store the downloaded model. + +*Currently only Chinese input is supported.* diff --git a/cluster_label2.npy b/cluster_label2.npy new file mode 100644 index 0000000000000000000000000000000000000000..dff3170b9605f041d689cbad35d5c0e56590bb9c GIT binary patch literal 160128 zcmbT-dHm;N-T!?_$TEhpO_Z4;WM6AE!_4WFeeBzbnz1v?SW1W~Yt*0;lI=S`O2`bq zP-rl}vM*yvw$yiAy2YwMlmg99C$Ln})@6T5kOx$C~o%i|B z0>>}#wbc(g__za)U464rtB*bWbE~gCYV|{=9sA{DCr|y-v||rC`2Xd9Ve*m39sJ+? z;|`mA^uhm4ud(6!|NXzU*IawlDWm>>{%^6p7Wkk4?H)cGZWX@xq5t{EmC^LKj$R`9 z=HV^jx#5lJvA;|F>*IsH&k#MT?WiqXGKS^u_zftqb1| z+5a;G;ThquaH-He?~V2z;{4|LZ?*p_q4T=OrtxQ`@7~^ZM)I?Vp!2F1{4Gw^k3(Cx zYDn&0Guroq^F@1i@AmunJBQ*$e7g5;?Zb~>Pv7}3Os?+yG8%uIXz~89){F0TqmN5Z z{;d*UT^Jwj9?mDeMkjwcJN)5&`BQ({G=1;{s^uJ}td6$g;y&-((zd1Sootd8axjwn`d*5r*`|lCr|NEc)^yii>;Jn@+f5rZ> za42v796b}Q?$U{5}&=jyULI&sLp-c_8aW8RV7$Gh)Up*YLj0&B+KIL1Jz zo{PIh(v#=&(z_(&H}6T``P3Kw(ced8ch&4TH+}CX|A(bFI#d_c5q0|JYBOuwDi$`m1KwJG%e-@x`(Cy+6LX=YGz8Ve%KkN83kU(i^(p&goqpKG=G7gWs%o-+}bR z(U|DFv-4tTo$n6xq3O#{aW3xAL;K;&t@Cc;Ox+fL`?miz>F*x?B{W~Ezw^`oNhlA! z!?#`d_Zi8ZcXad?p}3US;^*M@J1;$cwx7H( zH;zqD+>m=meh}a8KiIkzLVd=&{4@SyoyT`G{mqhlzkz6T8oh7Ex371-J-g0F-}jGm z{Hyiy$$RNn;@EdT`+K#{{^Ifa_+#5&f4Doob>8FY_yZxoj*b>@-bWrjm;LQS=VOPT z_}zP}ySKIu?Hu&=cXi?F&hy*wtWaH6Up9(=Y4*jR{ngEtlZ(FIU&grsuoOef1^r0qyTQYEE{{4V!0gbm;u@N?cD)?*96N z`pj=b>x=rN-^d^L-z&S`hdiyhk?NQ}Rt6kZ11AUw5}Zntk=|;^YgppL)4qv~}i4_wZiwLw_B}p6@b! z!hELB%6IV~Kh$^r^FG#(&b~Ns-$&vv*}3_T-@Lo};$B;~PW*dkc_vQq#n;oV^WO5# z_Y}W;JiQe{^Y+crV6cD4x8c$9SbXjjUw%#NzUsPjteT#>CExCjFK^8oi^un_>Z7{q zyo;sBzx37Hza`fvPw#x!#Fy9J(fj;4x%%l|-i@EvO;4ZkJHf;8-G4?j`>#iPH*vOT z{BMN#2SkfA{g6M*)!tPen=8bZ_!S4_;z?cCr!H<^eDlL``SH$ht5E+}N5!4Icp*J` zDDSO%v;97mo_okYdHra5>fd0ry!3AP*7J|^dB@#bFTb80Li^6}JKSZ-eGjR>?#0ji zGd2Cu;h`bFoQ&VSr-ZQ=Q2#EBVeZ*7J@h*^4`D2gv=ifQW<(^-H~k^6&2K_`48$1M$^W_joJ&zX`pk_&Xpz z|Ba6JF6Ttc6Y(>;{d^aBALm&$x&Hcow0#%LuDYh~n+y5(tn`*>{oSE>RX<)%@95;t ztA5`SUp~oK=kbpBrsq7fqRkQNhkTIdzDu4;pWZ;Ues7)kNxXnLQC-!ar==jyonL+ec&%wpVy!20xhWg*o_u&VAdaLy(hxivqi>FPZ$Fz@i;!b>eM}0sX z$`^6M58hLLoSl7oOSSK!@n4NzJmhC}o*((uJ3id{!Peayic9aS-VICc_f@|)I?rE{ zyRY{(H;Ajd+E>00wC>1IUQJDp+&g$L`Q-18{O6qX)XSZ+@0|Ln`mLTjr~cr*)FJC$ z&z`t)Z}IhL=T}e1r!P+VSD!RD%P;4be?#k`JUA=6<_7QUz1e5S`Sln1yG7^SB;;>? zb5C(E-^ZmdUd8D^=ej2O@1idYFAOh8?%kh@<_B>-I(_$>AN^LSKKNeHC$`L<_u@zQ zS|qu+@orzs?#S>9;ork8!qr>K#JS*f!c9+er zKF6Qp@wL|7o!;2!BcmS-HwvdE*GJS(by!~TLNTZef3&B*N?ofdi(*lPM;Ak?kB#*lkX&cP&a%}Ij6qxNavOhx3{0TN6RmM z@m@={UL5k1`#9GH>8YdY(Z=!lX_fYum+II_@$JJu>ZiHH{rPLH){Tel=iT*(ZIY|k z_ogTAcaHYH&LOYWUwrv{Wb4L+@=kvoh<|qbi@#l>cS}Ag`ZwX}@$KWi`AuA1mY#m` zbhLY$quj$ey>sdds9Qtz#MgT%+TT6pxww^g^4mKe*S*An-z`_ju6UC-e($&@xwz6- zy|4PDzR}ySb^7Pg?eE>q1M)8tx;!s^@A%-tEL_edoMr{ngy<{Cg#r z7v5Js>nqNCTIYH(^sY-pXKn#;?!4^GPfvWv19isxs#D?*{aov>44qq^-Pi$MN$xx4 z_wn_kyQA$lG5XN(jPU(%Amp#bvwux|@nj$K-m%HOlk+?nUtFo@;&F6x_Yo(zwC~l) z{}v90&VPNh_jEqzdOEp!qkdeKeSO1uo%8hM1FhdY)NjP8_gXN0^9qe*Ge0O)BMZ>$3%Qtg|bKjR-+?lKR zeMEBgjUUC$rpeXc(e1lu{1c+B^N#9=KB11vOMd3ByR+w2 zj_>}*NAsJ0@`vnvBl#-P;_jgM{|bTnG%5YDFa>pKSbTnChrij^7q)NxFIp#FJ`(Nk z7w&aN`r_c^XmhT(lsA_p=WqU(zst7IQQ@c47yk!D`~Dt?zA1F?C8B-DiGzQ}XUE?k z)DiOc($hEQM9WYAzajbTJAj{GPk&q}?!D{t@!{d!SG=h!=6~<2FZhlT#}8&tT)GFj zxcggr)5C4scOY7w@I9a({#$yxr^gRt;`@%{=Vj9qr<+B)@7B@ceZTD78>-{>H^1@A z-L3cD{3Pzw548BCXMR^V7SG<>;pmV*-N*e_?ws;n9_`b*jgp(|y}Q1BU3&VWeyM(n z`^o8T9y;%H`0c}w;#z+?E4?*Bem$vm!^69iuNvwr{5B&!{&!w^tlwSTdhzaeVR{!P z*YCx{)cD@}s`h&;K2UemVgB19ecwgy4b>0v_;d#KrTNkN(l*%}6JOql1NP)Azll@x zjr)1m2eLO1TKB2;8w|zoy+ia4jXon3k8tzWAD#S!=%=F9GkW3xeP#OMR^1R+k7bAd z#icoNa`IKPBQ7?G78mEH=RL&PqVX3D-))_DRG+-tKzh4}TcvNF-7)@aA%E(Bi>B{; zd}MkfY`F;NV)_HH|R|om=+4Kg&uZH^Qf!Wp1_*Xr0 zE`DN1-c8S*KBr!&gTKuFilOz+ze#-e+%S9U$9c)s(N~lIIov5!H!hBLZgb4p+3_B~ zO;4P=-!<|1)BanwezoN6@UMN;1#6>CW&g-f9lkUA(Qx7P-Oukq^z{w% zggDr`bz6k&nJ1kGI=B7xjlZ>CoS0+$4mqcDx<9|GXWp5f_xgHv$As3+i=G&&-`?qn z_B$X{&zw&@ey#Q5jz849@vS#U^P@QTZu+l#TYq%x`BNUwXy5yj%UAhre%FtzH-B8v z`pv^1XXl`3b#1Zu!@|2lai>m+)2|M(s~(JN-ImF3k5=#0ZRc7&J@wd}pdaDe-<&!k zyS`7gF@)9~bIh{JeALP_Mn$E%DWF_1XK1BlD@aq0jI1yraDOSN8V{y|?qY zw>+BF`Ogb)Yn^zMw<{#q_vGuo=}!yYSDcTGe_kjq(c<-r*NLhIKO*-J3W2S zdwMtD%jSm{+h3os&$98a2$u-e5%$HCI8{GS>>TR!+}8gtzBb~E3ydOWR+v=je zVQzIE^+z31-@Jo+sz2go?an73+{eBBPW1cEuV0Bv^+284IC~3*`to<8#jkg_p1;JE z`uBeO>PPlfXZgYR#?Vh+&R4;|A;GbH<UkgTyYwx~8>*gm{Pd1JgXJezE4*7e5)=!Q<5b7tJB^QtKW$*a> zt1hh+U){Pbd-sRt*|(xE3&r6c(Z0jPqx`mBoW0(DueHv1_$JZ51MDMi-NW}3KSA#z z9!Is0x@u0je#pM!ey{lYq4|NmN!gP(;@SMMOLFh858j(ybB6c+OMLT}I>R60?-$wC zZ~19d>nDW+;VS8miyj~PJ%S(4O;4PPi#^jnHq;;Z>Ad)(L;DRqXW|9q_n7umPnL-u zNKt*0_vWI(FV-%YC|_a6F``X+w8yFPkh`+8^dsk(DvcD*Nm zc~|~fG&?VbYliY=zw8Wzr-jdk;zu0HYk4M5`CZ*G$9{1C?8$TaDG!`eU3Wfn&sST& zedm{l{I^Ja=TeW+h3tEG{c>LXv8`v%JoK65Q^M;*`Q&$A{hJ@n9pXj#tT3*zSP_!Bxe|BEMaWKPqc`0u3FJFoMN&fXWo ziQ(768$;zAv{I(y>D?>yEGBsUNH{b_^v-a|a` z2fvGF`>8YH!Fu!Xd)YO|?-adzctHNUBzjEfJpO(rZp6tW>21;asnNrtcZ^oo#JRe~ zKcmtc-8%D$`te+H{dID*dBERkR*f&O<(K-S&kYQ*$8VR#pB3sO>W%uVe#y7nI=6k@ zOP%DJoow!o3KOBFf(0jiSy3K)?&$d9 zQ#}&@>I6TfzQC~l;XRyN9?Nt0S~z`ntW#gj#WOpf{>0C}iT_@9#Itx4pZd%W>Ae%S%U{p_CJU%e5R>WBB2ukJCtbDb6rhWz6^`ZIrd=b7o- zf1}Q4p7c)qxoi6B*HY2{T2u2Vg8%ZdDSEL9gNTK{N+3Y$<3+a zcTVTqI{B;Nuj7k1?}JIWt^eJeZqr}OA%PbJsayo9r^iBs=$cjwT@ z_+#nT@%z;1lSA>QKe^wl$zKmA&~G1g$o%U4S4ht}#MS)x1EG4N?j4ccS;^HCb-;U= z8=N0~P3!G%zWYS>_-*m@hWcH;pPiha#yd}FUOF`T$n@2*)#Cpvln)C=s~^)^H!i+& z(bFg3V(EJ?ea|^x&d#LJyWQA2cKAzv52SCN@Evh?{GoaNc>IZ>_hnbT5l_COhQ^2Q zwvpMH7W&`|FRlw4UG0gTvzU!`j)G zhc8F(nciPQ_4koz_1FE^icfw{c6=v@uY=MX9Xg*lGw4I^zMtlfAnq9;$B|w8h@+wMu*3Q;)-0IGRNpk>c#l%J=gjd zLiOF8aanqwN->mHC@`&`iezW+|nj=pkr zw74=q+#CO&p}HZz`F+RavqE)5Tn&rwUGoobyr-gvv+0JJ-qKgc3uwATSX6s>W26@C40WZ)Ia^udDRu?LYuGEAMYz3 z=`Y&;>f1rlhqpg^UVMFkJ^Allz1y7h)nWIS|H~#nBpe8L4#$VrbdH;&)i3fhlOGwn zmwx?deCLq&_Sq&mzpMND6o08}=0f}UJ;VGWe$}1NXJ35sU)Nb+r~cu73#6}a%Fo}X zC$9O|_l3HE-a7rUq59_cB!2XrEf35g&a0limz@#e^PSUoxO{jwKL5yv`SG25zwGGO z;> zb#!j)#HD$`ev^`)HN>9$;TONhtdXAiVvA_7VE>qz-Og%KFP0a4a;m1SoAP(fS`9i&3K0BXo9e-L+|Ni9OYcSfJ*MXPttyI=D2lHV4xcWShLB+r(K?;Otk ziTLtXym<%lI4->d+y86P{9*my;`7^*(SK;2`}=#5`OiG$p0}hguI6V){Cp_+>hULr z_lD|$d3b#Mk)gP7K7J7=<|c7(znxmo4-ZC*my@IAmwTGi#OXKFGaoG*Enipdej`HX z5U1j5-_~6gx);By@6KsX60iErUF~BYbuRyX9Qkghudj`1{hr~< z*;miJJGnZ!PkJL-=RM?!{O~>fYI^F@Nv-GS;mMs}-t8XWKI#>}^Y>tS*JgL`P#nlZ zb#U9{-sgDd3Dqt2*S%hDpDWXo2czQ4)3=k`_u=UE(i2a+L~of~9LYoZ!*A-o^#`_| z|IOFt-p7-lo_*gP`tw!E`Pq5hUtQyGe%LL0;{1T@JRJJ&5{Kwjle6P~;#VB_PT4Pg z_3VQ7T{OOW;JoVKdFkm7+eMp023oJq>0i$2-s-#autQ&7nESrfKAUCVyImQrKUnAf z_a*mxpT2udeEG6&w7K!oXn7%i_J}WE<xpV0I@}KWT^N~E)-{htG zY93Q}yc_?kLj#@fmCof}d&fU4yfNH2y@Akso0sk5UIVRrC;idU4}|7e@1!rCn|vVr zemFNf&czQClIuSUN6Vx6(fY{9?8tk5_Isl`B~Pzw{qwCmBD^BLb1xA+5H273oo|-) z@#U|0zaYMNk+=MMb?e2u_h#q**7N^s(cazp#5Mld^o|Mj4{@wcmQE@96&CZy@8sE3Szd0P-WKV(&)4F!^HlPa;{POiV)V#R zeJ~I3w>mp2y*pa(UB#DoV{ggy?roiZyjOgA{9N+!q4>BYntlHLM|^Slv1suu4)rN> zf%}}&{+G3mp7$S&KPSBt!sEj;!@bfIR|mw`r~DmEKV2;SDe0+4Go!zre35ANc9r=0 z)l13cg*sp#{yRH8aU&o2Q@p4b{Ip8z)XnwUU*4Ic*NE@kox|@x-bFm>JNkgQl2*O?GrsI`m}KKP(3j( zsrRFj^SgMQ8DG6tC(QxkYOD4gNS~hgJt96ojZa^lcsyF($gc}q=lta6zH3@{MtZMB zpA(AT8>89vF4pn)mg&6^nj6Ib==k#G`Sz9nmq&|Z@g#n|`(SqDx%fP}_2yIO;UDLi znVx$6a`yOpescGHH@Uc(n4DejIX=ERC9j8+oWT%BD#dLT3>sGsgB z-qgPd*^?)`wcdHW+Z*ZmT|s{C5MQ3MJ2w8`v*-Te)Hx1L&wldqp7@^%H_XnY=xf6T z<2w((j*4%7+9ujN%9|aqVm z^^xrAU$-a-mPa} z9U2ke{^Egu`EhvrShrWSIMWC0XO3`R{aL-xAAj8b%ZB9g?GN$U)u;K-oGTt5P5-EH z)y@f@i*|48(Cui@x=zYVB<6jWX@868C&ZtYH;=h`{{k^;TxOVcB z)89IrA1*mWer&TuYJ_JL(|(Y9D1Jb$M?Iuda_dcdG}+}QzyKu z{_kDZYTeE0d$(W4=f5SA4;!-2%=jM*`AfX(bBPxaSG&Xz$LxF}zB;5IU*5ic*H9Pd zr~i65HNBUk7Y*v0+y}x?t{QBV*tv6?zgLaJXd)d3V$KBm`e0uIJu4cwp7xnqU_~Od@ zO^$zCcHNhq8SxKIPd#}$njiHa`sV(pTK{bN{CZ^k8S!@wcd=inE?%8n{T~-CPQ2T! z`1^z#Wykm8cJW7r&Ut0DIqZQ|r9n=;+Hr z^UE#K4~84G-n%UpP2c;77w_9m^A z{LXJ1#Gf6CUvWJvJI=>1&i`=oy<7i$^g#62!^t5%aj(y?ub=zQFgJ*U$FoCUe3`f9 zhxgD&{C98vYQ20`m&KF1cTe{apYq=K$1Ta#m7TI{o$oYrwZ6q~>!vRrE|1UMTX}I~e0}RXt><_9d+!I6Z{Ip`EKiP%Z@+`nH;?U>91bRzPuoSW9d4Aq zI`iZB?E7vR8qXJ|w_)q$!>#fCUaVeUAOFeFJ>QM-uDS{N?L*laNbj}KT)9p1 z6GM5S-me&6+;7!-aj||h`6B7B8cs{D{!WQ@?s?H?rmz1^imz^|2h-Ez&r#_?-wWRJ z#q`bJTc$4`%qb_Nw_kF3BLCcPpXBNgJFCT)N9v6ET>U-P$o;_ri7KQ&w`)PMDZ!T9pS zyZc>fdUEk$J{s5k#Lt}c=7#$I_-Oq^eR#L^-b;NwGQN5Gsn&ZB`N}Vwr*BU1PSaX9 zCY&6aJIsr-;+t#KPxW?*1Fe&9*3&yC`C#~5_+I#n?CN{R#@AQH zFZuA~@{AvcWk)=U_rdhl@uj2986%_R#m3R{L)>|HewP@ay3Tp*Zla=Vjl!xyR+{s{_AE zk003ej>jdR+dB1Phxpru-woOQV*2{_4$<;WKD(cM*7wGzzfpE4M*IGHJ^H}p-fzeF zg1x)<_7h_JHMU$ zXW?$)Vxj!@E^|BA%pvy?=k`+%KAatKWBpUaMW`yEd_^W#bB zx%YC>3x=C!PuTB(<=KVt&6)V_&3~(;r|+vH-^!kRRgcBz%+?(l^0T<|F8t;k`oDT7Z{Etz z@X$K-aI5&ob}sp`YJB&)H~H!D?~Xn-TAa%t`+HaC-m-P_b;al%vg5w;P5smdyt97d z`(|%;hn!n{zny%3a_8GI{wty1N7X&+$0XNh#fv_CR&w?Bl;~fEey9CQ`}4CtaC&-| zgkMWfUM&+}f6(`4#9t%*anbIxK=kzF&M!X?Prg%f{b^>jx+V_f(Q>UbU+9P8%io_r zxL5YxXq~#@_ZNN1KH>;p+^p34{ldZQ%ENP`#oK-9v8$fwSL&yCoY?w3L;PE_!>+m^ zF5PFKb^NyEkp0(=J}y)z%v<`6dB=Os%I?V4c?b1!Fu8jk6fMs56>|AHEPF4eZ=L+u zG`{zknf~PX`py@k-N$~uGxXCf(%UXO_eHN9eMPi6O8q!5zWh`d{QhuE@)xpmMl^fp zNB=xj4<Z<>6~P#ria`qSyD+s{RBk(^(}x%%R} z<$?6Y*ZgR8K_A~K{*OZWD8El{KlMT%9+kd2^=x{B@#W3L=tV>MCU5LBCAo9Yj^+p7 z*{8LSxn){<&Oy(;o#XxV?7M&T=Fc}d9e;X<>sz;I`Y%MA!FI;wdSd$gS1 z%XgAG<$n7OIpn8U-v_tayoOy-#v=nq6_|yWO0&M(fm}Z$`_r^`a+)>WnTl(x? z>520XMcw0xAH`vlP@FE zH@}|IepiR$V0^TDEE+8@=VZsZ)fM)qC09olj5fc|%bvKGFPElwTJphgcxYZ;D_XqE z2lI~j`$+cWrFU67>YsPtE&;IuMYE$%tdjneXiQA9mSJ+AwFKq-UH$0;ccB$9sOo}^WjO+eplrO@nzp9)4Ml& zbHl0epNt+BjtM6ve=}OVO^8-U7KxslKL0%({YrBF*RRD9`r7pGO5gbpiEqyIPGjP$ zANuZ`_}*=Pw0IjItsba5>gvSgcZJ(_o*koShTe~Td9ZK#;%=E}bHiD!7Z39D*Xa#U zF23n+7T-L#bF?{^ztm&jH}c|=^nIso(SGJQ@i0IB-1OC(FGSl{-EogqTc`ftnEtDw z_n#2$z8kde;rQ!>^TK08=RGZbaVp;A&3ei0H>&mgBcH^tJkrlkOkaHIyUSwi1;X(1e1AlvmAEtL_s4uyP{mlRB$=o6PtPnjj6mNI7|Mj8Y zx8#fb@VkWiJ(wN+*!;0reEsPC=%0uCw~up}UtUSRR`NfF`u=UnPYr(%ipNW%M}{Z1 zUO!g9)c2E;|0ewQ+V%v=1deq5d2i16%C|9dO?-0-{E9~p`Zb?KAoZ631g zJ6AoOnf#9-Sh|16C-0!{dguGn8|)nN>G){zE`QZI`7h4Z`zNv|-uTx!%v-+q)i3XQ zZu_nsitqcgD?aW^?ws;}`}F1}cfS4OoBw7<&kT=9Uwvgy9abkcO5gW~`E>F4-cOyp zExT7G*Eii$9aJwzr2m)n9}o2{eNz29I{nqtn-gH3?tGCbPq5JaN%=FAX>qUPneeWaA)Y(|RaeYq z{IAd4mcDZz5Y4W6^U3(WT)0^4Uy9a$#G^dY zcf|QMt>f>@qF>04x}=_qyM@x*AU*kTLG(95_q{fHo7U-b8$@3b-kRRkq3>P!>)q5x zex8{AgpePfi53s~qP`%W<<0TwKN`*n7Y^kgKdV>1$ITtH(qFE7jftKaJuI4k{~kRz zydxZ!{!yX1Y47CDcWrXm;cLj6M?o|pYG$*p&tSb9YS9n zX#MC=JyIvshb@zzA3Co*{Z96u3&ov0zahRk<;T&Fhw@n*cn^O2U+Kxy(b3{voay_I zcg_jn-0+mvspr1;2a^9Xx$gw=a8h#j7FYHikz5@WKl145w57amDFVEG9%hUIM8%Mw0e$&EPA^*&d zmjCSgjuJQO-__|4WS_s}!^Gr2NxpRSJ<;ArfBIZ}dE}k-s|S;>ksZG~EfcLDjg8(e z+&=tB>z#8%w7zSuGUu6323o&QD8Ixd|1XsMzU(d@ny19svGK*nD$(lIdD+=Bd@s~j zoOf#c-NT*3>%y7Y-^5WZgTm) zas@Sl%=Y<3?G4+>9CE>8WvsSfbh-9zk)m)o-=ujT*9`1<6Vt(OR-v^@TE^e4i#T7P5oiqY<&e(V>2 z-;jUrYn}MgN7d(7(^rS}H}!gx*6aWKMyp5qi+B?^>f=kqVPSrbYA9r=xEO<(d2*9$!4j2mJ}o zOi$h(8Lf`Er}tA|=cdPR?(aV0dRTg24}TkO6rPxW#HBoNZ}G=Jho^rRJJIg5WPIPD z|A;oX$;+MN`@UTwn%@_VW_QQ-f#z)eQGJqEGh4rUXujMsJMvY0UJ#$Z)E{%Cxmx`4 zqx@F)7Riox7jIX`_q`#Ge$@H=UUF%A{9|6(E&gfY&sxXcsOY)L`B5F>$1{^Zn7;bz zyy9%R9}ee*SA>^`=cKQm`W@xH+Qc+wEeuZcplL{`pcQ=9~1KDOVRT27p;@$ zCk&D6+h2;WZ>gW+_n7P~lb*h#uc#C5CBC+A-9@eYY^XlG5WQ3QO8SpR-xjSdZ5pla zsH3N+KO=lBy>~e8I5he2_^*WK7J2q+eD}Yyb?%|Ac{k`j zC#0`lJrey|sJ^IQySCq9;q~e9?<&!YCHI~e#<#yX(67{c^Uss5M?cVd^-A5Gm7Y0~ zpPq<+Z|nGbSp2!xh2rLm(eh|w^r-OqP`&cqJ+pm&VZZcz=bj%uCi$b`-1z$Hu=vhx zf9H8Uxj6KDhwr_Mv$uD8>b^YQExzw=b>7@3uD+L^I3oW-_Qiua)xEAw{zmK7G2h|d z^@F)KJ-@%1oBkvI4%xdSdVIA0e_ymbbFNY8(^IGJPhbAKr+PDx94+3)$3Hi_`$wCP z{}A7NV84^nleb4k?-G*B_e0|I)3({w2c1jadq;I;sn&m~b)%wJ3C#~PlD{2)cIf@p zCG~7>@-gA6t#{w`;x8Hh;Ly9yjXplT(a{@)@^fVkYyuV!U;z3llu8XliNy@&bZ#N_Jsh}NqkKS-|5eInZXt{lBlX#LXB-^m`mL!x&K z|C}B_yc8|oUykOlbz1j$d~tL~wD{aPy#wOU48`Z`(F>*L{dSM9FNlW?ZUzeD@IdHD3!se|ST zeQ8qiUBiE7|IBD}*B|45ICM_u`uFrN3-tl<`bd25^r_D_N^!V!M_-OIw zKH^_IsSlq^e`Gi&{6nbEywrZ;hCS=eG5VKDSqAS(DoG{S4BS%I-mc3 zVA1#+guWNm3w_P}^NsZO4b{PuI_LcO`nY^hhn#=A^vyHg)qNIAeq;7-i@rVDJ$@Nq z-#amSiS*QUdHYEGpNHp!*85J-zwPgy`tF}vC(qo=+;DMn@9Nzz?|cg6NUzPR5z z{^IeM4D|zbgJ1tw@|D6VtvC0JqZN~XE0q6>MvGJTUoF0Sx#w52FF*DDW78MM52QCT z{`6@5VYBG*;WeS}W#|8A{0qW2!l|J=^L?(am{;F!fA3-bS*i5{q5JT=c=9gJyG-`H zr+&CWe11J9yT0$m9Y0>1e1522Z5=HheOJnBeNkNb?!fmR-r)l03e^SYl82+(SDh38 z{JKZ$-w7`cp9&9az5Uhm5%I;H@6YGrFF(ZYYw^AN6Uk@9=XdAWFgd#`q$iK~Q(aZ( z)^DA>nA8UFtqCc5@ z$x!|eM9X9SLcZ$f;^WuZ+p~3>Mt?Vy7wYuL^wd@7`cC?OPr4wzaq-R5@{WJ#i9dNV zx^*ju<}LB~dVKX`M)b7yT`#;j^sXbK*9`yAI{o09XnoKwg&Hi&u4!A3m3!`E32{d?dcO79Zvp`M6s8Q_{E3 z?$HN?=9o3oeIx#)_49}jq<;mp3L+__ASc1-*Ct|^g;LDB{_eqoBEG_zIuAwAXZ zZ;3B{?ub4x^!-1fee`ef&tKx{l-BK-{(<3?`0AH@IyJd|KOwyfLvgDP=>xutrnX)l zk+U~GJMvn6RA=--@9F&xYu#lkC;TBE ze%`tj!WBdDaA~ypLmg3{mq^b4<_GVnem#)h&_1KSd^@?kF_&!7y~l^@r~ay6%0GD} z&cy#?*{83M{XV)H@$VD)obs)dwhAJe~4r8`~CFP)!EVdmpIz3 z{eBtV8QzkQnY#K)6x5clS6fBp#8kd;?ergtmsu* z_g?sJe04+}J1_p*L-tdjM#bMDJ@NPBA$)aY=lBnV=KT|*e;A6-tD~>a{%51bk9cN( zljO&S%eU_SX!rA8`o><#Z%BW3xKYUO>b5x5?>1|_`g}4Z|w%D(p(H;bhAM(F(V)V?3gnd#YIebHA(wa=LJ%p>k0 z55%u|@E7U7-+Fb{95FSy_&YSczl8ki{nWu%lB?I=X~Fmlh2rJ(_VNC&MUPK!Z1kGZ z>WF)OCAs{amfl+NH;T63$mqqB+jlzsaH;erhkM2Uc<8(8`DpXTZs~h}`E^nJ^}-3^ zp{*B(`u3Xf)l1){@_WhT`k(n9|E=sl8mimsxcjf4{MFDrudW{5KF%$Fu1J5~(A+Te zz4-C`^z==0=#S#-i{f=&{BfbY*st^Or+D(tOC@(d^OgKkKj=S`ef7>gr({Qb%OCfi z-a6~uLtgU-+PUPDb-tVS%-;C$-uB-vG(T;gT)dfwys!6kK6Q43)*X|*$BWLF*Ke|qZH{MPAL`kB1sU;R&h9h?1i!W~2Nnm(@n&S{_d>4~G6 z@mCG+3dKLa%agZ~qs6iM|3vb6?ITb5>8r^v46hFDXcTJ9oHr@=?*^^ZNMT z3+3Z6(cbsn=+i^rGn-{k{T65X!=&`pms6v~_2g)Cn!J<`&t!jY=(}gX*1ZybIGi1B zoW6c}UVQoUfqr`C6z5l`Zg1Tq>D?K9R`j{ytoYvH{P;`6KPem^z8Ovk&6A5od)Jkt z)sb&To4bG0KI-A4$<+_>rN2Fz{w1L}^PR6QypWyWh2}NC56+K2EY#1$_rCGX12dx4 z75#Tie0hi7qw}p6o}J!W(dzLE@zt@p(dvo1|Ht_1g+Aflz8B3K>W_EXr1kpDI-Ntj ztd@M!_}4~H3x|j5&A&&>lM`E~PW~#nJbx|u-tpfHonvBhb!J?0^Sk_02b}AI^vo}( zWk>v}$L46~5WnL8$<{3wej&SaL+|nPXy@_nzlbkh)x$00%V%-z_XP27-hC;3d2?Ck zQLi5uA~)|_Z(r}STl)Ignb|udK0WcbX!612Cxqe^-V*=xP@c&fedX}v>hN!)-A8=R zi|@X^Tg{Ju>RjT9e}33HakY8$m(%zA6+7xO|Et#vwN7187reLmU7wK8&$W)7wW8H4 z@he|HoLoLliJsDV)gSsFeD|a$uGLfVG{5!YK)l}Gy5ENaIMM?^mbqJ@hAeGdcTjhVKvA@4uqu&l=H3Wp{eCc-KeGMaw1s zVfypJt>e3o`tg3=Wz^P%|X-$U@z7pMB6IQm<1{#IY)k$UW%<WhDjR=37R^V^p?$L`_y|0z;_FxPR2`e0 z-XiH860Luk6Xbz>KRUfPLf-*DL&POUnO5U1^wr5LqrHRv z_R;wAMqMA4oyWo_()&}mN&LS@PYl%=^~1cVKCGYKP=AYyyON8eWrxt>@P+o*H{5Hv z^yTloXzwjv_2Q^2_6!M=y=PeE7eT zPl~=Z)W6k%Y4NW}U%c|uL-G9{Hzxgu@A+%T^qhm=c8z~T_Rb8= z0s4ykY^9Of?ZrB29~J6gX%xMllnA1$8Wjz27% zJLEj?M~fSCfp|o-vt0J~Y2B{j72#q-?7bR2rF9pEe$TooxjZ~4THLGG5428y4`=d~?fs@h^$be-A~A|1+XLAC7IEzQ1C8eM5a&Cq4S&>qGH>98PY%ej(ny z(`L!l{khR!X}vs|A3ZaCBoyC!woY6=m3+gY1r^F1mb?oV&&@Pu$6l((~5e|-Fzp?WB;=cMQN8gVDD zwrkxx;U?)H9X%o1`MwbSKyrC?S$y`*FV1mc@|RktU-+J|kGi5xst4=0?(T40XdYF6 z)q%~DThA{$bq;aAe0usEzla<0BrovS$PWAFqLJB?A5TS(55L#Cqr&sz&ksk3`v36g z|0}#a92<@ae;AsdzLNdN;`=>Xo_ZJYsUF%^pDfu~}bIT|7>!jAJtC0(M@AF%)eq0glUi^Dt{6B~Cm!E!?;pyQK zp}O~6w0B%0JLLZry=ADb|7Y~%@RaO48a*cZrf^jJzlOJj<3e+{ci|WF9e+$n|B3LG zP<)$D*cV6Qp8v$5yzw3Gz4jhr*Z0E;-RITh;y_%;H}%-LuWsFXtrM@`kH%NeS5IF( zx;?!&L+_)`&W+C>{G;DbN^b7qubbkFck#Dq`}>ahVYIqwM>^SG&d+KW|5`A8y<^%?qnW&kw!Rchfh&trYED2BSB@ z59O=)Q%~i=K>LUn{o?u7e=+o)&gq`7CEq7IJBIq8dTcJx-{loQOlqBXk;n71|Cvy} z@$ZEA;%0;F$OC=ysr2;0Yop!co6(1a_hm#Dx_}V#r`RI3* zX{~4fPwD+K{ym{Sv}SU3K->?H|4#T~s6IaueMC4dRR14rA9- zslOA`Qy0aPI5@ubhllFpvaLHVzIWn>SK{;U%R}}NU!O})oco<;$MofsKH#30C)YRE zi@qm(ARLzcpGCWmeY~T-;yw0FU!C&ZwOjnH!aK9?eCCgjCI51A^~m`~#P>b%*=YS_ zX7p{LxKY>r{_t?}AG9C2{U3_Y&-%VNU9RY-e$p58Lw$gK{=1|7$S;lNzvZ+0+xWiouZ|WM7ezmxzV~=L zzB;jU^!?!vL-WF@X!YKD{bV2O+D9C&AAfdy@$}XB>xNGxKPL410l%4-)JyADOaJZg z`0VKe>ifCz?f>=k#qEyCUyJ{dP#)ZwTpxWr`i=C|`Ek+w;mCl|LH4Y4C1#L?%{S5MW2(ed|a{k9?h=s)YlzcqZNb?WL5qVW%l)<0iv zo!_6ukvRQw`g23~{!_I3T^p@Fi<7gmFRs6k-Y%hh)9>y`-~4b?dg6nBE{wlHI4)c- z+%|kVdn<jYzsg(lqW-pA_QaJrO8!5Pe2ex`pUoZSa_|2}`tIl4gYmzTolQgM zJ3sl3@vn+DA0823zC4(Gp->*H@85_oe)OSB;+vD?i#jU4R&KxllfHO(EPBImZhF2) zu87ay>eNB;_e+0XsQ%k`llYsbe{)FByP6O9@3Qp7t2sy?+%UT>FLMnqc{+s|DL|M@E(J$KRlcq z{wS1BH%7|`{qe5&{PTA7i=j9>IQrGl_k(x6p!1Ckt$QN<%i}w@xYmcn$rsW$7sx~X zXrtun)a}vcq9dc%50`0Q=N3;VCFhrKL_e6m`sW?&zj1ov{kiDt!as)kgu3{b`1KJ2cw8^kH+*el*cQz5BvJ4ee_qqGpH}$&CY~Sy&jQ0{YTyS_w>vS zYo#|QzP>9S^(Xbvxfe@+=g|9&ik64=Z<_@U?H}rlI-*bgzV#EsDWQ6ORO`iwxKNk* zQ(sqSPRx#YUM<@D$j6_>S08t6Klivl`l0l^`kBk0d`t0u>e?ok9b8NK!{=?`KLg)3JY@Tth<NL)r0$zv+sV><8R(Mj*lJ} zo)CYrXz~1s_~Pd4$<;k`u6x{@{ zbEo<=G;hS?P1)0@%^|18=g%vm%{%4_-?taG??oZM{GxT!!uLb*H7xqg@Tk_CcSgoP zGaMDlU-7R0O>F(!>CKK-Kl$C9AbxL2Z}srb*2|;6N2{;TMzg1%?HJ!&uzj@mHLw3Z zzBs)in!iu!9DdhPZ{+)D(pP78j@EDYkA62)4|j;(C$zu(zQ2912)*OA>8o4vN1Zx6 z`R7_UFEnrR6F;b5Z?$fz^yR1b6bIgOzx2#^ueIKNwu@df^bSWw|EcxcgyP3MeQA9A zztcMX$$UIK{`%p8t#kg_@sEx_KU!Xkb9wm1oV$<3qclJlxR=1z4=pL#rd^2Iy5zj(QpUi;k??h=1=wE2GH z_~KK3z8Zi3^u@=4@%=vVqiFfJShTzu7j54BT(my%R1*!I3Zf!HNT3(_mcmg@Qc|Q8NL%=et1{&u5+!F-a)O~Cgd;s znS0bdb>QRat55T@$G?A!{zhp3pSJEx;pyT34lfVYFZE5m6xZH~f9dh--`W?wbTt2{ zujV1&`QlPv7LV5ZKK=J0_QkRJ;hE&((|%u!ze4xiJ$mnGe%7zltxu;%J`gQlPH6qL zp>_J*_lM|jpWfv7{CQZk`CDBR-@}vp`{BPu`yTm5^!njH!V}x?!_f!$WM2eaqCe#e*3Kg<67)~SncL_Zyl z3{OqpJBkPKsE??lho>(N-QV0IpUgYYwvYSl60Lvz-_iPoc|;$Qr{aG~`yU%F)B5eg z$)S7f7Cj04FpNm#E_;Ge}za#0-SEavPa`CxNw7C3zdJlyBK0H13P@KIKUp+mq zb^J6Tx%0V)ymbGKT4yd>zjf*iKfIKl_x)+~C82k^we@~?^xpRSP5Q5f{Gwm_&YYP1 z^Whoc)7hUAt*?0Z$C6){+`ZIS-?`@E^V7Fp+{hns{91Z*v%6UIveADH?f>KGx6<=& z&b>x*etk0fm+9Xb&9Ca%592SD-bbT95pElw-}v{%egyQh6==(zFv#qdnK#aBPoUwOGma{1(33&lS&G$+r`?!5S8L-VM4(>c`}@3my>z59C6@^01a zc^`B4uIU|?Tz>8otzW21`^G;r93IXNPYs{V?){;DH9cA$9+>_&!gs^hLiO5wz|Zn# zUiKDC|B+C?{$uiW;y)FAOSE@&e*U>O`K9Tr1K#_XCq3_CU-f_K z^vz#iO<(+pCw}-*`qM-CF`{+)=h?~C$*ZHy`Qma|eD{`@^5O90=ZD_WIpn#0#f|xG z*Vd^M@=~3!?OT-`8-%uVhb1JR4OAAhSGyCnCn zYehT9?ABcrvZH^Plg>~6pRL#5ro`VVd@>vtUYfr8v2^rF$rlX|3U5z7KRh8kIk|U| zpX1}dlAikF`-ndlPCk%6`&Y&HZu~)Rj#;gB$D}VV)HijLU%!){c=~AcQ{j5qJ3QK) z{^|Jo9DhF%fBp224>yZ{Ui2~1;@Z1A&qL|$9X=7tSN44msRQD1MEc@mwdkL<@AlEJ zg_A>ZA#T{w2Tx7!kkGutzVlqwKIV&+qHhb=Y`r<`sAzstS5}N~ZkyP8`>P}3_xtHT z5Y7(8<65mZr>1PkMUS9al{Z@9o)9vZ~NBl*i&D-Ms zg81|wPk)(E+{{aF)6o5|NROYygW$z zCvMJ6Zr)MX#KTd^y(hnW7yGG;%V&Qebg%Kz7l-2Xt?WM&|Cwlhni7AX`0_?w{77=Y zznmSd->lI(_ZCM#PH!MIubh_tR`IR(T{kAa^G=R7f2yO)#-E+N(b4>b=I@Us-!AO7HmakMZS^{^;FKPw&QX{q)U|?5W$k zrYGO_NZ-2+kDi>K`BZ)NT_Df4Pv7r{@?*XD=E!5R_h5YY;!pL)96K-lH^YmZTjqwZrwYf`s4k@p}cm^XVZ7zo3hI<^87pT)g$M4DE@~- z`7T}-&%SrsJlc80{pVUQzYa=oQvA<^-q|}{75{+r-Cw=>a{O)6Qx_hK)^Dx5HGTPT zLiAFhIq}o2KQ)}4eDBcv^VePReFwS+e;=M5@hZ>c%P(91Kzi<{p6nZczx4KsekU|< z=>I>AKO%ghb?)u&3}?pwWO{z@5EpyI|6|Axw?^+09-5t{qy2vWT>Km3Zy0?}^hbuQ zlYeuPuaJCbe)F?=c20WNwr->758^&rU%xZGed5oE-XQw+Xmv)u-JIOH)j5CX@lMVy zAI!1(#8TPcGvse}cW)p0Y~C?X(7U7c9}V4K{}(^}<-OeNEA8`k>;EfUJly~Psq6m3 z{~oVAkAI6~M5jVK!89f<8pKE$Q$rioU5zyHqjus)k&(zs5#`#eeiWUd-;@kv7;4c} zDzY_9q+HVYxf5j)RnxI5QiE16qG__hW)ZWhJJkO3eE8#YUHjK@&VBB4pZolHo%{Xy zrZcU7c3}OZ@%rb-@h?Zk%X?`1?C>AHE=sg3$^!5Qxy z_gk3#lcQ^+zMpQ3zdCwY^rq-=6z<&;&+o@OAMT$VFJ9{1_w^m+yC$+8*%p>&58?z&R`gTI{jp=_f{*}hYw>P4NMnS5r{y!k)zM@09D zemlFz$HRpWCFd7Cb^9OEH|OA$Ix~;-xjYVM*ZKT|XFqIR{jBU9_dGw|K5_RR_g$(V z{ElLt$s@nSUq0aryn_euYZ!8ii>kU-u?LFcb6x!t1o^(o}X|{-@n(o ze5o_}aj!em`&o3`&fgIIkIBuIrSabR&V!G$T9@B*v*R7}bnD}iKNau2E4Eg`?_Q6^yXRT)=D~~U>6?>VSC^k@eRWhH$%}eh+WMvKKRI4K&r1&1 zE{|UsgV+KMz%g}wVdu!#+W2Fm zk7ZBXZ-~FQ^>4(V9RJDy^?mZ;ABC!c=#dza7kR@#53vtarUo{?h#!deIh;k zu8WsXaoHj{Jk*c!CvJPDcXd>JkL)~k{z5!Fm=Ujk?`!|=$&ZSPH@ooWs`Slu{XIMR z_3by$zL4BKFK&H$@>8O4NMC+6z4!4hy<4IiqTUhk;r8U_rT&_g9KNcvM>|LVtHVXf z<<;-9FK34yTu^7;-@d=Z`H}XC-{$zmQE`SBW0Lz$l|S!heE}DKmA?1#aQuX*^>zSg>Y!$seb%Ub_o)bCQ4idS;=b7lO@*7u3tk=#0e zPfY$sR9}gsd#q}GhwSn5Bk^0cJ}v6^gahN@`I310>f_qv!`Xvp+a!lOe-l4Fx_xxV z_Wh6Oe#!mL0~eg9uDtvG&U|_MHb>Rz{qgD!?)zSNsCD(MUd>x`a%_5UM~{e#=MVDV zedn~U{+ugb>PkPW&*R$%Ct`=h>y!iwIu8!*kX)aRX&t_Ow)LgaxlzBH#15G^^73R0 z6QkeF?ve54@=nRkS^i8&&-|Dl50BLA$;ssb4$$-a@8O-J{`H@^aA0=KVex?@548Wz z=s!dkx9=NK`YYqNjk0gf@bl-{S)1N5@e`xI57mQzAEgfTmApUfy!KCu|4zJj)NAqT z+xzm+>ekh* z`LiPVO;Phi9r(SDo<4Seb@q|=n=4py-^u+Ci_&-cfUc?2yJ=MCn9v=_K z`Az?`*}W$!@8jY}N7dam@xJ43&A$4c6fe%NrawOUFQeXZUuzxSz!iN2KaOjk-{H*n zCCSgr&i3)%cLyflm>hnXzn^J+ht_`r4KgkQ`o%r~JOyI@}x?FF)}9qWpiWb@;j< zUO$WH?Bu&g^)Fm_CHei?T^H~DyEdL&-8e_x&qz=H#rYS>;h*z;SNYz$F+Kh1ovf}u z+4}hWRM$r*H+O#C`q@$E&X4D>xWOCWQ{GGRv~%{&1NP+ATu~>dch02tEsna!`|pe7 zFGb&rsssJ*JAfRn@K@Z==-jc{@!hgHIb4A|>Pml%Y2SqC!l-z_FZVYWw#&Zy+_Qb~ z29D`hb*&EI^%vU*FU9dMldC6nVg16+zc)R2#qWEPo0m)DS4PEaadw`JE^GbbsC>NA zI@~fh?oQ9|(4R_AUp^YIUiFXqf1~{~qwry4JUj9~JNcaGu)4Q?qZ5em6Vn zOCDZJ&$~h${-@;b1ux9GMXh^>JeqxRbpGVz{PjDA`Mgc*2Snk;m*aPcs?(d|Cq(7r zuy}FSUq|MrdbRJ_ zwJuNM!tM`SKQ%f&YVM5B&c>+xY>Hoy{*tKtn*Z<^ua5pMeQ_7Jmy`2LpXg`v=alrm z9#wA_4eVbTua4juJn$WRNP35NuhCIBc3kV~Vq)vBCWnLS!abbNzh&vKjoz1?zl~mz z9Ns(=@7=`TC(}PI`f7B2RJ~1)_j{T8zb3opjQ5TFx+i>87yD*Uy}^|~$e#MWEM9%O zpSYM;E3)@^`j5uzJMR(sfeY+SZQsMubK3ut_zB(!n;=6PEj%^+8 ztAF*Y{>{x_rvF6yr^T;`?h^fW>*lWCf8F=B*5Mson%({fTOSoKulya8eDCPxQS+JI z!9IXz@~Q8|TOG=~K6h`p50}4^pYYl~`QD=S-SSU=j7Scj_HEt#_;fscK0MyG)>jqW^kE}!DLF@1SkkiP!V_wEPJpKIT<>3cu%@*h~GVbo8Fq} zW6^0*IK5Tp>=pf!=nd(?rK6L3k1cFnUoLHZS@KJw=If;R^V5Sb`sb+R$D{{ePe?9) zBU|4#`s*lv%>#DdXdMpeANl%t_W65l`s(_;)~`yweH8wh6W#~j<)@{`uSen^&)(?x ziP25b&kd}rXZ5?h_5IVI9aRtJqrSSO^#h~u5cc;2kM=Lu= zpLjQ@clrNncFYgBVoscx9dpGy!~6Qn?LQ!@PQ82idqj5TMkhzzll)WJIVk$s^u80_ zJ-Pb+S-g7KHND%T@Xzm@)^^F? ziNfj34vp?V{O+&*%}M^>(SCg<|8tVVBm3M}-o<6rzD- zdh*MkajomqZ^WxR=a|dmv$Jnh{rFp8k{jo#xRmtV`PvTcbzmneK_&wvz%U6>7j)gaSq<3s|Zh9|8 z#qomH7bl13e~}zM$%{O)=f3EE?Rz_WM@QlHh4H={r^UlRd3-Lp`dbvQPQTMR_RA}L z;Wu0`uhqq2*;8M$;^FlD*@rjcu5R>~_sFH`-xYm0dnd-rukXRblAG7=A+B%`j+rZ0 zW>>x^#`EXKc=3Fta~@6Z-6g&cCVwIdPdCTk9ChCB;wMDe`Nw#5XaA+i_l}x3{Pez( z-#gNSOX}^T$$#B_*2b@oz8UrV&7Se{2&c?3^V@tlIJ@xc-`fv&-ip_+>ThN8q4s;1 zs|)q#owYc*Epe zn_B1pt?}M33*zN#INp1Uy)BZPe^fyg|2L!uPxamW?7`!4@drn@iK>5n%+Bu(tnbx1_e8gD{j7NY ziI@51_v2CYqx9kB%;bL@Rkv`AzCQD=_D=mse!zh}v-46^AH5KNV03f)Zi_d6A4@KO zuMe#23wW_r>-tY0!#j9?LU!H9_keiA8}E?`>C5w?&JnLgt@GRc&Cw@Xe=5qJ^WeHV z)VC|0)4tcD-b?yi9MvQI_)z+A<-Y9cGyWf#-ixj4%cb$tqTZ!@$HOOeH77d{MW;u% zOaGGSr;^_lZ~pQJt~{RJ>h%0Ru`piWiyK^o@A~h^^u@t_?@q4%)Qh^9+WOh~;XHFh zoYj|i{^#2#Pw>+`$lQ?leKLXhy|Mi(qhne(C;UDxUi#ErdUW9YedB#^c<=r}a_{id zI|n{L6AyRcss6aE{qCXfW+ne<`>&49j81Ibdwh2MqUhC8@m8+M(n55=1k>S@d5a2FoQ^NFo{4|}&dUp*h1-VOQdU80`!)ynkL=|1V77xf;4 zH{{=9Kl;y6-w$y1n&bRrWua~1BJpDWToF6zCe z|J5P?UQW-uAI{0QzH-mY(|;v;c=yncH^r;t#p%I^t>Q0EZ(Q`A=*{qhW#<|g-^7Oz{AZ?G>a|8PdW zhFf15T^_wYs^9qeyX3!!dROanc%Y6C=sstpZ@$ck*T3pX+@`kgucLpM{;tt=QMe*r z=7R6-2huw*s&C+|IV&%|!zQE;*Uif>=Z|=dkJoo_LjUd9x;nmXKwq7ykHb56x9E0J z@1|w#AD4Vh{KfGXMfHU`ci#(JHz!Vr9~Je!^FCHL@Bv<~Okcf?$Pe|uJYN0o6#w^8 zI3~_;%6B!~;O8Tqw=gQ6>c%`2r|mmOogWqd#rFSGRDR|2+T?pi7er@8^`W?khdjQN zo_oC&za+XWdVBPi=F)udsKRGOZ}Q}@+qF{(myP^bM&F; z@zJkz-n{t1_lu1wHHyRnbvi&oo&Vvst2jp|(`Q!WbqzsCSII#a*`|gVBe|Y_J za`9G|>Ug+y-|gZnAN#fLceBUyM_lCpn)KvZ-94OqGTQz#8nfq zFD~lg{W!M|PQ#5SlW)x4Zqdcb;l(@2|1i1t$9u_rAMj&!^0U*wKkD75KfO!zox0@z zm$D-t`^Ka4=Upa_erK86x#k7jP&a_3`jZ{Tz|}>F9*${LZ~M-o5Sr zAFVHK-S0t{C65g4-8vkS-=P%D&8^ad6Y2)Oi{IwSj&$?#K>Y z7AN&EK0S8yqq#A+{mzHukGAiU*5T+D@%q4A`kVCOl6Z>mZ`-FY=f&%ri`#!s)VzlS zuO$CT`}pgfFe|zG6o2nQc;x)gWbc9Yjf(2KljFZ06_?}NCr=l*K05gU(S^z3=(X|c z^6hwiwP!p$8X51q|Kjw5p6~Dv#@`urU-iX5^IzTSM|C8QzPH3h z{%_B|d-=YflpH?k%l*3NKeq0B@y6s2C5M;%e6{rvt*?rjPwD_J`kl&sd^ebv>P|d& z$^OmJ*Ru~tR>zCyl=z3E`$wJg!R$A5F7^n0^@HmA-^e_VFmOCS7E>n}xLOAk(dGr72|jGq?uJ;UCL_AhH4Zpo{- zFKvB9`sRRp!I%5ev!42p`Qd%4f4`Hye6EdW558WQA@BNM$D6z2GTi>3MD2%Xejic? z`og=;{k?<6WoK5D|8RCea`k$Be!ZL=y(|8z=*;N+sNWfGrI#K3yD|9=QT?qR*}J!W z%cBd^=jWpM8=?nC%_Veda(+*YXXgvqUmx8ms=j8V|4j1n@%)sx&B=X-sHanst4sG+ z2l6>DJ$dmxc1&`(D}U_!eS|)Iuzq~|%?WnUQLW3L?+Nkq?tZTO!)x!fAGLoFSNOT9 zb@N4Bw@UuYC_nVSzSaNxrKet@N=8=_@_?!DNp=Z(EhF4wZ@RE3cKl)bu2eQ9g z{MhK9M4dM&{+RUTaaD5m)sgyye=nvtW55oa6MylYmEOeYD^Yd)(RlqOewPo(-Dhg@ zhoatJ>h;RxkeFVq77rc+W!@a}gUtat!;`f!WbiR3Kj+rCw z&;Bj>cV7G4U*D@Me!wZeGpK9+%bU2UGw=N=1Mz_O@(8E+pIbxn?)^G4J1e7aq-Z|R ziZ}20CvLB`{@eC{JN~A4f6wzhbzpKh@M!$1_DzmIGwQs(D=Vv3#T$W zn&2@FdkF~9P?}&$Z z`8^h{^Fx2%o4?btBR~9hKm9Du@-wP^>!Z(Q?~EusR&VmI&v$I!0qK7}9?lHr&3^Gv z&+5xP#N*EFy&nB|=lmpodAxekKll~t-51><3a__`R}cGT?_0@tjrZR5UJ+OK^j?rR z^J!9cyk8f@zZJD_UOXI<#}|@&2Z$4Y%@h07m$_h`b&a8mvZue_jrT59zy8kR-|?N4 zfuqxhXHUmZh{DrL;^D10!10e|Z&`X{qrNxQ`*8A8qLZWhM%B@F*)t#D23*shJ~ z{yjQZeW*`$B2MN9|MZ!9^}ZATf5@K=QF(T6_2Byg-nox=54>d;zIlKCOZLChJ=Vov z6Ez?0pPiiDSKD`Ya_><%s4ib=-TUy$`1#RqWXJsX_pm9+w@gp{!3%wUQ|lw5>Sb2E zcq*;u-Plqf4Ulps(Fmy{aF1y10GLktevfL+casN1ttx zT)g3${j*ya_sindt@w$eSY}eU`=wkrGC^gKh@nM z+1oqnJ*IBeiTCVv>CK3q8`WptjproiuRM$UaO>i(4zJGtH(Q5m>Q>zx)cS+b!MwnQ zrLC)9xF=pyTQ~3BSARX9Kkk1@`f%gC*45wBt*c}AUYvY&)V&wRi?=+ei}BevH?~jz zOUccv?e2fSUccP>w=xLl#O-~2NY8hR zcZEKa$KmXlBk0QZ^ZSZ;b9di(dH7NMHPH=ioEER&^{eys$rO6g*V?aN505uL_e;-t zzF+32=UsDqdUK=iMCm^e4~IV)uYTdvwaLA6kBR?Y)cZkQt8;b6-^~MY(Kq^wU-ZUg z-#hb=_WNDMIqQ?(AAKXLAD72J5uFviJp1bC3(4mtS8to+M?`mtimUYtlk0bNJ)C@0 zRKMyY@B7;4V`7xc%Z*7ecN@#;yQ?@Df7 zSvPmc#nn4(M*Gaee~E`9@IW5*mpKf7CbZ8xRo&}{`K=!sRp;XG99gbYHT?Np=j_$GIT+h9 z_;;hbM&<9w_(!8RW{2KCCHLO)4ism3f*&7jzk1S-8?$pj>*{@T{Q0emgSv2^xvkUF z|K_9o`rTlB`rbz);(h1)uDc*P`}$#Fa_5UbyqcHa@Jc;<|1596I`Mw|iu*<Q>U4u|H)tB+5`!^<<{`LQ5g{LGzA$zO`Din`a#c=vxc|M@W@-hA=B zz+d&M{!UNdoP%TLgLlKVojWmlMAUn1a{T$(gZqD;{J7|UX?;aJydK;ypTp_h7hMvC zA8-y%h@ZOl-tul$Pw--^?9Gdc54?RN`I`K84%|Jub-!D&e@b$BQ>S|-SI^=;oE&bd z59giR`oiemQGHKe9l~dM^Uh;uTKm-u3*Je)ly-uq>4Jiqsf z-#7Y1RDC<|tmMn1^1mW}YIJn@5pQwxuGI(Er0;ho?;7WbuevdhUTvTG56>nhSNAu@ z9~t#Ma&`WRH@t+iv(r~Uc>UpC&ih&V`b6Bs!8`AY^!TeUyvN|PxcoHz&*iT<;&+)B z*=t?>>L+;~Y8|ecEB7QnD*EBj!ZVe{4dDe>3_{^dr&B zqO02PJp*UvwCu}b0gSntyzm^{U-0ba={D`Q0`d)G`=dVkT zzwmtf>|NQqys5Kilfxf&f=)>jtb#~PKW~P66RGrBSJcire9bZcSkm$T99D}cL zaJY5fH~LaO^cQ@Zo4)(W_sZnrAn)dlxXQQtk80nzs63q#FAi`+9LBWn9CfR&+@F8q zY<~QE`+r^B4~-Y^m*Y1_#n1iCtJ$ren4MjtgL7hf@*|_ZFU*Vg=U4W;)8L=^r;jd9 zkH7H8JXp}W-x#{h#T1|9HQsmwQ@QkG@a5yDn+{g{c0@{QLj^5&d0BeW=4@ z+UIu^zx%kS{$7&aUq-zL)VuGy!#Y_14W}<&aNOU`RzyN~akBa^Sm4*Yv6{@K=l8=agS4nLe+ zyxE68?g`)F3VbvlCS}LF%y)=7SAPrB6PL;H`uoaw@5SN#)#rD|KNFpw{teN&(LJN^ zM1IxzHm&>a((nA!2fL)VGJ1UHnwvA@%?EyXFY5Onq$j?YXHP%C!TpnujKX{QfeYpm zzs!jj+xLkm|I9h(pE_`k{K3a_^Y_8@#zyTQ6+bs>&Zygm+RtzKG>6o+IjukCUq5^Q zUD`Q6kE+AxI>#J^)9Mxet5g1rYoGk!&&a;MhFjOB|6J?dS7#^xc5?I4eBUiOoN>PS zsBf>&j(eLs>O-8(pQp0x-tsb>-1|$v@K+ze1NEkl^qcv)JbQ5X>im?i+gmsPd`FrW zx3<1r`|gjbQ~jbI^xqZjgQwm*<_Uc9`-XnV{E)c7b9k(uCS+e;zSTy0@}SPxBEx;;>8WV$TNIilindw_{x6b zq2d06zv2)7yVlVC{fBt)7d}qS&h=4sDZcuF-n{hG*O+*5*SGKzzNmlpe7Bkt@;rzA zC_T99zW225yN2HMTOM zUZ3okefGrb_4LHsIo@@CFVd&(Gu%FTRA2IcW$T+eSD)zb6|L(_e!(;6m`CiZhx4;L zB0D4Ey({Reht!5%?mv9z9v5%^sSn?S`X4S#NMF4k9)Div`yN!M;;KKqSN@`X`r?`P zn_FA8{(-VJ@w2|ETt<(~IN1 zhuHTnliz9SImdm}$A?>oL#sM(b#mwGzop5|-}B=2seN$o+SZTCu6e>Qc@$T5;qORG z+V9@)#p@gAO-eo|YHm!3Xa7s_`sPpL;ovXx^SR{eNI&X3`FkTh@t)BB$dEZTB|YEU z=8$=3|w+#i0iuWpv5FMi~{&zvvb_h!I7#Oto)@KYbNe|78LH8VO_eZjZC zOg=jOo8#fMzI>x~eE}!GpIp7{+kSPUey69$U;d0wAAYQfmp}f%y8~N?U*e>%d>3rn zd5@&8FP!f?Q6Au*c_y#=&U^U2?ENMRZ`H-MovRL>jo&JL_w{b0=Uz{z=R9@ke&+CP z+10mawoiZiUV0)qJ^A-;@?9dH>P$S$dANh>SNDIobM)W!`9}`d-P7+sUrbM(s3UzU zKXAc3yP04 zj_4k%(}$<(>YC*GR{xnR|JM4r=;o-n%A@@1_c@(Aa=<=3anEC0cdozxlAqH0sHlA3 zn*Cd%kGJl9u0GCAJ~s+i+BW1p^jZRQ-kehx@Jig^@QCa#ZNIvDC;6nPa}J9aH~Qx2 zhU}=nwejA$;-=3&-}<8L!ue(K&qURQdcHWh`r0=h9>TM;vj36jZRuSZRe$PQ-9Fzs zKjlrI9^AUP%!%ihIvC8yiuBYmyxS)E_oDiBZ9LrV8bcpRu5a&)hiBfoBio?fm&eP) z%=S4?KGntg*7q5(Lr-7ARk){r)e*n_T}i$^(ZU&>qdtF}+(gV$yV&>Mi|tz(rGHra zysO~L*!1Mj@2uv7J~*TOzQfh4JnKVqWH`I(___AqG;rQA@%qEPydU9&d%?GH+51@g zH^;lL{#0l1PJJv*U;o4oou558?R#KF^3~CsqwWVs56F=3q6^ZKZ*`&`?`YlpFvoUH z{=Q%RhdMA%?n-{{fIgg2pFd4+xAgRl^TdBv>*f&r7<>mGpS>rd%cJa>|ATe^YI@@0 zoi?g-^`rP-o?LzUj@~Id?(;|K$%nqMeo^a7qwvH1^c_3uZEEM3ckq63b|$xeNxZ)o z&PncG=B9hAm!sOJ&)|W$$?J{j>0>x8Z+_?32hN4R;)iD|fqB-O~ z>gL__<e`17Ch=egGPpSt&cU{7C)xBJ5d^IDvpw@>!%Hy_2#J8VpP zdq&}x{yRPS^8CI&x;YAG^_}|GZ~6po?A*TFqU`JE`;w1CJ71jO-NUWR)3kW+WB%{h z{tc}^6!q?aKm3+&eX=C|C!*#fJ>Q|~eQbJ*qVlS)&F}MCS1+H7w_iQ$OZ{_Ide2Aa zM&X0Hxg+_CD7)}oy_i$-cvSI%&*X<>N1w{y!Re1^T|c;&Je<XLX^E)tmFSY9D;n7vd;h?s;bVQ=eXCWm!7=y|Hk%3hTt#1<-y!m zf77zJI65u^>UrzraKQQA`BT!D7x#vD-dX%~-f(vLsejcge30L3($^pANS*U@@b^CW zVvhQ~&79Onld|tU=w9kZ9gOV0Cq?D+=j~sbd}RFDQTH0G8~rAq`cu45XutaSa`xRr zTvw&%_doRo_x%1aCq1~Quf@-~*Qe(k`7l5An|yk2ilaH@`;ndD?24N{6kqj%pOHS? zy`lTdr+$$)?`8G$X8XrR)o*5o<|U{1Lj3+w_txk7>B82XN#NYRfJor5(_lML4yXK3y z>Ra);wT~Ie!y38)Ti>`9kHT)?x8Q>I~>*j;tnsb&mVnsQTE*XuUc0>7qku! z^!=&Hedo!u_o==-H+#oN+w;-PXZ^v(L9e%Q>U+zs$-+evaoD)~|p#J#%qx3)2IqD>K zNFKx!&arEr!;dl96L&ZQFU0r7^xlbz3xDV@YkhMREUm*)hzne!hv(-l+1(fPw&eq>Noijx8J4beXAdRkC;C@q~|`~`PSuYa(dR?*SY-meW@N_ zZ=XKg5HH^D|J~#(qU;^ef3U0Ho=Q)?*T-KHRUhzgU2^vIvHDey@J_!?Y~PL1UGm5K zenh93FK#{;u(Ug$&h$&UPuZNGf0PxX6m>*`}f_TP&>9F<4$`F{Jn zFD^}wpUdOnl{~T|zwq-5*|QJM@KcU3FpFLa*wr+o24Kkx^x z^V_~_vZue{)4R#VU3}#SRZq@=U*?iHsOJs&tO6S zyL10@^v{yR`Ni@2$a(rw{`OB#{-$JaRdVJ>I$$T@P?3ahBoh#n@WO3*JvUUC8p8OIYeWCx>w;x{cYy0f2ipsy=Lmp_K zdmR@KPv^$F@7?jpXp;&c`R8l-_+&e)}CnUf|}w>Df0So}VM)H$+!P z;f{E!?_aepujM{ zoufY7S3R-|7xam_=zY0k=eRHb*mJJ9=-0hF-`sVs`hT-^csHSQ)H!@GPuSNtf7X6^ z6L{u@%1}_ zxSNmgOWgJ@j_yxSf4GNyh?jHyZmmA}uOBz%*OSp>qWaJ~6OQWdnfawJN5+eT-<#jd zp7^?lf8W4wdA~h7&bhh$^wqQ9hven>_HBsD^JxR;tBbpnkBP#$Dec$4=9zlEzjggC z59YA@Kc3w~qCak*b+|3x*QWQK^md3BXSk!TR=2LM)XVB-Z_lYcBZJWx+?(7T6!@KSxLtBKk14uc>5uJ-i6y8aL!b+u=D;&MQKitnt}`6G|! z>0#;bmY#d;6)&&y?A{NxPoE!}K73>E^yK{3Z@agDuW+8F833^e{J1dLhZk_b-ccWlkis`iTkql z`96>za(=KUKk!}M=<{vc0PpClv&pU72VdUGkMpAPxGH^kFs*g@crf1lbN>~|t;^5D z$(`%{a8`1CrT^fI`hhba7`TW2ej&O0sxNtm)9RLg9Y3UA#8Drc^ZR8VzD;N&`}$hl zisK#W!7KH-I=Q)jX8fY)vd)v2tKxswx^whz%R}l`AHn4p+c!6A9iEy~_RE93t2=ce zFQ;cu9L|WBXMTz!z3b8=r*HkBE@a-(6CZsfuHtFl!A*7G-lIErY4-J>etom`FSHJC z_$5Avq$e)+`~6LvhST#tg@f{OM(2p9e2bg>jZAM;_UNly{yJw$`raLIe{B2VB0Pg5 z>^R@MSbvSnu6v;3J)-r0?HuoW{Wi6AIHhmY;jH%Q3x12gK873OPG3IV|4*{V|7+q8 z&JVaOukeIFm!`iw>i*7EhwgV*dhpDB+*5qvn>;+2KjxhKT7NjZaM1bs;@;Nb^9}J! zqVgo3o03oLTy~u+zVhk!H1)0S%u{n*U+Fh>q_5;jTrSFw3!>^mT;bGX`577V-n9?U zzPc5st+Fr9ljGlty0^KcUT$t(eXAeet0QyBIp(-}@~&8sox{?n=RVHW|8vs!K0S?M zRQ~k)^yKhPy}1|vGZr}%f_T-yAe#~#b-7O zl23JE|K#-cjN+ZEe$|)wvL}Dy4Ntc1LR&}mJ-z)qZ|5k#yf4iU?<)1-KK%PZ_SLg` zb-w&B>6}}m@`9=daiPx-xS`%=bguUg+~KFV=s)}A>yiBUQIy~CRGi?LbIgNvo&UY| z!vlTgUYEA+_u%W}&z>fRD)xBDGUvXKMylV|D?LYY8 z-s*$gy`M>cMRwr9!O7LD{EDmi!5ew>&S2ksAD6wWI!8Uot8?{%Io7rz`GqfVUcZaW zn1OhxWBFB|^1x5OcT8=+{F)p3Tff-{&&Az-dg7(7<=MJpAEiK#97|dop^h1!nZ>@XIj+n#qw-U9@_fT z(HA-ZUX6@@zdzG^HOepkZ<)UO)(7g?z5O0$|AzK0jJ}qAe`it0&RN>Jdz=&xPj-x_ ze`frN1OCBtbuV7@;EOz|^V_lq*W~+>>^oN+CZ^9{_g269S6tzOyzo=MPZ_v}IjC=+ z=sbSEls^pZ&=WLaH za#SAGmp=Vk`sDD|{l)9<0Xs9|aX$RzAAFM^{ouZjWbfSW z;XLtKm;B1~^u6yv>+s@x>GO|V9N<0w)S>y~ouUuy6CZgehYxV9yL_XOf%pLdS< z_&d-)r!TJecJGBzdH7~}>SAo`;_TczlJ6XKFW+tIL_O%A>(bX};^qGGCNAfs?_7Ct zzI@2LIY|$$_#MJI^e1IsT}-{Is-KT-pZ;feadP(AQy=cZZ~nQ@+V=BHTrbNn z-y7E9vUxM6{oZ+a`sy)r@BiNghm(7@e_>QS4~jn~`WMkv8FD}ViJLltWBS>)YFXg<&BO^1JZyheUA3xNO`{-YJ&^P>34)2KPm-@THH zGd%GggjeGBZ2Hd8kMzwEcKmw){`{x*U&&5Xy_uK%)4#iSFLARzCHb=K$cH#OSN_jP zPd>#@9m}iz=7T(mzkABJxcWPp{o*3N&aqE?`NtpMdG5)N@!i*ZWnKI`-RqpFc=5+~ z%hBokUNR5#gZS&e1Jd{J@c1Ww&O^;#cHj_w@jWd+-;BDa{!kyQTX)Yn@#;eUFG}uQ zcqXpm?47NS;G6uxD}BqJ{&PP3)2IA7FhA}rZhsx`eDi|8@-#Yqb_VMNeu$H}t;;^# zoSB^;B_A98ugTS=b?5L$Jk;AG?YB?5E{m*4@N)#SF8^33I>fAZ<0w>)=-Op;DdeV35Rv-9|P(K~}{&yY!cgo&F(H)}f z@_Swe9%voTvJa2ewyu8En|P`taS=~_BM#*J)<Ym^7z|RTkvnyWWLB3mlkgI?3QrFgU~4&=Am`>M)|Kk)CJs7ujZ+|@KgQB4?ExK zp7i00cYrwTpPoK(j(dpX()8esyz)c+iH~@zPyRWVpW<>(0TK7Xy7Ig41^J<0nj*PJN(`cfR# zmH5zq|NQp5mpYd>|9-*ynt!_sRQr9Gi@SKz8%|H&$>s01*5%hZ`e9S+>Py_8%Mbkl ze?F3&9eBN4cG#5d&iH*x`r;$b=DhRx`C$6) z$?rpo3qRJShxffA@BH__lap`%dkKCy$N5vTEB^dWJmh}PMafs?Cmj7nymx^6sDJhA z_g#I@-wEwkm;BV9Pqa?oeY~satCMvDaf6?5%Q+~yxU%OSgZ8{z#06f;`~La+NYw8f z>WTmG(feOK$>mp{!Kpdfl`npa^Bb+pFF)6};ee=jiTdOpTv8X#cMkvD*L~!NKD=^1 zd(J`meSLmU&A$y%`9PPY#~*nTNB15)pPqd`&HgS?etM6%7iypTi9_4p|85^V*S+WG zx4L78ednY9o<2SGF=(G3_L(p8DsIksEW1zTxIA1RKX~5$>3e708xQ}St6t$0JJ!XI z|KhBxLle(+;RS(KKbBz>^Mifd~e7L|HYjjOFG~GZt8`2 zxUp9}eD%MV(HH9m>T^c=@+t1>TU>uHJ$dx% ze)y?R#8W*uXOKS9J)EzuobTTD9h$!Tjm$si@r%B7{JQj?&CXBbz2BU-d-7e2*tc8!1<}De_j{yzabNK|v3>44c)t9Rqx8iQg$L}z1N})) zT>ZU4ow19GFWx>p+!^ebIo*rD?(u=-`VM7}zh5ox$EUB} 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + + return logits + + +class CoglmStrategy: + def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.temperature2 = temperature2 + self.topk = top_k + self.top_p = top_p + self.eps = eps + if end_tokens is None: + end_tokens = [] + self.end_tokens = end_tokens + self._is_done = False + self.outlier_count_down = torch.zeros(16) + self.vis_list = [[]for i in range(16)] + self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) + self.start_pos = -1 + self.white_cluster = [] + # self.fout = open('tmp.txt', 'w') + + @property + def is_done(self) -> bool: + return self._is_done + + def forward(self, logits, tokens, mems, temperature=None, temperature2=None): + if temperature is None: + temperature = self.temperature + if temperature2 is None: + temperature2 = self.temperature2 + logits = logits / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -65504 + + rprobs = F.softmax(logits.float(), dim=-1) + c = self.cluster_labels.expand(*rprobs.shape) + cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + # self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n') + # self.fout.flush() + best_scores, best_clusters = cprobs.topk(self.topk) + bz = logits.shape[0] + for i in range(bz): + selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] + logits[i, self.cluster_labels != selected_cluster] = -65504 + + # logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch + pred = torch.multinomial(probs, num_samples=1) + + if pred.numel() == 1 and pred.item() in self.end_tokens: + self._is_done = True + tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) + return tokens, mems + + def finalize(self, tokens, mems): + self._is_done = False + return tokens, mems \ No newline at end of file diff --git a/cogvideo_pipeline.py b/cogvideo_pipeline.py new file mode 100644 index 0000000..0efb161 --- /dev/null +++ b/cogvideo_pipeline.py @@ -0,0 +1,793 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_pipeline.py +@Time : 2022/07/15 11:24:56 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +import os +import sys +import torch +import argparse +import time +from torchvision.utils import save_image +import stat +from icetk import icetk as tokenizer +import logging, sys + +import torch.distributed as dist +tokenizer.add_special_tokens(['', '', '']) + + +from SwissArmyTransformer import get_args +from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders +from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually +from SwissArmyTransformer.resources import auto_create + +from models.cogvideo_cache_model import CogVideoCacheModel +from coglm_strategy import CoglmStrategy + + +def get_masks_and_position_ids_stage1(data, textlen, framelen): + # Extract batch size and sequence length. + tokens = data + seq_length = len(data[0]) + # Attention mask (lower triangular). + attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device) + attention_mask[:, :textlen, textlen:] = 0 + attention_mask[:, textlen:, textlen:].tril_() + attention_mask.unsqueeze_(1) + # Unaligned version + position_ids = torch.zeros(seq_length, dtype=torch.long, + device=data.device) + torch.arange(textlen, out=position_ids[:textlen], + dtype=torch.long, device=data.device) + torch.arange(512, 512+seq_length-textlen, out=position_ids[textlen:], + dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + +def get_masks_and_position_ids_stage2(data, textlen, framelen): + # Extract batch size and sequence length. + tokens = data + seq_length = len(data[0]) + + # Attention mask (lower triangular). + attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device) + attention_mask[:, :textlen, textlen:] = 0 + attention_mask[:, textlen:, textlen:].tril_() + attention_mask.unsqueeze_(1) + + # Unaligned version + position_ids = torch.zeros(seq_length, dtype=torch.long, + device=data.device) + torch.arange(textlen, out=position_ids[:textlen], + dtype=torch.long, device=data.device) + frame_num = (seq_length-textlen)//framelen + assert frame_num == 5 + torch.arange(512, 512+framelen, out=position_ids[textlen:textlen+framelen], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*2, 512+framelen*3, out=position_ids[textlen+framelen:textlen+framelen*2], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*(frame_num-1), 512+framelen*frame_num, out=position_ids[textlen+framelen*2:textlen+framelen*3], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*1, 512+framelen*2, out=position_ids[textlen+framelen*3:textlen+framelen*4], + dtype=torch.long, device=data.device) + torch.arange(512+framelen*3, 512+framelen*4, out=position_ids[textlen+framelen*4:textlen+framelen*5], + dtype=torch.long, device=data.device) + + position_ids = position_ids.unsqueeze(0) + + return tokens, attention_mask, position_ids + +def my_update_mems(hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len): + if hiddens is None: + return None, mems_indexs + mem_num = len(hiddens) + ret_mem = [] + with torch.no_grad(): + for id in range(mem_num): + if hiddens[id][0] is None: + ret_mem.append(None) + else: + if id == 0 and limited_spatial_channel_mem and mems_indexs[id]+hiddens[0][0].shape[1] >= text_len+frame_len: + if mems_indexs[id] == 0: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][layer, :, :text_len] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, :text_len] + new_mem_len_part2 = (mems_indexs[id]+hiddens[0][0].shape[1]-text_len)%frame_len + if new_mem_len_part2 > 0: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][layer, :, text_len:text_len+new_mem_len_part2] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, -new_mem_len_part2:] + mems_indexs[id] = text_len+new_mem_len_part2 + else: + for layer, hidden in enumerate(hiddens[id]): + mems_buffers[id][layer, :, mems_indexs[id]:mems_indexs[id]+hidden.shape[1]] = hidden.expand(mems_buffers[id].shape[1], -1, -1) + mems_indexs[id] += hidden.shape[1] + ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]]) + return ret_mem, mems_indexs + + +def my_save_multiple_images(imgs, path, subdir, debug=True): + # imgs: list of tensor images + if debug: + imgs = torch.cat(imgs, dim=0) + print("\nSave to: ", path, flush=True) + save_image(imgs, path, normalize=True) + else: + print("\nSave to: ", path, flush=True) + single_frame_path = os.path.join(path, subdir) + os.makedirs(single_frame_path, exist_ok=True) + for i in range(len(imgs)): + save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True) + os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) + save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True) + os.chmod(os.path.join(single_frame_path,f'frame_concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) + +def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len): + # The fisrt token's position id of the frame that the next token belongs to; + if total_len < text_len: + return None + return (total_len-text_len)//frame_len * frame_len + text_len + +def my_filling_sequence( + model, + args, + seq, + batch_size, + get_masks_and_position_ids, + text_len, + frame_len, + strategy=BaseStrategy(), + strategy2=BaseStrategy(), + mems=None, + log_text_attention_weights=0, # default to 0: no artificial change + mode_stage1=True, + enforce_no_swin=False, + guider_seq=None, + guider_text_len=0, + guidance_alpha=1, + limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内 + **kw_args + ): + ''' + seq: [2, 3, 5, ..., -1(to be generated), -1, ...] + mems: [num_layers, batch_size, len_mems(index), mem_hidden_size] + cache, should be first mems.shape[1] parts of context_tokens. + mems are the first-level citizens here, but we don't assume what is memorized. + input mems are used when multi-phase generation. + ''' + if guider_seq is not None: + logging.debug("Using Guidance In Inference") + if limited_spatial_channel_mem: + logging.debug("Limit spatial-channel's mem to current frame") + assert len(seq.shape) == 2 + + # building the initial tokens, attention_mask, and position_ids + actual_context_length = 0 + + while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens + actual_context_length += 1 # [0, context_length-1] are given + assert actual_context_length > 0 + current_frame_num = (actual_context_length-text_len) // frame_len + assert current_frame_num >= 0 + context_length = text_len + current_frame_num * frame_len + + tokens, attention_mask, position_ids = get_masks_and_position_ids(seq, text_len, frame_len) + tokens = tokens[..., :context_length] + input_tokens = tokens.clone() + + if guider_seq is not None: + guider_index_delta = text_len - guider_text_len + guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(guider_seq, guider_text_len, frame_len) + guider_tokens = guider_tokens[..., :context_length-guider_index_delta] + guider_input_tokens = guider_tokens.clone() + + for fid in range(current_frame_num): + input_tokens[:, text_len+400*fid] = tokenizer[''] + if guider_seq is not None: + guider_input_tokens[:, guider_text_len+400*fid] = tokenizer[''] + + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 + # initialize generation + counter = context_length - 1 # Last fixed index is ``counter'' + index = 0 # Next forward starting index, also the length of cache. + mems_buffers_on_GPU = False + mems_indexs = [0, 0] + mems_len = [(400+74) if limited_spatial_channel_mem else 5*400+74, 5*400+74] + mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype) + for mem_len in mems_len] + + + if guider_seq is not None: + guider_attention_mask = guider_attention_mask.type_as(next(model.parameters())) # if fp16 + guider_mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype) + for mem_len in mems_len] + guider_mems_indexs = [0, 0] + guider_mems = None + + torch.cuda.empty_cache() + # step-by-step generation + while counter < len(seq[0]) - 1: + # we have generated counter+1 tokens + # Now, we want to generate seq[counter + 1], + # token[:, index: counter+1] needs forwarding. + if index == 0: + group_size = 2 if (input_tokens.shape[0] == batch_size and not mode_stage1) else batch_size + + logits_all = None + for batch_idx in range(0, input_tokens.shape[0], group_size): + logits, *output_per_layers = model( + input_tokens[batch_idx:batch_idx+group_size, index:], + position_ids[..., index: counter+1], + attention_mask, # TODO memlen + mems=mems, + text_len=text_len, + frame_len=frame_len, + counter=counter, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + **kw_args + ) + logits_all = torch.cat((logits_all, logits), dim=0) if logits_all is not None else logits + mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]] + next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(text_len, frame_len, mem_kv01[0][0].shape[1]) + for id, mem_kv in enumerate(mem_kv01): + for layer, mem_kv_perlayer in enumerate(mem_kv): + if limited_spatial_channel_mem and id == 0: + mems_buffers[id][layer, batch_idx:batch_idx+group_size, :text_len] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :text_len] + mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\ + mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:] + else: + mems_buffers[id][layer, batch_idx:batch_idx+group_size, :mem_kv_perlayer.shape[1]] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1) + mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[1], mem_kv01[1][0].shape[1] + if limited_spatial_channel_mem: + mems_indexs[0] -= (next_tokens_frame_begin_id - text_len) + + mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)] + logits = logits_all + + # Guider + if guider_seq is not None: + guider_logits_all = None + for batch_idx in range(0, guider_input_tokens.shape[0], group_size): + guider_logits, *guider_output_per_layers = model( + guider_input_tokens[batch_idx:batch_idx+group_size, max(index-guider_index_delta, 0):], + guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta], + guider_attention_mask, + mems=guider_mems, + text_len=guider_text_len, + frame_len=frame_len, + counter=counter-guider_index_delta, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + **kw_args + ) + guider_logits_all = torch.cat((guider_logits_all, guider_logits), dim=0) if guider_logits_all is not None else guider_logits + guider_mem_kv01 = [[o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]] + for id, guider_mem_kv in enumerate(guider_mem_kv01): + for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv): + if limited_spatial_channel_mem and id == 0: + guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_text_len] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :guider_text_len] + guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(guider_text_len, frame_len, guider_mem_kv_perlayer.shape[1]) + guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\ + guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:] + else: + guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_mem_kv_perlayer.shape[1]] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1) + guider_mems_indexs[0], guider_mems_indexs[1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[1][0].shape[1] + if limited_spatial_channel_mem: + guider_mems_indexs[0] -= (guider_next_tokens_frame_begin_id-guider_text_len) + guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)] + guider_logits = guider_logits_all + else: + if not mems_buffers_on_GPU: + if not mode_stage1: + torch.cuda.empty_cache() + for idx, mem in enumerate(mems): + mems[idx] = mem.to(next(model.parameters()).device) + if guider_seq is not None: + for idx, mem in enumerate(guider_mems): + guider_mems[idx] = mem.to(next(model.parameters()).device) + else: + torch.cuda.empty_cache() + for idx, mem_buffer in enumerate(mems_buffers): + mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device) + mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)] + if guider_seq is not None: + for idx, guider_mem_buffer in enumerate(guider_mems_buffers): + guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device) + guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)] + mems_buffers_on_GPU = True + + logits, *output_per_layers = model( + input_tokens[:, index:], + position_ids[..., index: counter+1], + attention_mask, # TODO memlen + mems=mems, + text_len=text_len, + frame_len=frame_len, + counter=counter, + log_text_attention_weights=log_text_attention_weights, + enforce_no_swin=enforce_no_swin, + limited_spatial_channel_mem=limited_spatial_channel_mem, + **kw_args + ) + mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers] + + if guider_seq is not None: + guider_logits, *guider_output_per_layers = model( + guider_input_tokens[:, max(index-guider_index_delta, 0):], + guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta], + guider_attention_mask, + mems=guider_mems, + text_len=guider_text_len, + frame_len=frame_len, + counter=counter-guider_index_delta, + log_text_attention_weights=0, + enforce_no_swin=enforce_no_swin, + limited_spatial_channel_mem=limited_spatial_channel_mem, + **kw_args + ) + guider_mem_kv0, guider_mem_kv1 = [o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers] + + if not mems_buffers_on_GPU: + torch.cuda.empty_cache() + for idx, mem_buffer in enumerate(mems_buffers): + mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device) + if guider_seq is not None: + for idx, guider_mem_buffer in enumerate(guider_mems_buffers): + guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device) + mems_buffers_on_GPU = True + + mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1], mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len) + if guider_seq is not None: + guider_mems, guider_mems_indexs = my_update_mems([guider_mem_kv0, guider_mem_kv1], guider_mems_buffers, guider_mems_indexs, limited_spatial_channel_mem, guider_text_len, frame_len) + + + counter += 1 + index = counter + + logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size] + tokens = tokens.expand(batch_size, -1) + if guider_seq is not None: + guider_logits = guider_logits[:, -1].expand(batch_size, -1) + guider_tokens = guider_tokens.expand(batch_size, -1) + + if seq[-1][counter].item() < 0: + # sampling + guided_logits = guider_logits+(logits-guider_logits)*guidance_alpha if guider_seq is not None else logits + if mode_stage1 and counter < text_len + 400: + tokens, mems = strategy.forward(guided_logits, tokens, mems) + else: + tokens, mems = strategy2.forward(guided_logits, tokens, mems) + if guider_seq is not None: + guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1) + + if seq[0][counter].item() >= 0: + for si in range(seq.shape[0]): + if seq[si][counter].item() >= 0: + tokens[si, -1] = seq[si, counter] + if guider_seq is not None: + guider_tokens[si, -1] = guider_seq[si, counter-guider_index_delta] + + else: + tokens = torch.cat((tokens, seq[:, counter:counter+1].clone().expand(tokens.shape[0], 1).to(device=tokens.device, dtype=tokens.dtype)), dim=1) + if guider_seq is not None: + guider_tokens = torch.cat((guider_tokens, + guider_seq[:, counter-guider_index_delta:counter+1-guider_index_delta] + .clone().expand(guider_tokens.shape[0], 1).to(device=guider_tokens.device, dtype=guider_tokens.dtype)), dim=1) + + input_tokens = tokens.clone() + if guider_seq is not None: + guider_input_tokens = guider_tokens.clone() + if (index-text_len-1)//400 < (input_tokens.shape[-1]-text_len-1)//400: + boi_idx = ((index-text_len-1)//400 +1)*400+text_len + while boi_idx < input_tokens.shape[-1]: + input_tokens[:, boi_idx] = tokenizer[''] + if guider_seq is not None: + guider_input_tokens[:, boi_idx-guider_index_delta] = tokenizer[''] + boi_idx += 400 + + if strategy.is_done: + break + return strategy.finalize(tokens, mems) + +class InferenceModel_Sequential(CogVideoCacheModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=-1, cogvideo_stage=1) + # TODO: check it + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) + return logits_parallel + +class InferenceModel_Interpolate(CogVideoCacheModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=10, cogvideo_stage=2) + # TODO: check it + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) + return logits_parallel + +def main(args): + assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1 + rank_id = args.device % args.parallel_size + generate_frame_num = args.generate_frame_num + + if args.stage_1 or args.both_stages: + model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1') + model_stage1.eval() + if args.both_stages: + model_stage1 = model_stage1.cpu() + + if args.stage_2 or args.both_stages: + model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2') + model_stage2.eval() + if args.both_stages: + model_stage2 = model_stage2.cpu() + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + strategy_cogview2 = CoglmStrategy(invalid_slices, + temperature=1.0, top_k=16) + strategy_cogvideo = CoglmStrategy(invalid_slices, + temperature=args.temperature, top_k=args.top_k, + temperature2=args.coglm_temperature2) + if not args.stage_1: + from sr_pipeline import DirectSuperResolution + dsr_path = auto_create('cogview2-dsr', path=None) # path=os.getenv('SAT_HOME', '~/.sat_models') + dsr = DirectSuperResolution(args, dsr_path, + max_bz=12, onCUDA=False) + + def process_stage2(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", parent_given_tokens=None, conddir=None, outputdir=None, gpu_rank=0, gpu_parallel_size=1): + stage2_starttime = time.time() + use_guidance = args.use_guidance_stage2 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage-2 model to cuda") + model = model.cuda() + logging.debug("moving in stage-2 model takes time: {:.2f}".format(time.time()-move_start_time)) + + try: + if parent_given_tokens is None: + assert conddir is not None + parent_given_tokens = torch.load(os.path.join(conddir, 'frame_tokens.pt'), map_location='cpu') + sample_num_allgpu = parent_given_tokens.shape[0] + sample_num = sample_num_allgpu // gpu_parallel_size + assert sample_num * gpu_parallel_size == sample_num_allgpu + parent_given_tokens = parent_given_tokens[gpu_rank*sample_num:(gpu_rank+1)*sample_num] + except: + logging.critical("No frame_tokens found in interpolation, skip") + return False + + # CogVideo Stage2 Generation + while duration >= 0.5: # TODO: You can change the boundary to change the frame rate + parent_given_tokens_num = parent_given_tokens.shape[1] + generate_batchsize_persample = (parent_given_tokens_num-1)//2 + generate_batchsize_total = generate_batchsize_persample * sample_num + total_frames = generate_frame_num + frame_len = 400 + enc_text = tokenizer.encode(seq_text) + enc_duration = tokenizer.encode(str(float(duration))+"秒") + seq = enc_duration + [tokenizer['']] + enc_text + [tokenizer['']] + [-1]*400*generate_frame_num + text_len = len(seq) - frame_len*generate_frame_num - 1 + + logging.info("[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(int(4/duration), tokenizer.decode(enc_text))) + + # generation + seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i] + seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1] + seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2] + + if use_guidance: + guider_seq = enc_duration + [tokenizer['']] + tokenizer.encode(video_guidance_text) + [tokenizer['']] + [-1]*400*generate_frame_num + guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1 + guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1) + for sample_i in range(sample_num): + for i in range(generate_batchsize_persample): + guider_seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i] + guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1] + guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2] + video_log_text_attention_weights = 0 + else: + guider_seq=None + guider_text_len=0 + video_log_text_attention_weights = 1.4 + + mbz = args.max_inference_batch_size + + assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0 + output_list = [] + start_time = time.time() + for tim in range(max(generate_batchsize_total // mbz, 1)): + input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone() + guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None + output_list.append( + my_filling_sequence(model, args, input_seq, + batch_size=min(generate_batchsize_total, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage2, + text_len=text_len, frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + mode_stage1=False, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + )[0] + ) + logging.info("Duration {:.2f}, Taken time {:.2f}\n".format(duration, time.time() - start_time)) + + output_tokens = torch.cat(output_list, dim=0) + output_tokens = output_tokens[:, text_len+1:text_len+1+(total_frames)*400].reshape(sample_num, -1, 400*total_frames) + output_tokens_merge = torch.cat((output_tokens[:, :, :1*400], + output_tokens[:, :, 400*3:4*400], + output_tokens[:, :, 400*1:2*400], + output_tokens[:, :, 400*4:(total_frames)*400]), dim=2).reshape(sample_num, -1, 400) + + output_tokens_merge = torch.cat((output_tokens_merge, output_tokens[:, -1:, 400*2:3*400]), dim=1) + duration /= 2 + parent_given_tokens = output_tokens_merge + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 2 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug("moving out model2 takes time: {:.2f}".format(time.time()-move_start_time)) + + logging.info("CogVideo Stage2 completed. Taken time {:.2f}\n".format(time.time() - stage2_starttime)) + + # decoding + # imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge] + # os.makedirs(output_dir_full_path, exist_ok=True) + # my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False) + # torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt')) + # os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2") + + # direct super-resolution by CogView2 + logging.info("[Direct super-resolution]") + dsr_starttime = time.time() + enc_text = tokenizer.encode(seq_text) + frame_num_per_sample = parent_given_tokens.shape[1] + parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400) + text_seq = torch.cuda.LongTensor(enc_text, device=args.device).unsqueeze(0).repeat(parent_given_tokens_2d.shape[0], 1) + sred_tokens = dsr(text_seq, parent_given_tokens_2d) + decoded_sr_videos = [] + + for sample_i in range(sample_num): + decoded_sr_imgs = [] + for frame_i in range(frame_num_per_sample): + decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:]) + decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480))) + decoded_sr_videos.append(decoded_sr_imgs) + + for sample_i in range(sample_num): + my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) + os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125") + + logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime)) + + return True + + + def process_stage1(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1): + process_start_time = time.time() + use_guide = args.use_guidance_stage1 + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cuda") + model = model.cuda() + logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time)) + + if video_raw_text is None: + video_raw_text = seq_text + mbz = args.stage1_max_inference_batch_size if args.stage1_max_inference_batch_size > 0 else args.max_inference_batch_size + assert batch_size < mbz or batch_size % mbz == 0 + frame_len = 400 + + # generate the first frame: + enc_text = tokenizer.encode(seq_text+image_text_suffix) + seq_1st = enc_text + [tokenizer['']] + [-1]*400 # IV!! # test local!!! # test randboi!!! + logging.info("[Generating First Frame with CogView2]Raw text: {:s}".format(tokenizer.decode(enc_text))) + text_len_1st = len(seq_1st) - frame_len*1 - 1 + + seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0) + output_list_1st = [] + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + output_list_1st.append( + my_filling_sequence(model, args,seq_1st.clone(), + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len_1st, + frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=1.4, + enforce_no_swin=True, + mode_stage1=True, + )[0] + ) + logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)) + output_tokens_1st = torch.cat(output_list_1st, dim=0) + given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400] + + # generate subsequent frames: + total_frames = generate_frame_num + enc_duration = tokenizer.encode(str(float(duration))+"秒") + if use_guide: + video_raw_text = video_raw_text + " 视频" + enc_text_video = tokenizer.encode(video_raw_text) + seq = enc_duration + [tokenizer['']] + enc_text_video + [tokenizer['']] + [-1]*400*generate_frame_num + guider_seq = enc_duration + [tokenizer['']] + tokenizer.encode(video_guidance_text) + [tokenizer['']] + [-1]*400*generate_frame_num + logging.info("[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(4/duration, tokenizer.decode(enc_text_video))) + + text_len = len(seq) - frame_len*generate_frame_num - 1 + guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1 + seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(batch_size, 1) + guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(batch_size, 1) + + for given_frame_id in range(given_tokens.shape[1]): + seq[:, text_len+1+given_frame_id*400: text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id] + guider_seq[:, guider_text_len+1+given_frame_id*400:guider_text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id] + output_list = [] + + if use_guide: + video_log_text_attention_weights = 0 + else: + guider_seq = None + video_log_text_attention_weights = 1.4 + + for tim in range(max(batch_size // mbz, 1)): + start_time = time.time() + input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone() + guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None + output_list.append( + my_filling_sequence(model, args,input_seq, + batch_size=min(batch_size, mbz), + get_masks_and_position_ids=get_masks_and_position_ids_stage1, + text_len=text_len, frame_len=frame_len, + strategy=strategy_cogview2, + strategy2=strategy_cogvideo, + log_text_attention_weights=video_log_text_attention_weights, + guider_seq=guider_seq2, + guider_text_len=guider_text_len, + guidance_alpha=args.guidance_alpha, + limited_spatial_channel_mem=True, + mode_stage1=True, + )[0] + ) + + output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:] + + if args.both_stages: + move_start_time = time.time() + logging.debug("moving stage 1 model to cpu") + model = model.cpu() + torch.cuda.empty_cache() + logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time)) + + # decoding + imgs, sred_imgs, txts = [], [], [] + for seq in output_tokens: + decoded_imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()[i*400: (i+1)*400]), size=(480, 480)) for i in range(total_frames)] + imgs.append(decoded_imgs) # only the last image (target) + + assert len(imgs) == batch_size + save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu() + if outputdir is not None: + for clip_i in range(len(imgs)): + # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True) + my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False) + os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25") + torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt')) + + logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time)) + + return save_tokens + + # ====================================================================================================== + + if args.stage_1 or args.both_stages: + if args.input_source != "interactive": + with open(args.input_source, 'r') as fin: + promptlist = fin.readlines() + promptlist = [p.strip() for p in promptlist] + else: + promptlist = None + + now_qi = -1 + while True: + now_qi += 1 + + if promptlist is not None: # with input-source + if args.multi_gpu: + if now_qi % dist.get_world_size() != dist.get_rank(): + continue + rk = dist.get_rank() + else: + rk = 0 + raw_text = promptlist[now_qi] + raw_text = raw_text.strip() + print(f'Working on Line No. {now_qi} on {rk}... [{raw_text}]') + else: # interactive + raw_text = input("\nPlease Input Query (stop to exit) >>> ") + raw_text = raw_text.strip() + if not raw_text: + print('Query should not be empty!') + continue + if raw_text == "stop": + return + + try: + path = os.path.join(args.output_path, f"{now_qi}_{raw_text}") + parent_given_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频", + image_text_suffix=" 高清摄影", + outputdir=path if args.stage_1 else None, batch_size=args.batch_size) + if args.both_stages: + process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频", + video_guidance_text="视频", parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, gpu_parallel_size=1) # TODO: 修改 + except (ValueError, FileNotFoundError) as e: + print(e) + continue + + elif args.stage_2: + sample_dirs = os.listdir(args.output_path) + for sample in sample_dirs: + raw_text = sample.split('_')[-1] + path = os.path.join(args.output_path, sample, 'Interp') + parent_given_tokens = torch.load(os.path.join(args.output_path, sample, "frame_tokens.pt")) + + process_stage2(raw_text, duration=2.0, video_raw_text=raw_text+" 视频", + video_guidance_text="视频", parent_given_tokens=parent_given_tokens, + outputdir=path, + gpu_rank=0, gpu_parallel_size=1) # TODO: 修改 + + else: + assert False + + +if __name__ == "__main__": + logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) + + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--generate-frame-num', type=int, default=5) + py_parser.add_argument('--coglm-temperature2', type=float, default=0.89) + # py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧 + # py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间 + py_parser.add_argument('--use-guidance-stage1', action='store_true') + py_parser.add_argument('--use-guidance-stage2', action='store_true') + py_parser.add_argument('--guidance-alpha', type=float, default=3.0) + py_parser.add_argument('--stage-1', action='store_true') # stage 1: sequential generation + py_parser.add_argument('--stage-2', action='store_true') # stage 2: interp + dsr + py_parser.add_argument('--both-stages', action='store_true') # stage 1&2: sequential generation; interp + dsr + py_parser.add_argument('--parallel-size', type=int, default=1) + py_parser.add_argument('--stage1-max-inference-batch-size', type=int, default=-1) # -1: use max-inference-batch-size + py_parser.add_argument('--multi-gpu', action='store_true') + + CogVideoCacheModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + args.layout = [int(x) for x in args.layout.split(',')] + args.do_train = False + + torch.cuda.set_device(args.device) + + with torch.no_grad(): + main(args) \ No newline at end of file diff --git a/models/cogvideo_cache_model.py b/models/cogvideo_cache_model.py new file mode 100644 index 0000000..ca39184 --- /dev/null +++ b/models/cogvideo_cache_model.py @@ -0,0 +1,695 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_cache_model.py +@Time : 2022/07/15 11:22:19 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +from multiprocessing import context +from tkinter import E +import torch +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim +from SwissArmyTransformer.model.transformer import unscaled_init_method +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +import torch.nn.functional as F +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +import math + + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 912), + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + + +def window_partition(x, window_size): + """ + Args: + x: (B, framenum, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, frame_num, window_size, window_size, C) + """ + B, framenum, H, W, C = x.shape + x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, frame_num, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, frame_num, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + framenum = windows.shape[1] + x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1) + return x + +class WindowAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + window_size, + shift_size, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + time_dim_attend_length=0 + ): + super(WindowAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense") + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.window_size = window_size + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + self.time_dim_attend_length = time_dim_attend_length + assert frame_resolution % window_size == 0 + assert 0 < shift_size < window_size + nW = (self.frame_resolution // self.window_size) ** 2 + ws_squre = self.window_size * self.window_size + + # odd non-shift, even shift + img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1)) + h_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, :, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size] + sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00)) + attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num) + attn_mask = attn_mask.tril() + + causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num) + causal_mask = causal_mask.tril() + + self.shift_sizes = [0, shift_size] + self.attn_mask = attn_mask + self.causal_mask = causal_mask + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + if not self.mask_initialized: + self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + if stage == 2: + assert frame_num == 3 + assert frame_num*frame_len == s1 + wind_square = self.window_size * self.window_size + nW = frame_len // wind_square + bswin = b0 * nW + + if memkv_text is not None: + s0 = memkv_text.shape[-2] + k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + + # shift + frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0) + if self.shift_sizes[layer_id%2] > 0: + frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3)) + # window partition + frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h] + q, k, v = qkv[0], qkv[1], qkv[2] + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + + if stage == 1: + if self.shift_sizes[layer_id%2] > 0: + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), + self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\ + - 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0)) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + else: + attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\ + - 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0)) + + if memkv_text is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + else: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2)) + attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0) + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_swin = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\ + .reshape(bswin, self.n_head, frame_num*wind_square, h))\ + .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + + context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution) + + # reverse cycle shift + if self.shift_sizes[layer_id%2] > 0: + context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + ret_context = context_swin.reshape(b0, s1, h0) + + # for mem + memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution) + memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution) + if self.shift_sizes[layer_id%2] > 0: + memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0) + + ret_mem = torch.cat((memk, memv), dim=-1) + return ret_context, ret_mem + + def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead] + # memkv [batchsize, pos, hidden_size*2] (include frames only) + # if memkv_text is not None: will attend to text + # pos: token's pos + b0, sin, h0 = frame_hidden_state.shape + h = h0 // self.n_head + assert sin == 1 + this_qkv = self.query_key_value[layer_id](frame_hidden_state) + thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:] + s1 = memkv.shape[1] if memkv is not None else 0 + frame_len = self.frame_resolution * self.frame_resolution + frame_num_before = s1 // frame_len + + + if memkv is not None: + pos_inframe = pos - frame_num_before * frame_len + + xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos + ypos = pos_inframe % self.frame_resolution + # [start, end) + if self.shift_sizes[layer_id%2] > 0: + xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2] + ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2] + xend = xstart + self.window_size + yend = ystart + self.window_size + xstart, ystart = max(0, xstart), max(0, ystart) + xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution) + else: + xstart = (xpos // self.window_size) * self.window_size + ystart = (ypos // self.window_size) * self.window_size + xend, yend = xstart + self.window_size, ystart+self.window_size + + # select index + selected_index = list() + if frame_num_before > 0: + # frames before + frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0 + for x in range(xstart, xend): + for y in range(ystart, yend): + selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start) + cnt_per_frame = len(selected_index) + for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame): + selected_index.append(selected_index[-cnt_per_frame]+frame_len) + + # the last frame + for x in range(xstart, xend): + for y in range(ystart, yend): + tmppos = x*self.frame_resolution+y + frame_num_before * frame_len + if tmppos < pos: + selected_index.append(tmppos) + else: + break + cnt_all = len(selected_index)+1 + selected_index = torch.tensor(selected_index, device=memkv.device) + used_memkv = torch.index_select(memkv, 1, selected_index) + used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:] + used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2) + used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2) + if memkv_text is not None: + cnt_all += memkv_text.shape[-2] + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3) + else: + used_k = thisk + used_v = thisv + + if memkv_text is not None: + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3) + else: + used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) + + thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h] + attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2)) + if memkv_text is not None: + attn[..., :memkv_text.shape[-2]] += log_text_attention_weights + attn = F.softmax(attn, dim=-1) + context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0) + + return context_swin, this_qkv[..., h0:] + +class FullAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + **kwargs, + ): + super(FullAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense") + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + + def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1): + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + assert stage == 1 + + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + + if memkv_text is not None: + s0 = memkv_text.shape[-2] + k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h] + q, k, v = qkv[0], qkv[1], qkv[2] + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril()) + + if memkv_text is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0) + else: + attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0] + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\ + .permute(0, 2, 1, 3).reshape(b0, s1, h0) + + # for mem + memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0) + memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0) + ret_mem = torch.cat((memk, memv), dim=-1) + + return context_swin, ret_mem + + def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1): + # pos: current token's pos + b0, sin, h0 = frame_hidden_state.shape + h = h0 // self.n_head + assert sin == 1 + assert stage == 1 + + this_qkv = self.query_key_value[layer_id](frame_hidden_state) + thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:] + + if memkv is not None: + used_k, used_v = memkv[..., :h0], memkv[..., h0:] + used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2) + used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2) + else: + used_k, used_v = thisk, thisv + + if memkv_text is not None: + used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2) + used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2) + + used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3) + used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3) + thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h] + attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2)) + if memkv_text is not None: + attn[..., :memkv_text.shape[-2]] += log_text_attention_weights + attn = F.softmax(attn, dim=-1) + + context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0) + + return context_swin, this_qkv[..., h0:] + + +def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask, + n_head, text_len, frame_len, frame_num, + attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs): + b, s0, h0 = q0.shape + s1 = s0 - text_len + h = h0 // n_head + assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num + # attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len] + if stage == 2: + assert frame_num == 3 + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.transpose(-1, -2) + + score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_any2text += log_text_attention_weights + score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \ + - 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len]) + # context for text + attention_probs_text = F.softmax(score_any2text_part1, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_text = attention_dropout(attention_probs_text) + context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :]) + context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0) + + if frame_num > 0: + score_any2text_part2 = score_any2text[..., text_len:, :] + + # score: frame local + q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2) + score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame) + if stage == 1: + score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \ + - 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1)) + + # context for frame + score_frame_all = torch.cat((score_any2text_part2, + score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_frame = attention_dropout(attention_probs_frame) + context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\ + view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h) + + context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0) + else: + context_frame = None + + return context_text2text, context_frame + +def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num, + attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs): + # limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame} + b, s0, h0 = k0.shape + frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1 + h = h0 // n_head + assert q0.shape[1] == 1 + assert v0.shape[1] == k0.shape[1] + + q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + if limited_spatial_channel_mem: + assert frame_num_before == 0 + assert stage == 1 # not implemented for stage-2 yet + score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + score[..., :text_len] += log_text_attention_weights + attention_probs_frame = F.softmax(score, dim=-1) + context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0) + + else: + score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_token2text += log_text_attention_weights + score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:]) + score_frame_all = torch.cat((score_token2text, + score_frame_local0), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + + context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \ + v0[:, :, text_len+frame_num_before*frame_len:, :]) + context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0) + + return context_frame + + +class CogVideoCacheModel(BaseModel): + def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) + self.layout = args.layout # [64, 64+1024, 64+6*1024] + self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2 + self.n_head = args.num_attention_heads + self.window_size = window_size if window_size is not None else args.window_size + + frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0])) + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + args.additional_seqlen, args.hidden_size + )) + + if self.stage == 1: + self.add_mixin('attention_plus', FullAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + n_head=args.num_attention_heads, + frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]), + )) + else: + self.add_mixin('attention_plus', WindowAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + window_size=self.window_size, + shift_size=self.window_size//2, + n_head=args.num_attention_heads, + frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]), + )) + + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations') + group.add_argument("--layout", type=str, default='64, 464, 2064') + group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后 + group.add_argument("--additional-seqlen", type=int, default=2000) + group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后 + return parser + + def disable_untrainable_params(self): + pass + + def position_embedding_forward(self, position_ids, **kw_args): + if position_ids.shape[-1] > 1: + if self.stage == 1: + if position_ids[0,-1] >= (512+400): + frame_num = position_ids.shape[-1] // 400 + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]), + self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400)) + ), + dim=-2 + ) + else: + position_embeddings = self.transformer.position_embeddings(position_ids) + else: + # given 3, interpolate 2 + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position_ids[..., :-800]), + self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400)) + ), + dim=-2 + ) + else: + if position_ids[0, 0] >= (512+400): + position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400)) + else: + position_embeddings = self.transformer.position_embeddings(position_ids) + return position_embeddings + + def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + hidden_size = hidden_states.shape[-1] + + # base model qkv + if mems is None: + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + assert (q0.shape[1]-text_len) % frame_len == 0 + memkv0 = torch.cat((k0, v0), dim=-1) + context_text, context_frame_local_text = attention_localframe_and_text_NAR( + q0, k0, v0, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=text_len, + frame_len=frame_len, + frame_num=(q0.shape[1]-text_len)//frame_len, + log_text_attention_weights=log_text_attention_weights, + stage=self.stage + ) + + # change: self.swin_attend_to_text默认为True: + memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:] + output_text = attn_module.dense(context_text) + + if (q0.shape[1]-text_len)//frame_len > 0: + assert (q0.shape[1]-text_len) % frame_len == 0 + context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference( + hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage) + if not enforce_no_swin: + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + else: + output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :]) + output = torch.cat((output_text, output_frame), dim=-2) + memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame + else: + output = output_text + memkv1 = memkv1_text + kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1) + + + else: + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + new_memkv0 = torch.cat((k0, v0), dim=-1) + old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:] + + context_frame_local_text = attention_localframe_and_text_AR( + q0, + torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2), + torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2), + n_head=attn_module.num_attention_heads_per_partition, + text_len=text_len, + frame_len=frame_len, + frame_num=None, + log_text_attention_weights=log_text_attention_weights, + layer_id=layer_id, + limited_spatial_channel_mem=limited_spatial_channel_mem, + ) + + old_memkv1 = mems[1][layer_id] if mems[1] is not None else None + + context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states, + old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None, + counter-text_len, + layer_id, + memkv_text=old_memkv1[..., :text_len, :], + log_text_attention_weights=log_text_attention_weights) + if not enforce_no_swin: + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + else: + output = attn_module.dense(context_frame_local_text) + + kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1) + + return output \ No newline at end of file diff --git a/models/cogvideo_model.py b/models/cogvideo_model.py new file mode 100644 index 0000000..dfbc136 --- /dev/null +++ b/models/cogvideo_model.py @@ -0,0 +1,543 @@ +# -*- encoding: utf-8 -*- +''' +@File : cogvideo_model.py +@Time : 2022/07/11 16:12:05 +@Author : Wenyi Hong +@Version : 1.0 +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib + +import torch +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim +from SwissArmyTransformer.model.transformer import unscaled_init_method +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +import torch.nn.functional as F +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +import math + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 912), + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + +def window_partition(x, window_size): + """ + Args: + x: (B, framenum, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, frame_num, window_size, window_size, C) + """ + B, framenum, H, W, C = x.shape + x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, frame_num, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, frame_num, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + framenum = windows.shape[1] + x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1) + return x + +class WindowAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + window_size, + shift_size, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + ): + super(WindowAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense", + ) + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.window_size = window_size + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + assert frame_resolution % window_size == 0 + assert 0 < shift_size < window_size + nW = (self.frame_resolution // self.window_size) ** 2 + ws_squre = self.window_size * self.window_size + + # odd non-shift, even shift + img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1)) + h_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + w_slices = (slice(0, -shift_size), + slice(-shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, :, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size] + sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00)) + attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num) + + self.attn_mask_sequential = attn_mask.clone().tril() + self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril() + + self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num) + self.attn_mask_interp = attn_mask.clone() + + # bi-dir + for bi_idx in range(0, frame_num, 2): + for uni_idx in range(1, frame_num, 2): + self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0 + self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0 + # uni-dir + for uni_idx in range(1, frame_num, 2): + self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_() + self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_() + for uni_idx2 in range(uni_idx+2, frame_num, 2): + self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0 + self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0 + + # expand dim + self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None] + self.attn_mask_interp = self.attn_mask_interp[None, None, :, None] + self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None] + self.causal_mask_interp = self.causal_mask_interp[None, None, :, None] + + self.shift_sizes = [0, shift_size] + # self.register_buffer("attn_mask", attn_mask) + # self.register_buffer("causal_mask", causal_mask) + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + + def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None, + text_attn_mask=None, mode_sequential=True): + # pb relax + swin_pb_relax = True + alpha = 16 + + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + if not self.mask_initialized: + self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + wind_square = self.window_size * self.window_size + nW = frame_len // wind_square + bswin = b0 * nW + + causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp + attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp + if text_hidden_state is not None: + s0 = text_hidden_state.shape[1] + qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h] + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2] + + # shift + frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0) + if self.shift_sizes[layer_id%2] > 0: + frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3)) + # window partition + frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0) + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h] + q, k, v = qkv[0], qkv[1], qkv[2] + + # pb-relax + if swin_pb_relax: + attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2)) + else: + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + + if self.shift_sizes[layer_id%2] > 0: + # attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0) + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\ + - 10000.0 * (1.0 - attn_mask) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + else: + attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\ + - 10000.0 * (1.0 - causal_mask) + attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square) + if swin_pb_relax: + swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1) + attn = (attn - swin_pb_relax_const)*alpha + + if text_hidden_state is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + else: + assert text_attn_mask is not None + text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2) + # pb-relax + if swin_pb_relax: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2)) + attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha + else: + attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2)) + + attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask) + attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0) + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_swin = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\ + .reshape(bswin, self.n_head, frame_num*wind_square, h))\ + .permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0) + + context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution) + # reverse cycle shift + if self.shift_sizes[layer_id%2] > 0: + context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3)) + context_swin = context_swin.reshape(b0, s1, h0) + + return context_swin + + +class FullAttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + frame_resolution, + n_head, + frame_num, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02), + ): + super(FullAttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3, + gather_output=False,init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear( + hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + bias=True, + module=self, + name="dense",) + for layer_id in range(num_layers) + ]) + + self.n_head = n_head + self.frame_resolution = frame_resolution + self.frame_len = frame_resolution * frame_resolution + self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril() + + self.mask_initialized = False + + self.attn_distribution = torch.nn.ParameterList([ + torch.nn.Parameter(torch.zeros(hidden_size)) + for _ in range(num_layers) + ]) + + def reinit(self, *pre_mixins): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + base_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data) + + def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None, + text_attn_mask=None, mode_sequential=False): + # pb relax + # frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead] + assert mode_sequential == True # only + swin_pb_relax = True + alpha = 16 + + if not self.mask_initialized: + self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype) + self.mask_initialized = True + b0, s1, h0 = frame_hidden_state.shape + h = h0 // self.n_head + frame_len = self.frame_resolution * self.frame_resolution + frame_num = s1 // frame_len + assert frame_num*frame_len == s1 + + qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\ + .permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h] + q, k, v = qkv[0], qkv[1], qkv[2] + + # frames-to-frames + if swin_pb_relax: + attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2)) + else: + attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2)) + attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask) + if swin_pb_relax: + swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1) + attn = (attn - swin_pb_relax_const)*alpha + + if text_hidden_state is None: + attn = F.softmax(attn, dim=-1) + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0) + else: + # frame-to-text + assert text_attn_mask is not None + s0 = text_hidden_state.shape[1] + qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h] + q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2] + text_attn_mask = text_attn_mask.unsqueeze(2) + if swin_pb_relax: + attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2)) + attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha + else: + attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2)) + attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask) + attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0) + + attn = torch.cat((attn, attn_frame2text), dim=-1) + attn = F.softmax(attn, dim=-1) + + if attn_dropout is not None: + with get_cuda_rng_tracker().fork(): + attn = attn_dropout(attn) + + context_frame = (torch.matmul(attn[..., :-s0], v) + + torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\ + .permute(0, 2, 1, 3).reshape(b0, s1, h0) + + return context_frame + + +def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local, + n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs): + b, s0, h0 = q0.shape + s1 = s0 - text_len + h = h0 // n_head + assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num + # attention_mask_totxt [b, 1, 1, text_len] + # attention_mask_local [1, 1, frame_num, frame_len, frame_len] + # attention_mask: [1, 1, text_len+frame_len, text_len+frame_len] + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.transpose(-1, -2) + + # score: any2text + score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len]) + score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \ + - 10000.0 * (1.0 - attention_mask_totxt) + score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \ + 10000.0 * (1.0 - attention_mask_totxt) + + # score: frame local + q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h) + k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2) + score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame) + score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \ + - 10000.0 * (1.0 - attention_mask_local) + + # context for frame + score_frame_all = torch.cat((score_any2text_part2, + score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1) + attention_probs_frame = F.softmax(score_frame_all, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_frame = attention_dropout(attention_probs_frame) + + context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h] + context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\ + view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h) + context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0) + + # context for text + attention_probs_text = F.softmax(score_any2text_part1, dim=-1) + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs_text = attention_dropout(attention_probs_text) + context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :]) + context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0) + + return context_text2text, context_frame + + +class CogVideoModel(BaseModel): + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output) + self.stage = args.cogvideo_stage # 1 or 2 + self.mode_sequential = True if self.stage==1 else False + self.layout = args.layout # [64, 64+400, 64+5*400] + self.n_head = args.num_attention_heads + frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0])) + frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]) + frame_len = self.layout[1]-self.layout[0] + + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + args.additional_seqlen, args.hidden_size + )) + + if args.window_size == -1: + # full attention + assert self.stage == 1 + self.add_mixin('attention_plus', FullAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + n_head=args.num_attention_heads, + frame_num=frame_num, + )) + else: + self.add_mixin('attention_plus', WindowAttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size, + frame_resolution=frame_resolution, + window_size=args.window_size, + shift_size=args.window_size//2, + n_head=args.num_attention_heads, + frame_num=frame_num, + )) + # attention_mask_local + self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0) + self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len) + + for idx in range(1, frame_num, 2): + self.attention_mask_local_interp[:, :, idx:idx+1].tril_() + self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0) + self.mask_initialized = False + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations') + group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num') + group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention") + group.add_argument("--additional-seqlen", type=int, default=2000) + group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) + return parser + + def disable_untrainable_params(self): + self.transformer.requires_grad_(False) + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :(64+400)] + position_plus = position_ids[..., (64+400):] + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400)) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, layer_id, **kw_args): + # mask.shape=[bs, 1, 1, 64] + if not self.mask_initialized: + self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype) + self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype) + self.mask_initialized = True + + attn_module = self.transformer.layers[layer_id].attention + hidden_size = hidden_states.shape[-1] + bs = hidden_states.shape[0] + + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None + + attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp + context_text, context_frame_local_text = attention_localframe_and_text( + q0, k0, v0, + attention_mask_totxt=mask, + attention_mask_local=attention_mask_local, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + frame_len=self.layout[1]-self.layout[0], + frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]), + attention_dropout=dropout_fn, + layer_id=layer_id, + ) + + context_frame_swin = self.get_mixin('attention_plus').attention_extra( + hidden_states[:, self.layout[0]:], layer_id, dropout_fn, + text_hidden_state=hidden_states[:, :self.layout[0]], + text_attn_mask=mask[..., 0, :], + mode_sequential=self.mode_sequential) + + attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id]) + attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0) + + output_text = attn_module.dense(context_text) + output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\ + +torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib) + output = torch.cat((output_text, output_frame), dim=-2) + + return output \ No newline at end of file diff --git a/pretrain_cogvideo.py b/pretrain_cogvideo.py new file mode 100644 index 0000000..defd906 --- /dev/null +++ b/pretrain_cogvideo.py @@ -0,0 +1,184 @@ +# -*- encoding: utf-8 -*- +''' +@File : pretrain_cogvideo.py +@Time : 2021/10/06 00:58:32 +@Author : Wenyi Hong +@Contact : hwy22@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse +import numpy as np +from icetk import icetk as tokenizer +tokenizer.add_special_tokens(['', '', '']) + +from models.cogvideo_model import CogVideoModel +from SwissArmyTransformer import mpu, get_args +from SwissArmyTransformer.training.deepspeed_training import training_main +from SwissArmyTransformer.data_utils import BinaryDataset + +def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + assert attention_mask_totxt is not None + layout = args.layout + assert seq_length == layout[-1] + n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long() + frame_len = layout[1]-layout[0] + position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long, + device=data.device) + for i in range(batch_size): + torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]], + dtype=torch.long, device=data.device) + torch.arange(512, 512+layout[2]-layout[0], + out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device) + return position_ids + + +def get_batch(data_iterator, args, timers): + # Items and their type. + keys = ['text', 'loss_mask', 'attention_mask_totxt'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens_ = data_b['text'].long() + loss_mask = data_b['loss_mask'].float() + attention_mask_totxt = data_b['attention_mask_totxt'].float() + + labels = tokens_[:, 1:].clone().contiguous() + loss_mask = loss_mask[:, 1:].contiguous() + tokens = tokens_[:, :-1].clone().contiguous() + + for idx in range(args.layout[0], args.layout[2], 400): + tokens[:, idx] = tokenizer[''] + # Get the masks and postition ids. + position_ids = get_masks_and_position_ids_video( + tokens, + attention_mask_totxt=attention_mask_totxt, + args=args + ) + attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1) + # Convert + if args.fp16: + attention_mask_totxt = attention_mask_totxt.half() + return tokens, labels, loss_mask, attention_mask_totxt, position_ids + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + + # Forward model. + logits, *mems = model(tokens, position_ids, attention_mask_totxt) + # ======= hyper params =======# + perframe_len = 400 + text_len=64 + frame_num = 5 + logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous() + losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:]) + # scaling loss mask + loss_mask = loss_mask[:, text_len:].reshape(-1) + + losses_1d = losses.reshape(-1) * loss_mask + loss = torch.sum(losses_1d) / loss_mask.sum() + # ===================== Log partial losses ======================== # + log_loss_dict = {} + bs = losses.shape[0] + + if args.cogvideo_stage == 1: + for i in range(frame_num): + log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) + else: + for i in range(1, frame_num-1): + log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) + + # ===================== END OF BLOCK ======================= # + return loss, log_loss_dict + + +def create_dataset_function(path, args): + dataset_layout = [64, 464, 2064] + input_layout = [64, 464, 2064] + # frame_num = 6 + # frame_interval = 2 # DEBUG!!! + def process_fn(row): + row = row.astype(np.int64) + text = row[:dataset_layout[0]] + frames = row[dataset_layout[0]:] + + if text[0] == tokenizer['']: + text = text[1:] # due to our way of data processing + if args.cogvideo_stage == 1: + text, loss_mask, frames = make_text_video_generation(text, frames) + else: + text, loss_mask, frames = mask_video_frame_interpolation(text, frames) + + n_pad = input_layout[0] - len(text) + parts = [ + np.array([tokenizer['']] * n_pad, dtype=np.int64), + text, + np.array([tokenizer['']], dtype=np.int64), + frames, + ] + ret = np.concatenate(parts, axis=0) + + attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad)) + return {'text': ret, + 'loss_mask': loss_mask, + 'attention_mask_totxt': attention_mask_totxt, + } + return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1]) + +def make_text_video_generation(text, frames): + input_layout = [64, 464, 2064] + text = text[text!= tokenizer['']][:input_layout[0]] # dataset format: 1.0秒{text} ... + loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位 + return text, loss_mask, frames + +def mask_video_frame_interpolation(text, frames): + input_layout = [64, 464, 2064] + frame_len = input_layout[1]-input_layout[0] + # text format: 1.0秒 {text} + text = text[text!= tokenizer['']][:input_layout[0]] + loss_mask = np.array([0] * (input_layout[1]+1) + + [1] * (input_layout[1]-input_layout[0]) + + [0] * (input_layout[1]-input_layout[0]) + + [1] * (input_layout[1]-input_layout[0]) + + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位 + + return text, loss_mask, frames + + + +if __name__ == '__main__': + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--txt-loss-scale', type=float, default=1) + CogVideoModel.add_model_specific_args(py_parser) + + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.layout = [int(x) for x in args.layout.split(',')] + + training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5cf5885 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +SwissArmyTransformer>=0.2 +icetk +gifmaker +torchvision \ No newline at end of file diff --git a/scripts/ds_brain_pretrain_cogvideo_stage1.sh b/scripts/ds_brain_pretrain_cogvideo_stage1.sh new file mode 100644 index 0000000..03c1b18 --- /dev/null +++ b/scripts/ds_brain_pretrain_cogvideo_stage1.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +# HOST_FILE_PATH="hostfile_single" + +video_data_test="" # TODO +CHECKPOINT_PATH="" # TODO: CogView2 ckpt + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name pretrain-cogvideo-stage1 \ + --tokenizer-type fake \ + --vocab-size 150010 \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --num-workers 0 \ + --num-layers 48 \ + --hidden-size 3072 \ + --num-attention-heads 48 \ + --layout 64,464,2064 \ + --window-size -1 \ + --cogvideo-stage 1 \ + --additional-seqlen 2000 \ + --train-iters 500000 \ + --resume-dataloader \ + --train-data ${video_data_test} \ + --train-data-weights 1 \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .001 \ + --checkpoint-activations \ + --max-sequence-length 1024 \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 500 \ + --eval-iters 15 \ + --log-interval 50 \ + --save $main_dir/checkpoints \ + --sandwich-ln \ + --load $CHECKPOINT_PATH \ +" + # --load $CHECKPOINT_PATH \ + # \ --sandwich-ln + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +#!/bin/bash + +# Distribute Example +#export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_DISABLE=0 +export NCCL_NET_GDR_LEVEL=2 +#export NCCL_IB_CUDA_SUPPORT=1 +#export NCCL_IB_GID_INDEX=3 +#export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null) +export NCCL_DEBUG=info +export OMP_NUM_THREADS=4 + +if [ $RLAUNCH_REPLICA == "0" ]; then + ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip +fi + +function finish { + rm -rf master_ip +} + +trap finish EXIT INT TERM + +while [ ! -f master_ip ]; do + echo "wait master_ip..." + ls > /dev/null && sleep 1; +done + +export MASTER_ADDR=$(cat master_ip) +echo "master_ip: $MASTER_ADDR" + +MP_SIZE=1 +task_set=$2 +source $1 +DATESTR=$(date +"%m-%d-%H-%M") + +mkdir logs +run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \ + --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt" + + +# run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/ds_brain_pretrain_cogvideo_stage2.sh b/scripts/ds_brain_pretrain_cogvideo_stage2.sh new file mode 100644 index 0000000..5b89b0a --- /dev/null +++ b/scripts/ds_brain_pretrain_cogvideo_stage2.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=warning NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +# HOST_FILE_PATH="hostfile_single" + +video_data_test="" # TODO +CHECKPOINT_PATH="" # TODO: CogView2 ckpt + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name pretrain-cogvideo-stage2 \ + --tokenizer-type fake \ + --vocab-size 150010 \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --num-workers 0 \ + --num-layers 48 \ + --hidden-size 3072 \ + --num-attention-heads 48 \ + --layout 64,464,2064 \ + --window-size 10 \ + --cogvideo-stage 2 \ + --additional-seqlen 2000 \ + --train-iters 500000 \ + --resume-dataloader \ + --train-data ${video_data_test} \ + --train-data-weights 1 \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .001 \ + --checkpoint-activations \ + --max-sequence-length 1024 \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 500 \ + --eval-iters 15 \ + --log-interval 50 \ + --save $main_dir/checkpoints \ + --sandwich-ln \ + --load $CHECKPOINT_PATH \ +" + # --load $CHECKPOINT_PATH \ + # \ --sandwich-ln + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +#!/bin/bash + +# Distribute Example +#export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_DISABLE=0 +export NCCL_NET_GDR_LEVEL=2 +#export NCCL_IB_CUDA_SUPPORT=1 +#export NCCL_IB_GID_INDEX=3 +#export NCCL_IB_HCA=$(pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; do cat $i/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo $i ; done; popd > /dev/null) +export NCCL_DEBUG=info +export OMP_NUM_THREADS=4 + +if [ $RLAUNCH_REPLICA == "0" ]; then + ifconfig eth0 | grep inet | grep -v inet6 | awk '{print $2}' > master_ip +fi + +function finish { + rm -rf master_ip +} + +trap finish EXIT INT TERM + +while [ ! -f master_ip ]; do + echo "wait master_ip..." + ls > /dev/null && sleep 1; +done + +export MASTER_ADDR=$(cat master_ip) +echo "master_ip: $MASTER_ADDR" + +MP_SIZE=1 +task_set=$2 +source $1 +DATESTR=$(date +"%m-%d-%H-%M") + +mkdir logs +run_cmd="sudo /opt/conda/bin/python -m torch.distributed.launch --nproc_per_node=8 \ + --nnodes=$RLAUNCH_REPLICA_TOTAL --node_rank=$RLAUNCH_REPLICA \ + --master_addr=$MASTER_ADDR --master_port=12355 pretrain_cogvideo.py $@ ${gpt_options} 2>&1 | tee logs/log-${DATESTR}-${RLAUNCH_REPLICA}.txt" + + +# run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_video_swin_cond_glm_interp.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/scripts/ds_config_zero.json b/scripts/ds_config_zero.json new file mode 100644 index 0000000..a9f7ad1 --- /dev/null +++ b/scripts/ds_config_zero.json @@ -0,0 +1,42 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "gradient_clipping": 0.1, + "zero_optimization": { + "stage": 2, + "cpu_offload": true, + "contiguous_gradients": false, + "overlap_comm": true, + "reduce_scatter": false, + "reduce_bucket_size": 100000000, + "allgather_bucket_size": 1000000000, + "load_from_fp32_weights": false + }, + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 400, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0002, + "betas": [ + 0.9, + 0.95 + ], + "eps": 1e-8, + "weight_decay": 1e-4 + } + }, + "activation_checkpointing": { + "partition_activations": false, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false + } + \ No newline at end of file diff --git a/scripts/inference_cogvideo_pipeline.sh b/scripts/inference_cogvideo_pipeline.sh new file mode 100644 index 0000000..ccbc543 --- /dev/null +++ b/scripts/inference_cogvideo_pipeline.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +NLAYERS=48 +NHIDDEN=3072 +NATT=48 +MAXSEQLEN=1024 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +MPSIZE=1 + +#SAMPLING ARGS +TEMP=1.05 +TOPK=12 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +MASTER_PORT=${MASTER_PORT} SAT_HOME=/sharefs/cogview-new python cogvideo_pipeline.py \ + --input-source interactive \ + --output-path ./output \ + --parallel-size 1 \ + --both-stages \ + --use-guidance-stage1 \ + --guidance-alpha 3.0 \ + --generate-frame-num 5 \ + --tokenizer-type fake \ + --mode inference \ + --distributed-backend nccl \ + --fp16 \ + --model-parallel-size $MPSIZE \ + --temperature $TEMP \ + --coglm-temperature2 0.89 \ + --top_k $TOPK \ + --sandwich-ln \ + --seed 1234 \ + --num-workers 0 \ + --batch-size 4 \ + --max-inference-batch-size 8 \ + $@ diff --git a/sr_pipeline/__init__.py b/sr_pipeline/__init__.py new file mode 100644 index 0000000..736cde4 --- /dev/null +++ b/sr_pipeline/__init__.py @@ -0,0 +1,17 @@ +# -*- encoding: utf-8 -*- +''' +@File : __init__.py +@Time : 2022/03/02 13:57:09 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +from .direct_sr import DirectSuperResolution +from .iterative_sr import IterativeSuperResolution +from .sr_group import SRGroup \ No newline at end of file diff --git a/sr_pipeline/direct_sr.py b/sr_pipeline/direct_sr.py new file mode 100644 index 0000000..fe32a3a --- /dev/null +++ b/sr_pipeline/direct_sr.py @@ -0,0 +1,117 @@ +# -*- encoding: utf-8 -*- +''' +@File : direct_sr.py +@Time : 2022/03/02 13:58:11 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch + +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview2.py +@Time : 2021/10/10 16:31:34 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from PIL import ImageEnhance, Image + +import torch +import argparse +from torchvision import transforms + +from SwissArmyTransformer import get_args +from SwissArmyTransformer.training.model_io import load_checkpoint +from .dsr_sampling import filling_sequence_dsr, IterativeEntfilterStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually + +from .dsr_model import DsrModel + +from icetk import icetk as tokenizer + +class DirectSuperResolution: + def __init__(self, args, path, max_bz=4, topk=6, onCUDA=False): + args.load = path + args.kernel_size = 5 + args.kernel_size2 = 5 + args.new_sequence_length = 4624 + args.layout = [96,496,4096] + + model = DsrModel(args) + if args.fp16: + model = model.half() + + load_checkpoint(model, args) # on cpu + model.eval() + self.model = model + self.onCUDA = onCUDA + if onCUDA: + self.model = self.model.cuda() + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + + self.strategy = IterativeEntfilterStrategy(invalid_slices, + temperature=1.0, topk=topk) # temperature not used # Temperature Freezed Here!! + self.max_bz = max_bz + + def __call__(self, text_tokens, image_tokens, enhance=False): + if len(text_tokens.shape) == 1: + text_tokens.unsqueeze_(0) + if len(image_tokens.shape) == 1: + image_tokens.unsqueeze_(0) + # ===================== Debug ======================== # + # new_image_tokens = [] + # for small_img in image_tokens: + # decoded = tokenizer.decode(image_ids=small_img) + # decoded = torch.nn.functional.interpolate(decoded, size=(480, 480)).squeeze(0) + # ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + # image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + # small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1) + # new_image_tokens.append(small_img2) + # image_tokens = torch.stack(new_image_tokens) + # return image_tokens + # ===================== END OF BLOCK ======================= # + if enhance: + new_image_tokens = [] + for small_img in image_tokens: + decoded = tokenizer.decode(image_ids=small_img).squeeze(0) + ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + small_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.), image_size=160).view(-1) + new_image_tokens.append(small_img2) + image_tokens = torch.stack(new_image_tokens) + + seq = torch.cat((text_tokens,image_tokens), dim=1) + seq1 = torch.tensor([tokenizer['']]*3601, device=image_tokens.device).unsqueeze(0).expand(text_tokens.shape[0], -1) + if not self.onCUDA: + print('Converting Dsr model...') + model = self.model.cuda() + else: + model = self.model + print('Direct super-resolution...') + output_list = [] + for tim in range(max((text_tokens.shape[0]+self.max_bz-1) // self.max_bz, 1)): + output1 = filling_sequence_dsr(model, + seq[tim*self.max_bz:(tim+1)*self.max_bz], + seq1[tim*self.max_bz:(tim+1)*self.max_bz], + warmup_steps=1, block_hw=(1, 0), + strategy=self.strategy + ) + output_list.extend(output1[1:]) + if not self.onCUDA: + print('Moving back Dsr to cpu...') + model = model.cpu() + torch.cuda.empty_cache() + return torch.cat(output_list, dim=0) \ No newline at end of file diff --git a/sr_pipeline/dsr_model.py b/sr_pipeline/dsr_model.py new file mode 100644 index 0000000..d918d18 --- /dev/null +++ b/sr_pipeline/dsr_model.py @@ -0,0 +1,225 @@ +# -*- encoding: utf-8 -*- +''' +@File : cuda2d_model.py +@Time : 2021/10/02 01:36:32 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + + +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim, unscaled_init_method +from SwissArmyTransformer.mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 512+400) + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2]) + assert new_edge % old_edge == 0 + self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size)) + # self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + + +class AttentionMixin(BaseMixin): + def __init__(self, num_layers, + hidden_size, + init_method=unscaled_init_method(0.02), + output_layer_init_method=unscaled_init_method(0.02) + ): + super(AttentionMixin, self).__init__() + self.num_layers = num_layers # replace attention in the LAST n layers + self.query_key_value = torch.nn.ModuleList( + [ColumnParallelLinear(hidden_size, 3 * hidden_size, stride=3, + gather_output=False, init_method=init_method) + for layer_id in range(num_layers) + ]) + self.dense = torch.nn.ModuleList( + [RowParallelLinear(hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + for layer_id in range(num_layers) + ]) + + def reinit(self, parent_model=None): + start_layer = len(self.transformer.layers) - self.num_layers + assert start_layer >= 0 + for layer_id in range(self.num_layers): + old_attention = self.transformer.layers[start_layer + layer_id].attention + self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data) + self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data) + self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data) + self.dense[layer_id].bias.data.copy_(old_attention.dense.bias.data) + +class DsrModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.original_sequence_length = args.max_sequence_length + additional_seqlen = args.new_sequence_length - args.max_sequence_length + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + additional_seqlen, args.hidden_size + )) + self.add_mixin('attention_plus', AttentionMixin( + num_layers=args.num_layers, + hidden_size=args.hidden_size + )) + self.layout = args.layout + # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]} + self.kernel_size = args.kernel_size + self.kernel_size2 = args.kernel_size2 + self.log_attention_weights = None + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :self.layout[1]] + position_plus = position_ids[..., self.layout[1]:] - self.original_sequence_length + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, + layer_id=None, log_attention_weights=None, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + # attention_plus on all layers + query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id] + dense_plus = self.get_mixin('attention_plus').dense[layer_id] + # split two parts + hidden_states_plus = hidden_states[:, self.layout[1]:] + hidden_states = hidden_states[:, :self.layout[1]] + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3) + # cuda2d model qkv + mixed_raw_layer = query_key_value_plus(hidden_states_plus) + q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer, 3) + + dropout_fn = attn_module.attention_dropout if self.training else None + + # cuda2d attention + context_layer0, context_layer1 = sparse_attention_2d_light( + q0, k0, v0, + q1, k1, v1, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + kernel_size=self.kernel_size, + kernel_size2=self.kernel_size2, + attention_dropout=dropout_fn, + log_attention_weights=log_attention_weights, + add_scalar=(kw_args['add_scalar'] if 'add_scalar' in kw_args else 0) + ) + + output_0 = attn_module.dense(context_layer0) + output_1 = dense_plus(context_layer1) + output = torch.cat((output_0, output_1), dim=1) + + return output + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) + # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]) + return logits_parallel + + def disable_untrainable_params(self): + self.transformer.requires_grad_(False) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations') + group.add_argument("--kernel-size", type=int, default=5) + group.add_argument("--kernel-size2", type=int, default=5) + group.add_argument("--layout", type=str, default='96,496,4096') + group.add_argument("--new-sequence-length", type=int, default=4096) + return parser + +def sparse_attention_2d_light(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, kernel_size2=7, attention_dropout=None, log_attention_weights = None, add_scalar=0, **kwargs): + ''' + q0, k0, v0: [batch_size, 1088, hidden_size] + q1, k1, v1: [batch_size, 4096, h2] + n_head: int + attention_mask: [batch_size, 1088, 1088] + ''' + from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting + + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l0, l1 = h0 // n_head, sqrt(s0-text_len), sqrt(s1) + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + if log_attention_weights is not None: + attention_scores += log_attention_weights + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + # scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, True) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + # cross attention + k0T = k0T[..., -l0**2:].reshape(b*n_head, h, l0, l0).contiguous() + scores_1_to_0 = f_similar(q1, k0T, kernel_size2, kernel_size2, False) # [b*n_head, l1, l1, field] + scores_1 = torch.cat( + ( + scores_1_to_0.view(b*n_head, -1, scores_1_to_0.shape[3]) + add_scalar, + scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3]) + ), + dim=-1) + attention_probs1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + # with get_cuda_rng_tracker().fork(): + attention_probs0 = attention_dropout(attention_probs0) + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) + # context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, True) + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head * h, l1**2) + # weighting for cross attention + probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view_as(scores_1_to_0) + v0_part = v0[:, :, -l0**2:].transpose(-1, -2).contiguous().view(b*n_head, h, l0, l0) + context1_to_0 = f_weighting(v0_part, probs_1_to_0.contiguous(), kernel_size2, kernel_size2, False) + context1_to_0 = context1_to_0.view(b, n_head * h, l1**2) + context1 = context1 + context1_to_0 + return context0.transpose(1, 2).reshape(b, s0, h0), context1.transpose(-1, -2) \ No newline at end of file diff --git a/sr_pipeline/dsr_sampling.py b/sr_pipeline/dsr_sampling.py new file mode 100644 index 0000000..5b8dded --- /dev/null +++ b/sr_pipeline/dsr_sampling.py @@ -0,0 +1,159 @@ +# -*- encoding: utf-8 -*- +''' +@File : cuda2d_sampling.py +@Time : 2021/10/09 00:46:04 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from cv2 import reduce +import torch + +import torch +import torch.nn.functional as F +import numpy as np + +def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + +class IterativeEntfilterStrategy: + def __init__(self, invalid_slices=[], temperature=1., topk=6): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long) + + + def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): + # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] + if temperature is None: + temperature = self.temperature + + logits = logits_.float() / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float('Inf') + logits = logits.view(-1, logits.shape[-1]) + + rprobs = F.softmax(logits.float(), dim=-1) + c = self.cluster_labels.expand(*rprobs.shape) + cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + + best_scores, best_clusters = cprobs.topk(self.topk) + bz = logits.shape[0] + best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True) + sampled_ids = torch.multinomial(best_scores, num_samples=1) + selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids) + selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500) + logits[selected_mask] = -65504 + # for i in range(bz): + # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] + # logits[i, self.cluster_labels != selected_cluster] = -65504 + + # logits = top_k_logits(logits, self.topk, self.top_p) + probs = F.softmax(logits.float()/0.6, dim=-1) # float is essetial, due to a bug in Pytorch + pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2]) + + assert tokens.shape[1] == pred.shape[1] + 1 + tokens = torch.cat((tokens[:, :1], pred), dim=1) + return tokens + +def filling_sequence_dsr( + model, + seq0, + seq1, + warmup_steps=3, + block_hw=(4, 4), + strategy=IterativeEntfilterStrategy(topk=10), + ): + ''' + seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] + 4095 {layout[2]} final_token. + Attention: + The sampling temperature are changing, temporally we hard code them here. + The temperature in the strategy is not used. + ''' + assert hasattr(model, 'layout') + layout = model.layout + assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \ + and seq0.shape[0] == seq1.shape[0] + assert len(layout) == 3 + assert seq1.shape[1] == layout[-1] - layout[-2] + 1 + assert (seq1 >= 0).all() and (seq0 >= 0).all() + device = seq0.device + # concat and pad sequences + batch_size = seq0.shape[0] + n_pad = layout[1] - seq0.shape[1] + assert n_pad > 0, "You should truncate long input before filling." + seq = torch.cat(( + torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype) + .unsqueeze(0).expand(batch_size, n_pad), + seq0, seq1), dim=1) # [b, layout[-1]+1] + assert seq.shape[1] == layout[-1] + 1 + + # build initial tokens, attention_mask, and position_ids + tokens = seq.clone() + attention_mask = torch.ones(layout[1], layout[1]).to(device) + attention_mask[:layout[0], layout[0]:] = 0 + attention_mask[n_pad:, :n_pad] = 0 + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 + position_ids = torch.cat(( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, layout[0] - n_pad), + torch.arange(513, 513 + layout[1] - layout[0]), + torch.arange(1024, 1024+layout[2]-layout[1]))).to(device) + log_attention_weights = torch.zeros(layout[1], layout[1], + device=device).type_as(next(model.parameters())) + log_attention_weights[layout[0]:, n_pad:layout[0]] = 0. + + # prepare for interation + unfixed = (tokens < 0) # just init an all-False tensor + unfixed[:, -layout[-1] + layout[-2]:] = True + + ll, rr = block_hw + edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) + num_steps = warmup_steps + ll - 1 + rr + # interative refining + + # unfixed[..., -(layout[-1] - layout[-2]):].view( + # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False + + + ret = [] + ret.append(tokens[:, layout[-2]+1:].clone()) + for step_cnt in range(1, num_steps+1): + if step_cnt <= warmup_steps: + logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights) + real_temp = 1. + new_tokens = strategy.forward(logits, tokens, real_temp) + tokens[unfixed] = new_tokens[unfixed] + else: + logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights) + real_temp = 1. + new_tokens = strategy.forward( + logits, tokens, real_temp, + entfilter=1.3, + filter_topk=5, + temperature2=0.6 + ) + # tokens[unfixed] = new_tokens[unfixed] + # fixed tokens (update unfixed) + unfixed2 = (tokens > 10000000) + for x in range(min(ll, step_cnt - warmup_steps)): + y = step_cnt - warmup_steps - x - 1 + if y < rr: + unfixed[..., -(layout[-1] - layout[-2]):].view( + batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False + unfixed2[..., -(layout[-1] - layout[-2]):].view( + batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = True + tokens[unfixed2] = new_tokens[unfixed2] + + ret.append(tokens[:, layout[-2]+1:].clone()) + + return ret diff --git a/sr_pipeline/iterative_sr.py b/sr_pipeline/iterative_sr.py new file mode 100644 index 0000000..a55a6b5 --- /dev/null +++ b/sr_pipeline/iterative_sr.py @@ -0,0 +1,118 @@ +# -*- encoding: utf-8 -*- +''' +@File : iterative_sr.py +@Time : 2022/03/02 15:57:45 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +# here put the import lib +import os +import sys +import math +import random +from PIL import ImageEnhance, Image + +import torch +import argparse +from torchvision import transforms + +from SwissArmyTransformer.training.model_io import load_checkpoint +from SwissArmyTransformer import get_args +from .itersr_sampling import filling_sequence_itersr, IterativeEntfilterStrategy +from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually + +from .itersr_model import ItersrModel + +from icetk import icetk as tokenizer + +class IterativeSuperResolution: + def __init__(self, args, path, max_bz=4, shared_transformer=None): + args.load = path + args.kernel_size = 5 + args.kernel_size2 = 5 + args.new_sequence_length = 4624 + args.layout = [16,3616] + + model = ItersrModel(args, transformer=shared_transformer) + if args.fp16: + model = model.half() + + load_checkpoint(model, args) # on cpu + model.eval() + self.model = model.cuda() + + # save cpu weights + self.saved_weights = dict((k,v.cpu()) + for k, v in model.named_parameters() + if 'transformer' in k + ) + + invalid_slices = [slice(tokenizer.num_image_tokens, None)] + + self.strategy = IterativeEntfilterStrategy(invalid_slices, + temperature=args.temp_all_itersr, topk=args.topk_itersr) + self.max_bz = max_bz + + def _restore_transformer_from_cpu(self, non_blocking=False): + for k, v in self.model.named_parameters(): + if k in self.saved_weights: + v.copy_(self.saved_weights[k]) + + def __call__(self, text_tokens, image_tokens, enhance=False, input_mask=None): + if len(text_tokens.shape) == 1: + text_tokens.unsqueeze_(0) + text_tokens = text_tokens.clone()[..., :16] + if len(image_tokens.shape) == 1: + image_tokens.unsqueeze_(0) + if enhance: + new_image_tokens = [] + for big_img in image_tokens: + decoded = tokenizer.decode(image_ids=big_img).squeeze(0) + ndarr = decoded.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + image_pil_raw = ImageEnhance.Sharpness(Image.fromarray(ndarr)) + big_img2 = tokenizer.encode(image_pil=image_pil_raw.enhance(1.5), image_size=480).view(-1) + new_image_tokens.append(big_img2) + image_tokens = torch.stack(new_image_tokens) + print('Converting Itersr model...') + self._restore_transformer_from_cpu() + model = self.model + print('iterative super-resolution...') + output_list = [] + for tim in range(max(text_tokens.shape[0] // self.max_bz, 1)): + big_img = image_tokens[tim*self.max_bz:(tim+1)*self.max_bz] + text_seq = text_tokens[tim*self.max_bz:(tim+1)*self.max_bz] + mask_raw = torch.tensor( + [ + -1, 0, 1, 2, 3, 4, + 0, -1, 2, -1, -2, 5, + 1, -2, 3, 4, 5, 6, + 2, 3, 4, 5, -1, 1, + 3, -1, -2, 0, -1, 2, + 4, 5, 6, 1, 3, -2 + ] + ).view(1, 6, 1, 6).expand(10, 6, 10, 6).reshape(-1).contiguous() + + topks = [60, 40, 40, 40, 20, 20, 10] + + for mask_ratio in range(1, 7): + self.strategy.topk = topks[mask_ratio] + mask = (mask_raw.to(big_img.device) >= mask_ratio) + if input_mask is not None: + mask = mask & input_mask + big_img.masked_fill_(mask, tokenizer['']) + seq1 = big_img + output1 = filling_sequence_itersr(model, text_seq, seq1, + warmup_steps=1, block_hw=(1, 0), + strategy=self.strategy + ) + big_img = output1 + print(f'Iter {mask_ratio} times.') + output_list.append(output1.clone()) + return torch.cat(output_list, dim=0) \ No newline at end of file diff --git a/sr_pipeline/itersr_model.py b/sr_pipeline/itersr_model.py new file mode 100644 index 0000000..40981bc --- /dev/null +++ b/sr_pipeline/itersr_model.py @@ -0,0 +1,232 @@ +# -*- encoding: utf-8 -*- +''' +@File : itersr_model.py +@Time : 2021/10/02 01:36:32 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + + +from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin + +from SwissArmyTransformer.mpu.utils import sqrt +from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker +from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear +from SwissArmyTransformer.model.transformer import unscaled_init_method, split_tensor_along_last_dim + +class PositionEmbeddingMixin(BaseMixin): + def __init__(self, additional_sequence_length, hidden_size, + init_method_std=0.02, reinit_slice=slice(512, 512+400) + ): + super(PositionEmbeddingMixin, self).__init__() + self.reinit_slice = reinit_slice + self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size) + torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) + + def reinit(self, parent_model=None): + old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice] + old_len, hidden_size = old_weights.shape + assert hidden_size == self.position_embeddings.weight.shape[-1] + old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2]) + assert new_edge % old_edge == 0 + self.position_embeddings.weight.data.view(new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size)) + +class ItersrModel(BaseModel): + def __init__(self, args, transformer=None): + super().__init__(args, transformer=transformer) + self.original_sequence_length = args.max_sequence_length + additional_seqlen = args.new_sequence_length - args.max_sequence_length + self.add_mixin('extra_position_embedding', PositionEmbeddingMixin( + additional_seqlen, args.hidden_size + )) + # self.add_mixin('attention_plus', AttentionMixin( + # num_layers=args.num_layers, + # hidden_size=args.hidden_size + # )) + self.layout = args.layout + # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]} + self.kernel_size = args.kernel_size + self.kernel_size2 = args.kernel_size2 + self.log_attention_weights = None + + def position_embedding_forward(self, position_ids, **kw_args): + position = position_ids[..., :self.layout[0]] + position_plus = position_ids[..., self.layout[0]:] - self.original_sequence_length + position_embeddings = torch.cat( + ( + self.transformer.position_embeddings(position), + self.get_mixin('extra_position_embedding').position_embeddings(position_plus) + ), + dim=-2 + ) + return position_embeddings + + def attention_forward(self, hidden_states, mask, + layer_id=None, log_attention_weights=None, **kw_args): + attn_module = self.transformer.layers[layer_id].attention + # base model qkv + mixed_raw_layer = attn_module.query_key_value(hidden_states) + q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, :self.layout[0]], 3) + # cuda2d model qkv + q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0]:], 3) + + dropout_fn = attn_module.attention_dropout if self.training else None + + # cuda2d attention + context_layer = sparse_attention_2d_text( + q0, k0, v0, + q1, k1, v1, + mask, + n_head=attn_module.num_attention_heads_per_partition, + text_len=self.layout[0], + kernel_size=self.kernel_size, + attention_dropout=dropout_fn, + log_attention_weights=log_attention_weights, + ) + + output = attn_module.dense(context_layer) + + return output + + def final_forward(self, logits, **kwargs): + logits_parallel = logits + logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]).float() + # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000]) + return logits_parallel + + # def disable_untrainable_params(self): + # self.transformer.requires_grad_(False) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('Cuda2dModel', 'cuda2d model configurations') + group.add_argument("--kernel-size", type=int, default=5) + group.add_argument("--kernel-size2", type=int, default=5) + group.add_argument("--layout", type=str, default='16,3616') + group.add_argument("--new-sequence-length", type=int, default=4096) + return parser + +def sparse_attention_2d_text(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs): + ''' + q0, k0, v0: [batch_size, 16, hidden_size] + q1, k1, v1: [batch_size, 3600, hidden_size] + n_head: int + attention_mask: [batch_size, 16] + ''' + from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l1 = h0 // n_head, sqrt(s1) + assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}" + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + # cross attention + scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T) + if log_attention_weights is not None: + scores_1_to_0 += log_attention_weights + scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + scores_1 = torch.cat( + ( + scores_1_to_0.view(b*n_head, s1, s0), + scores_1_to_1.view(b*n_head, -1, scores_1_to_1.shape[3]) + ), + dim=-1) + attention_probs1 = F.softmax(scores_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3]:].view_as(scores_1_to_1) + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head, h, l1**2) + # weighting for cross attention + probs_1_to_0 = attention_probs1[:, :, :scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3]) + + context1_to_0 = torch.matmul(probs_1_to_0, v0) + context1 = context1.transpose(-1, -2) + context1_to_0 + + output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0) + + return output + +def sparse_attention_2d_notext(q0, k0, v0, q1, k1, v1, attention_mask, n_head, text_len, kernel_size=9, attention_dropout=None, log_attention_weights = None, **kwargs): + ''' + q0, k0, v0: [batch_size, 16, hidden_size] + q1, k1, v1: [batch_size, 3600, hidden_size] + n_head: int + attention_mask: [batch_size, 16] + ''' + from SwissArmyTransformer.mpu.local_attention_function import f_similar, f_weighting + b, s0, h0 = q0.shape + b, s1, h1 = q1.shape + h, l1 = h0 // n_head, sqrt(s1) + assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}" + + q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3) + k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1) + + # standard attention for level 0 + attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T) + + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) + + attention_probs0 = F.softmax(attention_scores, dim=-1) + + # local attention for level 1 + q1 = (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1//n_head)).contiguous().view(b*n_head, h1//n_head, l1, l1) + k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b*n_head, h1//n_head, l1, l1) + scores_1_to_1 = f_similar(q1, k1, kernel_size*2-1, kernel_size, False) + + attention_probs1 = F.softmax(scores_1_to_1, dim=-1) + + if attention_dropout is not None: + with get_cuda_rng_tracker().fork(): + attention_probs1 = attention_dropout(attention_probs1) + + # weighting for level 0 + context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h] + # weighting for level 1 + probs_1_to_1 = attention_probs1 + context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size*2-1, kernel_size, False) + + context1 = context1_to_1.view(b, n_head, h, l1**2) + # weighting for cross attention + context1 = context1.transpose(-1, -2) + + output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0+s1, h0) + + return output \ No newline at end of file diff --git a/sr_pipeline/itersr_sampling.py b/sr_pipeline/itersr_sampling.py new file mode 100644 index 0000000..df22a00 --- /dev/null +++ b/sr_pipeline/itersr_sampling.py @@ -0,0 +1,168 @@ +# -*- encoding: utf-8 -*- +''' +@File : itersr_sampling.py +@Time : 2022/03/03 14:24:28 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import numpy as np + +import torch +import torch.nn.functional as F +from icetk import icetk as tokenizer + +def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + +# class IterativeEntfilterStrategy: +# def __init__(self, invalid_slices=[], temperature=1., topk=10): +# self.invalid_slices = invalid_slices +# self.temperature = temperature +# self.topk = topk +# self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long) + + +# def forward(self, logits_, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): +# # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] +# if temperature is None: +# temperature = self.temperature + +# logits = logits_.float() / temperature +# for invalid_slice in self.invalid_slices: +# logits[..., invalid_slice] = -float('Inf') +# logits = logits.view(-1, logits.shape[-1]) + +# rprobs = F.softmax(logits.float(), dim=-1) +# c = self.cluster_labels.expand(*rprobs.shape) +# cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs) + +# best_scores, best_clusters = cprobs.topk(self.topk) +# bz = logits.shape[0] +# best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True) +# sampled_ids = torch.multinomial(best_scores, num_samples=1) +# selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids) +# selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500) +# logits[selected_mask] = -65504 +# # for i in range(bz): +# # selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)] +# # logits[i, self.cluster_labels != selected_cluster] = -65504 + +# # logits = top_k_logits(logits, self.topk, self.top_p) +# probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch +# pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2]) + +# assert tokens.shape[1] == pred.shape[1] +# tokens = pred +# return tokens + +class IterativeEntfilterStrategy: + def __init__(self, invalid_slices=[], temperature=1., topk=10): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + + def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): + # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] + if temperature is None: + temperature = self.temperature + # check entropy filter + # if entfilter is not None: + # assert temperature2 is not None + # topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1) + # ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length] + # temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2 + + logits = logits.float() / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float('Inf') + + # debiased topk + # probs = F.softmax(logits, dim=-1) + # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1) + # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + # edge_idx = tk_idx[:, :, -1:] + # edge_value = tk_value[:, :, -1:] + # edge_mask = probs.gather(dim=-1, index=pred) < edge_value + # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token + # pred.squeeze_(-1) # [batch_size, seq_length] + + top_k_logits_(logits, self.topk) + probs = F.softmax(logits, dim=-1) + pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + pred.squeeze_(-1) + + assert tokens.shape[1] == pred.shape[1] + tokens = pred + return tokens + +def filling_sequence_itersr( + model, + seq0, + seq1, + warmup_steps=3, + block_hw=(4, 4), + strategy=IterativeEntfilterStrategy(topk=10), + ): + ''' + seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] + 4095 {layout[2]} final_token. + Attention: + The sampling temperature are changing, temporally we hard code them here. + The temperature in the strategy is not used. + ''' + assert hasattr(model, 'layout') + layout = model.layout + + device = seq0.device + # concat and pad sequences + batch_size = seq0.shape[0] + n_pad = layout[0] - seq0.shape[1] + assert n_pad >= 0, "You should truncate long input before filling." + seq = torch.cat(( + torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype) + .unsqueeze(0).expand(batch_size, n_pad), + seq0, seq1), dim=1) # [b, layout[-1]+1] + assert seq.shape[1] == layout[-1] + + # build initial tokens, attention_mask, and position_ids + tokens = seq.clone() + attention_mask = torch.ones(layout[0]).to(device) + attention_mask[:n_pad] = 0 + attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16 + position_ids = torch.cat(( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, layout[0] - n_pad), + torch.arange(1024, 1024+layout[1]-layout[0]))).to(device) + log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters())) + log_attention_weights[n_pad:layout[0]] = 0. + log_attention_weights = log_attention_weights.unsqueeze(0) + + # prepare for interation + unfixed = (tokens == tokenizer['']) + ll, rr = block_hw + edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) + num_steps = 1 + # interative refining + + # unfixed[..., -(layout[-1] - layout[-2]):].view( + # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False + + + ret = [] + # ret.append(tokens[:, layout[-2]:-1].clone()) + for step_cnt in range(1, num_steps+1): + logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights) + real_temp = 1. + new_tokens = strategy.forward(logits, tokens, real_temp) + tokens[unfixed] = new_tokens[unfixed] + + ret.append(tokens[:, layout[-2]:].clone()) + return torch.cat(ret, dim=0) \ No newline at end of file diff --git a/sr_pipeline/sr_group.py b/sr_pipeline/sr_group.py new file mode 100644 index 0000000..1ec51b6 --- /dev/null +++ b/sr_pipeline/sr_group.py @@ -0,0 +1,49 @@ +# -*- encoding: utf-8 -*- +''' +@File : sr_group.py +@Time : 2022/04/02 01:17:21 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +import numpy as np +import torch +import torch.nn.functional as F +from SwissArmyTransformer.resources import auto_create +from .direct_sr import DirectSuperResolution +from .iterative_sr import IterativeSuperResolution + +class SRGroup: + def __init__(self, args, home_path=None,): + dsr_path = auto_create('cogview2-dsr', path=home_path) + itersr_path = auto_create('cogview2-itersr', path=home_path) + dsr = DirectSuperResolution(args, dsr_path) + itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer) + self.dsr = dsr + self.itersr = itersr + + def sr_base(self, img_tokens, txt_tokens): + assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2 + batch_size = img_tokens.shape[0] + txt_len = txt_tokens.shape[-1] + if len(txt_tokens.shape) == 1: + txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) + sred_tokens = self.dsr(txt_tokens, img_tokens) + iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone()) + return iter_tokens[-batch_size:] + + # def sr_patch(self, img_tokens, txt_tokens): + # assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2 + # batch_size = img_tokens.shape[0] * 9 + # txt_len = txt_tokens.shape[-1] + # if len(txt_tokens.shape) == 1: + # txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) + # img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400) + # iter_tokens = self.sr_base(img_tokens, txt_tokens) + # return iter_tokens \ No newline at end of file