mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-13 19:19:15 +08:00
init
This commit is contained in:
parent
30ebd13299
commit
34dd61430e
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,4 +5,5 @@ runs/
|
||||
checkpoints/
|
||||
master_ip
|
||||
logs/
|
||||
*.DS_Store
|
||||
*.DS_Store
|
||||
.idea
|
26
Dockerfile
26
Dockerfile
@ -1,26 +0,0 @@
|
||||
FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV TZ=Europe/London
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
|
||||
RUN apt-get update && apt-get upgrade -y
|
||||
RUN apt install git -y \
|
||||
&& apt install python3 -y \
|
||||
&& apt install python3-pip -y \
|
||||
&& apt install git -y \
|
||||
&& apt install python-is-python3 -y \
|
||||
&& apt install python3-tk -y \
|
||||
&& apt install ffmpeg libsm6 libxext6 -y \
|
||||
&& apt install git -y
|
||||
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 \
|
||||
&& pip install opencv-python
|
||||
|
||||
|
||||
RUN pip install SwissArmyTransformer>=0.2.9 icetk gifmaker
|
2
LICENSE
2
LICENSE
@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2024 CogVideo Model Team @ Zhipu AI
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
116
Model_License
116
Model_License
@ -1,79 +1,71 @@
|
||||
The CogVideo License
|
||||
|
||||
Section I: PREAMBLE
|
||||
|
||||
Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
|
||||
|
||||
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
||||
|
||||
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
|
||||
|
||||
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
||||
|
||||
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
||||
|
||||
NOW THEREFORE, You and Licensor agree as follows:
|
||||
The CogVideoX License
|
||||
|
||||
1. Definitions
|
||||
|
||||
- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
||||
- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
||||
- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
||||
- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
||||
- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
||||
- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
||||
- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
||||
- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
|
||||
- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
||||
- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
|
||||
- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||
- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
|
||||
“Licensor” means the CogVideoX Model Team that distributes its Software.
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS
|
||||
“Software” means the CogVideoX model parameters made available under this license.
|
||||
|
||||
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
||||
2. License Grant
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
|
||||
Under the terms and conditions of this license, the licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. The intellectual property rights of the generated content belong to the user to the extent permitted by applicable local laws.
|
||||
This license allows you to freely use all open-source models in this repository for academic research. Users who wish to use the models for commercial purposes must register and obtain a basic commercial license in https://open.bigmodel.cn/mla/form .
|
||||
Users who have registered and obtained the basic commercial license can use the models for commercial activities for free, but must comply with all terms and conditions of this license. Additionally, the number of service users (visits) for your commercial activities must not exceed 1 million visits per month.
|
||||
If the number of service users (visits) for your commercial activities exceeds 1 million visits per month, you need to contact our business team to obtain more commercial licenses.
|
||||
The above copyright statement and this license statement should be included in all copies or significant portions of this software.
|
||||
|
||||
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
3. Restriction
|
||||
|
||||
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
||||
Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
||||
You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
||||
You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||
You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
||||
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
||||
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
||||
You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
|
||||
|
||||
Section IV: OTHER PROVISIONS
|
||||
You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
|
||||
|
||||
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
|
||||
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
|
||||
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
||||
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
||||
4. Disclaimer
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
5. Limitation of Liability
|
||||
|
||||
EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
||||
|
||||
6. Dispute Resolution
|
||||
|
||||
Attachment A
|
||||
This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
|
||||
|
||||
Use Restrictions
|
||||
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
|
||||
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
- In any way that violates any applicable national, federal, state, local or international law or regulation;
|
||||
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
||||
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
||||
- To generate or disseminate personal identifiable information that can be used to harm an individual;
|
||||
- To defame, disparage or otherwise harass others;
|
||||
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
||||
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
||||
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
||||
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
|
||||
- To provide medical advice and medical results interpretation;
|
||||
- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
|
||||
1. 定义
|
||||
|
||||
“许可方”是指分发其软件的 CogVideoX 模型团队。
|
||||
|
||||
“软件”是指根据本许可提供的 CogVideoX 模型参数。
|
||||
|
||||
2. 许可授予
|
||||
|
||||
根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。生成内容的知识产权所属,可根据适用当地法律的规定,在法律允许的范围内由用户享有生成内容的知识产权或其他权利。
|
||||
本许可允许您免费使用本仓库中的所有开源模型进行学术研究。对于希望将模型用于商业目的的用户,需在 https://open.bigmodel.cn/mla/form 完成登记并获得基础商用授权。
|
||||
|
||||
经过登记并获得基础商用授权的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
|
||||
在本许可证下,您的商业活动的服务用户数量(访问量)不得超过100万人次访问 / 每月。如果超过,您需要与我们的商业团队联系以获得更多的商业许可。
|
||||
上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
|
||||
|
||||
3.限制
|
||||
|
||||
您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
|
||||
|
||||
您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
|
||||
|
||||
4.免责声明
|
||||
|
||||
本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。
|
||||
在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
|
||||
|
||||
5. 责任限制
|
||||
|
||||
除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
|
||||
|
||||
6.争议解决
|
||||
|
||||
本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
|
||||
|
||||
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
|
150
README.md
150
README.md
@ -1,93 +1,125 @@
|
||||
# CogVideo
|
||||
# CogVideoX
|
||||
|
||||
This is the official repo for the paper: [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](http://arxiv.org/abs/2205.15868).
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
<div align="center">
|
||||
<img src=resources/logo.svg width="50%"/>
|
||||
<p align="center">
|
||||
|
||||
**News!** The [demo](https://models.aminer.cn/cogvideo/) for CogVideo is available!
|
||||
🤗 Experience on <a href="https://huggingface.co/spaces/THUDM/CogVideoX" target="_blank">CogVideoX Huggingface Space</a>
|
||||
</p>
|
||||
</div>
|
||||
<p align="center">
|
||||
👋 Join our <a href="resources/WECHAT.md" target="_blank">WeChat</a> and <a href="https://discord.gg/Ewaabk6s" target="_blank">Discord</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
📍 Visit <a href="https://chatglm.cn/video">清影</a> and <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9">API Platform</a> to experience larger-scale commercial video generation models.
|
||||
</p>
|
||||
|
||||
It's also integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [](https://huggingface.co/spaces/THUDM/CogVideo)
|
||||
## Update and News
|
||||
|
||||
- 🔥 **News**: ``2024/8/6``: We have also open-sourced **3D Causal VAE** used in **CogVideoX-2B**, which can reconstruct the video almost losslessly.
|
||||
- 🔥 **News**: ``2024/8/6``: We have open-sourced **CogVideoX-2B**,the first model in the CogVideoX series of video
|
||||
generation models.
|
||||
|
||||
**News!** The code and model for text-to-video generation is now available! Currently we only supports *simplified Chinese input*.
|
||||
**More powerful models with larger parameter sizes are on the way~ Stay tuned!**
|
||||
|
||||
https://user-images.githubusercontent.com/48993524/170857367-2033c514-3c9f-4297-876f-2468592a254b.mp4
|
||||
## CogVideoX-2B Gallery
|
||||
|
||||
* **Read** our paper [CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers](https://arxiv.org/abs/2205.15868) on ArXiv for a formal introduction.
|
||||
* **Try** our demo at [https://models.aminer.cn/cogvideo/](https://models.aminer.cn/cogvideo/)
|
||||
* **Run** our pretrained models for text-to-video generation. Please use A100 GPU.
|
||||
* **Cite** our paper if you find our work helpful
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
|
||||
<p>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</p>
|
||||
</div>
|
||||
|
||||
```
|
||||
@article{hong2022cogvideo,
|
||||
title={CogVideo: Large-scale Pretraining for Text-to-Video Generation via Transformers},
|
||||
author={Hong, Wenyi and Ding, Ming and Zheng, Wendi and Liu, Xinghan and Tang, Jie},
|
||||
journal={arXiv preprint arXiv:2205.15868},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/9de41efd-d4d1-4095-aeda-246dd834e91d" width="80%" controls autoplay></video>
|
||||
<p>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</p>
|
||||
</div>
|
||||
|
||||
## Web Demo
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/941d6661-6a8d-4a1b-b912-59606f0b2841" width="80%" controls autoplay></video>
|
||||
<p>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</p>
|
||||
</div>
|
||||
|
||||
The demo for CogVideo is at [https://models.aminer.cn/cogvideo/](https://models.aminer.cn/cogvideo/), where you can get hands-on practice on text-to-video generation. *The original input is in Chinese.*
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/938529c4-91ae-4f60-b96b-3c3947fa63cb" width="80%" controls autoplay></video>
|
||||
<p>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</p>
|
||||
</div>
|
||||
|
||||
## Model Introduction
|
||||
|
||||
## Generated Samples
|
||||
CogVideoX is an open-source version of the video generation model, which is homologous
|
||||
to [清影](https://chatglm.cn/video).
|
||||
|
||||
**Video samples generated by CogVideo**. The actual text inputs are in Chinese. Each sample is a 4-second clip of 32 frames, and here we sample 9 frames uniformly for display purposes.
|
||||
The table below shows the list of video generation models we currently provide,
|
||||
along with related basic information:
|
||||
|
||||

|
||||
| Model Name | CogVideoX-2B |
|
||||
|-------------------------------------------|--------------------------------------------------------------|
|
||||
| Prompt Language | English |
|
||||
| GPU Memory Required for Inference (FP16) | 21.6GB |
|
||||
| GPU Memory Required for Fine-tuning(bs=1) | 46.2GB |
|
||||
| Prompt Max Length | 226 Tokens |
|
||||
| Video Length | 6 seconds |
|
||||
| Frames Per Second | 8 frames |
|
||||
| Resolution | 720 * 480 |
|
||||
| Quantized Inference | Not Supported |
|
||||
| Multi-card Inference | Not Supported |
|
||||
| Download Link | 🤗 [CogVideoX-2B](https://huggingface.co/THUDM/CogVideoX-2B) |
|
||||
|
||||

|
||||
## Project Structure
|
||||
|
||||
This open-source repository will guide developers to quickly get started with the basic usage and fine-tuning examples
|
||||
of the **CogVideoX** open-source model.
|
||||
|
||||
### Inference
|
||||
|
||||
**CogVideo is able to generate relatively high-frame-rate videos.**
|
||||
A 4-second clip of 32 frames is shown below.
|
||||
+ [cli_demo](inference/cli_demo.py): A more detailed explanation of the inference code, mentioning the significance of
|
||||
common parameters.
|
||||
+ [cli_vae_demo](inference/cli_vae_demo.py): Executing the VAE inference code alone currently requires 71GB of memory, but it will be optimized in the future.
|
||||
+ [convert_demo](inference/converter_demo.py): How to convert user input into a format suitable for CogVideoX.
|
||||
+ [web_demo](inference/web_demo.py): A simple streamlit web application demonstrating how to use the CogVideoX-2B model
|
||||
to generate videos.
|
||||
|
||||

|
||||
<div style="text-align: center;">
|
||||
<img src="resources/web_demo.png" style="width: 100%; height: auto;" />
|
||||
</div>
|
||||
|
||||
## Getting Started
|
||||
### sat
|
||||
|
||||
### Setup
|
||||
+ [sat_demo](sat/configs/README_zh.md): Contains the inference code and fine-tuning code of SAT weights. It is
|
||||
recommended to improve based on the CogVideoX model structure. Innovative researchers use this code to better perform
|
||||
rapid stacking and development.
|
||||
|
||||
* Hardware: Linux servers with Nvidia A100s are recommended, but it is also okay to run the pretrained models with smaller `--max-inference-batch-size` and `--batch-size` or training smaller models on less powerful GPUs.
|
||||
* Environment: install dependencies via `pip install -r requirements.txt`.
|
||||
* LocalAttention: Make sure you have CUDA installed and compile the local attention kernel.
|
||||
### Tools
|
||||
|
||||
```shell
|
||||
pip install git+https://github.com/Sleepychord/Image-Local-Attention
|
||||
```
|
||||
This folder contains some tools for model conversion / caption generation, etc.
|
||||
|
||||
## Docker
|
||||
Alternatively you can use Docker to handle all dependencies.
|
||||
+ [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): Convert SAT model weights to Huggingface model weights.
|
||||
+ [caption_demo](tools/caption_demo.py): Caption tool, a model that understands videos and outputs them in text.
|
||||
|
||||
1. Run ```./build_image.sh```
|
||||
2. Run ```./run_image.sh```
|
||||
3. Run ```./install_image_local_attention```
|
||||
## Project Plan
|
||||
|
||||
Optionally, after that you can recommit the image to avoid having to install image local attention again.
|
||||
- [x] Open source CogVideoX model
|
||||
- [x] Open source 3D Causal VAE used in CogVideoX.
|
||||
- [x] CogVideoX model inference example (CLI / Web Demo)
|
||||
- [x] CogVideoX online experience demo (Huggingface Space)
|
||||
- [x] CogVideoX open source model API interface example (Huggingface)
|
||||
- [x] CogVideoX model fine-tuning example (SAT)
|
||||
- [ ] CogVideoX model fine-tuning example (Huggingface / SAT)
|
||||
- [ ] Open source CogVideoX-Pro (adapted for CogVideoX-2B suite)
|
||||
- [ ] Release CogVideoX technical report
|
||||
|
||||
We welcome your contributions. You can click [here](resources/contribute.md) for more information.
|
||||
|
||||
### Download
|
||||
## Model License
|
||||
|
||||
Our code will automatically download or detect the models into the path defined by environment variable `SAT_HOME`. You can also manually download [CogVideo-Stage1](https://lfs.aminer.cn/misc/cogvideo/cogvideo-stage1.zip) , [CogVideo-Stage2](https://lfs.aminer.cn/misc/cogvideo/cogvideo-stage2.zip) and [CogView2-dsr](https://model.baai.ac.cn/model-detail/100041) place them under SAT_HOME (with folders named `cogvideo-stage1` , `cogvideo-stage2` and `cogview2-dsr`)
|
||||
The code in this repository is released under the [Apache 2.0 License](LICENSE).
|
||||
|
||||
### Text-to-Video Generation
|
||||
The model weights and implementation code are released under the [CogVideoX LICENSE](MODEL_LICENSE).
|
||||
|
||||
```
|
||||
./scripts/inference_cogvideo_pipeline.sh
|
||||
```
|
||||
## Citation
|
||||
|
||||
Arguments useful in inference are mainly:
|
||||
🌟 If you find our work helpful, please leave us a star. 🌟
|
||||
|
||||
* `--input-source [path or "interactive"]`. The path of the input file with one query per line. A CLI would be launched when using "interactive".
|
||||
* `--output-path [path]`. The folder containing the results.
|
||||
* `--batch-size [int]`. The number of samples will be generated per query.
|
||||
* `--max-inference-batch-size [int]`. Maximum batch size per forward. Reduce it if OOM.
|
||||
* `--stage1-max-inference-batch-size [int]` Maximum batch size per forward in Stage 1. Reduce it if OOM.
|
||||
* `--both-stages`. Run both stage1 and stage2 sequentially.
|
||||
* `--use-guidance-stage1` Use classifier-free guidance in stage1, which is strongly suggested to get better results.
|
||||
|
||||
You'd better specify an environment variable `SAT_HOME` to specify the path to store the downloaded model.
|
||||
|
||||
*Currently only Chinese input is supported.*
|
||||
The paper is still being written and will be released soon. Stay tuned!
|
118
README_zh.md
Normal file
118
README_zh.md
Normal file
@ -0,0 +1,118 @@
|
||||
# CogVideoX
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src=resources/logo.svg width="50%"/>
|
||||
<p align="center">
|
||||
|
||||
🤗 在 <a href="https://huggingface.co/spaces/THUDM/CogVideoX" target="_blank">CogVideoX Huggingface Space</a> 体验视频生成模型
|
||||
</p>
|
||||
</div>
|
||||
<p align="center">
|
||||
👋 加入我们的 <a href="resources/WECHAT.md" target="_blank">微信</a> 和 <a href="https://discord.gg/Ewaabk6s" target="_blank">Discord</a>
|
||||
</p>
|
||||
<p align="center">
|
||||
📍 前往<a href="https://chatglm.cn/video"> 清影</a> 和 <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9"> API平台</a> 体验更大规模的商业版视频生成模型。
|
||||
</p>
|
||||
|
||||
## 项目更新
|
||||
|
||||
- 🔥 **News**: ``2024/8/6``: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。
|
||||
- 🔥 **News**: ``2024/8/6``: 我们开源 CogVideoX 系列视频生成模型的第一个模型, **CogVideoX-2B**。
|
||||
|
||||
**性能更强,参数量更大的模型正在到来的路上~,欢迎关注**
|
||||
|
||||
## CogVideoX-2B 视频作品
|
||||
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/ea3af39a-3160-4999-90ec-2f7863c5b0e9" width="80%" controls autoplay></video>
|
||||
<p>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</p>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/9de41efd-d4d1-4095-aeda-246dd834e91d" width="80%" controls autoplay></video>
|
||||
<p>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</p>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/941d6661-6a8d-4a1b-b912-59606f0b2841" width="80%" controls autoplay></video>
|
||||
<p>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</p>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<video src="https://github.com/user-attachments/assets/938529c4-91ae-4f60-b96b-3c3947fa63cb" width="80%" controls autoplay></video>
|
||||
<p>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</p>
|
||||
</div>
|
||||
|
||||
## 模型介绍
|
||||
|
||||
CogVideoX是 [清影](https://chatglm.cn/video) 同源的开源版本视频生成模型。
|
||||
|
||||
下表战展示目前我们提供的视频生成模型列表,以及相关基础信息:
|
||||
|
||||
| 模型名字 | CogVideoX-2B |
|
||||
|----------------|--------------------------------------------------------------|
|
||||
| 提示词语言 | English |
|
||||
| 推理显存消耗 (FP-16) | 21.6GB |
|
||||
| 微调显存消耗 (bs=1) | 46.2GB |
|
||||
| 提示词长度上限 | 226 Tokens |
|
||||
| 视频长度 | 6 seconds |
|
||||
| 帧率(每秒) | 8 frames |
|
||||
| 视频分辨率 | 720 * 480 |
|
||||
| 量化推理 | 不支持 |
|
||||
| 多卡推理 | 不支持 |
|
||||
| 权重地址 | 🤗 [CogVideoX-2B](https://huggingface.co/THUDM/CogVideoX-2B) |
|
||||
|
||||
## 项目结构
|
||||
|
||||
本开源仓库将带领开发者快速上手 **CogVideoX** 开源模型的基础调用方式、微调示例。
|
||||
|
||||
### inference
|
||||
|
||||
+ [cli_demo](inference/cli_demo.py): 更详细的推理代码讲解,常见参数的意义,在这里都会提及。
|
||||
+ [cli_vae_demo](inference/cli_vae_demo.py): 单独执行VAE的推理代码,目前需要71GB显存,将来会优化。
|
||||
+ [convert_demo](inference/converter_demo.py): 如何将用户的输入转换成适合 CogVideoX的长输入。
|
||||
+ [web_demo](inference/web_demo.py): 一个简单的streamlit网页应用,展示如何使用 CogVideoX-2B 模型生成视频。
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img src="resources/web_demo.png" style="width: 100%%; height: auto;" />
|
||||
</div>
|
||||
|
||||
### sat
|
||||
|
||||
+ [sat_demo](sat/configs/README_zh.md): 包含了 SAT 权重的推理代码和微调代码,推荐基于 CogVideoX
|
||||
模型结构进行改进,创新的研究者使用改代码以更好的进行快速的堆叠和开发。
|
||||
|
||||
### tools
|
||||
|
||||
本文件夹包含了一些工具,用于模型的转换 / Caption 等工作。
|
||||
|
||||
+ [convert_weight_sat2hf](tools/convert_weight_sat2hf.py): 将 SAT 模型权重转换为 Huggingface 模型权重。
|
||||
+ [caption_demo](tools/caption_demo.py): Caption 工具,对视频理解并用文字输出的模型。
|
||||
|
||||
## 项目规划
|
||||
|
||||
- [x] CogVideoX 模型开源
|
||||
- [x] CogVideoX 模型推理示例 (CLI / Web Demo)
|
||||
- [x] CogVideoX 在线体验示例 (Huggingface Space)
|
||||
- [x] CogVideoX 开源模型API接口示例 (Huggingface)
|
||||
- [x] CogVideoX 模型微调示例 (SAT)
|
||||
- [ ] CogVideoX 模型微调示例 (Huggingface / SAT)
|
||||
- [ ] CogVideoX-Pro 开源(适配 CogVideoX-2B 套件)
|
||||
- [ ] CogVideoX 技术报告公开
|
||||
|
||||
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
|
||||
|
||||
## 模型协议
|
||||
|
||||
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
|
||||
|
||||
本模型权重和模型实现代码根据 [CogVideoX LICENSE](MODEL_LICENSE) 许可证发布。
|
||||
|
||||
## 引用
|
||||
|
||||
🌟 如果您发现我们的工作有所帮助,欢迎留下宝贵的stars 🌟
|
||||
|
||||
论文还在撰写中,即将发布。
|
Binary file not shown.
Binary file not shown.
Before Width: | Height: | Size: 24 MiB |
Binary file not shown.
Before Width: | Height: | Size: 6.5 MiB |
Binary file not shown.
Before Width: | Height: | Size: 9.1 MiB |
@ -1,2 +0,0 @@
|
||||
#!/bin/sh
|
||||
docker build -t cog .
|
Binary file not shown.
@ -1,101 +0,0 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@File : coglm_strategy.py
|
||||
@Time : 2021/10/08 22:22:42
|
||||
@Author : Ming Ding
|
||||
@Contact : dm18@mails.tsinghua.edu.cn
|
||||
'''
|
||||
|
||||
# here put the import lib
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# convert to 1D
|
||||
logits = logits.view(logits.size()[1]).contiguous()
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
# going back to 2D
|
||||
logits = logits.view(1, -1).contiguous()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class CoglmStrategy:
|
||||
def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None, temperature2=0.89):
|
||||
self.invalid_slices = invalid_slices
|
||||
self.temperature = temperature
|
||||
self.temperature2 = temperature2
|
||||
self.topk = top_k
|
||||
self.top_p = top_p
|
||||
self.eps = eps
|
||||
if end_tokens is None:
|
||||
end_tokens = []
|
||||
self.end_tokens = end_tokens
|
||||
self._is_done = False
|
||||
self.outlier_count_down = torch.zeros(16)
|
||||
self.vis_list = [[]for i in range(16)]
|
||||
self.cluster_labels = torch.tensor(np.load('cluster_label2.npy'), device='cuda', dtype=torch.long)
|
||||
self.start_pos = -1
|
||||
self.white_cluster = []
|
||||
# self.fout = open('tmp.txt', 'w')
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._is_done
|
||||
|
||||
def forward(self, logits, tokens, mems, temperature=None, temperature2=None):
|
||||
if temperature is None:
|
||||
temperature = self.temperature
|
||||
if temperature2 is None:
|
||||
temperature2 = self.temperature2
|
||||
logits = logits / temperature
|
||||
for invalid_slice in self.invalid_slices:
|
||||
logits[..., invalid_slice] = -65504
|
||||
|
||||
rprobs = F.softmax(logits.float(), dim=-1)
|
||||
c = self.cluster_labels.expand(*rprobs.shape)
|
||||
cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
|
||||
# self.fout.write(str(tokens.shape[-1])+ ' ' + str(cprobs.topk(10)) + '\n')
|
||||
# self.fout.flush()
|
||||
best_scores, best_clusters = cprobs.topk(self.topk)
|
||||
bz = logits.shape[0]
|
||||
for i in range(bz):
|
||||
selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
|
||||
logits[i, self.cluster_labels != selected_cluster] = -65504
|
||||
|
||||
# logits = top_k_logits(logits, self.topk, self.top_p)
|
||||
probs = F.softmax(logits.float()/temperature2, dim=-1) # float is essetial, due to a bug in Pytorch
|
||||
pred = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
if pred.numel() == 1 and pred.item() in self.end_tokens:
|
||||
self._is_done = True
|
||||
tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
|
||||
return tokens, mems
|
||||
|
||||
def finalize(self, tokens, mems):
|
||||
self._is_done = False
|
||||
return tokens, mems
|
@ -1,793 +0,0 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@File : cogvideo_pipeline.py
|
||||
@Time : 2022/07/15 11:24:56
|
||||
@Author : Wenyi Hong
|
||||
@Version : 1.0
|
||||
@Contact : hwy22@mails.tsinghua.edu.cn
|
||||
'''
|
||||
|
||||
# here put the import lib
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import time
|
||||
from torchvision.utils import save_image
|
||||
import stat
|
||||
from icetk import icetk as tokenizer
|
||||
import logging, sys
|
||||
|
||||
import torch.distributed as dist
|
||||
tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
|
||||
|
||||
|
||||
from SwissArmyTransformer import get_args
|
||||
from SwissArmyTransformer.data_utils import BinaryDataset, make_loaders
|
||||
from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
|
||||
from SwissArmyTransformer.generation.utils import timed_name, save_multiple_images, generate_continually
|
||||
from SwissArmyTransformer.resources import auto_create
|
||||
|
||||
from models.cogvideo_cache_model import CogVideoCacheModel
|
||||
from coglm_strategy import CoglmStrategy
|
||||
|
||||
|
||||
def get_masks_and_position_ids_stage1(data, textlen, framelen):
|
||||
# Extract batch size and sequence length.
|
||||
tokens = data
|
||||
seq_length = len(data[0])
|
||||
# Attention mask (lower triangular).
|
||||
attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
|
||||
attention_mask[:, :textlen, textlen:] = 0
|
||||
attention_mask[:, textlen:, textlen:].tril_()
|
||||
attention_mask.unsqueeze_(1)
|
||||
# Unaligned version
|
||||
position_ids = torch.zeros(seq_length, dtype=torch.long,
|
||||
device=data.device)
|
||||
torch.arange(textlen, out=position_ids[:textlen],
|
||||
dtype=torch.long, device=data.device)
|
||||
torch.arange(512, 512+seq_length-textlen, out=position_ids[textlen:],
|
||||
dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
def get_masks_and_position_ids_stage2(data, textlen, framelen):
|
||||
# Extract batch size and sequence length.
|
||||
tokens = data
|
||||
seq_length = len(data[0])
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
attention_mask = torch.ones((1, textlen+framelen, textlen+framelen), device=data.device)
|
||||
attention_mask[:, :textlen, textlen:] = 0
|
||||
attention_mask[:, textlen:, textlen:].tril_()
|
||||
attention_mask.unsqueeze_(1)
|
||||
|
||||
# Unaligned version
|
||||
position_ids = torch.zeros(seq_length, dtype=torch.long,
|
||||
device=data.device)
|
||||
torch.arange(textlen, out=position_ids[:textlen],
|
||||
dtype=torch.long, device=data.device)
|
||||
frame_num = (seq_length-textlen)//framelen
|
||||
assert frame_num == 5
|
||||
torch.arange(512, 512+framelen, out=position_ids[textlen:textlen+framelen],
|
||||
dtype=torch.long, device=data.device)
|
||||
torch.arange(512+framelen*2, 512+framelen*3, out=position_ids[textlen+framelen:textlen+framelen*2],
|
||||
dtype=torch.long, device=data.device)
|
||||
torch.arange(512+framelen*(frame_num-1), 512+framelen*frame_num, out=position_ids[textlen+framelen*2:textlen+framelen*3],
|
||||
dtype=torch.long, device=data.device)
|
||||
torch.arange(512+framelen*1, 512+framelen*2, out=position_ids[textlen+framelen*3:textlen+framelen*4],
|
||||
dtype=torch.long, device=data.device)
|
||||
torch.arange(512+framelen*3, 512+framelen*4, out=position_ids[textlen+framelen*4:textlen+framelen*5],
|
||||
dtype=torch.long, device=data.device)
|
||||
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
def my_update_mems(hiddens, mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len):
|
||||
if hiddens is None:
|
||||
return None, mems_indexs
|
||||
mem_num = len(hiddens)
|
||||
ret_mem = []
|
||||
with torch.no_grad():
|
||||
for id in range(mem_num):
|
||||
if hiddens[id][0] is None:
|
||||
ret_mem.append(None)
|
||||
else:
|
||||
if id == 0 and limited_spatial_channel_mem and mems_indexs[id]+hiddens[0][0].shape[1] >= text_len+frame_len:
|
||||
if mems_indexs[id] == 0:
|
||||
for layer, hidden in enumerate(hiddens[id]):
|
||||
mems_buffers[id][layer, :, :text_len] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, :text_len]
|
||||
new_mem_len_part2 = (mems_indexs[id]+hiddens[0][0].shape[1]-text_len)%frame_len
|
||||
if new_mem_len_part2 > 0:
|
||||
for layer, hidden in enumerate(hiddens[id]):
|
||||
mems_buffers[id][layer, :, text_len:text_len+new_mem_len_part2] = hidden.expand(mems_buffers[id].shape[1], -1, -1)[:, -new_mem_len_part2:]
|
||||
mems_indexs[id] = text_len+new_mem_len_part2
|
||||
else:
|
||||
for layer, hidden in enumerate(hiddens[id]):
|
||||
mems_buffers[id][layer, :, mems_indexs[id]:mems_indexs[id]+hidden.shape[1]] = hidden.expand(mems_buffers[id].shape[1], -1, -1)
|
||||
mems_indexs[id] += hidden.shape[1]
|
||||
ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
|
||||
return ret_mem, mems_indexs
|
||||
|
||||
|
||||
def my_save_multiple_images(imgs, path, subdir, debug=True):
|
||||
# imgs: list of tensor images
|
||||
if debug:
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
print("\nSave to: ", path, flush=True)
|
||||
save_image(imgs, path, normalize=True)
|
||||
else:
|
||||
print("\nSave to: ", path, flush=True)
|
||||
single_frame_path = os.path.join(path, subdir)
|
||||
os.makedirs(single_frame_path, exist_ok=True)
|
||||
for i in range(len(imgs)):
|
||||
save_image(imgs[i], os.path.join(single_frame_path, f'{str(i).rjust(4,"0")}.jpg'), normalize=True)
|
||||
os.chmod(os.path.join(single_frame_path,f'{str(i).rjust(4,"0")}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
|
||||
save_image(torch.cat(imgs, dim=0), os.path.join(single_frame_path,f'frame_concat.jpg'), normalize=True)
|
||||
os.chmod(os.path.join(single_frame_path,f'frame_concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
|
||||
|
||||
def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
|
||||
# The fisrt token's position id of the frame that the next token belongs to;
|
||||
if total_len < text_len:
|
||||
return None
|
||||
return (total_len-text_len)//frame_len * frame_len + text_len
|
||||
|
||||
def my_filling_sequence(
|
||||
model,
|
||||
args,
|
||||
seq,
|
||||
batch_size,
|
||||
get_masks_and_position_ids,
|
||||
text_len,
|
||||
frame_len,
|
||||
strategy=BaseStrategy(),
|
||||
strategy2=BaseStrategy(),
|
||||
mems=None,
|
||||
log_text_attention_weights=0, # default to 0: no artificial change
|
||||
mode_stage1=True,
|
||||
enforce_no_swin=False,
|
||||
guider_seq=None,
|
||||
guider_text_len=0,
|
||||
guidance_alpha=1,
|
||||
limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
|
||||
**kw_args
|
||||
):
|
||||
'''
|
||||
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
|
||||
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
|
||||
cache, should be first mems.shape[1] parts of context_tokens.
|
||||
mems are the first-level citizens here, but we don't assume what is memorized.
|
||||
input mems are used when multi-phase generation.
|
||||
'''
|
||||
if guider_seq is not None:
|
||||
logging.debug("Using Guidance In Inference")
|
||||
if limited_spatial_channel_mem:
|
||||
logging.debug("Limit spatial-channel's mem to current frame")
|
||||
assert len(seq.shape) == 2
|
||||
|
||||
# building the initial tokens, attention_mask, and position_ids
|
||||
actual_context_length = 0
|
||||
|
||||
while seq[-1][actual_context_length] >= 0: # the last seq has least given tokens
|
||||
actual_context_length += 1 # [0, context_length-1] are given
|
||||
assert actual_context_length > 0
|
||||
current_frame_num = (actual_context_length-text_len) // frame_len
|
||||
assert current_frame_num >= 0
|
||||
context_length = text_len + current_frame_num * frame_len
|
||||
|
||||
tokens, attention_mask, position_ids = get_masks_and_position_ids(seq, text_len, frame_len)
|
||||
tokens = tokens[..., :context_length]
|
||||
input_tokens = tokens.clone()
|
||||
|
||||
if guider_seq is not None:
|
||||
guider_index_delta = text_len - guider_text_len
|
||||
guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(guider_seq, guider_text_len, frame_len)
|
||||
guider_tokens = guider_tokens[..., :context_length-guider_index_delta]
|
||||
guider_input_tokens = guider_tokens.clone()
|
||||
|
||||
for fid in range(current_frame_num):
|
||||
input_tokens[:, text_len+400*fid] = tokenizer['<start_of_image>']
|
||||
if guider_seq is not None:
|
||||
guider_input_tokens[:, guider_text_len+400*fid] = tokenizer['<start_of_image>']
|
||||
|
||||
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
|
||||
# initialize generation
|
||||
counter = context_length - 1 # Last fixed index is ``counter''
|
||||
index = 0 # Next forward starting index, also the length of cache.
|
||||
mems_buffers_on_GPU = False
|
||||
mems_indexs = [0, 0]
|
||||
mems_len = [(400+74) if limited_spatial_channel_mem else 5*400+74, 5*400+74]
|
||||
mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
|
||||
for mem_len in mems_len]
|
||||
|
||||
|
||||
if guider_seq is not None:
|
||||
guider_attention_mask = guider_attention_mask.type_as(next(model.parameters())) # if fp16
|
||||
guider_mems_buffers = [torch.zeros(args.num_layers, batch_size, mem_len, args.hidden_size*2, dtype=next(model.parameters()).dtype)
|
||||
for mem_len in mems_len]
|
||||
guider_mems_indexs = [0, 0]
|
||||
guider_mems = None
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
# step-by-step generation
|
||||
while counter < len(seq[0]) - 1:
|
||||
# we have generated counter+1 tokens
|
||||
# Now, we want to generate seq[counter + 1],
|
||||
# token[:, index: counter+1] needs forwarding.
|
||||
if index == 0:
|
||||
group_size = 2 if (input_tokens.shape[0] == batch_size and not mode_stage1) else batch_size
|
||||
|
||||
logits_all = None
|
||||
for batch_idx in range(0, input_tokens.shape[0], group_size):
|
||||
logits, *output_per_layers = model(
|
||||
input_tokens[batch_idx:batch_idx+group_size, index:],
|
||||
position_ids[..., index: counter+1],
|
||||
attention_mask, # TODO memlen
|
||||
mems=mems,
|
||||
text_len=text_len,
|
||||
frame_len=frame_len,
|
||||
counter=counter,
|
||||
log_text_attention_weights=log_text_attention_weights,
|
||||
enforce_no_swin=enforce_no_swin,
|
||||
**kw_args
|
||||
)
|
||||
logits_all = torch.cat((logits_all, logits), dim=0) if logits_all is not None else logits
|
||||
mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]]
|
||||
next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(text_len, frame_len, mem_kv01[0][0].shape[1])
|
||||
for id, mem_kv in enumerate(mem_kv01):
|
||||
for layer, mem_kv_perlayer in enumerate(mem_kv):
|
||||
if limited_spatial_channel_mem and id == 0:
|
||||
mems_buffers[id][layer, batch_idx:batch_idx+group_size, :text_len] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :text_len]
|
||||
mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
|
||||
mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
|
||||
else:
|
||||
mems_buffers[id][layer, batch_idx:batch_idx+group_size, :mem_kv_perlayer.shape[1]] = mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
|
||||
mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[1], mem_kv01[1][0].shape[1]
|
||||
if limited_spatial_channel_mem:
|
||||
mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
|
||||
|
||||
mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
|
||||
logits = logits_all
|
||||
|
||||
# Guider
|
||||
if guider_seq is not None:
|
||||
guider_logits_all = None
|
||||
for batch_idx in range(0, guider_input_tokens.shape[0], group_size):
|
||||
guider_logits, *guider_output_per_layers = model(
|
||||
guider_input_tokens[batch_idx:batch_idx+group_size, max(index-guider_index_delta, 0):],
|
||||
guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
|
||||
guider_attention_mask,
|
||||
mems=guider_mems,
|
||||
text_len=guider_text_len,
|
||||
frame_len=frame_len,
|
||||
counter=counter-guider_index_delta,
|
||||
log_text_attention_weights=log_text_attention_weights,
|
||||
enforce_no_swin=enforce_no_swin,
|
||||
**kw_args
|
||||
)
|
||||
guider_logits_all = torch.cat((guider_logits_all, guider_logits), dim=0) if guider_logits_all is not None else guider_logits
|
||||
guider_mem_kv01 = [[o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]]
|
||||
for id, guider_mem_kv in enumerate(guider_mem_kv01):
|
||||
for layer, guider_mem_kv_perlayer in enumerate(guider_mem_kv):
|
||||
if limited_spatial_channel_mem and id == 0:
|
||||
guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_text_len] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, :guider_text_len]
|
||||
guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(guider_text_len, frame_len, guider_mem_kv_perlayer.shape[1])
|
||||
guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
|
||||
guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
|
||||
else:
|
||||
guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, :guider_mem_kv_perlayer.shape[1]] = guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)
|
||||
guider_mems_indexs[0], guider_mems_indexs[1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[1][0].shape[1]
|
||||
if limited_spatial_channel_mem:
|
||||
guider_mems_indexs[0] -= (guider_next_tokens_frame_begin_id-guider_text_len)
|
||||
guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
|
||||
guider_logits = guider_logits_all
|
||||
else:
|
||||
if not mems_buffers_on_GPU:
|
||||
if not mode_stage1:
|
||||
torch.cuda.empty_cache()
|
||||
for idx, mem in enumerate(mems):
|
||||
mems[idx] = mem.to(next(model.parameters()).device)
|
||||
if guider_seq is not None:
|
||||
for idx, mem in enumerate(guider_mems):
|
||||
guider_mems[idx] = mem.to(next(model.parameters()).device)
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
for idx, mem_buffer in enumerate(mems_buffers):
|
||||
mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
|
||||
mems = [mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)]
|
||||
if guider_seq is not None:
|
||||
for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
|
||||
guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
|
||||
guider_mems = [guider_mems_buffers[id][:, :, :guider_mems_indexs[id]] for id in range(2)]
|
||||
mems_buffers_on_GPU = True
|
||||
|
||||
logits, *output_per_layers = model(
|
||||
input_tokens[:, index:],
|
||||
position_ids[..., index: counter+1],
|
||||
attention_mask, # TODO memlen
|
||||
mems=mems,
|
||||
text_len=text_len,
|
||||
frame_len=frame_len,
|
||||
counter=counter,
|
||||
log_text_attention_weights=log_text_attention_weights,
|
||||
enforce_no_swin=enforce_no_swin,
|
||||
limited_spatial_channel_mem=limited_spatial_channel_mem,
|
||||
**kw_args
|
||||
)
|
||||
mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers], [o['mem_kv'][1] for o in output_per_layers]
|
||||
|
||||
if guider_seq is not None:
|
||||
guider_logits, *guider_output_per_layers = model(
|
||||
guider_input_tokens[:, max(index-guider_index_delta, 0):],
|
||||
guider_position_ids[..., max(index-guider_index_delta, 0): counter+1-guider_index_delta],
|
||||
guider_attention_mask,
|
||||
mems=guider_mems,
|
||||
text_len=guider_text_len,
|
||||
frame_len=frame_len,
|
||||
counter=counter-guider_index_delta,
|
||||
log_text_attention_weights=0,
|
||||
enforce_no_swin=enforce_no_swin,
|
||||
limited_spatial_channel_mem=limited_spatial_channel_mem,
|
||||
**kw_args
|
||||
)
|
||||
guider_mem_kv0, guider_mem_kv1 = [o['mem_kv'][0] for o in guider_output_per_layers], [o['mem_kv'][1] for o in guider_output_per_layers]
|
||||
|
||||
if not mems_buffers_on_GPU:
|
||||
torch.cuda.empty_cache()
|
||||
for idx, mem_buffer in enumerate(mems_buffers):
|
||||
mems_buffers[idx] = mem_buffer.to(next(model.parameters()).device)
|
||||
if guider_seq is not None:
|
||||
for idx, guider_mem_buffer in enumerate(guider_mems_buffers):
|
||||
guider_mems_buffers[idx] = guider_mem_buffer.to(next(model.parameters()).device)
|
||||
mems_buffers_on_GPU = True
|
||||
|
||||
mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1], mems_buffers, mems_indexs, limited_spatial_channel_mem, text_len, frame_len)
|
||||
if guider_seq is not None:
|
||||
guider_mems, guider_mems_indexs = my_update_mems([guider_mem_kv0, guider_mem_kv1], guider_mems_buffers, guider_mems_indexs, limited_spatial_channel_mem, guider_text_len, frame_len)
|
||||
|
||||
|
||||
counter += 1
|
||||
index = counter
|
||||
|
||||
logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
|
||||
tokens = tokens.expand(batch_size, -1)
|
||||
if guider_seq is not None:
|
||||
guider_logits = guider_logits[:, -1].expand(batch_size, -1)
|
||||
guider_tokens = guider_tokens.expand(batch_size, -1)
|
||||
|
||||
if seq[-1][counter].item() < 0:
|
||||
# sampling
|
||||
guided_logits = guider_logits+(logits-guider_logits)*guidance_alpha if guider_seq is not None else logits
|
||||
if mode_stage1 and counter < text_len + 400:
|
||||
tokens, mems = strategy.forward(guided_logits, tokens, mems)
|
||||
else:
|
||||
tokens, mems = strategy2.forward(guided_logits, tokens, mems)
|
||||
if guider_seq is not None:
|
||||
guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]), dim=1)
|
||||
|
||||
if seq[0][counter].item() >= 0:
|
||||
for si in range(seq.shape[0]):
|
||||
if seq[si][counter].item() >= 0:
|
||||
tokens[si, -1] = seq[si, counter]
|
||||
if guider_seq is not None:
|
||||
guider_tokens[si, -1] = guider_seq[si, counter-guider_index_delta]
|
||||
|
||||
else:
|
||||
tokens = torch.cat((tokens, seq[:, counter:counter+1].clone().expand(tokens.shape[0], 1).to(device=tokens.device, dtype=tokens.dtype)), dim=1)
|
||||
if guider_seq is not None:
|
||||
guider_tokens = torch.cat((guider_tokens,
|
||||
guider_seq[:, counter-guider_index_delta:counter+1-guider_index_delta]
|
||||
.clone().expand(guider_tokens.shape[0], 1).to(device=guider_tokens.device, dtype=guider_tokens.dtype)), dim=1)
|
||||
|
||||
input_tokens = tokens.clone()
|
||||
if guider_seq is not None:
|
||||
guider_input_tokens = guider_tokens.clone()
|
||||
if (index-text_len-1)//400 < (input_tokens.shape[-1]-text_len-1)//400:
|
||||
boi_idx = ((index-text_len-1)//400 +1)*400+text_len
|
||||
while boi_idx < input_tokens.shape[-1]:
|
||||
input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
|
||||
if guider_seq is not None:
|
||||
guider_input_tokens[:, boi_idx-guider_index_delta] = tokenizer['<start_of_image>']
|
||||
boi_idx += 400
|
||||
|
||||
if strategy.is_done:
|
||||
break
|
||||
return strategy.finalize(tokens, mems)
|
||||
|
||||
class InferenceModel_Sequential(CogVideoCacheModel):
|
||||
def __init__(self, args, transformer=None, parallel_output=True):
|
||||
super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=-1, cogvideo_stage=1)
|
||||
# TODO: check it
|
||||
|
||||
def final_forward(self, logits, **kwargs):
|
||||
logits_parallel = logits
|
||||
logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
|
||||
return logits_parallel
|
||||
|
||||
class InferenceModel_Interpolate(CogVideoCacheModel):
|
||||
def __init__(self, args, transformer=None, parallel_output=True):
|
||||
super().__init__(args, transformer=transformer, parallel_output=parallel_output, window_size=10, cogvideo_stage=2)
|
||||
# TODO: check it
|
||||
|
||||
def final_forward(self, logits, **kwargs):
|
||||
logits_parallel = logits
|
||||
logits_parallel = torch.nn.functional.linear(logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float())
|
||||
return logits_parallel
|
||||
|
||||
def main(args):
|
||||
assert int(args.stage_1) + int(args.stage_2) + int(args.both_stages) == 1
|
||||
rank_id = args.device % args.parallel_size
|
||||
generate_frame_num = args.generate_frame_num
|
||||
|
||||
if args.stage_1 or args.both_stages:
|
||||
model_stage1, args = InferenceModel_Sequential.from_pretrained(args, 'cogvideo-stage1')
|
||||
model_stage1.eval()
|
||||
if args.both_stages:
|
||||
model_stage1 = model_stage1.cpu()
|
||||
|
||||
if args.stage_2 or args.both_stages:
|
||||
model_stage2, args = InferenceModel_Interpolate.from_pretrained(args, 'cogvideo-stage2')
|
||||
model_stage2.eval()
|
||||
if args.both_stages:
|
||||
model_stage2 = model_stage2.cpu()
|
||||
|
||||
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
||||
strategy_cogview2 = CoglmStrategy(invalid_slices,
|
||||
temperature=1.0, top_k=16)
|
||||
strategy_cogvideo = CoglmStrategy(invalid_slices,
|
||||
temperature=args.temperature, top_k=args.top_k,
|
||||
temperature2=args.coglm_temperature2)
|
||||
if not args.stage_1:
|
||||
from sr_pipeline import DirectSuperResolution
|
||||
dsr_path = auto_create('cogview2-dsr', path=None) # path=os.getenv('SAT_HOME', '~/.sat_models')
|
||||
dsr = DirectSuperResolution(args, dsr_path,
|
||||
max_bz=12, onCUDA=False)
|
||||
|
||||
def process_stage2(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", parent_given_tokens=None, conddir=None, outputdir=None, gpu_rank=0, gpu_parallel_size=1):
|
||||
stage2_starttime = time.time()
|
||||
use_guidance = args.use_guidance_stage2
|
||||
if args.both_stages:
|
||||
move_start_time = time.time()
|
||||
logging.debug("moving stage-2 model to cuda")
|
||||
model = model.cuda()
|
||||
logging.debug("moving in stage-2 model takes time: {:.2f}".format(time.time()-move_start_time))
|
||||
|
||||
try:
|
||||
if parent_given_tokens is None:
|
||||
assert conddir is not None
|
||||
parent_given_tokens = torch.load(os.path.join(conddir, 'frame_tokens.pt'), map_location='cpu')
|
||||
sample_num_allgpu = parent_given_tokens.shape[0]
|
||||
sample_num = sample_num_allgpu // gpu_parallel_size
|
||||
assert sample_num * gpu_parallel_size == sample_num_allgpu
|
||||
parent_given_tokens = parent_given_tokens[gpu_rank*sample_num:(gpu_rank+1)*sample_num]
|
||||
except:
|
||||
logging.critical("No frame_tokens found in interpolation, skip")
|
||||
return False
|
||||
|
||||
# CogVideo Stage2 Generation
|
||||
while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
|
||||
parent_given_tokens_num = parent_given_tokens.shape[1]
|
||||
generate_batchsize_persample = (parent_given_tokens_num-1)//2
|
||||
generate_batchsize_total = generate_batchsize_persample * sample_num
|
||||
total_frames = generate_frame_num
|
||||
frame_len = 400
|
||||
enc_text = tokenizer.encode(seq_text)
|
||||
enc_duration = tokenizer.encode(str(float(duration))+"秒")
|
||||
seq = enc_duration + [tokenizer['<n>']] + enc_text + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
|
||||
text_len = len(seq) - frame_len*generate_frame_num - 1
|
||||
|
||||
logging.info("[Stage2: Generating Frames, Frame Rate {:d}]\nraw text: {:s}".format(int(4/duration), tokenizer.decode(enc_text)))
|
||||
|
||||
# generation
|
||||
seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
|
||||
for sample_i in range(sample_num):
|
||||
for i in range(generate_batchsize_persample):
|
||||
seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
|
||||
seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
|
||||
seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
|
||||
|
||||
if use_guidance:
|
||||
guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
|
||||
guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
|
||||
guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(generate_batchsize_total, 1)
|
||||
for sample_i in range(sample_num):
|
||||
for i in range(generate_batchsize_persample):
|
||||
guider_seq[sample_i*generate_batchsize_persample+i][text_len+1:text_len+1+400] = parent_given_tokens[sample_i][2*i]
|
||||
guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+400:text_len+1+800] = parent_given_tokens[sample_i][2*i+1]
|
||||
guider_seq[sample_i*generate_batchsize_persample+i][text_len+1+800:text_len+1+1200] = parent_given_tokens[sample_i][2*i+2]
|
||||
video_log_text_attention_weights = 0
|
||||
else:
|
||||
guider_seq=None
|
||||
guider_text_len=0
|
||||
video_log_text_attention_weights = 1.4
|
||||
|
||||
mbz = args.max_inference_batch_size
|
||||
|
||||
assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
|
||||
output_list = []
|
||||
start_time = time.time()
|
||||
for tim in range(max(generate_batchsize_total // mbz, 1)):
|
||||
input_seq = seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
|
||||
guider_seq2 = (guider_seq[:min(generate_batchsize_total, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
|
||||
output_list.append(
|
||||
my_filling_sequence(model, args, input_seq,
|
||||
batch_size=min(generate_batchsize_total, mbz),
|
||||
get_masks_and_position_ids=get_masks_and_position_ids_stage2,
|
||||
text_len=text_len, frame_len=frame_len,
|
||||
strategy=strategy_cogview2,
|
||||
strategy2=strategy_cogvideo,
|
||||
log_text_attention_weights=video_log_text_attention_weights,
|
||||
mode_stage1=False,
|
||||
guider_seq=guider_seq2,
|
||||
guider_text_len=guider_text_len,
|
||||
guidance_alpha=args.guidance_alpha,
|
||||
limited_spatial_channel_mem=True,
|
||||
)[0]
|
||||
)
|
||||
logging.info("Duration {:.2f}, Taken time {:.2f}\n".format(duration, time.time() - start_time))
|
||||
|
||||
output_tokens = torch.cat(output_list, dim=0)
|
||||
output_tokens = output_tokens[:, text_len+1:text_len+1+(total_frames)*400].reshape(sample_num, -1, 400*total_frames)
|
||||
output_tokens_merge = torch.cat((output_tokens[:, :, :1*400],
|
||||
output_tokens[:, :, 400*3:4*400],
|
||||
output_tokens[:, :, 400*1:2*400],
|
||||
output_tokens[:, :, 400*4:(total_frames)*400]), dim=2).reshape(sample_num, -1, 400)
|
||||
|
||||
output_tokens_merge = torch.cat((output_tokens_merge, output_tokens[:, -1:, 400*2:3*400]), dim=1)
|
||||
duration /= 2
|
||||
parent_given_tokens = output_tokens_merge
|
||||
|
||||
if args.both_stages:
|
||||
move_start_time = time.time()
|
||||
logging.debug("moving stage 2 model to cpu")
|
||||
model = model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
logging.debug("moving out model2 takes time: {:.2f}".format(time.time()-move_start_time))
|
||||
|
||||
logging.info("CogVideo Stage2 completed. Taken time {:.2f}\n".format(time.time() - stage2_starttime))
|
||||
|
||||
# decoding
|
||||
# imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()), size=(480, 480)) for seq in output_tokens_merge]
|
||||
# os.makedirs(output_dir_full_path, exist_ok=True)
|
||||
# my_save_multiple_images(imgs, output_dir_full_path,subdir="frames", debug=False)
|
||||
# torch.save(output_tokens_merge.cpu(), os.path.join(output_dir_full_path, 'frame_token.pt'))
|
||||
# os.system(f"gifmaker -i '{output_dir_full_path}'/frames/0*.jpg -o '{output_dir_full_path}/{str(float(duration))}_concat.gif' -d 0.2")
|
||||
|
||||
# direct super-resolution by CogView2
|
||||
logging.info("[Direct super-resolution]")
|
||||
dsr_starttime = time.time()
|
||||
enc_text = tokenizer.encode(seq_text)
|
||||
frame_num_per_sample = parent_given_tokens.shape[1]
|
||||
parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
|
||||
text_seq = torch.cuda.LongTensor(enc_text, device=args.device).unsqueeze(0).repeat(parent_given_tokens_2d.shape[0], 1)
|
||||
sred_tokens = dsr(text_seq, parent_given_tokens_2d)
|
||||
decoded_sr_videos = []
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
decoded_sr_imgs = []
|
||||
for frame_i in range(frame_num_per_sample):
|
||||
decoded_sr_img = tokenizer.decode(image_ids=sred_tokens[frame_i+sample_i*frame_num_per_sample][-3600:])
|
||||
decoded_sr_imgs.append(torch.nn.functional.interpolate(decoded_sr_img, size=(480, 480)))
|
||||
decoded_sr_videos.append(decoded_sr_imgs)
|
||||
|
||||
for sample_i in range(sample_num):
|
||||
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
|
||||
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125")
|
||||
|
||||
logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def process_stage1(model, seq_text, duration, video_raw_text=None, video_guidance_text="视频", image_text_suffix="", outputdir=None, batch_size=1):
|
||||
process_start_time = time.time()
|
||||
use_guide = args.use_guidance_stage1
|
||||
if args.both_stages:
|
||||
move_start_time = time.time()
|
||||
logging.debug("moving stage 1 model to cuda")
|
||||
model = model.cuda()
|
||||
logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
|
||||
|
||||
if video_raw_text is None:
|
||||
video_raw_text = seq_text
|
||||
mbz = args.stage1_max_inference_batch_size if args.stage1_max_inference_batch_size > 0 else args.max_inference_batch_size
|
||||
assert batch_size < mbz or batch_size % mbz == 0
|
||||
frame_len = 400
|
||||
|
||||
# generate the first frame:
|
||||
enc_text = tokenizer.encode(seq_text+image_text_suffix)
|
||||
seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1]*400 # IV!! # test local!!! # test randboi!!!
|
||||
logging.info("[Generating First Frame with CogView2]Raw text: {:s}".format(tokenizer.decode(enc_text)))
|
||||
text_len_1st = len(seq_1st) - frame_len*1 - 1
|
||||
|
||||
seq_1st = torch.cuda.LongTensor(seq_1st, device=args.device).unsqueeze(0)
|
||||
output_list_1st = []
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
output_list_1st.append(
|
||||
my_filling_sequence(model, args,seq_1st.clone(),
|
||||
batch_size=min(batch_size, mbz),
|
||||
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
||||
text_len=text_len_1st,
|
||||
frame_len=frame_len,
|
||||
strategy=strategy_cogview2,
|
||||
strategy2=strategy_cogvideo,
|
||||
log_text_attention_weights=1.4,
|
||||
enforce_no_swin=True,
|
||||
mode_stage1=True,
|
||||
)[0]
|
||||
)
|
||||
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
|
||||
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
||||
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
|
||||
|
||||
# generate subsequent frames:
|
||||
total_frames = generate_frame_num
|
||||
enc_duration = tokenizer.encode(str(float(duration))+"秒")
|
||||
if use_guide:
|
||||
video_raw_text = video_raw_text + " 视频"
|
||||
enc_text_video = tokenizer.encode(video_raw_text)
|
||||
seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
|
||||
guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(video_guidance_text) + [tokenizer['<start_of_image>']] + [-1]*400*generate_frame_num
|
||||
logging.info("[Stage1: Generating Subsequent Frames, Frame Rate {:.1f}]\nraw text: {:s}".format(4/duration, tokenizer.decode(enc_text_video)))
|
||||
|
||||
text_len = len(seq) - frame_len*generate_frame_num - 1
|
||||
guider_text_len = len(guider_seq) - frame_len*generate_frame_num - 1
|
||||
seq = torch.cuda.LongTensor(seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
|
||||
guider_seq = torch.cuda.LongTensor(guider_seq, device=args.device).unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
for given_frame_id in range(given_tokens.shape[1]):
|
||||
seq[:, text_len+1+given_frame_id*400: text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
|
||||
guider_seq[:, guider_text_len+1+given_frame_id*400:guider_text_len+1+(given_frame_id+1)*400] = given_tokens[:, given_frame_id]
|
||||
output_list = []
|
||||
|
||||
if use_guide:
|
||||
video_log_text_attention_weights = 0
|
||||
else:
|
||||
guider_seq = None
|
||||
video_log_text_attention_weights = 1.4
|
||||
|
||||
for tim in range(max(batch_size // mbz, 1)):
|
||||
start_time = time.time()
|
||||
input_seq = seq[:min(batch_size, mbz)].clone() if tim == 0 else seq[mbz*tim:mbz*(tim+1)].clone()
|
||||
guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone() if tim == 0 else guider_seq[mbz*tim:mbz*(tim+1)].clone()) if guider_seq is not None else None
|
||||
output_list.append(
|
||||
my_filling_sequence(model, args,input_seq,
|
||||
batch_size=min(batch_size, mbz),
|
||||
get_masks_and_position_ids=get_masks_and_position_ids_stage1,
|
||||
text_len=text_len, frame_len=frame_len,
|
||||
strategy=strategy_cogview2,
|
||||
strategy2=strategy_cogvideo,
|
||||
log_text_attention_weights=video_log_text_attention_weights,
|
||||
guider_seq=guider_seq2,
|
||||
guider_text_len=guider_text_len,
|
||||
guidance_alpha=args.guidance_alpha,
|
||||
limited_spatial_channel_mem=True,
|
||||
mode_stage1=True,
|
||||
)[0]
|
||||
)
|
||||
|
||||
output_tokens = torch.cat(output_list, dim=0)[:, 1+text_len:]
|
||||
|
||||
if args.both_stages:
|
||||
move_start_time = time.time()
|
||||
logging.debug("moving stage 1 model to cpu")
|
||||
model = model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
logging.debug("moving in model1 takes time: {:.2f}".format(time.time()-move_start_time))
|
||||
|
||||
# decoding
|
||||
imgs, sred_imgs, txts = [], [], []
|
||||
for seq in output_tokens:
|
||||
decoded_imgs = [torch.nn.functional.interpolate(tokenizer.decode(image_ids=seq.tolist()[i*400: (i+1)*400]), size=(480, 480)) for i in range(total_frames)]
|
||||
imgs.append(decoded_imgs) # only the last image (target)
|
||||
|
||||
assert len(imgs) == batch_size
|
||||
save_tokens = output_tokens[:, :+total_frames*400].reshape(-1, total_frames, 400).cpu()
|
||||
if outputdir is not None:
|
||||
for clip_i in range(len(imgs)):
|
||||
# os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
|
||||
my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False)
|
||||
os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25")
|
||||
torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt'))
|
||||
|
||||
logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time))
|
||||
|
||||
return save_tokens
|
||||
|
||||
# ======================================================================================================
|
||||
|
||||
if args.stage_1 or args.both_stages:
|
||||
if args.input_source != "interactive":
|
||||
with open(args.input_source, 'r') as fin:
|
||||
promptlist = fin.readlines()
|
||||
promptlist = [p.strip() for p in promptlist]
|
||||
else:
|
||||
promptlist = None
|
||||
|
||||
now_qi = -1
|
||||
while True:
|
||||
now_qi += 1
|
||||
|
||||
if promptlist is not None: # with input-source
|
||||
if args.multi_gpu:
|
||||
if now_qi % dist.get_world_size() != dist.get_rank():
|
||||
continue
|
||||
rk = dist.get_rank()
|
||||
else:
|
||||
rk = 0
|
||||
raw_text = promptlist[now_qi]
|
||||
raw_text = raw_text.strip()
|
||||
print(f'Working on Line No. {now_qi} on {rk}... [{raw_text}]')
|
||||
else: # interactive
|
||||
raw_text = input("\nPlease Input Query (stop to exit) >>> ")
|
||||
raw_text = raw_text.strip()
|
||||
if not raw_text:
|
||||
print('Query should not be empty!')
|
||||
continue
|
||||
if raw_text == "stop":
|
||||
return
|
||||
|
||||
try:
|
||||
path = os.path.join(args.output_path, f"{now_qi}_{raw_text}")
|
||||
parent_given_tokens = process_stage1(model_stage1, raw_text, duration=4.0, video_raw_text=raw_text, video_guidance_text="视频",
|
||||
image_text_suffix=" 高清摄影",
|
||||
outputdir=path if args.stage_1 else None, batch_size=args.batch_size)
|
||||
if args.both_stages:
|
||||
process_stage2(model_stage2, raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
|
||||
video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
|
||||
outputdir=path,
|
||||
gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
elif args.stage_2:
|
||||
sample_dirs = os.listdir(args.output_path)
|
||||
for sample in sample_dirs:
|
||||
raw_text = sample.split('_')[-1]
|
||||
path = os.path.join(args.output_path, sample, 'Interp')
|
||||
parent_given_tokens = torch.load(os.path.join(args.output_path, sample, "frame_tokens.pt"))
|
||||
|
||||
process_stage2(raw_text, duration=2.0, video_raw_text=raw_text+" 视频",
|
||||
video_guidance_text="视频", parent_given_tokens=parent_given_tokens,
|
||||
outputdir=path,
|
||||
gpu_rank=0, gpu_parallel_size=1) # TODO: 修改
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
|
||||
|
||||
py_parser = argparse.ArgumentParser(add_help=False)
|
||||
py_parser.add_argument('--generate-frame-num', type=int, default=5)
|
||||
py_parser.add_argument('--coglm-temperature2', type=float, default=0.89)
|
||||
# py_parser.add_argument("--interp-duration", type=float, default=-1) # -1是顺序生成,0是超分,0.5/1/2是插帧
|
||||
# py_parser.add_argument("--total-duration", type=float, default=4.0) # 整个的时间
|
||||
py_parser.add_argument('--use-guidance-stage1', action='store_true')
|
||||
py_parser.add_argument('--use-guidance-stage2', action='store_true')
|
||||
py_parser.add_argument('--guidance-alpha', type=float, default=3.0)
|
||||
py_parser.add_argument('--stage-1', action='store_true') # stage 1: sequential generation
|
||||
py_parser.add_argument('--stage-2', action='store_true') # stage 2: interp + dsr
|
||||
py_parser.add_argument('--both-stages', action='store_true') # stage 1&2: sequential generation; interp + dsr
|
||||
py_parser.add_argument('--parallel-size', type=int, default=1)
|
||||
py_parser.add_argument('--stage1-max-inference-batch-size', type=int, default=-1) # -1: use max-inference-batch-size
|
||||
py_parser.add_argument('--multi-gpu', action='store_true')
|
||||
|
||||
CogVideoCacheModel.add_model_specific_args(py_parser)
|
||||
|
||||
known, args_list = py_parser.parse_known_args()
|
||||
args = get_args(args_list)
|
||||
args = argparse.Namespace(**vars(args), **vars(known))
|
||||
args.layout = [int(x) for x in args.layout.split(',')]
|
||||
args.do_train = False
|
||||
|
||||
torch.cuda.set_device(args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
main(args)
|
254
gradio_demo.py
Normal file
254
gradio_demo.py
Normal file
@ -0,0 +1,254 @@
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from datetime import datetime, timedelta
|
||||
from openai import OpenAI
|
||||
import spaces
|
||||
import imageio
|
||||
import moviepy.editor as mp
|
||||
from typing import List, Union
|
||||
import PIL
|
||||
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
pipe = CogVideoXPipeline.from_pretrained("/share/home/zyx/Models/cogvideox-hf-0805", torch_dtype=dtype).to(device)
|
||||
|
||||
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
||||
|
||||
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
||||
There are a few rules to follow:
|
||||
|
||||
You will only ever output a single video description per user request.
|
||||
|
||||
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
||||
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
||||
|
||||
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
||||
"""
|
||||
|
||||
|
||||
def export_to_video_imageio(
|
||||
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
||||
) -> str:
|
||||
"""
|
||||
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
|
||||
"""
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
||||
|
||||
if isinstance(video_frames[0], PIL.Image.Image):
|
||||
video_frames = [np.array(frame) for frame in video_frames]
|
||||
|
||||
with imageio.get_writer(output_video_path, fps=fps) as writer:
|
||||
for frame in video_frames:
|
||||
writer.append_data(frame)
|
||||
|
||||
return output_video_path
|
||||
|
||||
|
||||
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
return prompt
|
||||
client = OpenAI()
|
||||
text = prompt.strip()
|
||||
|
||||
for i in range(retry_times):
|
||||
response = client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{"role": "user",
|
||||
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"'},
|
||||
{"role": "assistant",
|
||||
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance."},
|
||||
{"role": "user",
|
||||
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"'},
|
||||
{"role": "assistant",
|
||||
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field."},
|
||||
{"role": "user",
|
||||
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"'},
|
||||
{"role": "assistant",
|
||||
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background."},
|
||||
{"role": "user",
|
||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"'},
|
||||
],
|
||||
model="glm-4-0520",
|
||||
temperature=0.01,
|
||||
top_p=0.7,
|
||||
stream=False,
|
||||
max_tokens=250,
|
||||
)
|
||||
if response.choices:
|
||||
return response.choices[0].message.content
|
||||
return prompt
|
||||
|
||||
|
||||
@spaces.GPU(duration=240)
|
||||
def infer(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
prompt_embeds, _ = pipe.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=None,
|
||||
do_classifier_free_guidance=True,
|
||||
num_videos_per_prompt=1,
|
||||
max_sequence_length=226,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
|
||||
).frames[0]
|
||||
|
||||
|
||||
return video
|
||||
|
||||
|
||||
def save_video(tensor):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
video_path = f"./output/{timestamp}.mp4"
|
||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||
export_to_video_imageio(tensor[1:], video_path)
|
||||
return video_path
|
||||
|
||||
def convert_to_gif(video_path):
|
||||
clip = mp.VideoFileClip(video_path)
|
||||
clip = clip.set_fps(8)
|
||||
clip = clip.resize(height=240)
|
||||
gif_path = video_path.replace('.mp4', '.gif')
|
||||
clip.write_gif(gif_path, fps=8)
|
||||
return gif_path
|
||||
|
||||
|
||||
def delete_old_files():
|
||||
while True:
|
||||
now = datetime.now()
|
||||
cutoff = now - timedelta(minutes=10)
|
||||
output_dir = './output'
|
||||
for filename in os.listdir(output_dir):
|
||||
file_path = os.path.join(output_dir, filename)
|
||||
if os.path.isfile(file_path):
|
||||
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
|
||||
if file_mtime < cutoff:
|
||||
os.remove(file_path)
|
||||
time.sleep(600) # Sleep for 10 minutes
|
||||
|
||||
|
||||
threading.Thread(target=delete_old_files, daemon=True).start()
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("""
|
||||
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
||||
CogVideoX-2B Huggingface Space🤗
|
||||
</div>
|
||||
<div style="text-align: center;">
|
||||
<a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 Model Hub</a> |
|
||||
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a>
|
||||
</div>
|
||||
|
||||
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
||||
⚠️ This demo is for academic research and experiential use only.
|
||||
Users should strictly adhere to local laws and ethics.
|
||||
</div>
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
|
||||
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
|
||||
"Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
|
||||
"50 steps are recommended for most cases. will cause 120 seconds for inference.<br>")
|
||||
with gr.Row():
|
||||
num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
||||
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
||||
generate_button = gr.Button("🎬 Generate Video")
|
||||
|
||||
with gr.Column():
|
||||
video_output = gr.Video(label="CogVideoX Generate Video", width=720, height=480)
|
||||
with gr.Row():
|
||||
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
||||
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
||||
|
||||
gr.Markdown("""
|
||||
<table border="1" style="width: 100%; text-align: left; margin-top: 20px;">
|
||||
<tr>
|
||||
<th>Prompt</th>
|
||||
<th>Video URL</th>
|
||||
<th>Inference Steps</th>
|
||||
<th>Guidance Scale</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</td>
|
||||
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/1.mp4">Video 1</a></td>
|
||||
<td>50</td>
|
||||
<td>6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</td>
|
||||
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/2.mp4">Video 2</a></td>
|
||||
<td>50</td>
|
||||
<td>6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</td>
|
||||
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/3.mp4">Video 3</a></td>
|
||||
<td>50</td>
|
||||
<td>6</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</td>
|
||||
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/4.mp4">Video 4</a></td>
|
||||
<td>50</td>
|
||||
<td>6</td>
|
||||
</tr>
|
||||
</table>
|
||||
""")
|
||||
|
||||
|
||||
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
||||
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||
video_path = save_video(tensor)
|
||||
video_update = gr.update(visible=True, value=video_path)
|
||||
gif_path = convert_to_gif(video_path)
|
||||
gif_update = gr.update(visible=True, value=gif_path)
|
||||
|
||||
return video_path, video_update, gif_update
|
||||
|
||||
|
||||
def enhance_prompt_func(prompt):
|
||||
return convert_prompt(prompt, retry_times=1)
|
||||
|
||||
|
||||
generate_button.click(
|
||||
generate,
|
||||
inputs=[prompt, num_inference_steps, guidance_scale],
|
||||
outputs=[video_output, download_video_button, download_gif_button]
|
||||
)
|
||||
|
||||
enhance_button.click(
|
||||
enhance_prompt_func,
|
||||
inputs=[prompt],
|
||||
outputs=[prompt]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(server_name="127.0.0.1", server_port=7870, share=True)
|
107
inference/cli_demo.py
Normal file
107
inference/cli_demo.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""
|
||||
This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
||||
|
||||
Note:
|
||||
This script requires the `diffusers>=0.30.0` library to be installed.
|
||||
If the video exported using OpenCV appears “completely green” and cannot be viewed, lease switch to a different player to watch it. This is a normal phenomenon.
|
||||
|
||||
Run the script:
|
||||
$ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt: str,
|
||||
model_path: str,
|
||||
output_path: str = "./output.mp4",
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 6.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
"""
|
||||
Generates a video based on the given prompt and saves it to the specified path.
|
||||
|
||||
Parameters:
|
||||
- prompt (str): The description of the video to be generated.
|
||||
- model_path (str): The path of the pre-trained model to be used.
|
||||
- output_path (str): The path where the generated video will be saved.
|
||||
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
||||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
|
||||
- dtype (torch.dtype): The data type for computation (default is torch.float16).
|
||||
"""
|
||||
|
||||
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
|
||||
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
|
||||
# Encode the prompt to get the prompt embeddings
|
||||
prompt_embeds, _ = pipe.encode_prompt(
|
||||
prompt=prompt, # The textual description for video generation
|
||||
negative_prompt=None, # The negative prompt to guide the video generation
|
||||
do_classifier_free_guidance=True, # Whether to use classifier-free guidance
|
||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||
max_sequence_length=226, # Maximum length of the sequence, must be 226
|
||||
device=device, # Device to use for computation
|
||||
dtype=dtype, # Data type for computation
|
||||
)
|
||||
|
||||
# Generate the video frames using the pipeline
|
||||
video = pipe(
|
||||
num_inference_steps=num_inference_steps, # Number of inference steps
|
||||
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
|
||||
prompt_embeds=prompt_embeds, # Encoded prompt embeddings
|
||||
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
|
||||
).frames[0]
|
||||
|
||||
# Export the generated frames to a video file. fps must be 8
|
||||
export_to_video(video, output_path, fps=8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
||||
)
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert dtype argument to torch.dtype, NOT suggest BF16.
|
||||
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
||||
|
||||
# main function to generate video.
|
||||
generate_video(
|
||||
prompt=args.prompt,
|
||||
model_path=args.model_path,
|
||||
output_path=args.output_path,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
num_videos_per_prompt=args.num_videos_per_prompt,
|
||||
device=args.device,
|
||||
dtype=dtype,
|
||||
)
|
103
inference/cli_vae_demo.py
Normal file
103
inference/cli_vae_demo.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""
|
||||
This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers.
|
||||
|
||||
Note:
|
||||
This script requires the `diffusers>=0.30.0` library to be installed.
|
||||
If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon.
|
||||
Cost 71GB of GPU memory for encoding a 1-minute video at 720p resolution.
|
||||
|
||||
Run the script:
|
||||
$ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import imageio
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLCogVideoX
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def vae_demo(model_path, video_path, dtype, device):
|
||||
"""
|
||||
Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
|
||||
|
||||
Parameters:
|
||||
- model_path (str): The path to the pre-trained model.
|
||||
- video_path (str): The path to the video file.
|
||||
- dtype (torch.dtype): The data type for computation.
|
||||
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The encoded video frames.
|
||||
"""
|
||||
# Load the pre-trained model
|
||||
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
|
||||
# Load video frames
|
||||
video_reader = imageio.get_reader(video_path, 'ffmpeg')
|
||||
frames = []
|
||||
for frame in video_reader:
|
||||
frames.append(frame)
|
||||
video_reader.close()
|
||||
|
||||
# Transform frames to Tensor
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
|
||||
|
||||
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
|
||||
frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device)
|
||||
|
||||
# Run the model with Encoder and Decoder
|
||||
with torch.no_grad():
|
||||
output = model(frames_tensor)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def save_video(tensor, output_path):
|
||||
"""
|
||||
Saves the encoded video frames to a video file.
|
||||
|
||||
Parameters:
|
||||
- tensor (torch.Tensor): The encoded video frames.
|
||||
- output_path (str): The path to save the output video.
|
||||
"""
|
||||
# Remove batch dimension and permute back to [49, 480, 720, 3]
|
||||
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
|
||||
|
||||
# Clip values to [0, 1] and convert to uint8
|
||||
frames = np.clip(frames, 0, 1)
|
||||
frames = (frames * 255).astype(np.uint8)
|
||||
|
||||
# Save frames to video
|
||||
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
||||
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="./", help="The path to save the output video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device and dtype
|
||||
device = torch.device(args.device)
|
||||
dtype = torch.float16 if args.dtype == "float16" else torch.float32
|
||||
|
||||
output = vae_demo(args.model_path, args.video_path, dtype, device)
|
||||
save_video(output, args.output_path)
|
92
inference/convert_demo.py
Normal file
92
inference/convert_demo.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""
|
||||
|
||||
The CogVideoX model is pre-trained and fine-tuned using longer and more detailed prompts.Therefore, it requires highly granular and detailed prompts as input.This script aims to transform user inputs into executable inputs for CogVideoX, enabling superior video generation.
|
||||
|
||||
This step is not mandatory; the model will still function correctly and without errors even if the prompts are not refined using this script. However, we strongly recommend using it to ensure the generation of high-quality videos.
|
||||
|
||||
Note:
|
||||
Please set the OPENAI_API_KEY and OPENAI_BASE_URL(if needed) environment variable to your OpenAI API key before running this script.
|
||||
|
||||
Run the script:
|
||||
$ python convert_demo.py --prompt "A girl ridding a bike." # Using with OpenAI's API
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
||||
|
||||
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
||||
There are a few rules to follow:
|
||||
|
||||
You will only ever output a single video description per user request.
|
||||
|
||||
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
||||
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
||||
|
||||
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
||||
"""
|
||||
|
||||
|
||||
def convert_prompt(prompt: str, retry_times: int = 3):
|
||||
"""
|
||||
Convert a prompt to a format that can be used by the model for inference
|
||||
"""
|
||||
|
||||
client = OpenAI()
|
||||
text = prompt.strip()
|
||||
|
||||
for i in range(retry_times):
|
||||
response = client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": f"{sys_prompt}"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " a girl is on the beach"',
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A man jogging on a football field"',
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
|
||||
},
|
||||
],
|
||||
model="glm-4-0520", # glm-4-0520 and gpt-4o have be tested
|
||||
temperature=0.01,
|
||||
top_p=0.7,
|
||||
stream=False,
|
||||
max_tokens=250,
|
||||
)
|
||||
if response.choices:
|
||||
return response.choices[0].message.content
|
||||
return prompt
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
|
||||
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_prompt = convert_prompt(args.prompt, args.retry_times)
|
||||
print(converted_prompt)
|
214
inference/web_demo.py
Normal file
214
inference/web_demo.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""
|
||||
This script is used to create a Streamlit web application for generating videos using the CogVideoX model.
|
||||
|
||||
Run the script using Streamlit:
|
||||
$ export OPENAI_API_KEY=your OpenAI Key or ZhiupAI Key
|
||||
$ export OPENAI_BASE_URL=https://open.bigmodel.cn/api/paas/v4/ # using with ZhipuAI, Not using this when using OpenAI
|
||||
$ streamlit run web_demo.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from convert_demo import convert_prompt
|
||||
from diffusers import CogVideoXPipeline
|
||||
|
||||
|
||||
model_path: str = "THUDM/CogVideoX-2b"
|
||||
|
||||
|
||||
# Load the model at the start
|
||||
@st.cache_resource
|
||||
def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPipeline:
|
||||
"""
|
||||
Load the CogVideoX model.
|
||||
|
||||
Args:
|
||||
- model_path (str): Path to the model.
|
||||
- dtype (torch.dtype): Data type for model.
|
||||
- device (str): Device to load the model on.
|
||||
|
||||
Returns:
|
||||
- CogVideoXPipeline: Loaded model pipeline.
|
||||
"""
|
||||
return CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
|
||||
|
||||
# Define a function to generate video based on the provided prompt and model path
|
||||
def generate_video(
|
||||
pipe: CogVideoXPipeline,
|
||||
prompt: str,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 6.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate a video based on the provided prompt and model path.
|
||||
|
||||
Args:
|
||||
- pipe (CogVideoXPipeline): The pipeline for generating videos.
|
||||
- prompt (str): Text prompt for video generation.
|
||||
- num_inference_steps (int): Number of inference steps.
|
||||
- guidance_scale (float): Guidance scale for generation.
|
||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||
- device (str): Device to run the generation on.
|
||||
- dtype (torch.dtype): Data type for the model.
|
||||
|
||||
Returns:
|
||||
- List[np.ndarray]: Generated video frames.
|
||||
"""
|
||||
prompt_embeds, _ = pipe.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=None,
|
||||
do_classifier_free_guidance=True,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=226,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Generate video
|
||||
video = pipe(
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
|
||||
).frames[0]
|
||||
return video
|
||||
|
||||
|
||||
def save_video(video: List[np.ndarray], path: str, fps: int = 8) -> None:
|
||||
"""
|
||||
Save the generated video to a file.
|
||||
|
||||
Args:
|
||||
- video (List[np.ndarray]): Video frames.
|
||||
- path (str): Path to save the video.
|
||||
- fps (int): Frames per second for the video.
|
||||
"""
|
||||
# Remove the first frame
|
||||
video = video[1:]
|
||||
|
||||
writer = imageio.get_writer(path, fps=fps, codec="libx264")
|
||||
for frame in video:
|
||||
np_frame = np.array(frame)
|
||||
writer.append_data(np_frame)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def save_metadata(
|
||||
prompt: str,
|
||||
converted_prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
num_videos_per_prompt: int,
|
||||
path: str,
|
||||
) -> None:
|
||||
"""
|
||||
Save metadata to a JSON file.
|
||||
|
||||
Args:
|
||||
- prompt (str): Original prompt.
|
||||
- converted_prompt (str): Converted prompt.
|
||||
- num_inference_steps (int): Number of inference steps.
|
||||
- guidance_scale (float): Guidance scale.
|
||||
- num_videos_per_prompt (int): Number of videos per prompt.
|
||||
- path (str): Path to save the metadata.
|
||||
"""
|
||||
metadata = {
|
||||
"prompt": prompt,
|
||||
"converted_prompt": converted_prompt,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_videos_per_prompt": num_videos_per_prompt,
|
||||
}
|
||||
with open(path, "w") as f:
|
||||
json.dump(metadata, f, indent=4)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
Main function to run the Streamlit web application.
|
||||
"""
|
||||
st.set_page_config(page_title="CogVideoX-Demo", page_icon="🎥", layout="wide")
|
||||
st.write("# CogVideoX 🎥")
|
||||
dtype: torch.dtype = torch.float16
|
||||
device: str = "cuda"
|
||||
|
||||
global pipe
|
||||
pipe = load_model(model_path, dtype, device)
|
||||
|
||||
with st.sidebar:
|
||||
st.info("It will take some time to generate a video (~90 seconds per videos in 50 steps).", icon="ℹ️")
|
||||
num_inference_steps: int = st.number_input("Inference Steps", min_value=1, max_value=100, value=50)
|
||||
guidance_scale: float = st.number_input("Guidance Scale", min_value=0.0, max_value=20.0, value=6.0)
|
||||
num_videos_per_prompt: int = st.number_input("Videos per Prompt", min_value=1, max_value=10, value=1)
|
||||
|
||||
share_links_container = st.empty()
|
||||
|
||||
prompt: str = st.chat_input("Prompt")
|
||||
|
||||
if prompt:
|
||||
# Not Necessary, Suggestions
|
||||
with st.spinner("Refining prompts..."):
|
||||
converted_prompt = convert_prompt(prompt=prompt, retry_times=1)
|
||||
if converted_prompt is None:
|
||||
st.error("Failed to Refining the prompt, Using origin one.")
|
||||
|
||||
st.info(f"**Origin prompt:** \n{prompt} \n \n**Convert prompt:** \n{converted_prompt}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with st.spinner("Generating Video..."):
|
||||
start_time = time.time()
|
||||
video_paths = []
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = f"./output/{timestamp}"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
metadata_path = os.path.join(output_dir, "config.json")
|
||||
save_metadata(
|
||||
prompt, converted_prompt, num_inference_steps, guidance_scale, num_videos_per_prompt, metadata_path
|
||||
)
|
||||
|
||||
for i in range(num_videos_per_prompt):
|
||||
video_path = os.path.join(output_dir, f"output_{i + 1}.mp4")
|
||||
|
||||
video = generate_video(
|
||||
pipe, converted_prompt or prompt, num_inference_steps, guidance_scale, 1, device, dtype
|
||||
)
|
||||
save_video(video, video_path, fps=8)
|
||||
video_paths.append(video_path)
|
||||
with open(video_path, "rb") as video_file:
|
||||
video_bytes: bytes = video_file.read()
|
||||
st.video(video_bytes, autoplay=True, loop=True, format="video/mp4")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
used_time: float = time.time() - start_time
|
||||
st.success(f"Videos generated in {used_time:.2f} seconds.")
|
||||
|
||||
# Create download links in the sidebar
|
||||
with share_links_container:
|
||||
st.sidebar.write("### Download Links:")
|
||||
for video_path in video_paths:
|
||||
video_name = os.path.basename(video_path)
|
||||
with open(video_path, "rb") as f:
|
||||
video_bytes: bytes = f.read()
|
||||
b64_video = base64.b64encode(video_bytes).decode()
|
||||
href = f'<a href="data:video/mp4;base64,{b64_video}" download="{video_name}">Download {video_name}</a>'
|
||||
st.sidebar.markdown(href, unsafe_allow_html=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,3 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
pip install git+https://github.com/Sleepychord/Image-Local-Attention
|
@ -1,695 +0,0 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@File : cogvideo_cache_model.py
|
||||
@Time : 2022/07/15 11:22:19
|
||||
@Author : Wenyi Hong
|
||||
@Version : 1.0
|
||||
@Contact : hwy22@mails.tsinghua.edu.cn
|
||||
'''
|
||||
|
||||
# here put the import lib
|
||||
|
||||
from multiprocessing import context
|
||||
from tkinter import E
|
||||
import torch
|
||||
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
|
||||
|
||||
from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
|
||||
from SwissArmyTransformer.model.transformer import unscaled_init_method
|
||||
from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
|
||||
import torch.nn.functional as F
|
||||
from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
|
||||
import math
|
||||
|
||||
|
||||
class PositionEmbeddingMixin(BaseMixin):
|
||||
def __init__(self, additional_sequence_length, hidden_size,
|
||||
init_method_std=0.02, reinit_slice=slice(512, 912),
|
||||
):
|
||||
super(PositionEmbeddingMixin, self).__init__()
|
||||
self.reinit_slice = reinit_slice
|
||||
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
|
||||
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
|
||||
old_len, hidden_size = old_weights.shape
|
||||
assert hidden_size == self.position_embeddings.weight.shape[-1]
|
||||
self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, framenum, H, W, C)
|
||||
window_size (int): window size
|
||||
Returns:
|
||||
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
||||
"""
|
||||
B, framenum, H, W, C = x.shape
|
||||
x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, frame_num, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
framenum = windows.shape[1]
|
||||
x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
|
||||
return x
|
||||
|
||||
class WindowAttentionMixin(BaseMixin):
|
||||
def __init__(self, num_layers,
|
||||
hidden_size,
|
||||
frame_resolution,
|
||||
window_size,
|
||||
shift_size,
|
||||
n_head,
|
||||
frame_num,
|
||||
init_method=unscaled_init_method(0.02),
|
||||
output_layer_init_method=unscaled_init_method(0.02),
|
||||
time_dim_attend_length=0
|
||||
):
|
||||
super(WindowAttentionMixin, self).__init__()
|
||||
self.num_layers = num_layers # replace attention in the LAST n layers
|
||||
self.query_key_value = torch.nn.ModuleList(
|
||||
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
||||
gather_output=False,init_method=init_method)
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
self.dense = torch.nn.ModuleList(
|
||||
[RowParallelLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
init_method=output_layer_init_method,
|
||||
bias=True,
|
||||
module=self,
|
||||
name="dense")
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
|
||||
self.n_head = n_head
|
||||
self.window_size = window_size
|
||||
self.frame_resolution = frame_resolution
|
||||
self.frame_len = frame_resolution * frame_resolution
|
||||
self.time_dim_attend_length = time_dim_attend_length
|
||||
assert frame_resolution % window_size == 0
|
||||
assert 0 < shift_size < window_size
|
||||
nW = (self.frame_resolution // self.window_size) ** 2
|
||||
ws_squre = self.window_size * self.window_size
|
||||
|
||||
# odd non-shift, even shift
|
||||
img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
|
||||
h_slices = (slice(0, -shift_size),
|
||||
slice(-shift_size, None))
|
||||
w_slices = (slice(0, -shift_size),
|
||||
slice(-shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, :, h, w, :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
|
||||
sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
|
||||
attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
|
||||
attn_mask = attn_mask.tril()
|
||||
|
||||
causal_mask = torch.ones(ws_squre*frame_num, ws_squre*frame_num)
|
||||
causal_mask = causal_mask.tril()
|
||||
|
||||
self.shift_sizes = [0, shift_size]
|
||||
self.attn_mask = attn_mask
|
||||
self.causal_mask = causal_mask
|
||||
self.mask_initialized = False
|
||||
|
||||
self.attn_distribution = torch.nn.ParameterList([
|
||||
torch.nn.Parameter(torch.zeros(hidden_size))
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def reinit(self, *pre_mixins):
|
||||
start_layer = len(self.transformer.layers) - self.num_layers
|
||||
assert start_layer >= 0
|
||||
for layer_id in range(self.num_layers):
|
||||
old_attention = self.transformer.layers[start_layer + layer_id].attention
|
||||
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
|
||||
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
|
||||
|
||||
def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
|
||||
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
||||
if not self.mask_initialized:
|
||||
self.attn_mask = self.attn_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
||||
self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
||||
self.mask_initialized = True
|
||||
b0, s1, h0 = frame_hidden_state.shape
|
||||
h = h0 // self.n_head
|
||||
frame_len = self.frame_resolution * self.frame_resolution
|
||||
frame_num = s1 // frame_len
|
||||
if stage == 2:
|
||||
assert frame_num == 3
|
||||
assert frame_num*frame_len == s1
|
||||
wind_square = self.window_size * self.window_size
|
||||
nW = frame_len // wind_square
|
||||
bswin = b0 * nW
|
||||
|
||||
if memkv_text is not None:
|
||||
s0 = memkv_text.shape[-2]
|
||||
k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
||||
v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
||||
|
||||
# shift
|
||||
frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
|
||||
# window partition
|
||||
frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
|
||||
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
|
||||
.permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
||||
|
||||
if stage == 1:
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square),
|
||||
self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))\
|
||||
- 10000.0 * (1.0 - self.attn_mask[:,:frame_num*wind_square, :frame_num*wind_square].unsqueeze(1).unsqueeze(0))
|
||||
attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
|
||||
else:
|
||||
attn = torch.mul(attn, self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))\
|
||||
- 10000.0 * (1.0 - self.causal_mask[:frame_num*wind_square, :frame_num*wind_square].unsqueeze(0).unsqueeze(0))
|
||||
|
||||
if memkv_text is None:
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
||||
else:
|
||||
attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
|
||||
attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
|
||||
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
|
||||
context_swin = (torch.matmul(attn[..., :-s0], v) +
|
||||
torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
|
||||
.reshape(bswin, self.n_head, frame_num*wind_square, h))\
|
||||
.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
||||
|
||||
context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
|
||||
|
||||
# reverse cycle shift
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
||||
ret_context = context_swin.reshape(b0, s1, h0)
|
||||
|
||||
# for mem
|
||||
memk = k.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
||||
memv = v.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
||||
memk = window_reverse(memk, self.window_size, self.frame_resolution, self.frame_resolution)
|
||||
memv = window_reverse(memv, self.window_size, self.frame_resolution, self.frame_resolution)
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
memk = torch.roll(memk, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
||||
memv = torch.roll(memv, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
||||
memk, memv = memk.reshape(b0, s1, h0), memv.reshape(b0, s1, h0)
|
||||
|
||||
ret_mem = torch.cat((memk, memv), dim=-1)
|
||||
return ret_context, ret_mem
|
||||
|
||||
def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
|
||||
# frame_hidden_state [batchsize, 1, n_head*hiddensize_perhead]
|
||||
# memkv [batchsize, pos, hidden_size*2] (include frames only)
|
||||
# if memkv_text is not None: will attend to text
|
||||
# pos: token's pos
|
||||
b0, sin, h0 = frame_hidden_state.shape
|
||||
h = h0 // self.n_head
|
||||
assert sin == 1
|
||||
this_qkv = self.query_key_value[layer_id](frame_hidden_state)
|
||||
thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
|
||||
s1 = memkv.shape[1] if memkv is not None else 0
|
||||
frame_len = self.frame_resolution * self.frame_resolution
|
||||
frame_num_before = s1 // frame_len
|
||||
|
||||
|
||||
if memkv is not None:
|
||||
pos_inframe = pos - frame_num_before * frame_len
|
||||
|
||||
xpos = pos_inframe // self.frame_resolution # pos = xpos*self.frame_resolution + ypos
|
||||
ypos = pos_inframe % self.frame_resolution
|
||||
# [start, end)
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
xstart = ((xpos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
|
||||
ystart = ((ypos+self.shift_sizes[layer_id%2]) // self.window_size) * self.window_size - self.shift_sizes[layer_id%2]
|
||||
xend = xstart + self.window_size
|
||||
yend = ystart + self.window_size
|
||||
xstart, ystart = max(0, xstart), max(0, ystart)
|
||||
xend, yend = min(xend, self.frame_resolution), min(yend, self.frame_resolution)
|
||||
else:
|
||||
xstart = (xpos // self.window_size) * self.window_size
|
||||
ystart = (ypos // self.window_size) * self.window_size
|
||||
xend, yend = xstart + self.window_size, ystart+self.window_size
|
||||
|
||||
# select index
|
||||
selected_index = list()
|
||||
if frame_num_before > 0:
|
||||
# frames before
|
||||
frame_attended_start = max(0, frame_num_before-self.time_dim_attend_length+1) if self.time_dim_attend_length > 0 else 0
|
||||
for x in range(xstart, xend):
|
||||
for y in range(ystart, yend):
|
||||
selected_index.append(x*self.frame_resolution+y+frame_len*frame_attended_start)
|
||||
cnt_per_frame = len(selected_index)
|
||||
for _ in range((frame_num_before-frame_attended_start-1)*cnt_per_frame):
|
||||
selected_index.append(selected_index[-cnt_per_frame]+frame_len)
|
||||
|
||||
# the last frame
|
||||
for x in range(xstart, xend):
|
||||
for y in range(ystart, yend):
|
||||
tmppos = x*self.frame_resolution+y + frame_num_before * frame_len
|
||||
if tmppos < pos:
|
||||
selected_index.append(tmppos)
|
||||
else:
|
||||
break
|
||||
cnt_all = len(selected_index)+1
|
||||
selected_index = torch.tensor(selected_index, device=memkv.device)
|
||||
used_memkv = torch.index_select(memkv, 1, selected_index)
|
||||
used_k, used_v = used_memkv[..., :h0], used_memkv[..., h0:]
|
||||
used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
|
||||
used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
|
||||
if memkv_text is not None:
|
||||
cnt_all += memkv_text.shape[-2]
|
||||
used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
|
||||
used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
|
||||
used_k = used_k.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
|
||||
used_v = used_v.reshape(b0, cnt_all, self.n_head, h).permute(0, 2, 1, 3)
|
||||
else:
|
||||
used_k = thisk
|
||||
used_v = thisv
|
||||
|
||||
if memkv_text is not None:
|
||||
used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
|
||||
used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
|
||||
used_k = used_k.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
|
||||
used_v = used_v.reshape(b0, 1+memkv_text.shape[-2], self.n_head, h).permute(0, 2, 1, 3)
|
||||
else:
|
||||
used_k = used_k.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
|
||||
used_v = used_v.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3)
|
||||
|
||||
thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
|
||||
attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
|
||||
if memkv_text is not None:
|
||||
attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
|
||||
|
||||
return context_swin, this_qkv[..., h0:]
|
||||
|
||||
class FullAttentionMixin(BaseMixin):
|
||||
def __init__(self, num_layers,
|
||||
hidden_size,
|
||||
frame_resolution,
|
||||
n_head,
|
||||
frame_num,
|
||||
init_method=unscaled_init_method(0.02),
|
||||
output_layer_init_method=unscaled_init_method(0.02),
|
||||
**kwargs,
|
||||
):
|
||||
super(FullAttentionMixin, self).__init__()
|
||||
self.num_layers = num_layers # replace attention in the LAST n layers
|
||||
self.query_key_value = torch.nn.ModuleList(
|
||||
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
||||
gather_output=False,init_method=init_method)
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
self.dense = torch.nn.ModuleList(
|
||||
[RowParallelLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
init_method=output_layer_init_method,
|
||||
bias=True,
|
||||
module=self,
|
||||
name="dense")
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
|
||||
self.n_head = n_head
|
||||
self.frame_resolution = frame_resolution
|
||||
self.frame_len = frame_resolution * frame_resolution
|
||||
|
||||
self.attn_distribution = torch.nn.ParameterList([
|
||||
torch.nn.Parameter(torch.zeros(hidden_size))
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def reinit(self, *pre_mixins):
|
||||
start_layer = len(self.transformer.layers) - self.num_layers
|
||||
assert start_layer >= 0
|
||||
for layer_id in range(self.num_layers):
|
||||
old_attention = self.transformer.layers[start_layer + layer_id].attention
|
||||
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
|
||||
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
|
||||
|
||||
|
||||
def attention_extra_NAR_inference(self, frame_hidden_state, layer_id, attn_dropout=None, memkv_text=None, stage=1):
|
||||
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
||||
assert stage == 1
|
||||
|
||||
b0, s1, h0 = frame_hidden_state.shape
|
||||
h = h0 // self.n_head
|
||||
frame_len = self.frame_resolution * self.frame_resolution
|
||||
frame_num = s1 // frame_len
|
||||
assert frame_num*frame_len == s1
|
||||
|
||||
if memkv_text is not None:
|
||||
s0 = memkv_text.shape[-2]
|
||||
k_text = memkv_text[..., :h0].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
||||
v_text = memkv_text[..., h0:].expand(b0, -1, -1).reshape(b0, s0, self.n_head, h).permute(0, 2, 1, 3)
|
||||
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
|
||||
.permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
||||
attn = attn - 10000.0 * (1.0-torch.ones(b0, self.n_head, s1, s1, device=attn.device, dtype=attn.dtype).tril())
|
||||
|
||||
if memkv_text is None:
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
||||
else:
|
||||
attn_frame2text = torch.matmul(q / math.sqrt(h), k_text.transpose(-1, -2)) #[b0, s1, s0]
|
||||
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
context_swin = (torch.matmul(attn[..., :-s0], v) + torch.matmul(attn[..., -s0:], v_text))\
|
||||
.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
||||
|
||||
# for mem
|
||||
memk = k.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
||||
memv = v.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
||||
ret_mem = torch.cat((memk, memv), dim=-1)
|
||||
|
||||
return context_swin, ret_mem
|
||||
|
||||
def attention_extra_AR_inference(self, frame_hidden_state, memkv, pos, layer_id, log_text_attention_weights=0, attn_dropout=None, memkv_text=None, stage=1):
|
||||
# pos: current token's pos
|
||||
b0, sin, h0 = frame_hidden_state.shape
|
||||
h = h0 // self.n_head
|
||||
assert sin == 1
|
||||
assert stage == 1
|
||||
|
||||
this_qkv = self.query_key_value[layer_id](frame_hidden_state)
|
||||
thisq, thisk, thisv = this_qkv[..., :h0], this_qkv[..., h0:2*h0], this_qkv[..., 2*h0:]
|
||||
|
||||
if memkv is not None:
|
||||
used_k, used_v = memkv[..., :h0], memkv[..., h0:]
|
||||
used_k = torch.cat((used_k.expand(thisk.shape[0], -1, -1), thisk), dim=-2)
|
||||
used_v = torch.cat((used_v.expand(thisv.shape[0], -1, -1), thisv), dim=-2)
|
||||
else:
|
||||
used_k, used_v = thisk, thisv
|
||||
|
||||
if memkv_text is not None:
|
||||
used_k = torch.cat((memkv_text[..., :h0].expand(thisk.shape[0], -1, -1), used_k), dim=-2)
|
||||
used_v = torch.cat((memkv_text[..., h0:].expand(thisv.shape[0], -1, -1), used_v), dim=-2)
|
||||
|
||||
used_k = used_k.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
|
||||
used_v = used_v.reshape(b0, -1, self.n_head, h).permute(0, 2, 1, 3)
|
||||
thisq = thisq.reshape(b0, 1, self.n_head, h).permute(0, 2, 1, 3) # [b0, n_head, 1, h]
|
||||
attn = torch.matmul(thisq / math.sqrt(h), used_k.transpose(-1, -2))
|
||||
if memkv_text is not None:
|
||||
attn[..., :memkv_text.shape[-2]] += log_text_attention_weights
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
context_swin = torch.matmul(attn, used_v).permute(0, 2, 1, 3).reshape(b0, 1, h0)
|
||||
|
||||
return context_swin, this_qkv[..., h0:]
|
||||
|
||||
|
||||
def attention_localframe_and_text_NAR(q0, k0, v0, attention_mask,
|
||||
n_head, text_len, frame_len, frame_num,
|
||||
attention_dropout=None, log_text_attention_weights=0, stage=1, **kwargs):
|
||||
b, s0, h0 = q0.shape
|
||||
s1 = s0 - text_len
|
||||
h = h0 // n_head
|
||||
assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
|
||||
# attention_mask.shape [4, b or 1, 1, text_len+frame_len, text_len+frame_len]
|
||||
if stage == 2:
|
||||
assert frame_num == 3
|
||||
|
||||
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
||||
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
||||
k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
||||
k0T = k0.transpose(-1, -2)
|
||||
|
||||
score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
|
||||
score_any2text += log_text_attention_weights
|
||||
score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask[..., :text_len, :text_len]) \
|
||||
- 10000.0 * (1.0 - attention_mask[..., :text_len, :text_len])
|
||||
# context for text
|
||||
attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
|
||||
if attention_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attention_probs_text = attention_dropout(attention_probs_text)
|
||||
context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
|
||||
context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
|
||||
|
||||
if frame_num > 0:
|
||||
score_any2text_part2 = score_any2text[..., text_len:, :]
|
||||
|
||||
# score: frame local
|
||||
q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
||||
v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
||||
k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
|
||||
score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
|
||||
if stage == 1:
|
||||
score_frame_local0 = torch.mul(score_frame_local0, attention_mask[..., text_len:, text_len:].unsqueeze(1)) \
|
||||
- 10000.0 * (1.0 - attention_mask[..., text_len:, text_len:].unsqueeze(1))
|
||||
|
||||
# context for frame
|
||||
score_frame_all = torch.cat((score_any2text_part2,
|
||||
score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
|
||||
attention_probs_frame = F.softmax(score_frame_all, dim=-1)
|
||||
if attention_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attention_probs_frame = attention_dropout(attention_probs_frame)
|
||||
context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
|
||||
context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
|
||||
view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
|
||||
|
||||
context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
|
||||
else:
|
||||
context_frame = None
|
||||
|
||||
return context_text2text, context_frame
|
||||
|
||||
def attention_localframe_and_text_AR(q0, k0, v0, n_head, text_len, frame_len, frame_num,
|
||||
attention_dropout=None, log_text_attention_weights=0, layer_id=None, limited_spatial_channel_mem=False, stage=1, **kwargs):
|
||||
# limited_spatial_channel_mem=True means: mems in spatial channel is consisted of {mem_text, mem_current_frame}
|
||||
b, s0, h0 = k0.shape
|
||||
frame_num_before = (s0-text_len-1) // frame_len # frame_num == frame_num_before or frame_num == frame_num_before+1
|
||||
h = h0 // n_head
|
||||
assert q0.shape[1] == 1
|
||||
assert v0.shape[1] == k0.shape[1]
|
||||
|
||||
q0 = q0.reshape(b, 1, n_head, h).permute(0, 2, 1, 3)
|
||||
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
||||
k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
|
||||
|
||||
if limited_spatial_channel_mem:
|
||||
assert frame_num_before == 0
|
||||
assert stage == 1 # not implemented for stage-2 yet
|
||||
score = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
|
||||
score[..., :text_len] += log_text_attention_weights
|
||||
attention_probs_frame = F.softmax(score, dim=-1)
|
||||
context_frame = torch.matmul(attention_probs_frame, v0).transpose(1, 2).reshape(b, 1, h0)
|
||||
|
||||
else:
|
||||
score_token2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
|
||||
score_token2text += log_text_attention_weights
|
||||
score_frame_local0 = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., text_len+frame_num_before*frame_len:])
|
||||
score_frame_all = torch.cat((score_token2text,
|
||||
score_frame_local0), dim=-1)
|
||||
attention_probs_frame = F.softmax(score_frame_all, dim=-1)
|
||||
|
||||
context_token2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
|
||||
context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:], \
|
||||
v0[:, :, text_len+frame_num_before*frame_len:, :])
|
||||
context_frame = (context_token2text + context_frame_local0).transpose(1, 2).reshape(b, 1, h0)
|
||||
|
||||
return context_frame
|
||||
|
||||
|
||||
class CogVideoCacheModel(BaseModel):
|
||||
def __init__(self, args, transformer=None, parallel_output=True, window_size=None, cogvideo_stage=None):
|
||||
super().__init__(args, transformer=transformer, parallel_output=parallel_output)
|
||||
self.layout = args.layout # [64, 64+1024, 64+6*1024]
|
||||
self.stage = cogvideo_stage if cogvideo_stage is not None else args.cogvideo_stage # 1 or 2
|
||||
self.n_head = args.num_attention_heads
|
||||
self.window_size = window_size if window_size is not None else args.window_size
|
||||
|
||||
frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
|
||||
self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
|
||||
args.additional_seqlen, args.hidden_size
|
||||
))
|
||||
|
||||
if self.stage == 1:
|
||||
self.add_mixin('attention_plus', FullAttentionMixin(
|
||||
num_layers=args.num_layers,
|
||||
hidden_size=args.hidden_size,
|
||||
frame_resolution=frame_resolution,
|
||||
n_head=args.num_attention_heads,
|
||||
frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
|
||||
))
|
||||
else:
|
||||
self.add_mixin('attention_plus', WindowAttentionMixin(
|
||||
num_layers=args.num_layers,
|
||||
hidden_size=args.hidden_size,
|
||||
frame_resolution=frame_resolution,
|
||||
window_size=self.window_size,
|
||||
shift_size=self.window_size//2,
|
||||
n_head=args.num_attention_heads,
|
||||
frame_num=(args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0]),
|
||||
))
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_model_specific_args(cls, parser):
|
||||
group = parser.add_argument_group('VideoSwinLocalModel', 'video swin local model configurations')
|
||||
group.add_argument("--layout", type=str, default='64, 464, 2064')
|
||||
group.add_argument("--window-size", type=int, default=10) # 优先级在直接参数赋值之后
|
||||
group.add_argument("--additional-seqlen", type=int, default=2000)
|
||||
group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2]) # 优先级在直接参数赋值之后
|
||||
return parser
|
||||
|
||||
def disable_untrainable_params(self):
|
||||
pass
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kw_args):
|
||||
if position_ids.shape[-1] > 1:
|
||||
if self.stage == 1:
|
||||
if position_ids[0,-1] >= (512+400):
|
||||
frame_num = position_ids.shape[-1] // 400
|
||||
position_embeddings = torch.cat(
|
||||
(
|
||||
self.transformer.position_embeddings(position_ids[..., :-400*(frame_num-1)]),
|
||||
self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -400*(frame_num-1):]-(512+400))
|
||||
),
|
||||
dim=-2
|
||||
)
|
||||
else:
|
||||
position_embeddings = self.transformer.position_embeddings(position_ids)
|
||||
else:
|
||||
# given 3, interpolate 2
|
||||
position_embeddings = torch.cat(
|
||||
(
|
||||
self.transformer.position_embeddings(position_ids[..., :-800]),
|
||||
self.get_mixin('extra_position_embedding').position_embeddings(position_ids[..., -800:]-(512+400))
|
||||
),
|
||||
dim=-2
|
||||
)
|
||||
else:
|
||||
if position_ids[0, 0] >= (512+400):
|
||||
position_embeddings = self.get_mixin('extra_position_embedding').position_embeddings(position_ids-(512+400))
|
||||
else:
|
||||
position_embeddings = self.transformer.position_embeddings(position_ids)
|
||||
return position_embeddings
|
||||
|
||||
def attention_forward(self, hidden_states, mask, layer_id, mems=None, log_text_attention_weights=0, text_len=0, frame_len=0, counter=0, enforce_no_swin=False, limited_spatial_channel_mem=False, **kw_args):
|
||||
attn_module = self.transformer.layers[layer_id].attention
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
|
||||
# base model qkv
|
||||
if mems is None:
|
||||
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
||||
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
||||
assert (q0.shape[1]-text_len) % frame_len == 0
|
||||
memkv0 = torch.cat((k0, v0), dim=-1)
|
||||
context_text, context_frame_local_text = attention_localframe_and_text_NAR(
|
||||
q0, k0, v0,
|
||||
mask,
|
||||
n_head=attn_module.num_attention_heads_per_partition,
|
||||
text_len=text_len,
|
||||
frame_len=frame_len,
|
||||
frame_num=(q0.shape[1]-text_len)//frame_len,
|
||||
log_text_attention_weights=log_text_attention_weights,
|
||||
stage=self.stage
|
||||
)
|
||||
|
||||
# change: self.swin_attend_to_text默认为True:
|
||||
memkv1_text = self.get_mixin('attention_plus').query_key_value[layer_id](hidden_states[..., :text_len, :])[..., hidden_size:]
|
||||
output_text = attn_module.dense(context_text)
|
||||
|
||||
if (q0.shape[1]-text_len)//frame_len > 0:
|
||||
assert (q0.shape[1]-text_len) % frame_len == 0
|
||||
context_frame_swin, memkv1_frame = self.get_mixin('attention_plus').attention_extra_NAR_inference(
|
||||
hidden_states[:,text_len:], layer_id, memkv_text=memkv1_text, stage=self.stage)
|
||||
if not enforce_no_swin:
|
||||
attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
|
||||
attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
|
||||
output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
|
||||
+torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
|
||||
else:
|
||||
output_frame = attn_module.dense(context_frame_local_text[..., :frame_len, :])
|
||||
output = torch.cat((output_text, output_frame), dim=-2)
|
||||
memkv1 = torch.cat((memkv1_text, memkv1_frame), dim=-2) if memkv1_text is not None else memkv1_frame
|
||||
else:
|
||||
output = output_text
|
||||
memkv1 = memkv1_text
|
||||
kw_args['output_this_layer']['mem_kv'] = (memkv0, memkv1)
|
||||
|
||||
|
||||
else:
|
||||
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
||||
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
||||
new_memkv0 = torch.cat((k0, v0), dim=-1)
|
||||
old_k0, old_v0 = mems[0][layer_id][..., :hidden_size], mems[0][layer_id][..., hidden_size:]
|
||||
|
||||
context_frame_local_text = attention_localframe_and_text_AR(
|
||||
q0,
|
||||
torch.cat((old_k0.expand(k0.shape[0], -1, -1), k0), dim=-2),
|
||||
torch.cat((old_v0.expand(v0.shape[0], -1, -1), v0), dim=-2),
|
||||
n_head=attn_module.num_attention_heads_per_partition,
|
||||
text_len=text_len,
|
||||
frame_len=frame_len,
|
||||
frame_num=None,
|
||||
log_text_attention_weights=log_text_attention_weights,
|
||||
layer_id=layer_id,
|
||||
limited_spatial_channel_mem=limited_spatial_channel_mem,
|
||||
)
|
||||
|
||||
old_memkv1 = mems[1][layer_id] if mems[1] is not None else None
|
||||
|
||||
context_frame_swin, new_memkv1 = self.get_mixin('attention_plus').attention_extra_AR_inference(hidden_states,
|
||||
old_memkv1[..., text_len:, :] if old_memkv1.shape[-2]>text_len else None,
|
||||
counter-text_len,
|
||||
layer_id,
|
||||
memkv_text=old_memkv1[..., :text_len, :],
|
||||
log_text_attention_weights=log_text_attention_weights)
|
||||
if not enforce_no_swin:
|
||||
attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
|
||||
attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
|
||||
output = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
|
||||
+torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
|
||||
else:
|
||||
output = attn_module.dense(context_frame_local_text)
|
||||
|
||||
kw_args['output_this_layer']['mem_kv'] = (new_memkv0, new_memkv1)
|
||||
|
||||
return output
|
@ -1,543 +0,0 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@File : cogvideo_model.py
|
||||
@Time : 2022/07/11 16:12:05
|
||||
@Author : Wenyi Hong
|
||||
@Version : 1.0
|
||||
@Contact : hwy22@mails.tsinghua.edu.cn
|
||||
'''
|
||||
|
||||
# here put the import lib
|
||||
|
||||
import torch
|
||||
from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
|
||||
|
||||
from SwissArmyTransformer.mpu.utils import split_tensor_along_last_dim
|
||||
from SwissArmyTransformer.model.transformer import unscaled_init_method
|
||||
from SwissArmyTransformer.mpu import ColumnParallelLinear, RowParallelLinear
|
||||
import torch.nn.functional as F
|
||||
from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
|
||||
import math
|
||||
|
||||
class PositionEmbeddingMixin(BaseMixin):
|
||||
def __init__(self, additional_sequence_length, hidden_size,
|
||||
init_method_std=0.02, reinit_slice=slice(512, 912),
|
||||
):
|
||||
super(PositionEmbeddingMixin, self).__init__()
|
||||
self.reinit_slice = reinit_slice
|
||||
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
|
||||
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
|
||||
old_len, hidden_size = old_weights.shape
|
||||
assert hidden_size == self.position_embeddings.weight.shape[-1]
|
||||
self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, framenum, H, W, C)
|
||||
window_size (int): window size
|
||||
Returns:
|
||||
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
||||
"""
|
||||
B, framenum, H, W, C = x.shape
|
||||
x = x.view(B, framenum, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, framenum, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, frame_num, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, frame_num, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
framenum = windows.shape[1]
|
||||
x = windows.view(B, H // window_size, W // window_size, framenum, window_size, window_size, -1)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, framenum, H, W, -1)
|
||||
return x
|
||||
|
||||
class WindowAttentionMixin(BaseMixin):
|
||||
def __init__(self, num_layers,
|
||||
hidden_size,
|
||||
frame_resolution,
|
||||
window_size,
|
||||
shift_size,
|
||||
n_head,
|
||||
frame_num,
|
||||
init_method=unscaled_init_method(0.02),
|
||||
output_layer_init_method=unscaled_init_method(0.02),
|
||||
):
|
||||
super(WindowAttentionMixin, self).__init__()
|
||||
self.num_layers = num_layers # replace attention in the LAST n layers
|
||||
self.query_key_value = torch.nn.ModuleList(
|
||||
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
||||
gather_output=False,init_method=init_method)
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
self.dense = torch.nn.ModuleList(
|
||||
[RowParallelLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
init_method=output_layer_init_method,
|
||||
bias=True,
|
||||
module=self,
|
||||
name="dense",
|
||||
)
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
|
||||
self.n_head = n_head
|
||||
self.window_size = window_size
|
||||
self.frame_resolution = frame_resolution
|
||||
self.frame_len = frame_resolution * frame_resolution
|
||||
assert frame_resolution % window_size == 0
|
||||
assert 0 < shift_size < window_size
|
||||
nW = (self.frame_resolution // self.window_size) ** 2
|
||||
ws_squre = self.window_size * self.window_size
|
||||
|
||||
# odd non-shift, even shift
|
||||
img_mask = torch.zeros((1, 1, frame_resolution, frame_resolution, 1))
|
||||
h_slices = (slice(0, -shift_size),
|
||||
slice(-shift_size, None))
|
||||
w_slices = (slice(0, -shift_size),
|
||||
slice(-shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, :, h, w, :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, 1, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
sub_attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) #[nW, self.window_size * self.window_size, self.window_size * self.window_size]
|
||||
sub_attn_mask = sub_attn_mask.masked_fill(sub_attn_mask != 0, float(0.0)).masked_fill(sub_attn_mask == 0, float(1.00))
|
||||
attn_mask = sub_attn_mask.repeat(1, frame_num, frame_num)
|
||||
|
||||
self.attn_mask_sequential = attn_mask.clone().tril()
|
||||
self.causal_mask_sequential = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num).tril()
|
||||
|
||||
self.causal_mask_interp = torch.ones(1, ws_squre*frame_num, ws_squre*frame_num)
|
||||
self.attn_mask_interp = attn_mask.clone()
|
||||
|
||||
# bi-dir
|
||||
for bi_idx in range(0, frame_num, 2):
|
||||
for uni_idx in range(1, frame_num, 2):
|
||||
self.attn_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
|
||||
self.causal_mask_interp[:, bi_idx*ws_squre:(bi_idx+1)*ws_squre, uni_idx*ws_squre:(uni_idx+1)*ws_squre] = 0
|
||||
# uni-dir
|
||||
for uni_idx in range(1, frame_num, 2):
|
||||
self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
|
||||
self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx:ws_squre*(uni_idx+1)].tril_()
|
||||
for uni_idx2 in range(uni_idx+2, frame_num, 2):
|
||||
self.attn_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
|
||||
self.causal_mask_interp[:, ws_squre*uni_idx:ws_squre*(uni_idx+1), ws_squre*uni_idx2:ws_squre*(uni_idx2+1)] = 0
|
||||
|
||||
# expand dim
|
||||
self.attn_mask_sequential = self.attn_mask_sequential[None, None, :, None]
|
||||
self.attn_mask_interp = self.attn_mask_interp[None, None, :, None]
|
||||
self.causal_mask_sequential = self.causal_mask_sequential[None, None, :, None]
|
||||
self.causal_mask_interp = self.causal_mask_interp[None, None, :, None]
|
||||
|
||||
self.shift_sizes = [0, shift_size]
|
||||
# self.register_buffer("attn_mask", attn_mask)
|
||||
# self.register_buffer("causal_mask", causal_mask)
|
||||
self.mask_initialized = False
|
||||
|
||||
self.attn_distribution = torch.nn.ParameterList([
|
||||
torch.nn.Parameter(torch.zeros(hidden_size))
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def reinit(self, *pre_mixins):
|
||||
start_layer = len(self.transformer.layers) - self.num_layers
|
||||
assert start_layer >= 0
|
||||
for layer_id in range(self.num_layers):
|
||||
old_attention = self.transformer.layers[start_layer + layer_id].attention
|
||||
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
|
||||
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
|
||||
|
||||
def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
|
||||
text_attn_mask=None, mode_sequential=True):
|
||||
# pb relax
|
||||
swin_pb_relax = True
|
||||
alpha = 16
|
||||
|
||||
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
||||
if not self.mask_initialized:
|
||||
self.attn_mask_sequential = self.attn_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
||||
self.causal_mask_sequential = self.causal_mask_sequential.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
||||
self.attn_mask_interp = self.attn_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
||||
self.causal_mask_interp = self.causal_mask_interp.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
||||
self.mask_initialized = True
|
||||
b0, s1, h0 = frame_hidden_state.shape
|
||||
h = h0 // self.n_head
|
||||
frame_len = self.frame_resolution * self.frame_resolution
|
||||
frame_num = s1 // frame_len
|
||||
assert frame_num*frame_len == s1
|
||||
wind_square = self.window_size * self.window_size
|
||||
nW = frame_len // wind_square
|
||||
bswin = b0 * nW
|
||||
|
||||
causal_mask = self.causal_mask_sequential if mode_sequential else self.causal_mask_interp
|
||||
attn_mask = self.attn_mask_sequential if mode_sequential else self.attn_mask_interp
|
||||
if text_hidden_state is not None:
|
||||
s0 = text_hidden_state.shape[1]
|
||||
qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
|
||||
q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
|
||||
|
||||
# shift
|
||||
frame_hidden_state = frame_hidden_state.reshape(b0, frame_num, self.frame_resolution, self.frame_resolution, h0)
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
frame_hidden_state = torch.roll(frame_hidden_state, shifts=(-self.shift_sizes[layer_id%2], -self.shift_sizes[layer_id%2]), dims=(2,3))
|
||||
# window partition
|
||||
frame_hidden_state = window_partition(frame_hidden_state, self.window_size).reshape(bswin, frame_num*wind_square, h0)
|
||||
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(bswin, frame_num*wind_square, 3, self.n_head, h)\
|
||||
.permute(2, 0, 3, 1, 4) #[3, bswin, n_head, frame_num*wind_size*wind_size, h]
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
# pb-relax
|
||||
if swin_pb_relax:
|
||||
attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
|
||||
else:
|
||||
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
||||
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
# attn = attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square) + self.attn_mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), attn_mask)\
|
||||
- 10000.0 * (1.0 - attn_mask)
|
||||
attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
|
||||
else:
|
||||
attn = torch.mul(attn.view(bswin // nW, nW, self.n_head, frame_num*wind_square, frame_num*wind_square), causal_mask)\
|
||||
- 10000.0 * (1.0 - causal_mask)
|
||||
attn = attn.view(bswin, self.n_head, frame_num*wind_square, frame_num*wind_square)
|
||||
if swin_pb_relax:
|
||||
swin_pb_relax_const = torch.max(attn.reshape(bswin, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
|
||||
attn = (attn - swin_pb_relax_const)*alpha
|
||||
|
||||
if text_hidden_state is None:
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
||||
else:
|
||||
assert text_attn_mask is not None
|
||||
text_attn_mask = text_attn_mask.unsqueeze(2).unsqueeze(2)
|
||||
# pb-relax
|
||||
if swin_pb_relax:
|
||||
attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / (math.sqrt(h)*alpha), k_text.unsqueeze(1).transpose(-1, -2))
|
||||
attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, -1, self.n_head, 1, 1))*alpha
|
||||
else:
|
||||
attn_frame2text = torch.matmul(q.reshape(b0, -1, self.n_head, frame_num*wind_square, h) / math.sqrt(h), k_text.unsqueeze(1).transpose(-1, -2))
|
||||
|
||||
attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
|
||||
attn_frame2text = attn_frame2text.reshape(bswin, self.n_head, frame_num*wind_square, s0)
|
||||
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
|
||||
context_swin = (torch.matmul(attn[..., :-s0], v) +
|
||||
torch.matmul(attn[..., -s0:].reshape(b0, -1, self.n_head,frame_num*wind_square, s0), v_text.unsqueeze(1))\
|
||||
.reshape(bswin, self.n_head, frame_num*wind_square, h))\
|
||||
.permute(0, 2, 1, 3).reshape(bswin, frame_num, self.window_size, self.window_size, h0)
|
||||
|
||||
context_swin = window_reverse(context_swin, self.window_size, self.frame_resolution, self.frame_resolution)
|
||||
# reverse cycle shift
|
||||
if self.shift_sizes[layer_id%2] > 0:
|
||||
context_swin = torch.roll(context_swin, shifts=(self.shift_sizes[layer_id%2], self.shift_sizes[layer_id%2]), dims=(2,3))
|
||||
context_swin = context_swin.reshape(b0, s1, h0)
|
||||
|
||||
return context_swin
|
||||
|
||||
|
||||
class FullAttentionMixin(BaseMixin):
|
||||
def __init__(self, num_layers,
|
||||
hidden_size,
|
||||
frame_resolution,
|
||||
n_head,
|
||||
frame_num,
|
||||
init_method=unscaled_init_method(0.02),
|
||||
output_layer_init_method=unscaled_init_method(0.02),
|
||||
):
|
||||
super(FullAttentionMixin, self).__init__()
|
||||
self.num_layers = num_layers # replace attention in the LAST n layers
|
||||
self.query_key_value = torch.nn.ModuleList(
|
||||
[ColumnParallelLinear(hidden_size, 3*hidden_size,stride=3,
|
||||
gather_output=False,init_method=init_method)
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
self.dense = torch.nn.ModuleList(
|
||||
[RowParallelLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
init_method=output_layer_init_method,
|
||||
bias=True,
|
||||
module=self,
|
||||
name="dense",)
|
||||
for layer_id in range(num_layers)
|
||||
])
|
||||
|
||||
self.n_head = n_head
|
||||
self.frame_resolution = frame_resolution
|
||||
self.frame_len = frame_resolution * frame_resolution
|
||||
self.causal_mask = torch.ones(1, 1, self.frame_len*frame_num, self.frame_len*frame_num).tril()
|
||||
|
||||
self.mask_initialized = False
|
||||
|
||||
self.attn_distribution = torch.nn.ParameterList([
|
||||
torch.nn.Parameter(torch.zeros(hidden_size))
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def reinit(self, *pre_mixins):
|
||||
start_layer = len(self.transformer.layers) - self.num_layers
|
||||
assert start_layer >= 0
|
||||
for layer_id in range(self.num_layers):
|
||||
base_attention = self.transformer.layers[start_layer + layer_id].attention
|
||||
self.query_key_value[layer_id].weight.data.copy_(base_attention.query_key_value.weight.data)
|
||||
self.query_key_value[layer_id].bias.data.copy_(base_attention.query_key_value.bias.data)
|
||||
|
||||
def attention_extra(self, frame_hidden_state, layer_id, attn_dropout, text_hidden_state=None,
|
||||
text_attn_mask=None, mode_sequential=False):
|
||||
# pb relax
|
||||
# frame_hidden_state [batchsize, frame_num*frame_size, n_head*hiddensize_perhead]
|
||||
assert mode_sequential == True # only
|
||||
swin_pb_relax = True
|
||||
alpha = 16
|
||||
|
||||
if not self.mask_initialized:
|
||||
self.causal_mask = self.causal_mask.to(device=frame_hidden_state.device, dtype=frame_hidden_state.dtype)
|
||||
self.mask_initialized = True
|
||||
b0, s1, h0 = frame_hidden_state.shape
|
||||
h = h0 // self.n_head
|
||||
frame_len = self.frame_resolution * self.frame_resolution
|
||||
frame_num = s1 // frame_len
|
||||
assert frame_num*frame_len == s1
|
||||
|
||||
qkv = self.query_key_value[layer_id](frame_hidden_state).reshape(b0, s1, 3, self.n_head, h)\
|
||||
.permute(2, 0, 3, 1, 4) #[3, b0, n_head, s1, h]
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
# frames-to-frames
|
||||
if swin_pb_relax:
|
||||
attn = torch.matmul(q / (math.sqrt(h)*alpha), k.transpose(-1, -2))
|
||||
else:
|
||||
attn = torch.matmul(q / math.sqrt(h), k.transpose(-1, -2))
|
||||
attn = torch.mul(attn, self.causal_mask) - 10000.0 * (1.0 - self.causal_mask)
|
||||
if swin_pb_relax:
|
||||
swin_pb_relax_const = torch.max(attn.reshape(b0, self.n_head, -1), dim=-1, keepdim=True)[0].detach().unsqueeze(-1)
|
||||
attn = (attn - swin_pb_relax_const)*alpha
|
||||
|
||||
if text_hidden_state is None:
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
context_swin = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
||||
else:
|
||||
# frame-to-text
|
||||
assert text_attn_mask is not None
|
||||
s0 = text_hidden_state.shape[1]
|
||||
qkv_text = self.query_key_value[layer_id](text_hidden_state).reshape(b0, s0, 3, self.n_head, h).permute(2, 0, 3, 1, 4) #[3, b0, n_head, s0, h]
|
||||
q_text, k_text, v_text = qkv_text[0], qkv_text[1], qkv_text[2]
|
||||
text_attn_mask = text_attn_mask.unsqueeze(2)
|
||||
if swin_pb_relax:
|
||||
attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / (math.sqrt(h)*alpha), k_text.transpose(-1, -2))
|
||||
attn_frame2text = (attn_frame2text-swin_pb_relax_const.reshape(b0, self.n_head, 1, 1))*alpha
|
||||
else:
|
||||
attn_frame2text = torch.matmul(q.reshape(b0, self.n_head, s1, h) / math.sqrt(h), k_text.transpose(-1, -2))
|
||||
attn_frame2text = torch.mul(text_attn_mask, attn_frame2text) - 10000.0 * (1.0 - text_attn_mask)
|
||||
attn_frame2text = attn_frame2text.reshape(b0, self.n_head, s1, s0)
|
||||
|
||||
attn = torch.cat((attn, attn_frame2text), dim=-1)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
if attn_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attn = attn_dropout(attn)
|
||||
|
||||
context_frame = (torch.matmul(attn[..., :-s0], v) +
|
||||
torch.matmul(attn[..., -s0:].reshape(b0, self.n_head,s1, s0), v_text))\
|
||||
.permute(0, 2, 1, 3).reshape(b0, s1, h0)
|
||||
|
||||
return context_frame
|
||||
|
||||
|
||||
def attention_localframe_and_text(q0, k0, v0, attention_mask_totxt, attention_mask_local,
|
||||
n_head, text_len, frame_len, frame_num, attention_dropout=None, layer_id=0, **kwargs):
|
||||
b, s0, h0 = q0.shape
|
||||
s1 = s0 - text_len
|
||||
h = h0 // n_head
|
||||
assert q0.shape[1] == v0.shape[1] == k0.shape[1] == text_len+frame_len*frame_num
|
||||
# attention_mask_totxt [b, 1, 1, text_len]
|
||||
# attention_mask_local [1, 1, frame_num, frame_len, frame_len]
|
||||
# attention_mask: [1, 1, text_len+frame_len, text_len+frame_len]
|
||||
|
||||
q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
||||
v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
||||
k0 = k0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
|
||||
k0T = k0.transpose(-1, -2)
|
||||
|
||||
# score: any2text
|
||||
score_any2text = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T[..., :text_len])
|
||||
score_any2text_part1 = torch.mul(score_any2text[..., :text_len, :], attention_mask_totxt) \
|
||||
- 10000.0 * (1.0 - attention_mask_totxt)
|
||||
score_any2text_part2 = torch.mul(score_any2text[..., text_len:, :], attention_mask_totxt) - \
|
||||
10000.0 * (1.0 - attention_mask_totxt)
|
||||
|
||||
# score: frame local
|
||||
q0_frame = q0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
||||
v0_frame = v0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h)
|
||||
k0T_frame = k0[:, :, text_len:].reshape(b, n_head, frame_num, frame_len, h).transpose(-1, -2)
|
||||
score_frame_local0 = torch.matmul(q0_frame / math.sqrt(q0_frame.shape[-1]), k0T_frame)
|
||||
score_frame_local0 = torch.mul(score_frame_local0, attention_mask_local) \
|
||||
- 10000.0 * (1.0 - attention_mask_local)
|
||||
|
||||
# context for frame
|
||||
score_frame_all = torch.cat((score_any2text_part2,
|
||||
score_frame_local0.view(b, n_head, s1, frame_len)), dim=-1)
|
||||
attention_probs_frame = F.softmax(score_frame_all, dim=-1)
|
||||
|
||||
if attention_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attention_probs_frame = attention_dropout(attention_probs_frame)
|
||||
|
||||
context_frame2text = torch.matmul(attention_probs_frame[..., :text_len], v0[..., :text_len, :]) # [b, n_head, s1, h]
|
||||
context_frame_local0 = torch.matmul(attention_probs_frame[..., text_len:text_len+frame_len].\
|
||||
view(b, n_head, frame_num, frame_len, frame_len), v0_frame).view(b, n_head, s1, h)
|
||||
context_frame = (context_frame2text + context_frame_local0).transpose(1, 2).reshape(b, s1, h0)
|
||||
|
||||
# context for text
|
||||
attention_probs_text = F.softmax(score_any2text_part1, dim=-1)
|
||||
if attention_dropout is not None:
|
||||
with get_cuda_rng_tracker().fork():
|
||||
attention_probs_text = attention_dropout(attention_probs_text)
|
||||
context_text2text = torch.matmul(attention_probs_text, v0[..., :text_len, :])
|
||||
context_text2text = context_text2text.transpose(1, 2).reshape(b, text_len, h0)
|
||||
|
||||
return context_text2text, context_frame
|
||||
|
||||
|
||||
class CogVideoModel(BaseModel):
|
||||
def __init__(self, args, transformer=None, parallel_output=True):
|
||||
super().__init__(args, transformer=transformer, parallel_output=parallel_output)
|
||||
self.stage = args.cogvideo_stage # 1 or 2
|
||||
self.mode_sequential = True if self.stage==1 else False
|
||||
self.layout = args.layout # [64, 64+400, 64+5*400]
|
||||
self.n_head = args.num_attention_heads
|
||||
frame_resolution = int(math.sqrt(self.layout[1]-self.layout[0]))
|
||||
frame_num = (args.layout[2]-args.layout[0])//(args.layout[1]-args.layout[0])
|
||||
frame_len = self.layout[1]-self.layout[0]
|
||||
|
||||
self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
|
||||
args.additional_seqlen, args.hidden_size
|
||||
))
|
||||
|
||||
if args.window_size == -1:
|
||||
# full attention
|
||||
assert self.stage == 1
|
||||
self.add_mixin('attention_plus', FullAttentionMixin(
|
||||
num_layers=args.num_layers,
|
||||
hidden_size=args.hidden_size,
|
||||
frame_resolution=frame_resolution,
|
||||
n_head=args.num_attention_heads,
|
||||
frame_num=frame_num,
|
||||
))
|
||||
else:
|
||||
self.add_mixin('attention_plus', WindowAttentionMixin(
|
||||
num_layers=args.num_layers,
|
||||
hidden_size=args.hidden_size,
|
||||
frame_resolution=frame_resolution,
|
||||
window_size=args.window_size,
|
||||
shift_size=args.window_size//2,
|
||||
n_head=args.num_attention_heads,
|
||||
frame_num=frame_num,
|
||||
))
|
||||
# attention_mask_local
|
||||
self.attention_mask_local_sequential = torch.ones(1, 1, frame_num, frame_len, frame_len).tril().unsqueeze(0)
|
||||
self.attention_mask_local_interp = torch.ones(1, 1, frame_num, frame_len, frame_len)
|
||||
|
||||
for idx in range(1, frame_num, 2):
|
||||
self.attention_mask_local_interp[:, :, idx:idx+1].tril_()
|
||||
self.attention_mask_local_interp = self.attention_mask_local_interp.unsqueeze(0)
|
||||
self.mask_initialized = False
|
||||
|
||||
@classmethod
|
||||
def add_model_specific_args(cls, parser):
|
||||
group = parser.add_argument_group('CogVideoModel', 'CogVideo model configurations')
|
||||
group.add_argument("--layout", type=str, default='64, 464, 2064', help='text_len, textlen+frame_len, textlen+frame_len*frame_num')
|
||||
group.add_argument("--window-size", type=int, default=10, help="swin attention's window size in temperal channel, -1 represents full attention")
|
||||
group.add_argument("--additional-seqlen", type=int, default=2000)
|
||||
group.add_argument("--cogvideo-stage", type=int, default=1, choices=[1,2])
|
||||
return parser
|
||||
|
||||
def disable_untrainable_params(self):
|
||||
self.transformer.requires_grad_(False)
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kw_args):
|
||||
position = position_ids[..., :(64+400)]
|
||||
position_plus = position_ids[..., (64+400):]
|
||||
position_embeddings = torch.cat(
|
||||
(
|
||||
self.transformer.position_embeddings(position),
|
||||
self.get_mixin('extra_position_embedding').position_embeddings(position_plus-(512+400))
|
||||
),
|
||||
dim=-2
|
||||
)
|
||||
return position_embeddings
|
||||
|
||||
def attention_forward(self, hidden_states, mask, layer_id, **kw_args):
|
||||
# mask.shape=[bs, 1, 1, 64]
|
||||
if not self.mask_initialized:
|
||||
self.attention_mask_local_sequential = self.attention_mask_local_sequential.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
self.attention_mask_local_interp = self.attention_mask_local_interp.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
self.mask_initialized = True
|
||||
|
||||
attn_module = self.transformer.layers[layer_id].attention
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
bs = hidden_states.shape[0]
|
||||
|
||||
# base model qkv
|
||||
mixed_raw_layer = attn_module.query_key_value(hidden_states)
|
||||
q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer, 3)
|
||||
dropout_fn = self.transformer.layers[layer_id].attention.attention_dropout if self.training else None
|
||||
|
||||
attention_mask_local = self.attention_mask_local_sequential if self.mode_sequential else self.attention_mask_local_interp
|
||||
context_text, context_frame_local_text = attention_localframe_and_text(
|
||||
q0, k0, v0,
|
||||
attention_mask_totxt=mask,
|
||||
attention_mask_local=attention_mask_local,
|
||||
n_head=attn_module.num_attention_heads_per_partition,
|
||||
text_len=self.layout[0],
|
||||
frame_len=self.layout[1]-self.layout[0],
|
||||
frame_num=(self.layout[2]-self.layout[0])//(self.layout[1]-self.layout[0]),
|
||||
attention_dropout=dropout_fn,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
context_frame_swin = self.get_mixin('attention_plus').attention_extra(
|
||||
hidden_states[:, self.layout[0]:], layer_id, dropout_fn,
|
||||
text_hidden_state=hidden_states[:, :self.layout[0]],
|
||||
text_attn_mask=mask[..., 0, :],
|
||||
mode_sequential=self.mode_sequential)
|
||||
|
||||
attn_distrib = torch.sigmoid(self.get_mixin('attention_plus').attn_distribution[layer_id])
|
||||
attn_distrib = attn_distrib.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
output_text = attn_module.dense(context_text)
|
||||
output_frame = torch.mul(attn_module.dense(context_frame_local_text), attn_distrib)\
|
||||
+torch.mul(self.get_mixin('attention_plus').dense[layer_id](context_frame_swin), 1-attn_distrib)
|
||||
output = torch.cat((output_text, output_frame), dim=-2)
|
||||
|
||||
return output
|
Binary file not shown.
@ -1,184 +0,0 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
'''
|
||||
@File : pretrain_cogvideo.py
|
||||
@Time : 2021/10/06 00:58:32
|
||||
@Author : Wenyi Hong
|
||||
@Contact : hwy22@mails.tsinghua.edu.cn
|
||||
'''
|
||||
|
||||
# here put the import lib
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from icetk import icetk as tokenizer
|
||||
tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
|
||||
|
||||
from models.cogvideo_model import CogVideoModel
|
||||
from SwissArmyTransformer import mpu, get_args
|
||||
from SwissArmyTransformer.training.deepspeed_training import training_main
|
||||
from SwissArmyTransformer.data_utils import BinaryDataset
|
||||
|
||||
def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None):
|
||||
# Extract batch size and sequence length.
|
||||
batch_size, seq_length = data.size()
|
||||
assert attention_mask_totxt is not None
|
||||
layout = args.layout
|
||||
assert seq_length == layout[-1]
|
||||
n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long()
|
||||
frame_len = layout[1]-layout[0]
|
||||
position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long,
|
||||
device=data.device)
|
||||
for i in range(batch_size):
|
||||
torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]],
|
||||
dtype=torch.long, device=data.device)
|
||||
torch.arange(512, 512+layout[2]-layout[0],
|
||||
out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device)
|
||||
return position_ids
|
||||
|
||||
|
||||
def get_batch(data_iterator, args, timers):
|
||||
# Items and their type.
|
||||
keys = ['text', 'loss_mask', 'attention_mask_totxt']
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
timers('data loader').start()
|
||||
if data_iterator is not None:
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
timers('data loader').stop()
|
||||
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
# Unpack.
|
||||
tokens_ = data_b['text'].long()
|
||||
loss_mask = data_b['loss_mask'].float()
|
||||
attention_mask_totxt = data_b['attention_mask_totxt'].float()
|
||||
|
||||
labels = tokens_[:, 1:].clone().contiguous()
|
||||
loss_mask = loss_mask[:, 1:].contiguous()
|
||||
tokens = tokens_[:, :-1].clone().contiguous()
|
||||
|
||||
for idx in range(args.layout[0], args.layout[2], 400):
|
||||
tokens[:, idx] = tokenizer['<start_of_image>']
|
||||
# Get the masks and postition ids.
|
||||
position_ids = get_masks_and_position_ids_video(
|
||||
tokens,
|
||||
attention_mask_totxt=attention_mask_totxt,
|
||||
args=args
|
||||
)
|
||||
attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1)
|
||||
# Convert
|
||||
if args.fp16:
|
||||
attention_mask_totxt = attention_mask_totxt.half()
|
||||
return tokens, labels, loss_mask, attention_mask_totxt, position_ids
|
||||
|
||||
|
||||
def forward_step(data_iterator, model, args, timers):
|
||||
"""Forward step."""
|
||||
|
||||
# Get the batch.
|
||||
timers('batch generator').start()
|
||||
tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch(
|
||||
data_iterator, args, timers)
|
||||
timers('batch generator').stop()
|
||||
|
||||
# Forward model.
|
||||
logits, *mems = model(tokens, position_ids, attention_mask_totxt)
|
||||
# ======= hyper params =======#
|
||||
perframe_len = 400
|
||||
text_len=64
|
||||
frame_num = 5
|
||||
logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous()
|
||||
losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:])
|
||||
# scaling loss mask
|
||||
loss_mask = loss_mask[:, text_len:].reshape(-1)
|
||||
|
||||
losses_1d = losses.reshape(-1) * loss_mask
|
||||
loss = torch.sum(losses_1d) / loss_mask.sum()
|
||||
# ===================== Log partial losses ======================== #
|
||||
log_loss_dict = {}
|
||||
bs = losses.shape[0]
|
||||
|
||||
if args.cogvideo_stage == 1:
|
||||
for i in range(frame_num):
|
||||
log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
|
||||
else:
|
||||
for i in range(1, frame_num-1):
|
||||
log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1)
|
||||
|
||||
# ===================== END OF BLOCK ======================= #
|
||||
return loss, log_loss_dict
|
||||
|
||||
|
||||
def create_dataset_function(path, args):
|
||||
dataset_layout = [64, 464, 2064]
|
||||
input_layout = [64, 464, 2064]
|
||||
# frame_num = 6
|
||||
# frame_interval = 2 # DEBUG!!!
|
||||
def process_fn(row):
|
||||
row = row.astype(np.int64)
|
||||
text = row[:dataset_layout[0]]
|
||||
frames = row[dataset_layout[0]:]
|
||||
|
||||
if text[0] == tokenizer['<pad>']:
|
||||
text = text[1:] # due to our way of data processing
|
||||
if args.cogvideo_stage == 1:
|
||||
text, loss_mask, frames = make_text_video_generation(text, frames)
|
||||
else:
|
||||
text, loss_mask, frames = mask_video_frame_interpolation(text, frames)
|
||||
|
||||
n_pad = input_layout[0] - len(text)
|
||||
parts = [
|
||||
np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64),
|
||||
text,
|
||||
np.array([tokenizer['<start_of_image>']], dtype=np.int64),
|
||||
frames,
|
||||
]
|
||||
ret = np.concatenate(parts, axis=0)
|
||||
|
||||
attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad))
|
||||
return {'text': ret,
|
||||
'loss_mask': loss_mask,
|
||||
'attention_mask_totxt': attention_mask_totxt,
|
||||
}
|
||||
return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1])
|
||||
|
||||
def make_text_video_generation(text, frames):
|
||||
input_layout = [64, 464, 2064]
|
||||
text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ...
|
||||
loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位
|
||||
return text, loss_mask, frames
|
||||
|
||||
def mask_video_frame_interpolation(text, frames):
|
||||
input_layout = [64, 464, 2064]
|
||||
frame_len = input_layout[1]-input_layout[0]
|
||||
# text format: <pad> 1.0秒 <n> {text} <pad> <pad>
|
||||
text = text[text!= tokenizer['<pad>']][:input_layout[0]]
|
||||
loss_mask = np.array([0] * (input_layout[1]+1)
|
||||
+ [1] * (input_layout[1]-input_layout[0])
|
||||
+ [0] * (input_layout[1]-input_layout[0])
|
||||
+ [1] * (input_layout[1]-input_layout[0])
|
||||
+ [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位
|
||||
|
||||
return text, loss_mask, frames
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
py_parser = argparse.ArgumentParser(add_help=False)
|
||||
py_parser.add_argument('--txt-loss-scale', type=float, default=1)
|
||||
CogVideoModel.add_model_specific_args(py_parser)
|
||||
|
||||
known, args_list = py_parser.parse_known_args()
|
||||
|
||||
args = get_args(args_list)
|
||||
args = argparse.Namespace(**vars(args), **vars(known))
|
||||
|
||||
args.layout = [int(x) for x in args.layout.split(',')]
|
||||
|
||||
training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)
|
27
pyproject.toml
Normal file
27
pyproject.toml
Normal file
@ -0,0 +1,27 @@
|
||||
[tool.ruff]
|
||||
line-length = 119
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Never enforce `E501` (line length violations).
|
||||
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
|
||||
# Ignore import violations in all `__init__.py` files.
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
|
||||
[tool.ruff.format]
|
||||
# Like Black, use double quotes for strings.
|
||||
quote-style = "double"
|
||||
|
||||
# Like Black, indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Like Black, respect magic trailing commas.
|
||||
skip-magic-trailing-comma = false
|
||||
|
||||
# Like Black, automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
@ -1,4 +1,8 @@
|
||||
SwissArmyTransformer==0.2.9
|
||||
icetk
|
||||
gifmaker
|
||||
torchvision
|
||||
git+https://github.com/huggingface/diffusers.git@878f609aa5ce4a78fea0f048726889debde1d7e8#egg=diffusers
|
||||
torch>=2.4.0
|
||||
torchvision>=0.19.0
|
||||
streamlit>=1.37.0
|
||||
opencv-python>=4.10
|
||||
imageio-ffmpeg>=0.5.1
|
||||
openai>=1.38.0
|
||||
transformers>=4.43.3
|
7
resources/WECHAT.md
Normal file
7
resources/WECHAT.md
Normal file
@ -0,0 +1,7 @@
|
||||
<div align="center">
|
||||
<img src=wechat.jpg width="60%"/>
|
||||
|
||||
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
|
||||
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
|
||||
</div>
|
||||
|
50
resources/contribute.md
Normal file
50
resources/contribute.md
Normal file
@ -0,0 +1,50 @@
|
||||
# Contribution Guide
|
||||
|
||||
There may still be many incomplete aspects in this project.
|
||||
|
||||
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
|
||||
and are willing to submit a PR and share it with the community, upon review, we
|
||||
will acknowledge your contribution on the project homepage.
|
||||
|
||||
## Model Algorithms
|
||||
|
||||
- Support for model quantization inference (Int4, Int8, etc. quantization engineering)
|
||||
- Support for multi-card inference / model inference concurrency engineering
|
||||
- Support for non-CUDA architecture inference devices
|
||||
|
||||
## Model Engineering / Secondary Development
|
||||
|
||||
- Model fine-tuning examples / best prompt practices
|
||||
- Video super-resolution/frame interpolation for enhancing video generation quality.
|
||||
- Any peripheral tools for the model
|
||||
- Any minimal complete open-source projects using the CogVideoX open-source model
|
||||
|
||||
## Code Standards
|
||||
|
||||
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
|
||||
style. You can organize the code according to the following specifications:
|
||||
|
||||
1. Install the `ruff` tool
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
Then, run the `ruff` tool
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
|
||||
|
||||
```shell
|
||||
ruff format tools sat inference
|
||||
```
|
||||
|
||||
Once your code meets the standard, there should be no errors.
|
||||
|
||||
## Naming Conventions
|
||||
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
|
||||
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.
|
||||
|
45
resources/contribute_zh.md
Normal file
45
resources/contribute_zh.md
Normal file
@ -0,0 +1,45 @@
|
||||
# 贡献指南
|
||||
|
||||
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。
|
||||
|
||||
## 模型算法
|
||||
|
||||
- 模型量化推理支持 (Int4,Int8等量化工程)
|
||||
- 模型多卡推理支持 / 模型推理并发工程
|
||||
- 非 CUDA 架构 推理设备支持
|
||||
|
||||
## 模型工程 / 模型二次开发
|
||||
|
||||
- 模型微调示例 / 最佳提示词实践
|
||||
- 视频超分/插帧,用于美化视频生成效果。
|
||||
- 任何模型周边工具
|
||||
- 任何使用CogVideoX开源模型制作的最小完整开源项目
|
||||
|
||||
## 代码规范
|
||||
|
||||
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
|
||||
|
||||
1. 安装`ruff`工具
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
接着,运行`ruff`工具
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
检查代码风格,如果有问题,您可以通过`ruff formate`命令自动修复。
|
||||
|
||||
```shell
|
||||
ruff formate tools sat inference
|
||||
```
|
||||
|
||||
如果您的代码符合规范,应该不会出现任何的错误。
|
||||
|
||||
## 命名规范
|
||||
|
||||
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
|
||||
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
|
298
resources/logo.svg
Normal file
298
resources/logo.svg
Normal file
@ -0,0 +1,298 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 28.2.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<svg version="1.1" id="图层_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 841.89 368.6" style="enable-background:new 0 0 841.89 368.6;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill:#1F08A3;}
|
||||
.st1{fill:url(#SVGID_1_);}
|
||||
.st2{fill:url(#SVGID_00000072969730447947445580000008550966321517582978_);}
|
||||
.st3{fill:#191935;}
|
||||
.st4{fill:url(#SVGID_00000134944507322615418430000016680692556048730035_);}
|
||||
.st5{fill:url(#SVGID_00000079475197965678598910000014857864397887059643_);}
|
||||
.st6{fill:url(#SVGID_00000011017870202623612780000003647410284342320817_);}
|
||||
.st7{fill:url(#SVGID_00000158717432109861805260000005726541443920151984_);}
|
||||
.st8{fill:url(#SVGID_00000142876900428461769000000008069403543541257898_);}
|
||||
.st9{fill:url(#SVGID_00000043431278852236617580000017143239672604928432_);}
|
||||
.st10{fill:url(#SVGID_00000117640229439763589850000006678067399342795696_);}
|
||||
.st11{fill:url(#SVGID_00000110458653123900400710000000913276274190314169_);}
|
||||
.st12{fill:url(#SVGID_00000031170474287730726890000016823097828633490353_);}
|
||||
.st13{fill:url(#SVGID_00000009584095435268706030000012050499560522802058_);}
|
||||
.st14{fill:url(#SVGID_00000020368880049868222280000013138163128272475547_);}
|
||||
.st15{fill:url(#SVGID_00000091714870927564046460000005861143291349344438_);}
|
||||
.st16{fill:url(#SVGID_00000121271759264793556690000000860873829552106905_);}
|
||||
.st17{fill:url(#SVGID_00000038375467373794976910000013035737086102657724_);}
|
||||
.st18{fill:url(#SVGID_00000012456568760464884060000014807595715017774739_);}
|
||||
.st19{fill:url(#SVGID_00000131350652581417219040000010050473447924313765_);}
|
||||
</style>
|
||||
<g id="图层_1_00000155127390320397588370000011557700728645101750_">
|
||||
</g>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="st0" d="M251.92,232.88c-4.82-2.55-8.57-6.17-11.24-10.85c-2.68-4.69-4-10.11-4-16.28V179.1
|
||||
c0-6.08,1.34-11.45,4-16.1s6.41-8.24,11.24-10.79c4.82-2.55,10.43-3.83,16.8-3.83c6.38,0,11.97,1.21,16.8,3.64
|
||||
c4.82,2.43,8.57,5.84,11.24,10.24c2.66,4.39,4,9.48,4,15.23c0,0.41-0.14,0.74-0.43,0.99c-0.28,0.25-0.61,0.36-1.02,0.36
|
||||
l-16.98,1.1c-0.96,0-1.43-0.46-1.43-1.35c0-3.86-1.12-6.95-3.35-9.24s-5.18-3.45-8.85-3.45s-6.61,1.17-8.85,3.51
|
||||
c-2.24,2.35-3.35,5.4-3.35,9.2v28c0,3.78,1.12,6.82,3.35,9.13c2.24,2.3,5.18,3.45,8.85,3.45s6.61-1.15,8.85-3.45
|
||||
c2.24-2.3,3.35-5.34,3.35-9.13c0-0.9,0.47-1.35,1.43-1.35l16.98,0.87c0.39,0,0.74,0.13,1.02,0.36c0.28,0.25,0.43,0.54,0.43,0.87
|
||||
c0,5.84-1.34,10.98-4,15.42c-2.68,4.44-6.41,7.87-11.24,10.3c-4.82,2.43-10.43,3.64-16.8,3.64
|
||||
C262.33,236.71,256.74,235.43,251.92,232.88z"/>
|
||||
<path class="st0" d="M316.18,230.79c-5.02-3.95-8.36-9.29-10.05-16.03c-0.96-3.37-1.43-7.07-1.43-11.1
|
||||
c0-4.52,0.52-8.5,1.56-11.97c1.83-6.49,5.24-11.58,10.22-15.23c4.98-3.65,11.02-5.5,18.11-5.5c7.01,0,12.95,1.83,17.81,5.5
|
||||
c4.87,3.65,8.25,8.69,10.16,15.1c1.12,3.7,1.67,7.65,1.67,11.84c0,3.78-0.44,7.4-1.32,10.85c-1.67,6.91-5.02,12.38-10.05,16.41
|
||||
c-5.02,4.03-11.17,6.05-18.41,6.05C327.3,236.71,321.2,234.74,316.18,230.79z M340.21,216.91c1.51-1.53,2.63-3.59,3.34-6.22
|
||||
c0.47-2.14,0.72-4.49,0.72-7.02c0-2.47-0.28-4.85-0.83-7.15c-0.65-2.55-1.72-4.52-3.23-5.92s-3.43-2.09-5.73-2.09
|
||||
c-4.63,0-7.65,2.68-9.09,8.02c-0.47,1.97-0.72,4.36-0.72,7.15c0,2.55,0.24,4.9,0.72,7.02c0.63,2.63,1.73,4.71,3.29,6.22
|
||||
c1.56,1.53,3.48,2.28,5.8,2.28C336.78,219.2,338.69,218.44,340.21,216.91z"/>
|
||||
<path class="st0" d="M407.21,172.39c0.28-0.28,0.61-0.43,1.02-0.43h16.98c0.39,0,0.74,0.14,1.02,0.43
|
||||
c0.28,0.28,0.43,0.65,0.43,1.06v55.14c0,11.26-3.15,19.32-9.45,24.17c-6.3,4.85-14.27,7.28-23.91,7.28
|
||||
c-3.59,0-7.37-0.33-11.35-0.99c-0.8-0.08-1.2-0.61-1.2-1.61l0.6-15.29c0-0.58,0.16-0.96,0.47-1.17c0.32-0.2,0.72-0.22,1.2-0.06
|
||||
c3.28,0.74,6.3,1.1,9.09,1.1c4.54,0,8.13-1.07,10.76-3.21c2.63-2.14,3.95-5.51,3.95-10.11l-0.83,0.87
|
||||
c-2.87,2.96-7.01,4.44-12.43,4.44c-5.26,0-10.08-1.2-14.47-3.57c-4.38-2.38-7.53-6.58-9.45-12.58c-1.28-3.95-1.91-8.8-1.91-14.55
|
||||
c0-6.33,0.76-11.51,2.27-15.54c1.83-5.1,4.76-9.17,8.79-12.21c4.02-3.04,8.71-4.57,14.05-4.57c5.73,0,10.16,1.81,13.28,5.43
|
||||
c0.16,0.17,0.32,0.22,0.47,0.19c0.16-0.05,0.24-0.19,0.24-0.43v-2.71C406.8,173.02,406.92,172.67,407.21,172.39z M406.8,203.16
|
||||
c0-2.22-0.08-3.97-0.24-5.24c-0.16-1.28-0.47-2.49-0.96-3.64c-0.63-1.81-1.65-3.23-3.06-4.25c-1.4-1.02-3.09-1.54-5.09-1.54
|
||||
c-3.75,0-6.41,1.94-8.02,5.8c-1.2,2.3-1.8,5.34-1.8,9.13c0,4.03,0.52,6.99,1.56,8.88c0.72,1.73,1.8,3.13,3.23,4.19
|
||||
c1.43,1.07,3.15,1.61,5.13,1.61c4.06,0,6.77-1.89,8.13-5.67C406.44,210.52,406.8,207.43,406.8,203.16z"/>
|
||||
<path class="st0" d="M451.94,234.49l-24.14-83.37l-0.13-0.49c0-0.82,0.44-1.23,1.31-1.23h18.3c0.88,0,1.43,0.41,1.67,1.23
|
||||
l13.75,56.13c0.08,0.25,0.2,0.36,0.36,0.36s0.28-0.13,0.36-0.36l13.51-56.13c0.24-0.82,0.8-1.23,1.67-1.23h17.94
|
||||
c0.47,0,0.83,0.17,1.07,0.49c0.24,0.33,0.28,0.74,0.13,1.23l-24.51,83.37c-0.24,0.82-0.76,1.23-1.56,1.23h-18.17
|
||||
C452.69,235.72,452.17,235.31,451.94,234.49z"/>
|
||||
<path class="st0" d="M504.84,163.26c-1.95-2.02-2.93-4.58-2.93-7.7c0-3.21,0.98-5.8,2.93-7.76c1.95-1.97,4.44-2.96,7.46-2.96
|
||||
s5.51,0.99,7.46,2.96c1.95,1.97,2.93,4.57,2.93,7.76c0,3.04-0.98,5.59-2.93,7.65c-1.95,2.06-4.44,3.09-7.46,3.09
|
||||
S506.79,165.27,504.84,163.26z M502.69,235.29c-0.28-0.28-0.43-0.63-0.43-1.04v-60.81c0-0.41,0.14-0.76,0.43-1.06
|
||||
c0.28-0.28,0.61-0.43,1.02-0.43h16.98c0.39,0,0.74,0.14,1.02,0.43c0.28,0.28,0.43,0.65,0.43,1.06v60.81
|
||||
c0,0.41-0.14,0.76-0.43,1.04c-0.28,0.28-0.61,0.43-1.02,0.43h-16.98C503.31,235.72,502.96,235.58,502.69,235.29z"/>
|
||||
<path class="st0" d="M651.35,207.97c-0.08,0.99-0.6,1.48-1.56,1.48h-36.46c-0.16,0-0.32,0.06-0.47,0.19
|
||||
c-0.16,0.13-0.2,0.27-0.13,0.43c0.16,0.91,0.55,2.09,1.2,3.57c0.96,1.73,2.39,3.12,4.3,4.19c1.91,1.07,4.27,1.61,7.06,1.61
|
||||
c5.02,0,8.96-1.69,11.84-5.06c0.32-0.41,0.68-0.61,1.07-0.61c0.39,0,0.72,0.16,0.96,0.49l9.21,10.85
|
||||
c0.32,0.25,0.47,0.58,0.47,0.99c0,0.33-0.16,0.66-0.47,0.99c-2.79,3.13-6.24,5.51-10.35,7.15s-8.58,2.47-13.45,2.47
|
||||
c-7.26,0-13.4-1.62-18.47-4.87c-5.07-3.24-8.71-7.8-10.95-13.62c-1.67-4.11-2.5-9.24-2.5-15.42c0-4.27,0.63-8.39,1.91-12.33
|
||||
c2.08-6.08,5.48-10.87,10.22-14.36s10.38-5.24,16.91-5.24c5.26,0,9.97,1.17,14.11,3.51c4.14,2.35,7.51,5.59,10.09,9.75
|
||||
c2.58,4.16,4.28,8.87,5.09,14.13C651.39,200.69,651.5,203.94,651.35,207.97z M613.46,194.52c-0.32,0.91-0.52,1.76-0.6,2.58
|
||||
c-0.16,0.41,0,0.61,0.47,0.61h16.98c0.32,0,0.47-0.16,0.47-0.49c0-0.66-0.16-1.48-0.47-2.47c-0.55-2.05-1.57-3.62-3.06-4.69
|
||||
c-1.48-1.07-3.32-1.61-5.56-1.61C617.55,188.49,614.81,190.5,613.46,194.52z"/>
|
||||
<path class="st0" d="M665.93,230.79c-5.02-3.95-8.36-9.29-10.05-16.03c-0.96-3.37-1.43-7.07-1.43-11.1
|
||||
c0-4.52,0.52-8.5,1.56-11.97c1.83-6.49,5.24-11.58,10.22-15.23c4.98-3.65,11.02-5.5,18.11-5.5c7.01,0,12.95,1.83,17.81,5.5
|
||||
c4.87,3.65,8.25,8.69,10.16,15.1c1.12,3.7,1.67,7.65,1.67,11.84c0,3.78-0.44,7.4-1.32,10.85c-1.67,6.91-5.02,12.38-10.05,16.41
|
||||
c-5.02,4.03-11.17,6.05-18.41,6.05C677.05,236.71,670.95,234.74,665.93,230.79z M689.96,216.91c1.51-1.53,2.63-3.59,3.35-6.22
|
||||
c0.47-2.14,0.72-4.49,0.72-7.02c0-2.47-0.28-4.85-0.83-7.15c-0.65-2.55-1.72-4.52-3.23-5.92c-1.51-1.4-3.43-2.09-5.73-2.09
|
||||
c-4.63,0-7.65,2.68-9.09,8.02c-0.47,1.97-0.72,4.36-0.72,7.15c0,2.55,0.24,4.9,0.72,7.02c0.63,2.63,1.73,4.71,3.29,6.22
|
||||
c1.56,1.53,3.48,2.28,5.8,2.28C686.53,219.2,688.45,218.44,689.96,216.91z"/>
|
||||
<path class="st0" d="M717.33,235.23c-0.24-0.33-0.2-0.74,0.13-1.23l22.6-41.07c0.16-0.25,0.16-0.49,0-0.74l-22.6-41.07
|
||||
l-0.24-0.74c0-0.66,0.44-0.99,1.32-0.99h18.41c0.88,0,1.48,0.33,1.8,0.99l13.51,25.91c0.24,0.33,0.47,0.33,0.72,0l13.62-25.91
|
||||
c0.32-0.66,0.91-0.99,1.8-0.99h18.41c0.55,0,0.91,0.17,1.07,0.49c0.16,0.33,0.13,0.74-0.13,1.23l-22.47,41.07
|
||||
c-0.08,0.25-0.08,0.49,0,0.74l22.47,41.07c0.16,0.33,0.24,0.58,0.24,0.74c0,0.66-0.39,0.99-1.2,0.99h-18.3
|
||||
c-0.8,0-1.4-0.33-1.8-0.99l-13.75-25.78c-0.24-0.41-0.47-0.41-0.72,0l-13.62,25.78c-0.39,0.66-0.99,0.99-1.8,0.99h-18.28
|
||||
C717.98,235.72,717.57,235.56,717.33,235.23z"/>
|
||||
</g>
|
||||
<path class="st0" d="M586.89,149.87c-0.28-0.28-0.61-0.43-1.02-0.43h-16.98c-0.39,0-0.74,0.14-1.01,0.43
|
||||
c-0.28,0.28-0.43,0.63-0.43,1.06v25.65c0,0.25-0.08,0.41-0.24,0.49c-0.16,0.08-0.32,0-0.47-0.25c-2.95-3.86-7.29-5.8-13.02-5.8
|
||||
c-5.73,0-10.61,1.56-14.65,4.69c-4.03,3.12-6.96,7.23-8.79,12.33c-1.67,4.6-2.5,9.78-2.5,15.54c0,5.1,0.63,9.87,1.91,14.32
|
||||
c1.75,5.84,4.72,10.44,8.9,13.81c4.19,3.37,8.99,5.06,14.41,5.06c5.98,0,10.57-2.27,13.75-6.79c0.16-0.16,0.32-0.22,0.47-0.19
|
||||
c0.16,0.05,0.24,0.19,0.24,0.43v4.08c0,0.41,0.14,0.76,0.43,1.04c0.28,0.28,0.61,0.43,1.01,0.43h16.98c0.39,0,0.74-0.14,1.02-0.43
|
||||
c0.28-0.28,0.43-0.63,0.43-1.04v-83.37C587.31,150.52,587.17,150.17,586.89,149.87z M567.06,208.9l-12.8,7.62
|
||||
c-4.14,2.47-9.32-0.61-9.32-5.56V195.7c0-4.95,5.18-8.02,9.32-5.56l12.8,7.62C571.2,200.25,571.2,206.42,567.06,208.9z"/>
|
||||
</g>
|
||||
<g>
|
||||
<g>
|
||||
|
||||
<linearGradient id="SVGID_1_" gradientUnits="userSpaceOnUse" x1="59.2757" y1="254.7704" x2="135.9786" y2="126.215" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="0" style="stop-color:#19D46A"/>
|
||||
<stop offset="1" style="stop-color:#0D00B5"/>
|
||||
</linearGradient>
|
||||
<path class="st1" d="M160.69,271.07h-42.1c-0.88,0-1.72-0.44-2.2-1.18L41.9,106.64c-1.2-1.76,0.08-4.16,2.2-4.16h42.1
|
||||
c0.88,0,1.72,0.44,2.2,1.18l74.49,163.25C164.09,268.68,162.83,271.07,160.69,271.07z"/>
|
||||
</g>
|
||||
<g>
|
||||
|
||||
<linearGradient id="SVGID_00000052810148107182236330000012848981250606710167_" gradientUnits="userSpaceOnUse" x1="146.394" y1="179.088" x2="200.0336" y2="58.4906" gradientTransform="matrix(1 0 0 1 0 76.8546)">
|
||||
<stop offset="4.256583e-04" style="stop-color:#D70066"/>
|
||||
<stop offset="1" style="stop-color:#0D00B5"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000052810148107182236330000012848981250606710167_);" d="M235.89,102.48h-42.66
|
||||
c-0.8,0-1.56,0.39-2.02,1.07l-74.81,163.72c-1.09,1.62,0.06,3.8,2.02,3.8h42.66c0.8,0,1.56-0.39,2.02-1.07l74.81-163.72
|
||||
C239,104.66,237.85,102.48,235.89,102.48z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path class="st3" d="M218.63,111.51c-7.17-0.43-15.24,4.27-26.77,18.35c-2.74,3.34-12.85,3.26-11.59,10.33
|
||||
c6.11,34.58-25.94,107.11-18.77,107.53s25.09-13.28,31.42-21.29c6.33-8.02,10.54-14.55,15.39-24.46
|
||||
c4.85-9.91,9.91-24.25,14.33-39.64c4.43-15.39,4.65-30.57,4.85-39.21C227.56,120.36,225.81,111.92,218.63,111.51z"/>
|
||||
<path d="M189.18,136.8c-1.69-9.28-8.93-28.14-13.78-34.05c-4.85-5.91-9.64-8.82-17.29-5.91c-4.22,1.61-14.13,7.32-21.29,29.31
|
||||
c-2.3,7.04-6.11,10.54-9.48,15.18c-3.37,4.65-8.85,10.96-8.85,19.83c0,8.85,3.62,16.43,8.47,22.11
|
||||
c4.85,5.69,25.78,22.46,30.25,27.1c4.47,4.65,16.74,16.77,19.91,22.27c3.17,5.48,3.17,5.48,3.17,5.48s6.02-13.92,7.91-18.35
|
||||
c1.89-4.43,5.37-27.83,5.69-41.12C194.17,165.38,190.39,143.46,189.18,136.8z"/>
|
||||
<path d="M164.13,255.42c-0.76,0.39-1.29-1.87,1.4-5.75c2.69-3.87,8.09-11.39,12.44-14.68c4.35-3.29,15.26-12.1,19.37-13.02
|
||||
c4.11-0.94,4.08,2.74,2.11,4.25c-3.76,2.91-4.46,2.2-9.28,6.55c-4.82,4.35-10.68,6.69-14.32,9.98
|
||||
C172.22,246.03,166.35,254.25,164.13,255.42z"/>
|
||||
<path d="M124.89,189.1c-12.46-3.64-32.79-7.59-40.54,20.88c0,0-5.46,2.39-9.62,11.84c-6.63,15.06,7.89,55.39,50.79,44.9
|
||||
c20.36-4.98,37.37-15.78,42.38-20.66c7.58-7.37,12.47-7.83,12.47-7.83C177.15,227.43,152.74,197.23,124.89,189.1z"/>
|
||||
<path d="M197.04,224.68c4.09-3.78,9.35-9.48,11.67-13.48c2.32-4,4.55-9.59,4.96-11.24c0.41-1.65-0.61-6.61,1.81-6.33
|
||||
c2.44,0.3,0.94,4.22-0.09,6.08c-1.4,2.52-3.06,6.35-4.17,9.13c-1.12,2.79-5.06,7.83-8.24,11.39c-3.18,3.56-6.33,6.09-6.33,6.09
|
||||
L197.04,224.68z"/>
|
||||
<path d="M195.8,225.98c0,0,5.06-3.04,9.17-5.95c4.09-2.91,12.38-10.39,14.44-13.06c2.06-2.66,2.39-5.24,3.48-4.55
|
||||
c1.92,1.23-1.26,3.98-2.63,5.31c-1.37,1.32-6.06,5.94-9.72,8.91C206.88,219.62,197.4,226.46,195.8,225.98z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000070798706056830547700000003974558205504861373_" gradientUnits="userSpaceOnUse" x1="175.0956" y1="228.5175" x2="187.4265" y2="140.3143" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="0" style="stop-color:#8890F4"/>
|
||||
<stop offset="0.2835" style="stop-color:#FFCFFD"/>
|
||||
<stop offset="0.6062" style="stop-color:#F4F0FE"/>
|
||||
<stop offset="0.7338" style="stop-color:#EFFFFF"/>
|
||||
<stop offset="1" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000070798706056830547700000003974558205504861373_);" d="M183.02,228.43
|
||||
c0.08,0.41,0.39,0.66,0.44,0.25c0.91-8.69,4.36-18.96,4.47-44.43c0.11-26.91,0-39.18-19.64-44.18c-1.29-0.33-2.54-0.03-2.33,1.28
|
||||
c0.47,3.01,7.29,17.84,10.69,37.18C180.85,202.38,182.2,224.39,183.02,228.43z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000148655111165341704870000015822622798822732439_" gradientUnits="userSpaceOnUse" x1="180.2612" y1="133.8598" x2="180.2585" y2="133.8624" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="0" style="stop-color:#00C6B8"/>
|
||||
<stop offset="0.8389" style="stop-color:#4B4BFF"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000148655111165341704870000015822622798822732439_);" d="M180.27,235.42
|
||||
C180.2,235.47,180.31,235.37,180.27,235.42L180.27,235.42z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000162335655504960668340000008042784563402819467_" gradientUnits="userSpaceOnUse" x1="91.8193" y1="162.3303" x2="182.7514" y2="137.3534" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="5.071865e-03" style="stop-color:#6F69F4"/>
|
||||
<stop offset="0.4148" style="stop-color:#FFCEAD"/>
|
||||
<stop offset="0.5961" style="stop-color:#EFEFB9"/>
|
||||
<stop offset="0.6727" style="stop-color:#E7FFBF"/>
|
||||
<stop offset="1" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000162335655504960668340000008042784563402819467_);" d="M98.06,216.36
|
||||
c-6.71,1.86,9.43,5.17,24.79,8.25c10.54,2.11,43.58,7.17,54.82,7.76c0.25,0.02,0.3-0.28,0.08-0.43
|
||||
c-11.17-7.34-36.02-23.21-43.1-23.21C121.93,208.71,103.08,214.98,98.06,216.36z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000183956648102935592230000000034257340353923987_" gradientUnits="userSpaceOnUse" x1="129.9311" y1="208.2058" x2="193.0715" y2="132.703" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="5.071865e-03" style="stop-color:#6F69F4"/>
|
||||
<stop offset="0.3821" style="stop-color:#FFAFFF"/>
|
||||
<stop offset="0.5127" style="stop-color:#F5D1E4"/>
|
||||
<stop offset="0.6727" style="stop-color:#E7FFBF"/>
|
||||
<stop offset="0.9718" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000183956648102935592230000000034257340353923987_);" d="M136.52,164.49
|
||||
c0,0,5.26,11.15,13.61,18.03c7.18,5.92,13.48,9.73,15.81,12.46c7.94,9.29,11.01,25.56,14.93,33.32
|
||||
c0.58,1.13-28.73-27.91-43.64-52.68C133.82,169.95,135.6,162.96,136.52,164.49z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000127743825081950367160000014214139936879980969_" gradientUnits="userSpaceOnUse" x1="120.5534" y1="180.132" x2="113.4298" y2="185.2324" gradientTransform="matrix(0.9663 0.2574 0.2574 -0.9663 1.6568 269.9926)">
|
||||
<stop offset="0" style="stop-color:#FFA2FF"/>
|
||||
<stop offset="0.5569" style="stop-color:#FFFFB0"/>
|
||||
<stop offset="0.7133" style="stop-color:#DDFFCA"/>
|
||||
<stop offset="1" style="stop-color:#97FFFF"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000127743825081950367160000014214139936879980969_);" d="M155.29,115.81
|
||||
c3.06-0.57,11.29,4.06,11.29,6.49c0,1.1-11.76-1.07-13.8-2.13C150.06,118.75,153.27,116.18,155.29,115.81z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000054976593115952198230000008003013541946111389_" gradientUnits="userSpaceOnUse" x1="168.498" y1="247.5718" x2="163.8415" y2="254.2492" gradientTransform="matrix(0.9987 -0.0503 -0.0503 -0.9987 0.9774 386.0547)">
|
||||
<stop offset="0" style="stop-color:#FFA2FF"/>
|
||||
<stop offset="0.5569" style="stop-color:#FFFFB0"/>
|
||||
<stop offset="0.7133" style="stop-color:#DDFFCA"/>
|
||||
<stop offset="1" style="stop-color:#97FFFF"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000054976593115952198230000008003013541946111389_);" d="M149.94,124.45
|
||||
c1.48-1.29,7.8,0.05,8.96,1.23c1.24,1.26-6.21,2.13-8.65,2.38C148.19,128.25,148.45,125.74,149.94,124.45z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000169534306769865638960000010624933101701591231_" gradientUnits="userSpaceOnUse" x1="110.6462" y1="167.6579" x2="99.1469" y2="175.8909" gradientTransform="matrix(0.9322 0.362 0.362 -0.9322 11.275 241.6989)">
|
||||
<stop offset="0" style="stop-color:#FFA2FF"/>
|
||||
<stop offset="0.5569" style="stop-color:#FFFFB0"/>
|
||||
<stop offset="0.7133" style="stop-color:#DDFFCA"/>
|
||||
<stop offset="1" style="stop-color:#97FFFF"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000169534306769865638960000010624933101701591231_);" d="M164.36,104.99
|
||||
c4.57,0.41,13.81,13.86,11.8,15.47c-1.06,0.85-15.4-8-17.48-10.52C156.88,107.77,160.94,104.67,164.36,104.99z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000079458906221052217860000002178591648295304114_" gradientUnits="userSpaceOnUse" x1="162.8531" y1="71.2364" x2="157.6109" y2="74.9896" gradientTransform="matrix(0.8296 0.5584 0.5584 -0.8296 38.9717 104.0816)">
|
||||
<stop offset="0" style="stop-color:#FFA2FF"/>
|
||||
<stop offset="0.5569" style="stop-color:#FFFFB0"/>
|
||||
<stop offset="0.7133" style="stop-color:#DDFFCA"/>
|
||||
<stop offset="1" style="stop-color:#97FFFF"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000079458906221052217860000002178591648295304114_);" d="M210.08,125.93
|
||||
c2.3,0.35,6.93,5.43,6.39,7.18c-0.24,0.79-8.03-3.48-9.21-4.72C205.69,126.74,208.57,125.71,210.08,125.93z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000111909030176016825680000003151376107357579152_" gradientUnits="userSpaceOnUse" x1="189.1343" y1="139.3362" x2="185.6595" y2="144.3191" gradientTransform="matrix(0.9754 0.2206 0.2206 -0.9754 -7.4036 231.1159)">
|
||||
<stop offset="0" style="stop-color:#FFA2FF"/>
|
||||
<stop offset="0.5569" style="stop-color:#FFFFB0"/>
|
||||
<stop offset="0.7133" style="stop-color:#DDFFCA"/>
|
||||
<stop offset="1" style="stop-color:#97FFFF"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000111909030176016825680000003151376107357579152_);" d="M204.3,131.19
|
||||
c1.1-0.69,5.64,1.8,5.7,3.02c0.06,0.99-4.39-0.06-6.19-0.38C202.28,133.57,203.06,131.96,204.3,131.19z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000124859780877684741540000003625974166378130101_" gradientUnits="userSpaceOnUse" x1="158.4721" y1="62.348" x2="151.1771" y2="67.5709" gradientTransform="matrix(0.7882 0.6154 0.6154 -0.7882 56.9403 86.3531)">
|
||||
<stop offset="0" style="stop-color:#FFA2FF"/>
|
||||
<stop offset="0.5569" style="stop-color:#FFFFB0"/>
|
||||
<stop offset="0.7133" style="stop-color:#DDFFCA"/>
|
||||
<stop offset="1" style="stop-color:#97FFFF"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000124859780877684741540000003625974166378130101_);" d="M217.7,120.44
|
||||
c3.17,1.28,5.95,10.27,4.88,11.2c-1.15,0.99-9.06-6.66-10-8.9C211.77,120.81,215.31,119.48,217.7,120.44z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000151505342341095507280000017826952898926496684_" gradientUnits="userSpaceOnUse" x1="177.0845" y1="184.1351" x2="142.9964" y2="223.8563" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="5.071865e-03" style="stop-color:#6F69F4"/>
|
||||
<stop offset="0.3291" style="stop-color:#FFD9AD"/>
|
||||
<stop offset="0.6346" style="stop-color:#EAFABD"/>
|
||||
<stop offset="0.6727" style="stop-color:#E7FFBF"/>
|
||||
<stop offset="1" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000151505342341095507280000017826952898926496684_);" d="M173.34,196.31
|
||||
c-0.98-4.96-2.65-25.72-7.8-38.06c-3.75-8.98-14.08-25.72-16.19-20.02c-3.09,8.35-10.11,16.16-7.72,22.36
|
||||
c2.38,6.16,10.5,14.69,16.58,20.03C164.09,185.75,173.43,196.75,173.34,196.31z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000039821377531076698740000016045699938930074037_" gradientUnits="userSpaceOnUse" x1="168.0164" y1="130.694" x2="81.7921" y2="131.8498" gradientTransform="matrix(0.9993 0.0362 0.0362 -0.9993 2.0222 360.6481)">
|
||||
<stop offset="5.071865e-03" style="stop-color:#6F69F4"/>
|
||||
<stop offset="0.3302" style="stop-color:#FFECEC"/>
|
||||
<stop offset="0.5907" style="stop-color:#EDFACB"/>
|
||||
<stop offset="0.6727" style="stop-color:#E7FFBF"/>
|
||||
<stop offset="0.9718" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000039821377531076698740000016045699938930074037_);" d="M117.43,228.98
|
||||
c-7.92-1.15-19.95-3.64-27.47-5.12c-4.74-0.94,0.83,12.13,7.23,15.47c5.53,2.88,30.29,1.34,46.74-0.31
|
||||
c13.65-1.37,34.4-4.98,27.29-5.12C159.49,233.64,140.45,232.32,117.43,228.98z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000040549751368890075800000016406046424980702142_" gradientUnits="userSpaceOnUse" x1="104.442" y1="126.5369" x2="183.9285" y2="126.5369" gradientTransform="matrix(0.9986 -0.0528 -0.0528 -0.9986 -2.8922 382.0786)">
|
||||
<stop offset="5.071865e-03" style="stop-color:#6F69F4"/>
|
||||
<stop offset="0.3302" style="stop-color:#FFECEC"/>
|
||||
<stop offset="0.5907" style="stop-color:#EDFACB"/>
|
||||
<stop offset="0.6727" style="stop-color:#E7FFBF"/>
|
||||
<stop offset="0.9718" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000040549751368890075800000016406046424980702142_);" d="M172.91,237.28
|
||||
c0.35-0.24,0.94-0.96,0.55-0.83c-6.13,1.97-28.62,8.25-53.23,10.87c-18.84,2-22.85,0.44-25.02,1.39
|
||||
c-2.6,1.13,4.63,10.88,20.68,10C134.38,257.67,162.38,244.1,172.91,237.28z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000163791769488651178610000006995591105667394192_" gradientUnits="userSpaceOnUse" x1="154.7833" y1="149.8526" x2="87.8201" y2="163.9461" gradientTransform="matrix(0.9991 0.0421 0.0421 -0.9991 1.5546 359.6285)">
|
||||
<stop offset="5.071865e-03" style="stop-color:#6F69F4"/>
|
||||
<stop offset="0.601" style="stop-color:#FFCEAD"/>
|
||||
<stop offset="0.707" style="stop-color:#EFEFB9"/>
|
||||
<stop offset="0.7518" style="stop-color:#E7FFBF"/>
|
||||
<stop offset="0.976" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000163791769488651178610000006995591105667394192_);" d="M99.81,207.05
|
||||
c-1.15,0-1.24-3.51,2.74-5.24c0,0,4.44-5.58,13.21-3.89c6.05,1.17,15.76,6.69,14.49,6.41
|
||||
C125.62,203.29,105.6,207.05,99.81,207.05z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000058564881533081704170000010681590285087118983_" gradientUnits="userSpaceOnUse" x1="205.4404" y1="213.7062" x2="190.2183" y2="136.673" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="0" style="stop-color:#8890F4"/>
|
||||
<stop offset="0.4248" style="stop-color:#FFB8FD"/>
|
||||
<stop offset="0.5488" style="stop-color:#F4D9FE"/>
|
||||
<stop offset="0.6727" style="stop-color:#E7FFFF"/>
|
||||
<stop offset="1" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000058564881533081704170000010681590285087118983_);" d="M188.91,224.52
|
||||
c0-0.43,3.89-6.5,8.99-33.55c3.18-16.79,2.36-31.58,1.91-34.49c-0.33-2.06,6.9,2.05,7.04,10.02c0.09,5.69-1.17,28.49-11.53,49.01
|
||||
C191.09,223.87,188.93,224.98,188.91,224.52z"/>
|
||||
|
||||
<linearGradient id="SVGID_00000147192695983014756520000014062130175962991015_" gradientUnits="userSpaceOnUse" x1="204.0304" y1="167.3126" x2="218.5749" y2="220.9586" gradientTransform="matrix(1 0 0 -1 0 369.2756)">
|
||||
<stop offset="5.071865e-03" style="stop-color:#6F69F4"/>
|
||||
<stop offset="0.3291" style="stop-color:#FFCEAD"/>
|
||||
<stop offset="0.5707" style="stop-color:#EFEFB9"/>
|
||||
<stop offset="0.6727" style="stop-color:#E7FFBF"/>
|
||||
<stop offset="1" style="stop-color:#75D9DD"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_00000147192695983014756520000014062130175962991015_);" d="M203.78,203.43
|
||||
c-0.13,0.2-0.47,0.06-0.36-0.14c2.06-4.63,7.76-24.24,9.04-39.07c0.5-5.83,0.54-10.47,1.18-12.03c0.74-1.84,5.13-6.77,6.14-7.39
|
||||
c2.05-1.21,1.09,11.12-3.57,26.43C213.32,180.66,207.5,196.94,203.78,203.43z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 23 KiB |
BIN
resources/videos/1.mp4
Normal file
BIN
resources/videos/1.mp4
Normal file
Binary file not shown.
BIN
resources/videos/2.mp4
Normal file
BIN
resources/videos/2.mp4
Normal file
Binary file not shown.
BIN
resources/videos/3.mp4
Normal file
BIN
resources/videos/3.mp4
Normal file
Binary file not shown.
BIN
resources/videos/4.mp4
Normal file
BIN
resources/videos/4.mp4
Normal file
Binary file not shown.
BIN
resources/web_demo.png
Normal file
BIN
resources/web_demo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.6 MiB |
BIN
resources/wechat.jpg
Normal file
BIN
resources/wechat.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.6 KiB |
@ -1,2 +0,0 @@
|
||||
#!/bin/sh
|
||||
docker run --rm --name cog -it --gpus all -v "${PWD}":/workspace cog
|
182
sat/README.md
Normal file
182
sat/README.md
Normal file
@ -0,0 +1,182 @@
|
||||
# SAT CogVideoX-2B
|
||||
|
||||
This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
|
||||
fine-tuning code for SAT weights.
|
||||
|
||||
This code is the framework used by the team to train the model. It has few comments and requires careful study.
|
||||
|
||||
## Inference Model
|
||||
|
||||
1. Ensure that you have correctly installed the dependencies required by this folder.
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Download the model weights
|
||||
|
||||
First, go to the SAT mirror to download the dependencies.
|
||||
|
||||
```shell
|
||||
mkdir CogVideoX-2b-sat
|
||||
cd CogVideoX-2b-sat
|
||||
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
||||
mv 'index.html?dl=1' vae.zip
|
||||
uzip vae.zip
|
||||
wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
|
||||
mv 'index.html?dl=1' transformer.zip
|
||||
unzip transformer.zip
|
||||
```
|
||||
|
||||
Then unzip, the model structure should look like this:
|
||||
|
||||
```
|
||||
.
|
||||
├── transformer
|
||||
│ ├── 1000
|
||||
│ │ └── mp_rank_00_model_states.pt
|
||||
│ └── latest
|
||||
└── vae
|
||||
└── 3d-vae.pt
|
||||
```
|
||||
|
||||
Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
|
||||
|
||||
```shell
|
||||
git lfs install
|
||||
git clone https://huggingface.co/google/t5-v1_1-xxl.git
|
||||
```
|
||||
|
||||
**We don't need the tf_model.h5** file. This file can be deleted.
|
||||
|
||||
3. Modify the file `configs/cogvideox_2b_infer.yaml`.
|
||||
|
||||
```yaml
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl" ## T5 model path
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE model path
|
||||
```
|
||||
|
||||
+ If using txt to save multiple prompts, please refer to `configs/test.txt` for modification. One prompt per line. If
|
||||
you don't know how to write prompts, you can first use [this code](../inference/convert_demo.py) to call LLM for
|
||||
refinement.
|
||||
+ If using the command line as input, modify
|
||||
|
||||
```yaml
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
so that prompts can be entered from the command line.
|
||||
|
||||
If you want to change the output video directory, you can modify:
|
||||
|
||||
```yaml
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
The default is saved in the `.outputs/` folder.
|
||||
|
||||
4. Run the inference code to start inference
|
||||
|
||||
```shell
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
## Fine-Tuning the Model
|
||||
|
||||
### Preparing the Dataset
|
||||
|
||||
The dataset format should be as follows:
|
||||
|
||||
```
|
||||
.
|
||||
├── labels
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
└── videos
|
||||
├── 1.mp4
|
||||
├── 2.mp4
|
||||
├── ...
|
||||
```
|
||||
|
||||
Each txt file should have the same name as its corresponding video file and contain the labels for that video. Each
|
||||
video should have a one-to-one correspondence with a label. Typically, a video should not have multiple labels.
|
||||
|
||||
For style fine-tuning, please prepare at least 50 videos and labels with similar styles to facilitate fitting.
|
||||
|
||||
### Modifying the Configuration File
|
||||
|
||||
We support both `Lora` and `full-parameter fine-tuning` methods. Please note that both fine-tuning methods only apply to the `transformer` part. The `VAE part` is not modified. `T5` is only used as an Encoder.
|
||||
|
||||
the `configs/cogvideox_2b_sft.yaml` (for full fine-tuning) as follows.
|
||||
|
||||
```yaml
|
||||
# checkpoint_activations: True ## using gradient checkpointing (both checkpoint_activations in the configuration file need to be set to True)
|
||||
model_parallel_size: 1 # Model parallel size
|
||||
experiment_name: lora-disney # Experiment name (do not change)
|
||||
mode: finetune # Mode (do not change)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" # Transformer model path
|
||||
no_load_rng: True # Whether to load the random seed
|
||||
train_iters: 1000 # Number of training iterations
|
||||
eval_iters: 1 # Number of evaluation iterations
|
||||
eval_interval: 100 # Evaluation interval
|
||||
eval_batch_size: 1 # Batch size for evaluation
|
||||
save: ckpts # Model save path
|
||||
save_interval: 100 # Model save interval
|
||||
log_interval: 20 # Log output interval
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # Training and validation sets can be the same
|
||||
split: 1,0,0 # Ratio of training, validation, and test sets
|
||||
num_workers: 8 # Number of worker threads for data loading
|
||||
```
|
||||
|
||||
If you wish to use Lora fine-tuning, you also need to modify:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## Uncomment
|
||||
log_keys:
|
||||
- txt'
|
||||
|
||||
lora_config: ## Uncomment
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
```
|
||||
|
||||
### Fine-Tuning and Validation
|
||||
|
||||
1. Run the inference code to start fine-tuning.
|
||||
|
||||
```shell
|
||||
bash finetune.sh
|
||||
```
|
||||
|
||||
### Converting to Huggingface Diffusers Supported Weights
|
||||
|
||||
The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
|
||||
|
||||
```shell
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
**Note**: This content has not yet been tested with LORA fine-tuning models.
|
180
sat/README_zh.md
Normal file
180
sat/README_zh.md
Normal file
@ -0,0 +1,180 @@
|
||||
# SAT CogVideoX-2B
|
||||
|
||||
本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
|
||||
|
||||
该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
|
||||
|
||||
## 推理模型
|
||||
|
||||
1. 确保你已经正确安装本文件夹中的要求的依赖
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 下载模型权重
|
||||
|
||||
首先,前往 SAT 镜像下载依赖。
|
||||
|
||||
```shell
|
||||
mkdir CogVideoX-2b-sat
|
||||
cd CogVideoX-2b-sat
|
||||
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
||||
mv 'index.html?dl=1' vae.zip
|
||||
uzip vae.zip
|
||||
wget https://cloud.tsinghua.edu.cn/f/556a3e1329e74f1bac45/?dl=1
|
||||
mv 'index.html?dl=1' transformer.zip
|
||||
unzip transformer.zip
|
||||
```
|
||||
|
||||
然后,解压文件,模型结构应该如下
|
||||
|
||||
```
|
||||
.
|
||||
├── transformer
|
||||
│ ├── 1000
|
||||
│ │ └── mp_rank_00_model_states.pt
|
||||
│ └── latest
|
||||
└── vae
|
||||
└── 3d-vae.pt
|
||||
```
|
||||
|
||||
接着,克隆 T5 模型,该模型不用做训练和微调,但是必须使用。
|
||||
|
||||
```shell
|
||||
git lfs install
|
||||
git clone https://huggingface.co/google/t5-v1_1-xxl.git
|
||||
```
|
||||
|
||||
**我们不需要使用tf_model.h5**文件。该文件可以删除。
|
||||
|
||||
3. 修改`configs/cogvideox_2b_infer.yaml`中的文件。
|
||||
|
||||
```yaml
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl" ## T5 模型路径
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE 模型路径
|
||||
|
||||
```
|
||||
|
||||
+ 如果使用 txt 保存多个提示词,请参考`configs/test.txt`
|
||||
进行修改。每一行一个提示词。如果您不知道如何书写提示词,可以先使用[此代码](../inference/convert_demo.py)调用 LLM进行润色。
|
||||
+ 如果使用命令行作为输入,请修改
|
||||
|
||||
```yaml
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
这样就可以从命令行输入提示词。
|
||||
|
||||
如果你希望修改输出视频的地址,你可以修改:
|
||||
|
||||
```yaml
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
默认保存在`.outputs/`文件夹下。
|
||||
|
||||
4. 运行推理代码,即可推理
|
||||
|
||||
```shell
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
## 微调模型
|
||||
|
||||
### 准备数据集
|
||||
|
||||
数据集格式应该如下:
|
||||
|
||||
```
|
||||
.
|
||||
├── labels
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
└── videos
|
||||
├── 1.mp4
|
||||
├── 2.mp4
|
||||
├── ...
|
||||
```
|
||||
|
||||
每个 txt 与视频同名,为视频的标签。视频与标签应该一一对应。通常情况下,不使用一个视频对应多个标签。
|
||||
|
||||
如果为风格微调,清准备至少50条风格相似的视频和标签,以利于拟合。
|
||||
|
||||
### 修改配置文件
|
||||
|
||||
我们支持 `Lora` 和 全参数微调两种方式。请注意,两种微调方式都仅仅对 `transformer` 部分进行微调。不改动 `VAE` 部分。`T5`仅作为
|
||||
Encoder 使用。
|
||||
部分。 请按照以下方式修改`configs/cogvideox_2b_sft.yaml`(全量微调) 中的文件。
|
||||
|
||||
```yaml
|
||||
# checkpoint_activations: True ## using gradient checkpointing (配置文件中的两个checkpoint_activations都需要设置为True)
|
||||
model_parallel_size: 1 # 模型并行大小
|
||||
experiment_name: lora-disney # 实验名称(不要改动)
|
||||
mode: finetune # 模式(不要改动)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径
|
||||
no_load_rng: True # 是否加载随机数种子
|
||||
train_iters: 1000 # 训练迭代次数
|
||||
eval_iters: 1 # 验证迭代次数
|
||||
eval_interval: 100 # 验证间隔
|
||||
eval_batch_size: 1 # 验证集 batch size
|
||||
save: ckpts # 模型保存路径
|
||||
save_interval: 100 # 模型保存间隔
|
||||
log_interval: 20 # 日志输出间隔
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # 训练集和验证集可以相同
|
||||
split: 1,0,0 # 训练集,验证集,测试集比例
|
||||
num_workers: 8 # 数据加载器的工作线程数
|
||||
```
|
||||
|
||||
如果你希望使用 Lora 微调,你还需要修改:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## 解除注释
|
||||
log_keys:
|
||||
- txt'
|
||||
|
||||
lora_config: ## 解除注释
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
```
|
||||
|
||||
### 微调和验证
|
||||
|
||||
1. 运行推理代码,即可开始微调。
|
||||
|
||||
```shell
|
||||
bash finetune.sh
|
||||
```
|
||||
|
||||
### 转换到 Huggingface Diffusers 库支持的权重
|
||||
|
||||
SAT 权重格式与 Huggingface 的权重格式不同,需要转换。请运行
|
||||
|
||||
```shell
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
**注意** 本内容暂未测试 LORA 微调模型。
|
281
sat/arguments.py
Normal file
281
sat/arguments.py
Normal file
@ -0,0 +1,281 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import warnings
|
||||
import omegaconf
|
||||
from omegaconf import OmegaConf
|
||||
from sat.helpers import print_rank0
|
||||
from sat import mpu
|
||||
from sat.arguments import set_random_seed
|
||||
from sat.arguments import add_training_args, add_evaluation_args, add_data_args
|
||||
import torch.distributed
|
||||
|
||||
|
||||
def add_model_config_args(parser):
|
||||
"""Model arguments"""
|
||||
|
||||
group = parser.add_argument_group("model", "model configuration")
|
||||
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
|
||||
group.add_argument(
|
||||
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
|
||||
)
|
||||
group.add_argument("--force-pretrain", action="store_true")
|
||||
group.add_argument("--device", type=int, default=-1)
|
||||
group.add_argument("--debug", action="store_true")
|
||||
group.add_argument("--log-image", type=bool, default=True)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_sampling_config_args(parser):
|
||||
"""Sampling configurations"""
|
||||
|
||||
group = parser.add_argument_group("sampling", "Sampling Configurations")
|
||||
group.add_argument("--output-dir", type=str, default="samples")
|
||||
group.add_argument("--input-dir", type=str, default=None)
|
||||
group.add_argument("--input-type", type=str, default="cli")
|
||||
group.add_argument("--input-file", type=str, default="input.txt")
|
||||
group.add_argument("--final-size", type=int, default=2048)
|
||||
group.add_argument("--sdedit", action="store_true")
|
||||
group.add_argument("--grid-num-rows", type=int, default=1)
|
||||
group.add_argument("--force-inference", action="store_true")
|
||||
group.add_argument("--lcm_steps", type=int, default=None)
|
||||
group.add_argument("--sampling-num-frames", type=int, default=32)
|
||||
group.add_argument("--sampling-fps", type=int, default=8)
|
||||
group.add_argument("--only-save-latents", type=bool, default=False)
|
||||
group.add_argument("--only-log-video-latents", type=bool, default=False)
|
||||
group.add_argument("--latent-channels", type=int, default=32)
|
||||
group.add_argument("--image2video", action="store_true")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_args(args_list=None, parser=None):
|
||||
"""Parse all the args."""
|
||||
if parser is None:
|
||||
parser = argparse.ArgumentParser(description="sat")
|
||||
else:
|
||||
assert isinstance(parser, argparse.ArgumentParser)
|
||||
parser = add_model_config_args(parser)
|
||||
parser = add_sampling_config_args(parser)
|
||||
parser = add_training_args(parser)
|
||||
parser = add_evaluation_args(parser)
|
||||
parser = add_data_args(parser)
|
||||
|
||||
import deepspeed
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
|
||||
args = parser.parse_args(args_list)
|
||||
args = process_config_to_args(args)
|
||||
|
||||
if not args.train_data:
|
||||
print_rank0("No training data specified", level="WARNING")
|
||||
|
||||
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
|
||||
if args.train_iters is None and args.epochs is None:
|
||||
args.train_iters = 10000 # default 10k iters
|
||||
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
|
||||
|
||||
args.cuda = torch.cuda.is_available()
|
||||
|
||||
args.rank = int(os.getenv("RANK", "0"))
|
||||
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
if args.local_rank is None:
|
||||
args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
|
||||
|
||||
if args.device == -1:
|
||||
if torch.cuda.device_count() == 0:
|
||||
args.device = "cpu"
|
||||
elif args.local_rank is not None:
|
||||
args.device = args.local_rank
|
||||
else:
|
||||
args.device = args.rank % torch.cuda.device_count()
|
||||
|
||||
if args.local_rank != args.device and args.mode != "inference":
|
||||
raise ValueError(
|
||||
"LOCAL_RANK (default 0) and args.device inconsistent. "
|
||||
"This can only happens in inference mode. "
|
||||
"Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
|
||||
)
|
||||
|
||||
if args.rank == 0:
|
||||
print_rank0("using world size: {}".format(args.world_size))
|
||||
|
||||
if args.train_data_weights is not None:
|
||||
assert len(args.train_data_weights) == len(args.train_data)
|
||||
|
||||
if args.mode != "inference": # training with deepspeed
|
||||
args.deepspeed = True
|
||||
if args.deepspeed_config is None: # not specified
|
||||
deepspeed_config_path = os.path.join(
|
||||
os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
|
||||
)
|
||||
with open(deepspeed_config_path) as file:
|
||||
args.deepspeed_config = json.load(file)
|
||||
override_deepspeed_config = True
|
||||
else:
|
||||
override_deepspeed_config = False
|
||||
|
||||
assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
|
||||
|
||||
if args.zero_stage > 0 and not args.fp16 and not args.bf16:
|
||||
print_rank0("Automatically set fp16=True to use ZeRO.")
|
||||
args.fp16 = True
|
||||
args.bf16 = False
|
||||
|
||||
if args.deepspeed:
|
||||
if args.checkpoint_activations:
|
||||
args.deepspeed_activation_checkpointing = True
|
||||
else:
|
||||
args.deepspeed_activation_checkpointing = False
|
||||
if args.deepspeed_config is not None:
|
||||
deepspeed_config = args.deepspeed_config
|
||||
|
||||
if override_deepspeed_config: # not specify deepspeed_config, use args
|
||||
if args.fp16:
|
||||
deepspeed_config["fp16"]["enabled"] = True
|
||||
elif args.bf16:
|
||||
deepspeed_config["bf16"]["enabled"] = True
|
||||
deepspeed_config["fp16"]["enabled"] = False
|
||||
else:
|
||||
deepspeed_config["fp16"]["enabled"] = False
|
||||
deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
|
||||
deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
|
||||
optimizer_params_config = deepspeed_config["optimizer"]["params"]
|
||||
optimizer_params_config["lr"] = args.lr
|
||||
optimizer_params_config["weight_decay"] = args.weight_decay
|
||||
else: # override args with values in deepspeed_config
|
||||
if args.rank == 0:
|
||||
print_rank0("Will override arguments with manually specified deepspeed_config!")
|
||||
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
|
||||
args.fp16 = True
|
||||
else:
|
||||
args.fp16 = False
|
||||
if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
|
||||
args.bf16 = True
|
||||
else:
|
||||
args.bf16 = False
|
||||
if "train_micro_batch_size_per_gpu" in deepspeed_config:
|
||||
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
|
||||
if "gradient_accumulation_steps" in deepspeed_config:
|
||||
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
|
||||
else:
|
||||
args.gradient_accumulation_steps = None
|
||||
if "optimizer" in deepspeed_config:
|
||||
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
|
||||
args.lr = optimizer_params_config.get("lr", args.lr)
|
||||
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
|
||||
args.deepspeed_config = deepspeed_config
|
||||
|
||||
# initialize distributed and random seed because it always seems to be necessary.
|
||||
initialize_distributed(args)
|
||||
args.seed = args.seed + mpu.get_data_parallel_rank()
|
||||
set_random_seed(args.seed)
|
||||
return args
|
||||
|
||||
|
||||
def initialize_distributed(args):
|
||||
"""Initialize torch.distributed."""
|
||||
if torch.distributed.is_initialized():
|
||||
if mpu.model_parallel_is_initialized():
|
||||
if args.model_parallel_size != mpu.get_model_parallel_world_size():
|
||||
raise ValueError(
|
||||
"model_parallel_size is inconsistent with prior configuration."
|
||||
"We currently do not support changing model_parallel_size."
|
||||
)
|
||||
return False
|
||||
else:
|
||||
if args.model_parallel_size > 1:
|
||||
warnings.warn(
|
||||
"model_parallel_size > 1 but torch.distributed is not initialized via SAT."
|
||||
"Please carefully make sure the correctness on your own."
|
||||
)
|
||||
mpu.initialize_model_parallel(args.model_parallel_size)
|
||||
return True
|
||||
# the automatic assignment of devices has been moved to arguments.py
|
||||
if args.device == "cpu":
|
||||
pass
|
||||
else:
|
||||
torch.cuda.set_device(args.device)
|
||||
# Call the init process
|
||||
init_method = "tcp://"
|
||||
args.master_ip = os.getenv("MASTER_ADDR", "localhost")
|
||||
|
||||
if args.world_size == 1:
|
||||
from sat.helpers import get_free_port
|
||||
|
||||
default_master_port = str(get_free_port())
|
||||
else:
|
||||
default_master_port = "6000"
|
||||
args.master_port = os.getenv("MASTER_PORT", default_master_port)
|
||||
init_method += args.master_ip + ":" + args.master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
||||
)
|
||||
|
||||
# Set the model-parallel / data-parallel communicators.
|
||||
mpu.initialize_model_parallel(args.model_parallel_size)
|
||||
|
||||
# Set vae context parallel group equal to model parallel group
|
||||
from sgm.util import set_context_parallel_group, initialize_context_parallel
|
||||
|
||||
if args.model_parallel_size <= 2:
|
||||
set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group())
|
||||
else:
|
||||
initialize_context_parallel(2)
|
||||
# mpu.initialize_model_parallel(1)
|
||||
# Optional DeepSpeed Activation Checkpointing Features
|
||||
if args.deepspeed:
|
||||
import deepspeed
|
||||
|
||||
deepspeed.init_distributed(
|
||||
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
||||
)
|
||||
# # It seems that it has no negative influence to configure it even without using checkpointing.
|
||||
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
|
||||
else:
|
||||
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
|
||||
try:
|
||||
import deepspeed
|
||||
from deepspeed.runtime.activation_checkpointing.checkpointing import (
|
||||
_CUDA_RNG_STATE_TRACKER,
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME,
|
||||
)
|
||||
|
||||
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
|
||||
except Exception as e:
|
||||
from sat.helpers import print_rank0
|
||||
|
||||
print_rank0(str(e), level="DEBUG")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def process_config_to_args(args):
|
||||
"""Fetch args from only --base"""
|
||||
|
||||
configs = [OmegaConf.load(cfg) for cfg in args.base]
|
||||
config = OmegaConf.merge(*configs)
|
||||
|
||||
args_config = config.pop("args", OmegaConf.create())
|
||||
for key in args_config:
|
||||
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
|
||||
arg = OmegaConf.to_object(args_config[key])
|
||||
else:
|
||||
arg = args_config[key]
|
||||
if hasattr(args, key):
|
||||
setattr(args, key, arg)
|
||||
|
||||
if "model" in config:
|
||||
model_config = config.pop("model", OmegaConf.create())
|
||||
args.model_config = model_config
|
||||
if "deepspeed" in config:
|
||||
deepspeed_config = config.pop("deepspeed", OmegaConf.create())
|
||||
args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
|
||||
if "data" in config:
|
||||
data_config = config.pop("data", OmegaConf.create())
|
||||
args.data_config = data_config
|
||||
|
||||
return args
|
166
sat/configs/cogvideox_2b_infer.yaml
Normal file
166
sat/configs/cogvideox_2b_infer.yaml
Normal file
@ -0,0 +1,166 @@
|
||||
args:
|
||||
latent_channels: 16
|
||||
mode: inference
|
||||
load: "CogVideoX-2b-sat/transformer"
|
||||
batch_size: 1
|
||||
input_type: txt
|
||||
input_file: test.txt
|
||||
sampling_num_frames: 13
|
||||
sampling_fps: 8
|
||||
fp16: True
|
||||
output_dir: outputs/
|
||||
force_inference: True
|
||||
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
disable_first_stage_autocast: true
|
||||
log_keys:
|
||||
- txt
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
quantize_c_noise: False
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
params:
|
||||
shift_scale: 3.0
|
||||
|
||||
network_config:
|
||||
target: dit_video_concat.DiffusionTransformer
|
||||
params:
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 49
|
||||
time_compressed_rate: 4
|
||||
latent_width: 90
|
||||
latent_height: 60
|
||||
num_layers: 30
|
||||
patch_size: 2
|
||||
in_channels: 16
|
||||
out_channels: 16
|
||||
hidden_size: 1920
|
||||
adm_in_channels: 256
|
||||
num_attention_heads: 30
|
||||
|
||||
transformer_args:
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
skip_init: false
|
||||
model_parallel_size: 1
|
||||
is_decoder: false
|
||||
|
||||
modules:
|
||||
pos_embed_config:
|
||||
target: dit_video_concat.Basic3DPositionEmbeddingMixin
|
||||
params:
|
||||
text_length: 226
|
||||
height_interpolation: 1.875
|
||||
width_interpolation: 1.875
|
||||
|
||||
patch_embed_config:
|
||||
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||
params:
|
||||
text_hidden_size: 4096
|
||||
|
||||
adaln_layer_config:
|
||||
target: dit_video_concat.AdaLNMixin
|
||||
params:
|
||||
qk_ln: True
|
||||
|
||||
final_layer_config:
|
||||
target: dit_video_concat.FinalLayerMixin
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl"
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
|
||||
ignore_keys: [ 'loss' ]
|
||||
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
|
||||
regularizer_config:
|
||||
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
encoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||
params:
|
||||
double_z: true
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 2, 4 ]
|
||||
attn_resolutions: [ ]
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
decoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||
params:
|
||||
double_z: True
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 2, 4 ]
|
||||
attn_resolutions: [ ]
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: false
|
||||
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||
params:
|
||||
offset_noise_level: 0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
uniform_sampling: True
|
||||
num_idx: 1000
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
params:
|
||||
shift_scale: 3.0
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||
params:
|
||||
num_steps: 50
|
||||
verbose: True
|
||||
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
params:
|
||||
shift_scale: 3.0
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||
params:
|
||||
scale: 6
|
||||
exp: 5
|
||||
num_steps: 50
|
225
sat/configs/cogvideox_2b_sft.yaml
Normal file
225
sat/configs/cogvideox_2b_sft.yaml
Normal file
@ -0,0 +1,225 @@
|
||||
args:
|
||||
checkpoint_activations: True ## using gradient checkpointing
|
||||
model_parallel_size: 1
|
||||
experiment_name: lora-disney
|
||||
mode: finetune
|
||||
load: "CogVideoX-2b-sat/transformer"
|
||||
no_load_rng: True
|
||||
train_iters: 1000
|
||||
eval_iters: 1
|
||||
eval_interval: 100
|
||||
eval_batch_size: 1
|
||||
save: ckpts
|
||||
save_interval: 100
|
||||
log_interval: 20
|
||||
train_data: ["disney"]
|
||||
valid_data: ["disney"]
|
||||
split: 1,0,0
|
||||
num_workers: 8
|
||||
force_train: True
|
||||
only_log_video_latents: True
|
||||
|
||||
data:
|
||||
target: data_video.SFTDataset
|
||||
params:
|
||||
video_size: [480, 720]
|
||||
fps: 8
|
||||
max_num_frames: 49
|
||||
skip_frms_num: 3.
|
||||
|
||||
deepspeed:
|
||||
train_micro_batch_size_per_gpu: 1
|
||||
gradient_accumulation_steps: 1
|
||||
steps_per_print: 50
|
||||
gradient_clipping: 0.1
|
||||
zero_optimization:
|
||||
stage: 2
|
||||
cpu_offload: false
|
||||
contiguous_gradients: false
|
||||
overlap_comm: true
|
||||
reduce_scatter: true
|
||||
reduce_bucket_size: 1000000000
|
||||
allgather_bucket_size: 1000000000
|
||||
load_from_fp32_weights: false
|
||||
zero_allow_untested_optimizer: true
|
||||
bf16:
|
||||
enabled: False
|
||||
fp16:
|
||||
enabled: True
|
||||
loss_scale: 0
|
||||
loss_scale_window: 400
|
||||
hysteresis: 2
|
||||
min_loss_scale: 1
|
||||
optimizer:
|
||||
type: sat.ops.FusedEmaAdam
|
||||
params:
|
||||
lr: 0.0002
|
||||
betas: [0.9, 0.95]
|
||||
eps: 1e-8
|
||||
weight_decay: 1e-4
|
||||
activation_checkpointing:
|
||||
partition_activations: false
|
||||
contiguous_memory_optimization: false
|
||||
wall_clock_breakdown: false
|
||||
|
||||
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: ['all'] ## Using Lora
|
||||
log_keys:
|
||||
- txt
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
quantize_c_noise: False
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
params:
|
||||
shift_scale: 3.0
|
||||
|
||||
network_config:
|
||||
target: dit_video_concat.DiffusionTransformer
|
||||
params:
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 49
|
||||
time_compressed_rate: 4
|
||||
latent_width: 90
|
||||
latent_height: 60
|
||||
num_layers: 30
|
||||
patch_size: 2
|
||||
in_channels: 16
|
||||
out_channels: 16
|
||||
hidden_size: 1920
|
||||
adm_in_channels: 256
|
||||
num_attention_heads: 30
|
||||
|
||||
transformer_args:
|
||||
checkpoint_activations: True ## using gradient checkpointing
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
skip_init: false
|
||||
model_parallel_size: 1
|
||||
is_decoder: false
|
||||
|
||||
modules:
|
||||
pos_embed_config:
|
||||
target: dit_video_concat.Basic3DPositionEmbeddingMixin
|
||||
params:
|
||||
text_length: 226
|
||||
height_interpolation: 1.875
|
||||
width_interpolation: 1.875
|
||||
|
||||
lora_config: ## Using Lora
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
|
||||
patch_embed_config:
|
||||
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||
params:
|
||||
text_hidden_size: 4096
|
||||
|
||||
adaln_layer_config:
|
||||
target: dit_video_concat.AdaLNMixin
|
||||
params:
|
||||
qk_ln: True
|
||||
|
||||
final_layer_config:
|
||||
target: dit_video_concat.FinalLayerMixin
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl"
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt"
|
||||
ignore_keys: [ 'loss' ]
|
||||
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
|
||||
regularizer_config:
|
||||
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
encoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||
params:
|
||||
double_z: true
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 2, 4 ]
|
||||
attn_resolutions: [ ]
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
decoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||
params:
|
||||
double_z: True
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1, 2, 2, 4 ]
|
||||
attn_resolutions: [ ]
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: false
|
||||
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||
params:
|
||||
offset_noise_level: 0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
uniform_sampling: True
|
||||
num_idx: 1000
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
params:
|
||||
shift_scale: 3.0
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||
params:
|
||||
num_steps: 50
|
||||
verbose: True
|
||||
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
params:
|
||||
shift_scale: 3.0
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||
params:
|
||||
scale: 6
|
||||
exp: 5
|
||||
num_steps: 50
|
2
sat/configs/test.txt
Normal file
2
sat/configs/test.txt
Normal file
@ -0,0 +1,2 @@
|
||||
A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.
|
||||
A cinematic view of Earth rotating in space, showcasing the planet's vibrant blue oceans and swirling white clouds, high quality, realistic. The scene transitions from day to night, highlighting the twinkling city lights and the soft glow of the moon reflecting on the surface. Stars and distant galaxies form a breathtaking backdrop, adding to the grandeur and beauty of the Earth seen from space.
|
451
sat/data_video.py
Normal file
451
sat/data_video.py
Normal file
@ -0,0 +1,451 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
import math
|
||||
import torchvision.transforms as TT
|
||||
from sgm.webds import MetaDistributedWebDataset
|
||||
import random
|
||||
from fractions import Fraction
|
||||
from typing import Union, Optional, Dict, Any, Tuple
|
||||
from torchvision.io.video import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.io import _video_opt
|
||||
from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames
|
||||
from torchvision.transforms.functional import center_crop, resize
|
||||
from torchvision.transforms import InterpolationMode
|
||||
import decord
|
||||
from decord import VideoReader
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def read_video(
|
||||
filename: str,
|
||||
start_pts: Union[float, Fraction] = 0,
|
||||
end_pts: Optional[Union[float, Fraction]] = None,
|
||||
pts_unit: str = "pts",
|
||||
output_format: str = "THWC",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
"""
|
||||
Reads a video from a file, returning both the video frames and the audio frames
|
||||
|
||||
Args:
|
||||
filename (str): path to the video file
|
||||
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||
The start presentation time of the video
|
||||
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
|
||||
The end presentation time
|
||||
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
|
||||
either 'pts' or 'sec'. Defaults to 'pts'.
|
||||
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
|
||||
|
||||
Returns:
|
||||
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
|
||||
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
|
||||
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
|
||||
"""
|
||||
|
||||
output_format = output_format.upper()
|
||||
if output_format not in ("THWC", "TCHW"):
|
||||
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
|
||||
|
||||
_check_av_available()
|
||||
|
||||
if end_pts is None:
|
||||
end_pts = float("inf")
|
||||
|
||||
if end_pts < start_pts:
|
||||
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
||||
|
||||
info = {}
|
||||
audio_frames = []
|
||||
audio_timebase = _video_opt.default_timebase
|
||||
|
||||
with av.open(filename, metadata_errors="ignore") as container:
|
||||
if container.streams.audio:
|
||||
audio_timebase = container.streams.audio[0].time_base
|
||||
if container.streams.video:
|
||||
video_frames = _read_from_stream(
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.video[0],
|
||||
{"video": 0},
|
||||
)
|
||||
video_fps = container.streams.video[0].average_rate
|
||||
# guard against potentially corrupted files
|
||||
if video_fps is not None:
|
||||
info["video_fps"] = float(video_fps)
|
||||
|
||||
if container.streams.audio:
|
||||
audio_frames = _read_from_stream(
|
||||
container,
|
||||
start_pts,
|
||||
end_pts,
|
||||
pts_unit,
|
||||
container.streams.audio[0],
|
||||
{"audio": 0},
|
||||
)
|
||||
info["audio_fps"] = container.streams.audio[0].rate
|
||||
|
||||
aframes_list = [frame.to_ndarray() for frame in audio_frames]
|
||||
|
||||
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
|
||||
|
||||
if aframes_list:
|
||||
aframes = np.concatenate(aframes_list, 1)
|
||||
aframes = torch.as_tensor(aframes)
|
||||
if pts_unit == "sec":
|
||||
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
|
||||
if end_pts != float("inf"):
|
||||
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
|
||||
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
|
||||
else:
|
||||
aframes = torch.empty((1, 0), dtype=torch.float32)
|
||||
|
||||
if output_format == "TCHW":
|
||||
# [T,H,W,C] --> [T,C,H,W]
|
||||
vframes = vframes.permute(0, 3, 1, 2)
|
||||
|
||||
return vframes, aframes, info
|
||||
|
||||
|
||||
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
|
||||
def pad_last_frame(tensor, num_frames):
|
||||
# T, H, W, C
|
||||
if tensor.shape[0] < num_frames:
|
||||
last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
|
||||
padded_tensor = torch.cat([tensor, last_frame], dim=0)
|
||||
return padded_tensor
|
||||
else:
|
||||
return tensor[:num_frames]
|
||||
|
||||
|
||||
def load_video(
|
||||
video_data,
|
||||
sampling="uniform",
|
||||
duration=None,
|
||||
num_frames=4,
|
||||
wanted_fps=None,
|
||||
actual_fps=None,
|
||||
skip_frms_num=0.0,
|
||||
nb_read_frames=None,
|
||||
):
|
||||
decord.bridge.set_bridge("torch")
|
||||
vr = VideoReader(uri=video_data, height=-1, width=-1)
|
||||
if nb_read_frames is not None:
|
||||
ori_vlen = nb_read_frames
|
||||
else:
|
||||
ori_vlen = min(int(duration * actual_fps) - 1, len(vr))
|
||||
|
||||
max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps)
|
||||
start = random.randint(skip_frms_num, max_seek + 1)
|
||||
end = int(start + num_frames / wanted_fps * actual_fps)
|
||||
n_frms = num_frames
|
||||
|
||||
if sampling == "uniform":
|
||||
indices = np.arange(start, end, (end - start) / n_frms).astype(int)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# get_batch -> T, H, W, C
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
|
||||
return pad_last_frame(tensor_frms, num_frames)
|
||||
|
||||
|
||||
import threading
|
||||
|
||||
|
||||
def load_video_with_timeout(*args, **kwargs):
|
||||
video_container = {}
|
||||
|
||||
def target_function():
|
||||
video = load_video(*args, **kwargs)
|
||||
video_container["video"] = video
|
||||
|
||||
thread = threading.Thread(target=target_function)
|
||||
thread.start()
|
||||
timeout = 20
|
||||
thread.join(timeout)
|
||||
|
||||
if thread.is_alive():
|
||||
print("Loading video timed out")
|
||||
raise TimeoutError
|
||||
return video_container.get("video", None).contiguous()
|
||||
|
||||
|
||||
def process_video(
|
||||
video_path,
|
||||
image_size=None,
|
||||
duration=None,
|
||||
num_frames=4,
|
||||
wanted_fps=None,
|
||||
actual_fps=None,
|
||||
skip_frms_num=0.0,
|
||||
nb_read_frames=None,
|
||||
):
|
||||
"""
|
||||
video_path: str or io.BytesIO
|
||||
image_size: .
|
||||
duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown.
|
||||
num_frames: wanted num_frames.
|
||||
wanted_fps: .
|
||||
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
||||
"""
|
||||
|
||||
video = load_video_with_timeout(
|
||||
video_path,
|
||||
duration=duration,
|
||||
num_frames=num_frames,
|
||||
wanted_fps=wanted_fps,
|
||||
actual_fps=actual_fps,
|
||||
skip_frms_num=skip_frms_num,
|
||||
nb_read_frames=nb_read_frames,
|
||||
)
|
||||
|
||||
# --- copy and modify the image process ---
|
||||
video = video.permute(0, 3, 1, 2) # [T, C, H, W]
|
||||
|
||||
# resize
|
||||
if image_size is not None:
|
||||
video = resize_for_rectangle_crop(video, image_size, reshape_mode="center")
|
||||
|
||||
return video
|
||||
|
||||
|
||||
def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"):
|
||||
while True:
|
||||
r = next(src)
|
||||
if "mp4" in r:
|
||||
video_data = r["mp4"]
|
||||
elif "avi" in r:
|
||||
video_data = r["avi"]
|
||||
else:
|
||||
print("No video data found")
|
||||
continue
|
||||
|
||||
if txt_key not in r:
|
||||
txt = ""
|
||||
else:
|
||||
txt = r[txt_key]
|
||||
|
||||
if isinstance(txt, bytes):
|
||||
txt = txt.decode("utf-8")
|
||||
else:
|
||||
txt = str(txt)
|
||||
|
||||
duration = r.get("duration", None)
|
||||
if duration is not None:
|
||||
duration = float(duration)
|
||||
else:
|
||||
continue
|
||||
|
||||
actual_fps = r.get("fps", None)
|
||||
if actual_fps is not None:
|
||||
actual_fps = float(actual_fps)
|
||||
else:
|
||||
continue
|
||||
|
||||
required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num
|
||||
required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps
|
||||
|
||||
if duration is not None and duration < required_duration:
|
||||
continue
|
||||
|
||||
try:
|
||||
frames = process_video(
|
||||
io.BytesIO(video_data),
|
||||
num_frames=num_frames,
|
||||
wanted_fps=fps,
|
||||
image_size=image_size,
|
||||
duration=duration,
|
||||
actual_fps=actual_fps,
|
||||
skip_frms_num=skip_frms_num,
|
||||
)
|
||||
frames = (frames - 127.5) / 127.5
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
item = {
|
||||
"mp4": frames,
|
||||
"txt": txt,
|
||||
"num_frames": num_frames,
|
||||
"fps": fps,
|
||||
}
|
||||
|
||||
yield item
|
||||
|
||||
|
||||
class VideoDataset(MetaDistributedWebDataset):
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
image_size,
|
||||
num_frames,
|
||||
fps,
|
||||
skip_frms_num=0.0,
|
||||
nshards=sys.maxsize,
|
||||
seed=1,
|
||||
meta_names=None,
|
||||
shuffle_buffer=1000,
|
||||
include_dirs=None,
|
||||
txt_key="caption",
|
||||
**kwargs,
|
||||
):
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 1000000)
|
||||
if meta_names is None:
|
||||
meta_names = []
|
||||
|
||||
if path.startswith(";"):
|
||||
path, include_dirs = path.split(";", 1)
|
||||
super().__init__(
|
||||
path,
|
||||
partial(
|
||||
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
|
||||
),
|
||||
seed,
|
||||
meta_names=meta_names,
|
||||
shuffle_buffer=shuffle_buffer,
|
||||
nshards=nshards,
|
||||
include_dirs=include_dirs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_dataset_function(cls, path, args, **kwargs):
|
||||
return cls(path, **kwargs)
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
|
||||
"""
|
||||
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
|
||||
"""
|
||||
super(SFTDataset, self).__init__()
|
||||
|
||||
self.videos_list = []
|
||||
self.captions_list = []
|
||||
self.num_frames_list = []
|
||||
self.fps_list = []
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
for root, dirnames, filenames in os.walk(data_dir):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".mp4"):
|
||||
video_path = os.path.join(root, filename)
|
||||
vr = VideoReader(uri=video_path, height=-1, width=-1)
|
||||
actual_fps = vr.get_avg_fps()
|
||||
ori_vlen = len(vr)
|
||||
|
||||
if ori_vlen / actual_fps * fps > max_num_frames:
|
||||
num_frames = max_num_frames
|
||||
start = int(skip_frms_num)
|
||||
end = int(start + num_frames / fps * actual_fps)
|
||||
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
if ori_vlen > max_num_frames:
|
||||
num_frames = max_num_frames
|
||||
start = int(skip_frms_num)
|
||||
end = int(ori_vlen - skip_frms_num)
|
||||
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
|
||||
def nearest_smaller_4k_plus_1(n):
|
||||
remainder = n % 4
|
||||
if remainder == 0:
|
||||
return n - 3
|
||||
else:
|
||||
return n - remainder + 1
|
||||
|
||||
start = int(skip_frms_num)
|
||||
end = int(ori_vlen - skip_frms_num)
|
||||
num_frames = nearest_smaller_4k_plus_1(
|
||||
end - start
|
||||
) # 3D VAE requires the number of frames to be 4k+1
|
||||
end = int(start + num_frames)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
|
||||
tensor_frms = pad_last_frame(
|
||||
tensor_frms, num_frames
|
||||
) # the len of indices may be less than num_frames, due to round error
|
||||
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
|
||||
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
|
||||
tensor_frms = (tensor_frms - 127.5) / 127.5
|
||||
self.videos_list.append(tensor_frms)
|
||||
|
||||
# caption
|
||||
caption_path = os.path.join(root, filename.replace('videos', 'labels').replace('.mp4', '.txt'))
|
||||
if os.path.exists(caption_path):
|
||||
caption = open(caption_path, "r").read().splitlines()[0]
|
||||
else:
|
||||
caption = ""
|
||||
self.captions_list.append(caption)
|
||||
self.num_frames_list.append(num_frames)
|
||||
self.fps_list.append(fps)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = {
|
||||
"mp4": self.videos_list[index],
|
||||
"txt": self.captions_list[index],
|
||||
"num_frames": self.num_frames_list[index],
|
||||
"fps": self.fps_list[index],
|
||||
}
|
||||
return item
|
||||
|
||||
def __len__(self):
|
||||
return len(self.fps_list)
|
||||
|
||||
@classmethod
|
||||
def create_dataset_function(cls, path, args, **kwargs):
|
||||
return cls(data_dir=path, **kwargs)
|
318
sat/diffusion_video.py
Normal file
318
sat/diffusion_video.py
Normal file
@ -0,0 +1,318 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sat.helpers import print_rank0
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sgm.modules import UNCONDITIONAL_CONFIG
|
||||
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
|
||||
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
||||
from sgm.util import (
|
||||
default,
|
||||
disabled_train,
|
||||
get_obj_from_str,
|
||||
instantiate_from_config,
|
||||
log_txt_as_img,
|
||||
)
|
||||
import gc
|
||||
from sat import mpu
|
||||
import random
|
||||
|
||||
|
||||
class SATVideoDiffusionEngine(nn.Module):
|
||||
def __init__(self, args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
model_config = args.model_config
|
||||
# model args preprocess
|
||||
log_keys = model_config.get("log_keys", None)
|
||||
input_key = model_config.get("input_key", "mp4")
|
||||
network_config = model_config.get("network_config", None)
|
||||
network_wrapper = model_config.get("network_wrapper", None)
|
||||
denoiser_config = model_config.get("denoiser_config", None)
|
||||
sampler_config = model_config.get("sampler_config", None)
|
||||
conditioner_config = model_config.get("conditioner_config", None)
|
||||
first_stage_config = model_config.get("first_stage_config", None)
|
||||
loss_fn_config = model_config.get("loss_fn_config", None)
|
||||
scale_factor = model_config.get("scale_factor", 1.0)
|
||||
latent_input = model_config.get("latent_input", False)
|
||||
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
|
||||
no_cond_log = model_config.get("disable_first_stage_autocast", False)
|
||||
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
|
||||
compile_model = model_config.get("compile_model", False)
|
||||
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
|
||||
lr_scale = model_config.get("lr_scale", None)
|
||||
lora_train = model_config.get("lora_train", False)
|
||||
self.use_pd = model_config.get("use_pd", False) # progressive distillation
|
||||
|
||||
self.log_keys = log_keys
|
||||
self.input_key = input_key
|
||||
self.not_trainable_prefixes = not_trainable_prefixes
|
||||
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
||||
self.lr_scale = lr_scale
|
||||
self.lora_train = lora_train
|
||||
self.noised_image_input = model_config.get("noised_image_input", False)
|
||||
self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
|
||||
self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0)
|
||||
if args.fp16:
|
||||
dtype = torch.float16
|
||||
dtype_str = "fp16"
|
||||
elif args.bf16:
|
||||
dtype = torch.bfloat16
|
||||
dtype_str = "bf16"
|
||||
else:
|
||||
dtype = torch.float32
|
||||
dtype_str = "fp32"
|
||||
self.dtype = dtype
|
||||
self.dtype_str = dtype_str
|
||||
|
||||
network_config["params"]["dtype"] = dtype_str
|
||||
model = instantiate_from_config(network_config)
|
||||
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
||||
model, compile_model=compile_model, dtype=dtype
|
||||
)
|
||||
|
||||
self.denoiser = instantiate_from_config(denoiser_config)
|
||||
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
|
||||
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
|
||||
|
||||
self._init_first_stage(first_stage_config)
|
||||
|
||||
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
||||
|
||||
self.latent_input = latent_input
|
||||
self.scale_factor = scale_factor
|
||||
self.disable_first_stage_autocast = disable_first_stage_autocast
|
||||
self.no_cond_log = no_cond_log
|
||||
self.device = args.device
|
||||
|
||||
def disable_untrainable_params(self):
|
||||
total_trainable = 0
|
||||
for n, p in self.named_parameters():
|
||||
if p.requires_grad == False:
|
||||
continue
|
||||
flag = False
|
||||
for prefix in self.not_trainable_prefixes:
|
||||
if n.startswith(prefix) or prefix == "all":
|
||||
flag = True
|
||||
break
|
||||
|
||||
lora_prefix = ["matrix_A", "matrix_B"]
|
||||
for prefix in lora_prefix:
|
||||
if prefix in n:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
p.requires_grad_(False)
|
||||
else:
|
||||
total_trainable += p.numel()
|
||||
|
||||
print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
# reload the initial params from previous trained modules
|
||||
# you can also get access to other mixins through parent_model.get_mixin().
|
||||
pass
|
||||
|
||||
def _init_first_stage(self, config):
|
||||
model = instantiate_from_config(config).eval()
|
||||
model.train = disabled_train
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
self.first_stage_model = model
|
||||
|
||||
def forward(self, x, batch):
|
||||
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
||||
loss_mean = loss.mean()
|
||||
loss_dict = {"loss": loss_mean}
|
||||
return loss_mean, loss_dict
|
||||
|
||||
def shared_step(self, batch: Dict) -> Any:
|
||||
x = self.get_input(batch)
|
||||
if self.lr_scale is not None:
|
||||
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
|
||||
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
|
||||
lr_z = self.encode_first_stage(lr_x, batch)
|
||||
batch["lr_input"] = lr_z
|
||||
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
x = self.encode_first_stage(x, batch)
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loss, loss_dict = self(x, batch)
|
||||
return loss, loss_dict
|
||||
|
||||
def get_input(self, batch):
|
||||
return batch[self.input_key].to(self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_first_stage(self, z):
|
||||
z = 1.0 / self.scale_factor * z
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
||||
n_rounds = math.ceil(z.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
for n in range(n_rounds):
|
||||
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
||||
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||
else:
|
||||
kwargs = {}
|
||||
use_cp = False
|
||||
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs)
|
||||
all_out.append(out)
|
||||
out = torch.cat(all_out, dim=0)
|
||||
return out
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_first_stage(self, x, batch):
|
||||
frame = x.shape[2]
|
||||
|
||||
if frame > 1 and self.latent_input:
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return x * self.scale_factor # already encoded
|
||||
|
||||
use_cp = False
|
||||
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
||||
n_rounds = math.ceil(x.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
for n in range(n_rounds):
|
||||
out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples], use_cp=use_cp)
|
||||
all_out.append(out)
|
||||
z = torch.cat(all_out, dim=0)
|
||||
z = self.scale_factor * z
|
||||
return z
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
cond: Dict,
|
||||
uc: Union[Dict, None] = None,
|
||||
batch_size: int = 16,
|
||||
shape: Union[None, Tuple, List] = None,
|
||||
prefix=None,
|
||||
concat_images=None,
|
||||
**kwargs,
|
||||
):
|
||||
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
|
||||
if hasattr(self, "seeded_noise"):
|
||||
randn = self.seeded_noise(randn)
|
||||
|
||||
if prefix is not None:
|
||||
randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
|
||||
|
||||
# broadcast noise
|
||||
mp_size = mpu.get_model_parallel_world_size()
|
||||
if mp_size > 1:
|
||||
global_rank = torch.distributed.get_rank() // mp_size
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
|
||||
|
||||
scale = None
|
||||
scale_emb = None
|
||||
|
||||
denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
|
||||
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
||||
)
|
||||
|
||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
|
||||
samples = samples.to(self.dtype)
|
||||
return samples
|
||||
|
||||
@torch.no_grad()
|
||||
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
||||
"""
|
||||
Defines heuristics to log different conditionings.
|
||||
These can be lists of strings (text-to-image), tensors, ints, ...
|
||||
"""
|
||||
image_h, image_w = batch[self.input_key].shape[3:]
|
||||
log = dict()
|
||||
|
||||
for embedder in self.conditioner.embedders:
|
||||
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
|
||||
x = batch[embedder.input_key][:n]
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.dim() == 1:
|
||||
# class-conditional, convert integer to string
|
||||
x = [str(x[i].item()) for i in range(x.shape[0])]
|
||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
||||
elif x.dim() == 2:
|
||||
# size and crop cond and the like
|
||||
x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])]
|
||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif isinstance(x, (List, ListConfig)):
|
||||
if isinstance(x[0], str):
|
||||
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
log[embedder.input_key] = xc
|
||||
return log
|
||||
|
||||
@torch.no_grad()
|
||||
def log_video(
|
||||
self,
|
||||
batch: Dict,
|
||||
N: int = 8,
|
||||
ucg_keys: List[str] = None,
|
||||
only_log_video_latents=False,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
||||
if ucg_keys:
|
||||
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
||||
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
||||
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
||||
)
|
||||
else:
|
||||
ucg_keys = conditioner_input_keys
|
||||
log = dict()
|
||||
|
||||
x = self.get_input(batch)
|
||||
|
||||
c, uc = self.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
|
||||
)
|
||||
|
||||
sampling_kwargs = {}
|
||||
|
||||
N = min(x.shape[0], N)
|
||||
x = x.to(self.device)[:N]
|
||||
if not self.latent_input:
|
||||
log["inputs"] = x.to(torch.float32)
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
z = self.encode_first_stage(x, batch)
|
||||
if not only_log_video_latents:
|
||||
log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)
|
||||
log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous()
|
||||
z = z.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
log.update(self.log_conditionings(batch, N))
|
||||
|
||||
for k in c:
|
||||
if isinstance(c[k], torch.Tensor):
|
||||
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
||||
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if only_log_video_latents:
|
||||
latents = 1.0 / self.scale_factor * samples
|
||||
log["latents"] = latents
|
||||
else:
|
||||
samples = self.decode_first_stage(samples).to(torch.float32)
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
log["samples"] = samples
|
||||
return log
|
858
sat/dit_video_concat.py
Normal file
858
sat/dit_video_concat.py
Normal file
@ -0,0 +1,858 @@
|
||||
from functools import partial
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sat.model.base_model import BaseModel, non_conflict
|
||||
from sat.model.mixins import BaseMixin
|
||||
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
|
||||
from sat.mpu.layers import ColumnParallelLinear
|
||||
from sgm.util import instantiate_from_config
|
||||
|
||||
from sgm.modules.diffusionmodules.openaimodel import Timestep
|
||||
from sgm.modules.diffusionmodules.util import (
|
||||
linear,
|
||||
timestep_embedding,
|
||||
)
|
||||
from sat.ops.layernorm import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
class ImagePatchEmbeddingMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
bias=True,
|
||||
text_hidden_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
if text_hidden_size is not None:
|
||||
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
||||
else:
|
||||
self.text_proj = None
|
||||
|
||||
def word_embedding_forward(self, input_ids, **kwargs):
|
||||
# now is 3d patch
|
||||
images = kwargs["images"] # (b,t,c,h,w)
|
||||
B, T = images.shape[:2]
|
||||
emb = images.view(-1, *images.shape[2:])
|
||||
emb = self.proj(emb) # ((b t),d,h/2,w/2)
|
||||
emb = emb.view(B, T, *emb.shape[1:])
|
||||
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
|
||||
emb = rearrange(emb, "b t n d -> b (t n) d")
|
||||
|
||||
if self.text_proj is not None:
|
||||
text_emb = self.text_proj(kwargs["encoder_outputs"])
|
||||
emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d)
|
||||
|
||||
emb = emb.contiguous()
|
||||
return emb # (b,n_t+t*n_i,d)
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
w = self.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.proj.bias, 0)
|
||||
del self.transformer.word_embeddings
|
||||
|
||||
|
||||
def get_3d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
grid_height,
|
||||
grid_width,
|
||||
t_size,
|
||||
cls_token=False,
|
||||
height_interpolation=1.0,
|
||||
width_interpolation=1.0,
|
||||
time_interpolation=1.0,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
t_size: int of the temporal size
|
||||
return:
|
||||
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
embed_dim_spatial = embed_dim // 4 * 3
|
||||
embed_dim_temporal = embed_dim // 4
|
||||
|
||||
# spatial
|
||||
grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation
|
||||
grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_height, grid_width])
|
||||
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
||||
|
||||
# temporal
|
||||
grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation
|
||||
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
||||
|
||||
# concate: [T, H, W] order
|
||||
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
||||
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
|
||||
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
||||
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
|
||||
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
|
||||
|
||||
return pos_embed # [T, H*W, D]
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_height, dtype=np.float32)
|
||||
grid_w = np.arange(grid_width, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_height, grid_width])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class Basic3DPositionEmbeddingMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
height,
|
||||
width,
|
||||
compressed_num_frames,
|
||||
hidden_size,
|
||||
text_length=0,
|
||||
height_interpolation=1.0,
|
||||
width_interpolation=1.0,
|
||||
time_interpolation=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.text_length = text_length
|
||||
self.compressed_num_frames = compressed_num_frames
|
||||
self.spatial_length = height * width
|
||||
self.num_patches = height * width * compressed_num_frames
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
|
||||
)
|
||||
self.height_interpolation = height_interpolation
|
||||
self.width_interpolation = width_interpolation
|
||||
self.time_interpolation = time_interpolation
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
if kwargs["images"].shape[1] == 1:
|
||||
return self.pos_embedding[:, : self.text_length + self.spatial_length]
|
||||
|
||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
del self.transformer.position_embeddings
|
||||
pos_embed = get_3d_sincos_pos_embed(
|
||||
self.pos_embedding.shape[-1],
|
||||
self.height,
|
||||
self.width,
|
||||
self.compressed_num_frames,
|
||||
height_interpolation=self.height_interpolation,
|
||||
width_interpolation=self.width_interpolation,
|
||||
time_interpolation=self.time_interpolation,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float()
|
||||
pos_embed = rearrange(pos_embed, "t n d -> (t n) d")
|
||||
self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed)
|
||||
|
||||
|
||||
def broadcat(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatentation"
|
||||
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
||||
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
||||
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
height,
|
||||
width,
|
||||
compressed_num_frames,
|
||||
hidden_size,
|
||||
hidden_size_head,
|
||||
text_length,
|
||||
theta=10000,
|
||||
rot_v=False,
|
||||
pnp=False,
|
||||
learnable_pos_embed=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.rot_v = rot_v
|
||||
|
||||
dim_t = hidden_size_head // 4
|
||||
dim_h = hidden_size_head // 8 * 3
|
||||
dim_w = hidden_size_head // 8 * 3
|
||||
|
||||
# 'lang':
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
|
||||
|
||||
grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
|
||||
grid_h = torch.arange(height, dtype=torch.float32)
|
||||
grid_w = torch.arange(width, dtype=torch.float32)
|
||||
|
||||
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
|
||||
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
|
||||
|
||||
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
|
||||
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
# (T H W D)
|
||||
|
||||
self.pnp = pnp
|
||||
|
||||
if not self.pnp:
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
|
||||
freqs = freqs.contiguous()
|
||||
freqs_sin = freqs.sin()
|
||||
freqs_cos = freqs.cos()
|
||||
self.register_buffer("freqs_sin", freqs_sin)
|
||||
self.register_buffer("freqs_cos", freqs_cos)
|
||||
|
||||
self.text_length = text_length
|
||||
if learnable_pos_embed:
|
||||
num_patches = height * width * compressed_num_frames + text_length
|
||||
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
def rotary(self, t, **kwargs):
|
||||
if self.pnp:
|
||||
t_coords = kwargs["rope_position_ids"][:, :, 0]
|
||||
x_coords = kwargs["rope_position_ids"][:, :, 1]
|
||||
y_coords = kwargs["rope_position_ids"][:, :, 2]
|
||||
mask = (x_coords != -1) & (y_coords != -1) & (t_coords != -1)
|
||||
freqs = torch.zeros([t.shape[0], t.shape[2], t.shape[3]], dtype=t.dtype, device=t.device)
|
||||
freqs[mask] = self.freqs[t_coords[mask], x_coords[mask], y_coords[mask]]
|
||||
|
||||
else:
|
||||
|
||||
def reshape_freq(freqs):
|
||||
frame = t.shape[2]
|
||||
freqs = freqs[:frame].contiguous()
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
||||
return freqs
|
||||
|
||||
freqs_cos = reshape_freq(self.freqs_cos)
|
||||
freqs_sin = reshape_freq(self.freqs_sin)
|
||||
|
||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
if self.pos_embedding is not None:
|
||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
def attention_fn(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
attention_dropout=None,
|
||||
log_attention_weights=None,
|
||||
scaling_attention_score=True,
|
||||
**kwargs,
|
||||
):
|
||||
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
||||
|
||||
if self.pnp:
|
||||
query_layer = self.rotary(query_layer, **kwargs)
|
||||
key_layer = self.rotary(key_layer, **kwargs)
|
||||
if self.rot_v:
|
||||
value_layer = self.rotary(value_layer)
|
||||
else:
|
||||
query_layer = torch.cat(
|
||||
(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
key_layer = torch.cat(
|
||||
(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
if self.rot_v:
|
||||
value_layer = torch.cat(
|
||||
(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
]
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
return attention_fn_default(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
attention_dropout=attention_dropout,
|
||||
log_attention_weights=log_attention_weights,
|
||||
scaling_attention_score=scaling_attention_score,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
|
||||
"""
|
||||
x: (N, T/2 * S, patch_size**3 * C)
|
||||
imgs: (N, T, H, W, C)
|
||||
"""
|
||||
if rope_position_ids is not None:
|
||||
assert NotImplementedError
|
||||
# do pix2struct unpatchify
|
||||
L = x.shape[1]
|
||||
x = x.reshape(shape=(x.shape[0], L, p, p, c))
|
||||
x = torch.einsum("nlpqc->ncplq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
|
||||
else:
|
||||
b = x.shape[0]
|
||||
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
|
||||
|
||||
return imgs
|
||||
|
||||
|
||||
class FinalLayerMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
time_embed_dim,
|
||||
patch_size,
|
||||
out_channels,
|
||||
latent_width,
|
||||
latent_height,
|
||||
elementwise_affine,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
|
||||
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.latent_width = latent_width
|
||||
self.latent_height = latent_height
|
||||
|
||||
def final_forward(self, logits, **kwargs):
|
||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
|
||||
|
||||
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
|
||||
return unpatchify(
|
||||
x,
|
||||
c=self.out_channels,
|
||||
p=self.patch_size,
|
||||
w=self.latent_width // self.patch_size,
|
||||
h=self.latent_height // self.patch_size,
|
||||
rope_position_ids=kwargs.get("rope_position_ids", None),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
nn.init.xavier_uniform_(self.linear.weight)
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
|
||||
class SwiGLUMixin(BaseMixin):
|
||||
def __init__(self, num_layers, in_features, hidden_features, bias=False):
|
||||
super().__init__()
|
||||
self.w2 = nn.ModuleList(
|
||||
[
|
||||
ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
gather_output=False,
|
||||
bias=bias,
|
||||
module=self,
|
||||
name="dense_h_to_4h_gate",
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def mlp_forward(self, hidden_states, **kw_args):
|
||||
x = hidden_states
|
||||
origin = self.transformer.layers[kw_args["layer_id"]].mlp
|
||||
x1 = origin.dense_h_to_4h(x)
|
||||
x2 = self.w2[kw_args["layer_id"]](x)
|
||||
hidden = origin.activation_func(x2) * x1
|
||||
x = origin.dense_4h_to_h(hidden)
|
||||
return x
|
||||
|
||||
|
||||
class AdaLNMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
time_embed_dim,
|
||||
compressed_num_frames,
|
||||
qk_ln=True,
|
||||
hidden_size_head=None,
|
||||
elementwise_affine=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.compressed_num_frames = compressed_num_frames
|
||||
|
||||
self.adaLN_modulations = nn.ModuleList(
|
||||
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.qk_ln = qk_ln
|
||||
if qk_ln:
|
||||
self.query_layernorm_list = nn.ModuleList(
|
||||
[
|
||||
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.key_layernorm_list = nn.ModuleList(
|
||||
[
|
||||
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def layer_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
mask,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
text_length = kwargs["text_length"]
|
||||
# hidden_states (b,(n_t+t*n_i),d)
|
||||
text_hidden_states = hidden_states[:, :text_length] # (b,n,d)
|
||||
img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d)
|
||||
layer = self.transformer.layers[kwargs["layer_id"]]
|
||||
adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]]
|
||||
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
text_shift_msa,
|
||||
text_scale_msa,
|
||||
text_gate_msa,
|
||||
text_shift_mlp,
|
||||
text_scale_mlp,
|
||||
text_gate_mlp,
|
||||
) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1)
|
||||
gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = (
|
||||
gate_msa.unsqueeze(1),
|
||||
gate_mlp.unsqueeze(1),
|
||||
text_gate_msa.unsqueeze(1),
|
||||
text_gate_mlp.unsqueeze(1),
|
||||
)
|
||||
|
||||
# self full attention (b,(t n),d)
|
||||
img_attention_input = layer.input_layernorm(img_hidden_states)
|
||||
text_attention_input = layer.input_layernorm(text_hidden_states)
|
||||
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
|
||||
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
|
||||
|
||||
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
|
||||
attention_output = layer.attention(attention_input, mask, **kwargs)
|
||||
text_attention_output = attention_output[:, :text_length] # (b,n,d)
|
||||
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
|
||||
|
||||
if self.transformer.layernorm_order == "sandwich":
|
||||
text_attention_output = layer.third_layernorm(text_attention_output)
|
||||
img_attention_output = layer.third_layernorm(img_attention_output)
|
||||
img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d)
|
||||
text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d)
|
||||
|
||||
# mlp (b,(t n),d)
|
||||
img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d)
|
||||
text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d)
|
||||
img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp)
|
||||
text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp)
|
||||
mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d
|
||||
mlp_output = layer.mlp(mlp_input, **kwargs)
|
||||
img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d)
|
||||
text_mlp_output = mlp_output[:, :text_length] # language (b,n,d)
|
||||
if self.transformer.layernorm_order == "sandwich":
|
||||
text_mlp_output = layer.fourth_layernorm(text_mlp_output)
|
||||
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
|
||||
|
||||
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
|
||||
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
|
||||
|
||||
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
|
||||
return hidden_states
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
for layer in self.adaLN_modulations:
|
||||
nn.init.constant_(layer[-1].weight, 0)
|
||||
nn.init.constant_(layer[-1].bias, 0)
|
||||
|
||||
@non_conflict
|
||||
def attention_fn(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
attention_dropout=None,
|
||||
log_attention_weights=None,
|
||||
scaling_attention_score=True,
|
||||
old_impl=attention_fn_default,
|
||||
**kwargs,
|
||||
):
|
||||
if self.qk_ln:
|
||||
query_layernorm = self.query_layernorm_list[kwargs["layer_id"]]
|
||||
key_layernorm = self.key_layernorm_list[kwargs["layer_id"]]
|
||||
query_layer = query_layernorm(query_layer)
|
||||
key_layer = key_layernorm(key_layer)
|
||||
|
||||
return old_impl(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
attention_dropout=attention_dropout,
|
||||
log_attention_weights=log_attention_weights,
|
||||
scaling_attention_score=scaling_attention_score,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
|
||||
|
||||
class DiffusionTransformer(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
transformer_args,
|
||||
num_frames,
|
||||
time_compressed_rate,
|
||||
latent_width,
|
||||
latent_height,
|
||||
patch_size,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
num_attention_heads,
|
||||
elementwise_affine,
|
||||
time_embed_dim=None,
|
||||
num_classes=None,
|
||||
modules={},
|
||||
input_time="adaln",
|
||||
adm_in_channels=None,
|
||||
parallel_output=True,
|
||||
height_interpolation=1.0,
|
||||
width_interpolation=1.0,
|
||||
time_interpolation=1.0,
|
||||
use_SwiGLU=False,
|
||||
use_RMSNorm=False,
|
||||
zero_init_y_embed=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.latent_width = latent_width
|
||||
self.latent_height = latent_height
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.time_compressed_rate = time_compressed_rate
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.model_channels = hidden_size
|
||||
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
|
||||
self.num_classes = num_classes
|
||||
self.adm_in_channels = adm_in_channels
|
||||
self.input_time = input_time
|
||||
self.num_layers = num_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.is_decoder = transformer_args.is_decoder
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.height_interpolation = height_interpolation
|
||||
self.width_interpolation = width_interpolation
|
||||
self.time_interpolation = time_interpolation
|
||||
self.inner_hidden_size = hidden_size * 4
|
||||
self.zero_init_y_embed = zero_init_y_embed
|
||||
try:
|
||||
self.dtype = str_to_dtype[kwargs.pop("dtype")]
|
||||
except:
|
||||
self.dtype = torch.float32
|
||||
|
||||
if use_SwiGLU:
|
||||
kwargs["activation_func"] = F.silu
|
||||
elif "activation_func" not in kwargs:
|
||||
approx_gelu = nn.GELU(approximate="tanh")
|
||||
kwargs["activation_func"] = approx_gelu
|
||||
|
||||
if use_RMSNorm:
|
||||
kwargs["layernorm"] = RMSNorm
|
||||
else:
|
||||
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
|
||||
transformer_args.num_layers = num_layers
|
||||
transformer_args.hidden_size = hidden_size
|
||||
transformer_args.num_attention_heads = num_attention_heads
|
||||
transformer_args.parallel_output = parallel_output
|
||||
super().__init__(args=transformer_args, transformer=None, **kwargs)
|
||||
|
||||
module_configs = modules
|
||||
self._build_modules(module_configs)
|
||||
|
||||
if use_SwiGLU:
|
||||
self.add_mixin(
|
||||
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
|
||||
)
|
||||
|
||||
def _build_modules(self, module_configs):
|
||||
model_channels = self.hidden_size
|
||||
# time_embed_dim = model_channels * 4
|
||||
time_embed_dim = self.time_embed_dim
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "timestep":
|
||||
self.label_emb = nn.Sequential(
|
||||
Timestep(model_channels),
|
||||
nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
),
|
||||
)
|
||||
elif self.num_classes == "sequential":
|
||||
assert self.adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
linear(self.adm_in_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
if self.zero_init_y_embed:
|
||||
nn.init.constant_(self.label_emb[0][2].weight, 0)
|
||||
nn.init.constant_(self.label_emb[0][2].bias, 0)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
pos_embed_config = module_configs["pos_embed_config"]
|
||||
self.add_mixin(
|
||||
"pos_embed",
|
||||
instantiate_from_config(
|
||||
pos_embed_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
),
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
patch_embed_config = module_configs["patch_embed_config"]
|
||||
self.add_mixin(
|
||||
"patch_embed",
|
||||
instantiate_from_config(
|
||||
patch_embed_config,
|
||||
patch_size=self.patch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
in_channels=self.in_channels,
|
||||
),
|
||||
reinit=True,
|
||||
)
|
||||
if self.input_time == "adaln":
|
||||
adaln_layer_config = module_configs["adaln_layer_config"]
|
||||
self.add_mixin(
|
||||
"adaln_layer",
|
||||
instantiate_from_config(
|
||||
adaln_layer_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
hidden_size_head=self.hidden_size // self.num_attention_heads,
|
||||
time_embed_dim=self.time_embed_dim,
|
||||
elementwise_affine=self.elementwise_affine,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
final_layer_config = module_configs["final_layer_config"]
|
||||
self.add_mixin(
|
||||
"final_layer",
|
||||
instantiate_from_config(
|
||||
final_layer_config,
|
||||
hidden_size=self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
out_channels=self.out_channels,
|
||||
time_embed_dim=self.time_embed_dim,
|
||||
latent_width=self.latent_width,
|
||||
latent_height=self.latent_height,
|
||||
elementwise_affine=self.elementwise_affine,
|
||||
),
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
if "lora_config" in module_configs:
|
||||
lora_config = module_configs["lora_config"]
|
||||
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
||||
|
||||
return
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
b, t, d, h, w = x.shape
|
||||
if x.dtype != self.dtype:
|
||||
x = x.to(self.dtype)
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
# assert y.shape[0] == x.shape[0]
|
||||
assert x.shape[0] % y.shape[0] == 0
|
||||
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
kwargs["seq_length"] = t * h * w // (self.patch_size**2)
|
||||
kwargs["images"] = x
|
||||
kwargs["emb"] = emb
|
||||
kwargs["encoder_outputs"] = context
|
||||
kwargs["text_length"] = context.shape[1]
|
||||
|
||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
|
||||
output = super().forward(**kwargs)[0]
|
||||
|
||||
return output
|
13
sat/finetune.sh
Normal file
13
sat/finetune.sh
Normal file
@ -0,0 +1,13 @@
|
||||
#! /bin/bash
|
||||
|
||||
module load cuda
|
||||
echo "RUN on `hostname`, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
||||
|
||||
run_cmd="$environs python train_video.py --base configs/cogvideox_2b_sft.yaml --seed $RANDOM"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
||||
echo "DONE on `hostname`"
|
12
sat/inference.sh
Executable file
12
sat/inference.sh
Executable file
@ -0,0 +1,12 @@
|
||||
#! /bin/bash
|
||||
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
||||
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_2b_infer.yaml"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
||||
echo "DONE on `hostname`"
|
17
sat/requirements.txt
Normal file
17
sat/requirements.txt
Normal file
@ -0,0 +1,17 @@
|
||||
git+https://github.com/THUDM/SwissArmyTransformer.git
|
||||
diffusers>=0.29.2
|
||||
omegaconf>=2.3.0
|
||||
torch>=2.3.1
|
||||
torchvision>=0.19.0
|
||||
pytorch_lightning>=2.3.3
|
||||
kornia>=0.7.3
|
||||
beartype>=0.18.5
|
||||
numpy>=2.0.1
|
||||
fsspec>=2024.5.0
|
||||
safetensors>=0.4.3
|
||||
imageio-ffmpeg>=0.5.1
|
||||
imageio>=2.34.2
|
||||
scipy>=1.14.0
|
||||
decord>=0.6.0
|
||||
wandb>=0.17.5
|
||||
deepspeed>=0.14.4
|
221
sat/sample_video.py
Normal file
221
sat/sample_video.py
Normal file
@ -0,0 +1,221 @@
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
from typing import List, Union
|
||||
from tqdm import tqdm
|
||||
from omegaconf import ListConfig
|
||||
import imageio
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import torchvision.transforms as TT
|
||||
|
||||
from sat.model.base_model import get_model
|
||||
from sat.training.model_io import load_checkpoint
|
||||
from sat import mpu
|
||||
|
||||
from diffusion_video import SATVideoDiffusionEngine
|
||||
from arguments import get_args
|
||||
from torchvision.transforms.functional import center_crop, resize
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
|
||||
def read_from_cli():
|
||||
cnt = 0
|
||||
try:
|
||||
while True:
|
||||
x = input("Please input English text (Ctrl-D quit): ")
|
||||
yield x.strip(), cnt
|
||||
cnt += 1
|
||||
except EOFError as e:
|
||||
pass
|
||||
|
||||
|
||||
def read_from_file(p, rank=0, world_size=1):
|
||||
with open(p, "r") as fin:
|
||||
cnt = -1
|
||||
for l in fin:
|
||||
cnt += 1
|
||||
if cnt % world_size != rank:
|
||||
continue
|
||||
yield l.strip(), cnt
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list(set([x.input_key for x in conditioner.embedders]))
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
|
||||
batch = {}
|
||||
batch_uc = {}
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
if T is not None:
|
||||
batch["num_video_frames"] = T
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
for i, vid in enumerate(video_batch):
|
||||
gif_frames = []
|
||||
for frame in vid:
|
||||
frame = rearrange(frame, "c h w -> h w c")
|
||||
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
|
||||
gif_frames.append(frame)
|
||||
now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
|
||||
with imageio.get_writer(now_save_path, fps=fps) as writer:
|
||||
for frame in gif_frames:
|
||||
writer.append_data(frame)
|
||||
|
||||
|
||||
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
|
||||
def sampling_main(args, model_cls):
|
||||
if isinstance(model_cls, type):
|
||||
model = get_model(args, model_cls)
|
||||
else:
|
||||
model = model_cls
|
||||
|
||||
load_checkpoint(model, args)
|
||||
model.eval()
|
||||
|
||||
if args.input_type == "cli":
|
||||
data_iter = read_from_cli()
|
||||
elif args.input_type == "txt":
|
||||
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
|
||||
print("rank and world_size", rank, world_size)
|
||||
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
image_size = [480, 720]
|
||||
|
||||
sample_func = model.sample
|
||||
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
|
||||
num_samples = [1]
|
||||
force_uc_zero_embeddings = ["txt"]
|
||||
with torch.no_grad():
|
||||
for text, cnt in tqdm(data_iter):
|
||||
print("rank:", rank, "start to process", text, cnt)
|
||||
# TODO: broadcast image2video
|
||||
value_dict = {
|
||||
"prompt": text,
|
||||
"negative_prompt": "",
|
||||
"num_frames": torch.tensor(T).unsqueeze(0),
|
||||
}
|
||||
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
|
||||
for index in range(args.batch_size):
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
)
|
||||
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
latent = 1.0 / model.scale_factor * samples_z
|
||||
|
||||
recons = []
|
||||
for i in range(6):
|
||||
if i == 0:
|
||||
start_frame, end_frame = 0, 3
|
||||
else:
|
||||
start_frame, end_frame = i * 2 + 1, i * 2 + 3
|
||||
if i == 5:
|
||||
clear_fake_cp_cache = True
|
||||
else:
|
||||
clear_fake_cp_cache = False
|
||||
with torch.no_grad():
|
||||
recon = model.first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
|
||||
recons.append(recon)
|
||||
|
||||
|
||||
recon = torch.cat(recons, dim=2).to(torch.float32)
|
||||
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
|
||||
save_path = os.path.join(args.output_dir, str(cnt) + '_' + text.replace(' ', '_').replace('/', '')[:120], str(index))
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
||||
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
||||
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
|
||||
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
|
||||
py_parser = argparse.ArgumentParser(add_help=False)
|
||||
known, args_list = py_parser.parse_known_args()
|
||||
|
||||
args = get_args(args_list)
|
||||
args = argparse.Namespace(**vars(args), **vars(known))
|
||||
del args.deepspeed_config
|
||||
args.model_config.first_stage_config.params.cp_size = 1
|
||||
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
|
||||
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
|
||||
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
|
||||
|
||||
sampling_main(args, model_cls=SATVideoDiffusionEngine)
|
4
sat/sgm/__init__.py
Normal file
4
sat/sgm/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .models import AutoencodingEngine
|
||||
from .util import get_configs_path, instantiate_from_config
|
||||
|
||||
__version__ = "0.1.0"
|
110
sat/sgm/lr_scheduler.py
Normal file
110
sat/sgm/lr_scheduler.py
Normal file
@ -0,0 +1,110 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warm_up_steps,
|
||||
lr_min,
|
||||
lr_max,
|
||||
lr_start,
|
||||
max_decay_steps,
|
||||
verbosity_interval=0,
|
||||
):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
f = (
|
||||
self.f_min[cycle]
|
||||
+ (self.f_max[cycle] - self.f_min[cycle])
|
||||
* (self.cycle_lengths[cycle] - n)
|
||||
/ (self.cycle_lengths[cycle])
|
||||
)
|
||||
self.last_f = f
|
||||
return f
|
1
sat/sgm/models/__init__.py
Normal file
1
sat/sgm/models/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .autoencoder import AutoencodingEngine
|
630
sat/sgm/models/autoencoder.py
Normal file
630
sat/sgm/models/autoencoder.py
Normal file
@ -0,0 +1,630 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
||||
from ..modules.ema import LitEma
|
||||
from ..util import (
|
||||
default,
|
||||
get_nested_attribute,
|
||||
get_obj_from_str,
|
||||
instantiate_from_config,
|
||||
initialize_context_parallel,
|
||||
get_context_parallel_group,
|
||||
get_context_parallel_group_rank,
|
||||
is_context_parallel_initialized,
|
||||
)
|
||||
from ..modules.cp_enc_dec import _conv_split, _conv_gather
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractAutoencoder(pl.LightningModule):
|
||||
"""
|
||||
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
||||
unCLIP models, etc. Hence, it is fairly general, and specific features
|
||||
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ema_decay: Union[None, float] = None,
|
||||
monitor: Union[None, str] = None,
|
||||
input_key: str = "jpg",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_key = input_key
|
||||
self.use_ema = ema_decay is not None
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
self.automatic_optimization = False
|
||||
|
||||
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
if ckpt is None:
|
||||
return
|
||||
if isinstance(ckpt, str):
|
||||
ckpt = {
|
||||
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
||||
"params": {"ckpt_path": ckpt},
|
||||
}
|
||||
engine = instantiate_from_config(ckpt)
|
||||
engine(self)
|
||||
|
||||
@abstractmethod
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
# for EMA computation
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("encode()-method of abstract base class called")
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AutoencodingEngine(AbstractAutoencoder):
|
||||
"""
|
||||
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
||||
(we also restore them explicitly as special cases for legacy reasons).
|
||||
Regularizations such as KL or VQ are moved to the regularizer class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
encoder_config: Dict,
|
||||
decoder_config: Dict,
|
||||
loss_config: Dict,
|
||||
regularizer_config: Dict,
|
||||
optimizer_config: Union[Dict, None] = None,
|
||||
lr_g_factor: float = 1.0,
|
||||
trainable_ae_params: Optional[List[List[str]]] = None,
|
||||
ae_optimizer_args: Optional[List[dict]] = None,
|
||||
trainable_disc_params: Optional[List[List[str]]] = None,
|
||||
disc_optimizer_args: Optional[List[dict]] = None,
|
||||
disc_start_iter: int = 0,
|
||||
diff_boost_factor: float = 3.0,
|
||||
ckpt_engine: Union[None, str, dict] = None,
|
||||
ckpt_path: Optional[str] = None,
|
||||
additional_decode_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False # pytorch lightning
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(regularizer_config)
|
||||
self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
|
||||
self.diff_boost_factor = diff_boost_factor
|
||||
self.disc_start_iter = disc_start_iter
|
||||
self.lr_g_factor = lr_g_factor
|
||||
self.trainable_ae_params = trainable_ae_params
|
||||
if self.trainable_ae_params is not None:
|
||||
self.ae_optimizer_args = default(
|
||||
ae_optimizer_args,
|
||||
[{} for _ in range(len(self.trainable_ae_params))],
|
||||
)
|
||||
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
||||
else:
|
||||
self.ae_optimizer_args = [{}] # makes type consitent
|
||||
|
||||
self.trainable_disc_params = trainable_disc_params
|
||||
if self.trainable_disc_params is not None:
|
||||
self.disc_optimizer_args = default(
|
||||
disc_optimizer_args,
|
||||
[{} for _ in range(len(self.trainable_disc_params))],
|
||||
)
|
||||
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
||||
else:
|
||||
self.disc_optimizer_args = [{}] # makes type consitent
|
||||
|
||||
if ckpt_path is not None:
|
||||
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
||||
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
||||
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
||||
|
||||
def get_input(self, batch: Dict) -> torch.Tensor:
|
||||
# assuming unified data format, dataloader returns a dict.
|
||||
# image tensors should be scaled to -1 ... 1 and in channels-first
|
||||
# format (e.g., bchw instead if bhwc)
|
||||
return batch[self.input_key]
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = []
|
||||
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
||||
params += list(self.loss.get_trainable_autoencoder_parameters())
|
||||
if hasattr(self.regularization, "get_trainable_parameters"):
|
||||
params += list(self.regularization.get_trainable_parameters())
|
||||
params = params + list(self.encoder.parameters())
|
||||
params = params + list(self.decoder.parameters())
|
||||
return params
|
||||
|
||||
def get_discriminator_params(self) -> list:
|
||||
if hasattr(self.loss, "get_trainable_parameters"):
|
||||
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
||||
else:
|
||||
params = []
|
||||
return params
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x, **kwargs)
|
||||
if unregularized:
|
||||
return z, dict()
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
"z": z,
|
||||
"optimizer_idx": optimizer_idx,
|
||||
"global_step": self.global_step,
|
||||
"last_layer": self.get_last_layer(),
|
||||
"split": "train",
|
||||
"regularization_log": regularization_log,
|
||||
"autoencoder": self,
|
||||
}
|
||||
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||
else:
|
||||
extra_info = dict()
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
out_loss = self.loss(x, xrec, **extra_info)
|
||||
if isinstance(out_loss, tuple):
|
||||
aeloss, log_dict_ae = out_loss
|
||||
else:
|
||||
# simple loss function
|
||||
aeloss = out_loss
|
||||
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
||||
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
sync_dist=False,
|
||||
)
|
||||
self.log(
|
||||
"loss",
|
||||
aeloss.mean().detach(),
|
||||
prog_bar=True,
|
||||
logger=False,
|
||||
on_epoch=False,
|
||||
on_step=True,
|
||||
)
|
||||
return aeloss
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||
# -> discriminator always needs to return a tuple
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
||||
|
||||
def training_step(self, batch: dict, batch_idx: int):
|
||||
opts = self.optimizers()
|
||||
if not isinstance(opts, list):
|
||||
# Non-adversarial case
|
||||
opts = [opts]
|
||||
optimizer_idx = batch_idx % len(opts)
|
||||
if self.global_step < self.disc_start_iter:
|
||||
optimizer_idx = 0
|
||||
opt = opts[optimizer_idx]
|
||||
opt.zero_grad()
|
||||
with opt.toggle_model():
|
||||
loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
|
||||
self.manual_backward(loss)
|
||||
opt.step()
|
||||
|
||||
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||
log_dict.update(log_dict_ema)
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
||||
x = self.get_input(batch)
|
||||
|
||||
z, xrec, regularization_log = self(x)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
"z": z,
|
||||
"optimizer_idx": 0,
|
||||
"global_step": self.global_step,
|
||||
"last_layer": self.get_last_layer(),
|
||||
"split": "val" + postfix,
|
||||
"regularization_log": regularization_log,
|
||||
"autoencoder": self,
|
||||
}
|
||||
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
||||
else:
|
||||
extra_info = dict()
|
||||
out_loss = self.loss(x, xrec, **extra_info)
|
||||
if isinstance(out_loss, tuple):
|
||||
aeloss, log_dict_ae = out_loss
|
||||
else:
|
||||
# simple loss function
|
||||
aeloss = out_loss
|
||||
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
||||
full_log_dict = log_dict_ae
|
||||
|
||||
if "optimizer_idx" in extra_info:
|
||||
extra_info["optimizer_idx"] = 1
|
||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||
full_log_dict.update(log_dict_disc)
|
||||
self.log(
|
||||
f"val{postfix}/loss/rec",
|
||||
log_dict_ae[f"val{postfix}/loss/rec"],
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log_dict(full_log_dict, sync_dist=True)
|
||||
return full_log_dict
|
||||
|
||||
def get_param_groups(
|
||||
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
groups = []
|
||||
num_params = 0
|
||||
for names, args in zip(parameter_names, optimizer_args):
|
||||
params = []
|
||||
for pattern_ in names:
|
||||
pattern_params = []
|
||||
pattern = re.compile(pattern_)
|
||||
for p_name, param in self.named_parameters():
|
||||
if re.match(pattern, p_name):
|
||||
pattern_params.append(param)
|
||||
num_params += param.numel()
|
||||
if len(pattern_params) == 0:
|
||||
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
||||
params.extend(pattern_params)
|
||||
groups.append({"params": params, **args})
|
||||
return groups, num_params
|
||||
|
||||
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
||||
if self.trainable_ae_params is None:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
else:
|
||||
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||
if self.trainable_disc_params is None:
|
||||
disc_params = self.get_discriminator_params()
|
||||
else:
|
||||
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
||||
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||
opt_ae = self.instantiate_optimizer_from_config(
|
||||
ae_params,
|
||||
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
||||
self.optimizer_config,
|
||||
)
|
||||
opts = [opt_ae]
|
||||
if len(disc_params) > 0:
|
||||
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
||||
opts.append(opt_disc)
|
||||
|
||||
return opts
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
log = dict()
|
||||
additional_decode_kwargs = {}
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
||||
|
||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
||||
diff.clamp_(0, 1.0)
|
||||
log["diff"] = 2.0 * diff - 1.0
|
||||
# diff_boost shows location of small errors, by boosting their
|
||||
# brightness.
|
||||
log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
||||
if hasattr(self.loss, "log_images"):
|
||||
log.update(self.loss.log_images(x, xrec))
|
||||
with self.ema_scope():
|
||||
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||
diff_ema.clamp_(0, 1.0)
|
||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
if additional_log_kwargs:
|
||||
additional_decode_kwargs.update(additional_log_kwargs)
|
||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||
log_str = "reconstructions-" + "-".join(
|
||||
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
||||
)
|
||||
log[log_str] = xrec_add
|
||||
return log
|
||||
|
||||
|
||||
class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
def __init__(self, embed_dim: int, **kwargs):
|
||||
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||
ddconfig = kwargs.pop("ddconfig")
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
||||
super().__init__(
|
||||
encoder_config={
|
||||
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
decoder_config={
|
||||
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
else:
|
||||
N = x.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
z = list()
|
||||
for i_batch in range(n_batches):
|
||||
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||
z_batch = self.quant_conv(z_batch)
|
||||
z.append(z_batch)
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
else:
|
||||
N = z.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
dec = list()
|
||||
for i_batch in range(n_batches):
|
||||
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||
dec.append(dec_batch)
|
||||
dec = torch.cat(dec, 0)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
class IdentityFirstStage(AbstractAutoencoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_input(self, x: Any) -> Any:
|
||||
return x
|
||||
|
||||
def encode(self, x: Any, *args, **kwargs) -> Any:
|
||||
return x
|
||||
|
||||
def decode(self, x: Any, *args, **kwargs) -> Any:
|
||||
return
|
||||
|
||||
|
||||
class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path: Union[None, str] = None,
|
||||
ignore_keys: Union[Tuple, list] = (),
|
||||
image_video_weights=[1, 1],
|
||||
only_train_decoder=False,
|
||||
context_parallel_size=0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.context_parallel_size = context_parallel_size
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||
|
||||
def get_input(self, batch: dict) -> torch.Tensor:
|
||||
if self.context_parallel_size > 0:
|
||||
if not is_context_parallel_initialized():
|
||||
initialize_context_parallel(self.context_parallel_size)
|
||||
|
||||
batch = batch[self.input_key]
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
||||
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
|
||||
|
||||
batch = _conv_split(batch, dim=2, kernel_size=1)
|
||||
return batch
|
||||
|
||||
return batch[self.input_key]
|
||||
|
||||
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
if ckpt is None:
|
||||
return
|
||||
self.init_from_ckpt(ckpt)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
del sd[k]
|
||||
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
||||
print("Missing keys: ", missing_keys)
|
||||
print("Unexpected keys: ", unexpected_keys)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
|
||||
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
def __init__(
|
||||
self,
|
||||
cp_size=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.cp_size = cp_size
|
||||
return super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
use_cp: bool = True,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
||||
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
|
||||
|
||||
x = _conv_split(x, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
else:
|
||||
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
z = _conv_gather(z, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(
|
||||
self,
|
||||
z: torch.Tensor,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
use_cp: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
||||
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
|
||||
|
||||
z = _conv_split(z, dim=2, kernel_size=1)
|
||||
|
||||
x = super().decode(z, use_cp=use_cp, **kwargs)
|
||||
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
x = _conv_gather(x, dim=2, kernel_size=1)
|
||||
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
input_cp: bool = False,
|
||||
latent_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
**additional_decode_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
||||
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
6
sat/sgm/modules/__init__.py
Normal file
6
sat/sgm/modules/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .encoders.modules import GeneralConditioner
|
||||
|
||||
UNCONDITIONAL_CONFIG = {
|
||||
"target": "sgm.modules.GeneralConditioner",
|
||||
"params": {"emb_models": []},
|
||||
}
|
572
sat/sgm/modules/attention.py
Normal file
572
sat/sgm/modules/attention.py
Normal file
@ -0,0 +1,572 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
SDP_IS_AVAILABLE = True
|
||||
from torch.backends.cuda import SDPBackend, sdp_kernel
|
||||
|
||||
BACKEND_MAP = {
|
||||
SDPBackend.MATH: {
|
||||
"enable_math": True,
|
||||
"enable_flash": False,
|
||||
"enable_mem_efficient": False,
|
||||
},
|
||||
SDPBackend.FLASH_ATTENTION: {
|
||||
"enable_math": False,
|
||||
"enable_flash": True,
|
||||
"enable_mem_efficient": False,
|
||||
},
|
||||
SDPBackend.EFFICIENT_ATTENTION: {
|
||||
"enable_math": False,
|
||||
"enable_flash": False,
|
||||
"enable_mem_efficient": True,
|
||||
},
|
||||
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
||||
}
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
|
||||
SDP_IS_AVAILABLE = False
|
||||
sdp_kernel = nullcontext
|
||||
BACKEND_MAP = {}
|
||||
print(
|
||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
||||
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||
)
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
|
||||
from .diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.backend = backend
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
additional_tokens=None,
|
||||
n_times_crossframe_attn_in_self=0,
|
||||
):
|
||||
h = self.heads
|
||||
|
||||
if additional_tokens is not None:
|
||||
# get the number of masked tokens at the beginning of the output sequence
|
||||
n_tokens_to_mask = additional_tokens.shape[1]
|
||||
# add additional token
|
||||
x = torch.cat([additional_tokens, x], dim=1)
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
if n_times_crossframe_attn_in_self:
|
||||
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
||||
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
||||
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
||||
k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
||||
v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
## old
|
||||
"""
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||
"""
|
||||
## new
|
||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
||||
|
||||
del q, k, v
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
|
||||
if additional_tokens is not None:
|
||||
# remove additional token
|
||||
out = out[:, n_tokens_to_mask:]
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
|
||||
super().__init__()
|
||||
print(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads with a dimension of {dim_head}."
|
||||
)
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
additional_tokens=None,
|
||||
n_times_crossframe_attn_in_self=0,
|
||||
):
|
||||
if additional_tokens is not None:
|
||||
# get the number of masked tokens at the beginning of the output sequence
|
||||
n_tokens_to_mask = additional_tokens.shape[1]
|
||||
# add additional token
|
||||
x = torch.cat([additional_tokens, x], dim=1)
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
if n_times_crossframe_attn_in_self:
|
||||
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
||||
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
||||
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
||||
k = repeat(
|
||||
k[::n_times_crossframe_attn_in_self],
|
||||
"b ... -> (b n) ...",
|
||||
n=n_times_crossframe_attn_in_self,
|
||||
)
|
||||
v = repeat(
|
||||
v[::n_times_crossframe_attn_in_self],
|
||||
"b ... -> (b n) ...",
|
||||
n=n_times_crossframe_attn_in_self,
|
||||
)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
# TODO: Use this directly in the attention operation, as a bias
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||
)
|
||||
if additional_tokens is not None:
|
||||
# remove additional token
|
||||
out = out[:, n_tokens_to_mask:]
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention, # vanilla attention
|
||||
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
attn_mode="softmax",
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
||||
print(
|
||||
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||
)
|
||||
attn_mode = "softmax"
|
||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||
print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
|
||||
if not XFORMERS_IS_AVAILABLE:
|
||||
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
else:
|
||||
print("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
||||
else:
|
||||
assert sdp_backend is None
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None,
|
||||
backend=sdp_backend,
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
backend=sdp_backend,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
if self.checkpoint:
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
kwargs = {"x": x}
|
||||
|
||||
if context is not None:
|
||||
kwargs.update({"context": context})
|
||||
|
||||
if additional_tokens is not None:
|
||||
kwargs.update({"additional_tokens": additional_tokens})
|
||||
|
||||
if n_times_crossframe_attn_in_self:
|
||||
kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
|
||||
|
||||
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
x = (
|
||||
self.attn1(
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
additional_tokens=additional_tokens,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerSingleLayerBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention, # vanilla attention
|
||||
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
|
||||
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
attn_mode="softmax",
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context) + x
|
||||
x = self.ff(self.norm2(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
disable_self_attn=False,
|
||||
use_linear=False,
|
||||
attn_type="softmax",
|
||||
use_checkpoint=True,
|
||||
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
||||
from omegaconf import ListConfig
|
||||
|
||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||
context_dim = [context_dim]
|
||||
if exists(context_dim) and isinstance(context_dim, list):
|
||||
if depth != len(context_dim):
|
||||
print(
|
||||
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
||||
)
|
||||
# depth does not match context dims.
|
||||
assert all(
|
||||
map(lambda x: x == context_dim[0], context_dim)
|
||||
), "need homogenous context_dim to match depth automatically"
|
||||
context_dim = depth * [context_dim[0]]
|
||||
elif context_dim is None:
|
||||
context_dim = [None] * depth
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn,
|
||||
attn_mode=attn_type,
|
||||
checkpoint=use_checkpoint,
|
||||
sdp_backend=sdp_backend,
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
else:
|
||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context]
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if i > 0 and len(context) == 1:
|
||||
i = 0 # use same context for each block
|
||||
x = block(x, context=context[i])
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
0
sat/sgm/modules/autoencoding/__init__.py
Normal file
0
sat/sgm/modules/autoencoding/__init__.py
Normal file
8
sat/sgm/modules/autoencoding/losses/__init__.py
Normal file
8
sat/sgm/modules/autoencoding/losses/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
__all__ = [
|
||||
"GeneralLPIPSWithDiscriminator",
|
||||
"LatentLPIPS",
|
||||
]
|
||||
|
||||
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
||||
from .lpips import LatentLPIPS
|
||||
from .video_loss import VideoAutoencoderLoss
|
301
sat/sgm/modules/autoencoding/losses/discriminator_loss.py
Normal file
301
sat/sgm/modules/autoencoding/losses/discriminator_loss.py
Normal file
@ -0,0 +1,301 @@
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from einops import rearrange
|
||||
from matplotlib import colormaps
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from ....util import default, instantiate_from_config
|
||||
from ..lpips.loss.lpips import LPIPS
|
||||
from ..lpips.model.model import weights_init
|
||||
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
|
||||
|
||||
class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
disc_start: int,
|
||||
logvar_init: float = 0.0,
|
||||
disc_num_layers: int = 3,
|
||||
disc_in_channels: int = 3,
|
||||
disc_factor: float = 1.0,
|
||||
disc_weight: float = 1.0,
|
||||
perceptual_weight: float = 1.0,
|
||||
disc_loss: str = "hinge",
|
||||
scale_input_to_tgt_size: bool = False,
|
||||
dims: int = 2,
|
||||
learn_logvar: bool = False,
|
||||
regularization_weights: Union[None, Dict[str, float]] = None,
|
||||
additional_log_keys: Optional[List[str]] = None,
|
||||
discriminator_config: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
if self.dims > 2:
|
||||
print(
|
||||
f"running with dims={dims}. This means that for perceptual loss "
|
||||
f"calculation, the LPIPS loss will be applied to each frame "
|
||||
f"independently."
|
||||
)
|
||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
# output log variance
|
||||
self.logvar = nn.Parameter(torch.full((), logvar_init), requires_grad=learn_logvar)
|
||||
self.learn_logvar = learn_logvar
|
||||
|
||||
discriminator_config = default(
|
||||
discriminator_config,
|
||||
{
|
||||
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
||||
"params": {
|
||||
"input_nc": disc_in_channels,
|
||||
"n_layers": disc_num_layers,
|
||||
"use_actnorm": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
self.discriminator = instantiate_from_config(discriminator_config).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.regularization_weights = default(regularization_weights, {})
|
||||
|
||||
self.forward_keys = [
|
||||
"optimizer_idx",
|
||||
"global_step",
|
||||
"last_layer",
|
||||
"split",
|
||||
"regularization_log",
|
||||
]
|
||||
|
||||
self.additional_log_keys = set(default(additional_log_keys, []))
|
||||
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
||||
|
||||
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
||||
return self.discriminator.parameters()
|
||||
|
||||
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
||||
if self.learn_logvar:
|
||||
yield self.logvar
|
||||
yield from ()
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
# calc logits of real/fake
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
if len(logits_real.shape) < 4:
|
||||
# Non patch-discriminator
|
||||
return dict()
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
# -> (b, 1, h, w)
|
||||
|
||||
# parameters for colormapping
|
||||
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
||||
cmap = colormaps["PiYG"] # diverging colormap
|
||||
|
||||
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""(b, 1, ...) -> (b, 3, ...)"""
|
||||
logits = (logits + high) / (2 * high)
|
||||
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
||||
# -> (b, 1, ..., 3)
|
||||
logits = torch.from_numpy(logits_np).to(logits.device)
|
||||
return rearrange(logits, "b 1 ... c -> b c ...")
|
||||
|
||||
logits_real = torch.nn.functional.interpolate(
|
||||
logits_real,
|
||||
size=inputs.shape[-2:],
|
||||
mode="nearest",
|
||||
antialias=False,
|
||||
)
|
||||
logits_fake = torch.nn.functional.interpolate(
|
||||
logits_fake,
|
||||
size=reconstructions.shape[-2:],
|
||||
mode="nearest",
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
# alpha value of logits for overlay
|
||||
alpha_real = torch.abs(logits_real) / high
|
||||
alpha_fake = torch.abs(logits_fake) / high
|
||||
# -> (b, 1, h, w) in range [0, 0.5]
|
||||
# alpha value of lines don't really matter, since the values are the same
|
||||
# for both images and logits anyway
|
||||
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
||||
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
||||
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
||||
# -> (1, h, w)
|
||||
# blend logits and images together
|
||||
|
||||
# prepare logits for plotting
|
||||
logits_real = to_colormap(logits_real)
|
||||
logits_fake = to_colormap(logits_fake)
|
||||
# resize logits
|
||||
# -> (b, 3, h, w)
|
||||
|
||||
# make some grids
|
||||
# add all logits to one plot
|
||||
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
||||
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
||||
# I just love how torchvision calls the number of columns `nrow`
|
||||
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
||||
# -> (3, h, w)
|
||||
|
||||
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
||||
grid_images_fake = torchvision.utils.make_grid(0.5 * reconstructions + 0.5, nrow=4)
|
||||
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
||||
# -> (3, h, w) in range [0, 1]
|
||||
|
||||
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
||||
|
||||
# Create labeled colorbar
|
||||
dpi = 100
|
||||
height = 128 / dpi
|
||||
width = grid_logits.shape[2] / dpi
|
||||
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
||||
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
||||
plt.colorbar(
|
||||
img,
|
||||
cax=ax,
|
||||
orientation="horizontal",
|
||||
fraction=0.9,
|
||||
aspect=width / height,
|
||||
pad=0.0,
|
||||
)
|
||||
img.set_visible(False)
|
||||
fig.tight_layout()
|
||||
fig.canvas.draw()
|
||||
# manually convert figure to numpy
|
||||
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
||||
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
||||
|
||||
# Add colorbar to plot
|
||||
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
||||
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
||||
return {
|
||||
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
||||
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
||||
}
|
||||
|
||||
def calculate_adaptive_weight(
|
||||
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
reconstructions: torch.Tensor,
|
||||
*, # added because I changed the order here
|
||||
regularization_log: Dict[str, torch.Tensor],
|
||||
optimizer_idx: int,
|
||||
global_step: int,
|
||||
last_layer: torch.Tensor,
|
||||
split: str = "train",
|
||||
weights: Union[None, float, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
if self.scale_input_to_tgt_size:
|
||||
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
|
||||
|
||||
if self.dims > 2:
|
||||
inputs, reconstructions = map(
|
||||
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
||||
(inputs, reconstructions),
|
||||
)
|
||||
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
frame_indices = torch.randn((inputs.shape[0], inputs.shape[2])).topk(1, dim=-1).indices
|
||||
|
||||
from sgm.modules.autoencoding.losses.video_loss import pick_video_frame
|
||||
|
||||
input_frames = pick_video_frame(inputs, frame_indices)
|
||||
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
||||
|
||||
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if global_step >= self.discriminator_iter_start or not self.training:
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
if self.training:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
else:
|
||||
d_weight = torch.tensor(1.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
g_loss = torch.tensor(0.0, requires_grad=True)
|
||||
|
||||
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
||||
log = dict()
|
||||
for k in regularization_log:
|
||||
if k in self.regularization_weights:
|
||||
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
||||
if k in self.additional_log_keys:
|
||||
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
||||
|
||||
log.update(
|
||||
{
|
||||
f"{split}/loss/total": loss.clone().detach().mean(),
|
||||
f"{split}/loss/nll": nll_loss.detach().mean(),
|
||||
f"{split}/loss/rec": rec_loss.detach().mean(),
|
||||
f"{split}/loss/percep": p_loss.detach().mean(),
|
||||
f"{split}/loss/rec": rec_loss.detach().mean(),
|
||||
f"{split}/loss/g": g_loss.detach().mean(),
|
||||
f"{split}/scalars/logvar": self.logvar.detach(),
|
||||
f"{split}/scalars/d_weight": d_weight.detach(),
|
||||
}
|
||||
)
|
||||
|
||||
return loss, log
|
||||
elif optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
|
||||
if global_step >= self.discriminator_iter_start or not self.training:
|
||||
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
else:
|
||||
d_loss = torch.tensor(0.0, requires_grad=True)
|
||||
|
||||
log = {
|
||||
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
||||
f"{split}/logits/real": logits_real.detach().mean(),
|
||||
f"{split}/logits/fake": logits_fake.detach().mean(),
|
||||
}
|
||||
return d_loss, log
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
||||
|
||||
def get_nll_loss(
|
||||
self,
|
||||
rec_loss: torch.Tensor,
|
||||
weights: Optional[Union[float, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights * nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
|
||||
return nll_loss, weighted_nll_loss
|
64
sat/sgm/modules/autoencoding/losses/lpips.py
Normal file
64
sat/sgm/modules/autoencoding/losses/lpips.py
Normal file
@ -0,0 +1,64 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ....util import default, instantiate_from_config
|
||||
from ..lpips.loss.lpips import LPIPS
|
||||
|
||||
|
||||
class LatentLPIPS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_config,
|
||||
perceptual_weight=1.0,
|
||||
latent_weight=1.0,
|
||||
scale_input_to_tgt_size=False,
|
||||
scale_tgt_to_input_size=False,
|
||||
perceptual_weight_on_inputs=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
||||
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
||||
self.init_decoder(decoder_config)
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
self.latent_weight = latent_weight
|
||||
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
||||
|
||||
def init_decoder(self, config):
|
||||
self.decoder = instantiate_from_config(config)
|
||||
if hasattr(self.decoder, "encoder"):
|
||||
del self.decoder.encoder
|
||||
|
||||
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
||||
log = dict()
|
||||
loss = (latent_inputs - latent_predictions) ** 2
|
||||
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
||||
image_reconstructions = None
|
||||
if self.perceptual_weight > 0.0:
|
||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
||||
image_targets = self.decoder.decode(latent_inputs)
|
||||
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
|
||||
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
||||
|
||||
if self.perceptual_weight_on_inputs > 0.0:
|
||||
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
|
||||
if self.scale_input_to_tgt_size:
|
||||
image_inputs = torch.nn.functional.interpolate(
|
||||
image_inputs,
|
||||
image_reconstructions.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
elif self.scale_tgt_to_input_size:
|
||||
image_reconstructions = torch.nn.functional.interpolate(
|
||||
image_reconstructions,
|
||||
image_inputs.shape[2:],
|
||||
mode="bicubic",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
|
||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
||||
return loss, log
|
712
sat/sgm/modules/autoencoding/losses/video_loss.py
Normal file
712
sat/sgm/modules/autoencoding/losses/video_loss.py
Normal file
@ -0,0 +1,712 @@
|
||||
from typing import Any, Union
|
||||
from math import log2
|
||||
from beartype import beartype
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.autograd import grad as torch_grad
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
from einops import rearrange, einsum, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from kornia.filters import filter3d
|
||||
|
||||
from ..magvit2_pytorch import Residual, FeedForward, LinearSpaceAttention
|
||||
from .lpips import LPIPS
|
||||
|
||||
from sgm.modules.autoencoding.vqvae.movq_enc_3d import CausalConv3d, DownSample3D
|
||||
from sgm.util import instantiate_from_config
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
|
||||
def leaky_relu(p=0.1):
|
||||
return nn.LeakyReLU(p)
|
||||
|
||||
|
||||
def hinge_discr_loss(fake, real):
|
||||
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
||||
|
||||
|
||||
def hinge_gen_loss(fake):
|
||||
return -fake.mean()
|
||||
|
||||
|
||||
@autocast(enabled=False)
|
||||
@beartype
|
||||
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
||||
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
|
||||
|
||||
|
||||
def pick_video_frame(video, frame_indices):
|
||||
batch, device = video.shape[0], video.device
|
||||
video = rearrange(video, "b c f ... -> b f c ...")
|
||||
batch_indices = torch.arange(batch, device=device)
|
||||
batch_indices = rearrange(batch_indices, "b -> b 1")
|
||||
images = video[batch_indices, frame_indices]
|
||||
images = rearrange(images, "b 1 c ... -> b c ...")
|
||||
return images
|
||||
|
||||
|
||||
def gradient_penalty(images, output):
|
||||
batch_size = images.shape[0]
|
||||
|
||||
gradients = torch_grad(
|
||||
outputs=output,
|
||||
inputs=images,
|
||||
grad_outputs=torch.ones(output.size(), device=images.device),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
gradients = rearrange(gradients, "b ... -> b (...)")
|
||||
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
||||
|
||||
|
||||
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
f = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer("f", f)
|
||||
|
||||
def forward(self, x, space_only=False, time_only=False):
|
||||
assert not (space_only and time_only)
|
||||
|
||||
f = self.f
|
||||
|
||||
if space_only:
|
||||
f = einsum("i, j -> i j", f, f)
|
||||
f = rearrange(f, "... -> 1 1 ...")
|
||||
elif time_only:
|
||||
f = rearrange(f, "f -> 1 f 1 1")
|
||||
else:
|
||||
f = einsum("i, j, k -> i j k", f, f, f)
|
||||
f = rearrange(f, "... -> 1 ...")
|
||||
|
||||
is_images = x.ndim == 4
|
||||
|
||||
if is_images:
|
||||
x = rearrange(x, "b c h w -> b c 1 h w")
|
||||
|
||||
out = filter3d(x, f, normalized=True)
|
||||
|
||||
if is_images:
|
||||
out = rearrange(out, "b c 1 h w -> b c h w")
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DiscriminatorBlock(nn.Module):
|
||||
def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True):
|
||||
super().__init__()
|
||||
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(input_channels, filters, 3, padding=1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(filters, filters, 3, padding=1),
|
||||
leaky_relu(),
|
||||
)
|
||||
|
||||
self.maybe_blur = Blur() if antialiased_downsample else None
|
||||
|
||||
self.downsample = (
|
||||
nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
|
||||
)
|
||||
if downsample
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv_res(x)
|
||||
|
||||
x = self.net(x)
|
||||
|
||||
if exists(self.downsample):
|
||||
if exists(self.maybe_blur):
|
||||
x = self.maybe_blur(x, space_only=True)
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
x = (x + res) * (2**-0.5)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
channels=3,
|
||||
max_dim=512,
|
||||
attn_heads=8,
|
||||
attn_dim_head=32,
|
||||
linear_attn_dim_head=8,
|
||||
linear_attn_heads=16,
|
||||
ff_mult=4,
|
||||
antialiased_downsample=False,
|
||||
):
|
||||
super().__init__()
|
||||
image_size = pair(image_size)
|
||||
min_image_resolution = min(image_size)
|
||||
|
||||
num_layers = int(log2(min_image_resolution) - 2)
|
||||
|
||||
blocks = []
|
||||
|
||||
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
|
||||
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
|
||||
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
|
||||
|
||||
blocks = []
|
||||
attn_blocks = []
|
||||
|
||||
image_resolution = min_image_resolution
|
||||
|
||||
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
|
||||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
||||
|
||||
block = DiscriminatorBlock(
|
||||
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
|
||||
)
|
||||
|
||||
attn_block = nn.Sequential(
|
||||
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
|
||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||
)
|
||||
|
||||
blocks.append(nn.ModuleList([block, attn_block]))
|
||||
|
||||
image_resolution //= 2
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
dim_last = layer_dims[-1]
|
||||
|
||||
downsample_factor = 2**num_layers
|
||||
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
|
||||
|
||||
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
nn.Conv2d(dim_last, dim_last, 3, padding=1),
|
||||
leaky_relu(),
|
||||
Rearrange("b ... -> b (...)"),
|
||||
nn.Linear(latent_dim, 1),
|
||||
Rearrange("b 1 -> b"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block, attn_block in self.blocks:
|
||||
x = block(x)
|
||||
x = attn_block(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
|
||||
class DiscriminatorBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels,
|
||||
filters,
|
||||
antialiased_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_res = nn.Conv3d(input_channels, filters, 1, stride=2)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv3d(input_channels, filters, 3, padding=1),
|
||||
leaky_relu(),
|
||||
nn.Conv3d(filters, filters, 3, padding=1),
|
||||
leaky_relu(),
|
||||
)
|
||||
|
||||
self.maybe_blur = Blur() if antialiased_downsample else None
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
Rearrange("b c (f p1) (h p2) (w p3) -> b (c p1 p2 p3) f h w", p1=2, p2=2, p3=2),
|
||||
nn.Conv3d(filters * 8, filters, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv_res(x)
|
||||
|
||||
x = self.net(x)
|
||||
|
||||
if exists(self.downsample):
|
||||
if exists(self.maybe_blur):
|
||||
x = self.maybe_blur(x, space_only=True)
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
x = (x + res) * (2**-0.5)
|
||||
return x
|
||||
|
||||
|
||||
class DiscriminatorBlock3DWithfirstframe(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels,
|
||||
filters,
|
||||
antialiased_downsample=True,
|
||||
pad_mode="first",
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_res = DownSample3D(
|
||||
in_channels=input_channels,
|
||||
out_channels=filters,
|
||||
with_conv=True,
|
||||
compress_time=True,
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
CausalConv3d(input_channels, filters, kernel_size=3, pad_mode=pad_mode),
|
||||
leaky_relu(),
|
||||
CausalConv3d(filters, filters, kernel_size=3, pad_mode=pad_mode),
|
||||
leaky_relu(),
|
||||
)
|
||||
|
||||
self.maybe_blur = Blur() if antialiased_downsample else None
|
||||
|
||||
self.downsample = DownSample3D(
|
||||
in_channels=filters,
|
||||
out_channels=filters,
|
||||
with_conv=True,
|
||||
compress_time=True,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.downsample_res(x)
|
||||
|
||||
x = self.net(x)
|
||||
|
||||
if exists(self.downsample):
|
||||
if exists(self.maybe_blur):
|
||||
x = self.maybe_blur(x, space_only=True)
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
x = (x + res) * (2**-0.5)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator3D(nn.Module):
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
frame_num,
|
||||
channels=3,
|
||||
max_dim=512,
|
||||
linear_attn_dim_head=8,
|
||||
linear_attn_heads=16,
|
||||
ff_mult=4,
|
||||
antialiased_downsample=False,
|
||||
):
|
||||
super().__init__()
|
||||
image_size = pair(image_size)
|
||||
min_image_resolution = min(image_size)
|
||||
|
||||
num_layers = int(log2(min_image_resolution) - 2)
|
||||
temporal_num_layers = int(log2(frame_num))
|
||||
self.temporal_num_layers = temporal_num_layers
|
||||
|
||||
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
|
||||
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
|
||||
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
|
||||
|
||||
blocks = []
|
||||
|
||||
image_resolution = min_image_resolution
|
||||
frame_resolution = frame_num
|
||||
|
||||
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
|
||||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
||||
|
||||
if ind < temporal_num_layers:
|
||||
block = DiscriminatorBlock3D(
|
||||
in_chan,
|
||||
out_chan,
|
||||
antialiased_downsample=antialiased_downsample,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
frame_resolution //= 2
|
||||
else:
|
||||
block = DiscriminatorBlock(
|
||||
in_chan,
|
||||
out_chan,
|
||||
downsample=is_not_last,
|
||||
antialiased_downsample=antialiased_downsample,
|
||||
)
|
||||
attn_block = nn.Sequential(
|
||||
Residual(
|
||||
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
|
||||
),
|
||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||
)
|
||||
|
||||
blocks.append(nn.ModuleList([block, attn_block]))
|
||||
|
||||
image_resolution //= 2
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
dim_last = layer_dims[-1]
|
||||
|
||||
downsample_factor = 2**num_layers
|
||||
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
|
||||
|
||||
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
nn.Conv2d(dim_last, dim_last, 3, padding=1),
|
||||
leaky_relu(),
|
||||
Rearrange("b ... -> b (...)"),
|
||||
nn.Linear(latent_dim, 1),
|
||||
Rearrange("b 1 -> b"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.blocks):
|
||||
if i < self.temporal_num_layers:
|
||||
x = layer(x)
|
||||
if i == self.temporal_num_layers - 1:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
else:
|
||||
block, attn_block = layer
|
||||
x = block(x)
|
||||
x = attn_block(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
|
||||
class Discriminator3DWithfirstframe(nn.Module):
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
frame_num,
|
||||
channels=3,
|
||||
max_dim=512,
|
||||
linear_attn_dim_head=8,
|
||||
linear_attn_heads=16,
|
||||
ff_mult=4,
|
||||
antialiased_downsample=False,
|
||||
):
|
||||
super().__init__()
|
||||
image_size = pair(image_size)
|
||||
min_image_resolution = min(image_size)
|
||||
|
||||
num_layers = int(log2(min_image_resolution) - 2)
|
||||
temporal_num_layers = int(log2(frame_num))
|
||||
self.temporal_num_layers = temporal_num_layers
|
||||
|
||||
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
|
||||
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
|
||||
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
|
||||
|
||||
blocks = []
|
||||
|
||||
image_resolution = min_image_resolution
|
||||
frame_resolution = frame_num
|
||||
|
||||
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
|
||||
num_layer = ind + 1
|
||||
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
||||
|
||||
if ind < temporal_num_layers:
|
||||
block = DiscriminatorBlock3DWithfirstframe(
|
||||
in_chan,
|
||||
out_chan,
|
||||
antialiased_downsample=antialiased_downsample,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
frame_resolution //= 2
|
||||
else:
|
||||
block = DiscriminatorBlock(
|
||||
in_chan,
|
||||
out_chan,
|
||||
downsample=is_not_last,
|
||||
antialiased_downsample=antialiased_downsample,
|
||||
)
|
||||
attn_block = nn.Sequential(
|
||||
Residual(
|
||||
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
|
||||
),
|
||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||
)
|
||||
|
||||
blocks.append(nn.ModuleList([block, attn_block]))
|
||||
|
||||
image_resolution //= 2
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
dim_last = layer_dims[-1]
|
||||
|
||||
downsample_factor = 2**num_layers
|
||||
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
|
||||
|
||||
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
nn.Conv2d(dim_last, dim_last, 3, padding=1),
|
||||
leaky_relu(),
|
||||
Rearrange("b ... -> b (...)"),
|
||||
nn.Linear(latent_dim, 1),
|
||||
Rearrange("b 1 -> b"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.blocks):
|
||||
if i < self.temporal_num_layers:
|
||||
x = layer(x)
|
||||
if i == self.temporal_num_layers - 1:
|
||||
x = x.mean(dim=2)
|
||||
# x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
else:
|
||||
block, attn_block = layer
|
||||
x = block(x)
|
||||
x = attn_block(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
|
||||
class VideoAutoencoderLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
disc_start,
|
||||
perceptual_weight=1,
|
||||
adversarial_loss_weight=0,
|
||||
multiscale_adversarial_loss_weight=0,
|
||||
grad_penalty_loss_weight=0,
|
||||
quantizer_aux_loss_weight=0,
|
||||
vgg_weights=VGG16_Weights.DEFAULT,
|
||||
discr_kwargs=None,
|
||||
discr_3d_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.disc_start = disc_start
|
||||
self.perceptual_weight = perceptual_weight
|
||||
self.adversarial_loss_weight = adversarial_loss_weight
|
||||
self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
|
||||
self.grad_penalty_loss_weight = grad_penalty_loss_weight
|
||||
self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
|
||||
|
||||
if self.perceptual_weight > 0:
|
||||
self.perceptual_model = LPIPS().eval()
|
||||
# self.vgg = torchvision.models.vgg16(pretrained = True)
|
||||
# self.vgg.requires_grad_(False)
|
||||
# if self.adversarial_loss_weight > 0:
|
||||
# self.discr = Discriminator(**discr_kwargs)
|
||||
# else:
|
||||
# self.discr = None
|
||||
# if self.multiscale_adversarial_loss_weight > 0:
|
||||
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
|
||||
# else:
|
||||
# self.multiscale_discrs = None
|
||||
if discr_kwargs is not None:
|
||||
self.discr = Discriminator(**discr_kwargs)
|
||||
else:
|
||||
self.discr = None
|
||||
if discr_3d_kwargs is not None:
|
||||
# self.discr_3d = Discriminator3D(**discr_3d_kwargs)
|
||||
self.discr_3d = instantiate_from_config(discr_3d_kwargs)
|
||||
else:
|
||||
self.discr_3d = None
|
||||
# self.multiscale_discrs = nn.ModuleList([*multiscale_discrs])
|
||||
|
||||
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
||||
|
||||
def get_trainable_params(self) -> Any:
|
||||
params = []
|
||||
if self.discr is not None:
|
||||
params += list(self.discr.parameters())
|
||||
if self.discr_3d is not None:
|
||||
params += list(self.discr_3d.parameters())
|
||||
# if self.multiscale_discrs is not None:
|
||||
# for discr in self.multiscale_discrs:
|
||||
# params += list(discr.parameters())
|
||||
return params
|
||||
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
return self.get_trainable_params()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
reconstructions,
|
||||
optimizer_idx,
|
||||
global_step,
|
||||
aux_losses=None,
|
||||
last_layer=None,
|
||||
split="train",
|
||||
):
|
||||
batch, channels, frames = inputs.shape[:3]
|
||||
|
||||
if optimizer_idx == 0:
|
||||
recon_loss = F.mse_loss(inputs, reconstructions)
|
||||
|
||||
if self.perceptual_weight > 0:
|
||||
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
|
||||
|
||||
input_frames = pick_video_frame(inputs, frame_indices)
|
||||
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
||||
|
||||
perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
||||
else:
|
||||
perceptual_loss = self.zero
|
||||
|
||||
if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0:
|
||||
gen_loss = self.zero
|
||||
adaptive_weight = 0
|
||||
else:
|
||||
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
||||
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
|
||||
|
||||
# fake_logits = self.discr(recon_video_frames)
|
||||
fake_logits = self.discr_3d(reconstructions)
|
||||
gen_loss = hinge_gen_loss(fake_logits)
|
||||
|
||||
adaptive_weight = 1
|
||||
if self.perceptual_weight > 0 and last_layer is not None:
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2)
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
|
||||
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
|
||||
adaptive_weight.clamp_(max=1e3)
|
||||
|
||||
if torch.isnan(adaptive_weight).any():
|
||||
adaptive_weight = 1
|
||||
|
||||
# multiscale discriminator losses
|
||||
|
||||
# multiscale_gen_losses = []
|
||||
# multiscale_gen_adaptive_weights = []
|
||||
# if self.multiscale_adversarial_loss_weight > 0:
|
||||
# if not exists(recon_video_frames):
|
||||
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
||||
# recon_video_frames = pick_video_frame(reconstructions, frame_indices)
|
||||
# for discr in self.multiscale_discrs:
|
||||
# fake_logits = recon_video_frames
|
||||
|
||||
# multiscale_gen_loss = hinge_gen_loss(fake_logits)
|
||||
# multiscale_gen_losses.append(multiscale_gen_loss)
|
||||
|
||||
# multiscale_adaptive_weight = 1.
|
||||
|
||||
# if exists(norm_grad_wrt_perceptual_loss):
|
||||
# norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_layer).norm(p = 2)
|
||||
# multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min = 1e-5)
|
||||
# multiscale_adaptive_weight.clamp_(max = 1e3)
|
||||
|
||||
# multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
|
||||
# weighted_multiscale_gen_losses = sum(loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights))
|
||||
# else:
|
||||
# weighted_multiscale_gen_losses = self.zero
|
||||
|
||||
if aux_losses is None:
|
||||
aux_losses = self.zero
|
||||
|
||||
total_loss = (
|
||||
recon_loss
|
||||
+ aux_losses * self.quantizer_aux_loss_weight
|
||||
+ perceptual_loss * self.perceptual_weight
|
||||
+ gen_loss * self.adversarial_loss_weight
|
||||
)
|
||||
# gen_loss * adaptive_weight * self.adversarial_loss_weight + \
|
||||
# weighted_multiscale_gen_losses * self.multiscale_adversarial_loss_weight
|
||||
|
||||
log = {
|
||||
"{}/total_loss".format(split): total_loss.detach(),
|
||||
"{}/recon_loss".format(split): recon_loss.detach(),
|
||||
"{}/perceptual_loss".format(split): perceptual_loss.detach(),
|
||||
"{}/gen_loss".format(split): gen_loss.detach(),
|
||||
"{}/aux_losses".format(split): aux_losses.detach(),
|
||||
# "{}/weighted_multiscale_gen_losses".format(split): weighted_multiscale_gen_losses.detach(),
|
||||
"{}/adaptive_weight".format(split): adaptive_weight,
|
||||
# "{}/multiscale_adaptive_weights".format(split): sum(multiscale_gen_adaptive_weights),
|
||||
}
|
||||
|
||||
return total_loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# frame_indices = torch.randn((batch, frames)).topk(1, dim = -1).indices
|
||||
|
||||
# real = pick_video_frame(inputs, frame_indices)
|
||||
# fake = pick_video_frame(reconstructions, frame_indices)
|
||||
|
||||
# apply_gradient_penalty = self.grad_penalty_loss_weight > 0
|
||||
# if apply_gradient_penalty:
|
||||
# real = real.requires_grad_()
|
||||
|
||||
# real_logits = self.discr(real)
|
||||
# fake_logits = self.discr(fake.detach())
|
||||
|
||||
apply_gradient_penalty = self.grad_penalty_loss_weight > 0
|
||||
if apply_gradient_penalty:
|
||||
inputs = inputs.requires_grad_()
|
||||
real_logits = self.discr_3d(inputs)
|
||||
fake_logits = self.discr_3d(reconstructions.detach())
|
||||
|
||||
discr_loss = hinge_discr_loss(fake_logits, real_logits)
|
||||
|
||||
# # multiscale discriminators
|
||||
# multiscale_discr_losses = []
|
||||
# if self.multiscale_adversarial_loss_weight > 0:
|
||||
# for discr in self.multiscale_discrs:
|
||||
# multiscale_real_logits = discr(inputs)
|
||||
# multiscale_fake_logits = discr(reconstructions.detach())
|
||||
|
||||
# multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
|
||||
# multiscale_discr_losses.append(multiscale_discr_loss)
|
||||
# else:
|
||||
# multiscale_discr_losses.append(self.zero)
|
||||
|
||||
# gradient penalty
|
||||
if apply_gradient_penalty:
|
||||
# gradient_penalty_loss = gradient_penalty(real, real_logits)
|
||||
gradient_penalty_loss = gradient_penalty(inputs, real_logits)
|
||||
else:
|
||||
gradient_penalty_loss = self.zero
|
||||
|
||||
total_loss = discr_loss + self.grad_penalty_loss_weight * gradient_penalty_loss
|
||||
# self.grad_penalty_loss_weight * gradient_penalty_loss + \
|
||||
# sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
|
||||
|
||||
log = {
|
||||
"{}/total_disc_loss".format(split): total_loss.detach(),
|
||||
"{}/discr_loss".format(split): discr_loss.detach(),
|
||||
"{}/grad_penalty_loss".format(split): gradient_penalty_loss.detach(),
|
||||
# "{}/multiscale_discr_loss".format(split): sum(multiscale_discr_losses).detach(),
|
||||
"{}/logits_real".format(split): real_logits.detach().mean(),
|
||||
"{}/logits_fake".format(split): fake_logits.detach().mean(),
|
||||
}
|
||||
return total_loss, log
|
0
sat/sgm/modules/autoencoding/lpips/__init__.py
Normal file
0
sat/sgm/modules/autoencoding/lpips/__init__.py
Normal file
1
sat/sgm/modules/autoencoding/lpips/loss/.gitignore
vendored
Normal file
1
sat/sgm/modules/autoencoding/lpips/loss/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
vgg.pth
|
23
sat/sgm/modules/autoencoding/lpips/loss/LICENSE
Normal file
23
sat/sgm/modules/autoencoding/lpips/loss/LICENSE
Normal file
@ -0,0 +1,23 @@
|
||||
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
0
sat/sgm/modules/autoencoding/lpips/loss/__init__.py
Normal file
0
sat/sgm/modules/autoencoding/lpips/loss/__init__.py
Normal file
132
sat/sgm/modules/autoencoding/lpips/loss/lpips.py
Normal file
132
sat/sgm/modules/autoencoding/lpips/loss/lpips.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from ..util import get_ckpt_path
|
||||
|
||||
|
||||
class LPIPS(nn.Module):
|
||||
# Learned perceptual metric
|
||||
def __init__(self, use_dropout=True):
|
||||
super().__init__()
|
||||
self.scaling_layer = ScalingLayer()
|
||||
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
||||
self.net = vgg16(pretrained=True, requires_grad=False)
|
||||
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
||||
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
||||
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
||||
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
||||
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
||||
self.load_from_pretrained()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_from_pretrained(self, name="vgg_lpips"):
|
||||
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
|
||||
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
||||
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, name="vgg_lpips"):
|
||||
if name != "vgg_lpips":
|
||||
raise NotImplementedError
|
||||
model = cls()
|
||||
ckpt = get_ckpt_path(name)
|
||||
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
||||
return model
|
||||
|
||||
def forward(self, input, target):
|
||||
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
||||
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
||||
feats0, feats1, diffs = {}, {}, {}
|
||||
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
||||
for kk in range(len(self.chns)):
|
||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
||||
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
||||
|
||||
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
||||
val = res[0]
|
||||
for l in range(1, len(self.chns)):
|
||||
val += res[l]
|
||||
return val
|
||||
|
||||
|
||||
class ScalingLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(ScalingLayer, self).__init__()
|
||||
self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
|
||||
self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
|
||||
|
||||
def forward(self, inp):
|
||||
return (inp - self.shift) / self.scale
|
||||
|
||||
|
||||
class NetLinLayer(nn.Module):
|
||||
"""A single linear layer which does a 1x1 conv"""
|
||||
|
||||
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
||||
super(NetLinLayer, self).__init__()
|
||||
layers = (
|
||||
[
|
||||
nn.Dropout(),
|
||||
]
|
||||
if (use_dropout)
|
||||
else []
|
||||
)
|
||||
layers += [
|
||||
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
||||
]
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
return out
|
||||
|
||||
|
||||
def normalize_tensor(x, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
||||
return x / (norm_factor + eps)
|
||||
|
||||
|
||||
def spatial_average(x, keepdim=True):
|
||||
return x.mean([2, 3], keepdim=keepdim)
|
58
sat/sgm/modules/autoencoding/lpips/model/LICENSE
Normal file
58
sat/sgm/modules/autoencoding/lpips/model/LICENSE
Normal file
@ -0,0 +1,58 @@
|
||||
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
--------------------------- LICENSE FOR pix2pix --------------------------------
|
||||
BSD License
|
||||
|
||||
For pix2pix software
|
||||
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
----------------------------- LICENSE FOR DCGAN --------------------------------
|
||||
BSD License
|
||||
|
||||
For dcgan.torch software
|
||||
|
||||
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
89
sat/sgm/modules/autoencoding/lpips/model/model.py
Normal file
89
sat/sgm/modules/autoencoding/lpips/model/model.py
Normal file
@ -0,0 +1,89 @@
|
||||
import functools
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..util import ActNorm
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
try:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
except:
|
||||
nn.init.normal_(m.conv.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
"""Defines a PatchGAN discriminator as in Pix2Pix
|
||||
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
||||
"""
|
||||
|
||||
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
||||
"""Construct a PatchGAN discriminator
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
n_layers (int) -- the number of conv layers in the discriminator
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
if not use_actnorm:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
else:
|
||||
norm_layer = ActNorm
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func != nn.BatchNorm2d
|
||||
else:
|
||||
use_bias = norm_layer != nn.BatchNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = 1
|
||||
sequence = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
nf_mult = 1
|
||||
nf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(
|
||||
ndf * nf_mult_prev,
|
||||
ndf * nf_mult,
|
||||
kernel_size=kw,
|
||||
stride=2,
|
||||
padding=padw,
|
||||
bias=use_bias,
|
||||
),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(
|
||||
ndf * nf_mult_prev,
|
||||
ndf * nf_mult,
|
||||
kernel_size=kw,
|
||||
stride=1,
|
||||
padding=padw,
|
||||
bias=use_bias,
|
||||
),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
]
|
||||
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
||||
] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.main(input)
|
114
sat/sgm/modules/autoencoding/lpips/util.py
Normal file
114
sat/sgm/modules/autoencoding/lpips/util.py
Normal file
@ -0,0 +1,114 @@
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
||||
|
||||
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
||||
|
||||
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
||||
|
||||
|
||||
def download(url, local_path, chunk_size=1024):
|
||||
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
||||
with requests.get(url, stream=True) as r:
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
||||
with open(local_path, "wb") as f:
|
||||
for data in r.iter_content(chunk_size=chunk_size):
|
||||
if data:
|
||||
f.write(data)
|
||||
pbar.update(chunk_size)
|
||||
|
||||
|
||||
def md5_hash(path):
|
||||
with open(path, "rb") as f:
|
||||
content = f.read()
|
||||
return hashlib.md5(content).hexdigest()
|
||||
|
||||
|
||||
def get_ckpt_path(name, root, check=False):
|
||||
assert name in URL_MAP
|
||||
path = os.path.join(root, CKPT_MAP[name])
|
||||
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
||||
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
||||
download(URL_MAP[name], path)
|
||||
md5 = md5_hash(path)
|
||||
assert md5 == MD5_MAP[name], md5
|
||||
return path
|
||||
|
||||
|
||||
class ActNorm(nn.Module):
|
||||
def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
|
||||
assert affine
|
||||
super().__init__()
|
||||
self.logdet = logdet
|
||||
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
||||
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
||||
self.allow_reverse_init = allow_reverse_init
|
||||
|
||||
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
||||
|
||||
def initialize(self, input):
|
||||
with torch.no_grad():
|
||||
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
||||
mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
|
||||
std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
|
||||
|
||||
self.loc.data.copy_(-mean)
|
||||
self.scale.data.copy_(1 / (std + 1e-6))
|
||||
|
||||
def forward(self, input, reverse=False):
|
||||
if reverse:
|
||||
return self.reverse(input)
|
||||
if len(input.shape) == 2:
|
||||
input = input[:, :, None, None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
_, _, height, width = input.shape
|
||||
|
||||
if self.training and self.initialized.item() == 0:
|
||||
self.initialize(input)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
h = self.scale * (input + self.loc)
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
|
||||
if self.logdet:
|
||||
log_abs = torch.log(torch.abs(self.scale))
|
||||
logdet = height * width * torch.sum(log_abs)
|
||||
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
||||
return h, logdet
|
||||
|
||||
return h
|
||||
|
||||
def reverse(self, output):
|
||||
if self.training and self.initialized.item() == 0:
|
||||
if not self.allow_reverse_init:
|
||||
raise RuntimeError(
|
||||
"Initializing ActNorm in reverse direction is "
|
||||
"disabled by default. Use allow_reverse_init=True to enable."
|
||||
)
|
||||
else:
|
||||
self.initialize(output)
|
||||
self.initialized.fill_(1)
|
||||
|
||||
if len(output.shape) == 2:
|
||||
output = output[:, :, None, None]
|
||||
squeeze = True
|
||||
else:
|
||||
squeeze = False
|
||||
|
||||
h = output / self.scale - self.loc
|
||||
|
||||
if squeeze:
|
||||
h = h.squeeze(-1).squeeze(-1)
|
||||
return h
|
16
sat/sgm/modules/autoencoding/lpips/vqperceptual.py
Normal file
16
sat/sgm/modules/autoencoding/lpips/vqperceptual.py
Normal file
@ -0,0 +1,16 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def hinge_d_loss(logits_real, logits_fake):
|
||||
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
|
||||
)
|
||||
return d_loss
|
1762
sat/sgm/modules/autoencoding/magvit2_pytorch.py
Normal file
1762
sat/sgm/modules/autoencoding/magvit2_pytorch.py
Normal file
File diff suppressed because it is too large
Load Diff
30
sat/sgm/modules/autoencoding/regularizers/__init__.py
Normal file
30
sat/sgm/modules/autoencoding/regularizers/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ....modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from .base import AbstractRegularizer
|
||||
|
||||
|
||||
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
||||
def __init__(self, sample: bool = True):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
yield from ()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
log = dict()
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if self.sample:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
log["kl_loss"] = kl_loss
|
||||
return z, log
|
36
sat/sgm/modules/autoencoding/regularizers/base.py
Normal file
36
sat/sgm/modules/autoencoding/regularizers/base.py
Normal file
@ -0,0 +1,36 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class AbstractRegularizer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class IdentityRegularizer(AbstractRegularizer):
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
return z, dict()
|
||||
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
yield from ()
|
||||
|
||||
|
||||
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||
avg_probs = encodings.mean(0)
|
||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||
cluster_use = torch.sum(avg_probs > 0)
|
||||
return perplexity, cluster_use
|
@ -0,0 +1,180 @@
|
||||
"""
|
||||
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
||||
Code adapted from Jax version in Appendix A.1
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Module
|
||||
from torch import Tensor, int32
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from einops import rearrange, pack, unpack
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def default(*args):
|
||||
for arg in args:
|
||||
if exists(arg):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
# tensor helpers
|
||||
|
||||
|
||||
def round_ste(z: Tensor) -> Tensor:
|
||||
"""Round with straight through gradients."""
|
||||
zhat = z.round()
|
||||
return z + (zhat - z).detach()
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
class FSQ(Module):
|
||||
def __init__(
|
||||
self,
|
||||
levels: List[int],
|
||||
dim: Optional[int] = None,
|
||||
num_codebooks=1,
|
||||
keep_num_codebooks_dim: Optional[bool] = None,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
_levels = torch.tensor(levels, dtype=int32)
|
||||
self.register_buffer("_levels", _levels, persistent=False)
|
||||
|
||||
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
||||
self.register_buffer("_basis", _basis, persistent=False)
|
||||
|
||||
self.scale = scale
|
||||
|
||||
codebook_dim = len(levels)
|
||||
self.codebook_dim = codebook_dim
|
||||
|
||||
effective_codebook_dim = codebook_dim * num_codebooks
|
||||
self.num_codebooks = num_codebooks
|
||||
self.effective_codebook_dim = effective_codebook_dim
|
||||
|
||||
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
||||
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
||||
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
||||
|
||||
self.dim = default(dim, len(_levels) * num_codebooks)
|
||||
|
||||
has_projections = self.dim != effective_codebook_dim
|
||||
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
||||
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
||||
self.has_projections = has_projections
|
||||
|
||||
self.codebook_size = self._levels.prod().item()
|
||||
|
||||
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
|
||||
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
||||
|
||||
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
|
||||
"""Bound `z`, an array of shape (..., d)."""
|
||||
half_l = (self._levels - 1) * (1 + eps) / 2
|
||||
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
||||
shift = (offset / half_l).atanh()
|
||||
return (z + shift).tanh() * half_l - offset
|
||||
|
||||
def quantize(self, z: Tensor) -> Tensor:
|
||||
"""Quantizes z, returns quantized zhat, same shape as z."""
|
||||
quantized = round_ste(self.bound(z))
|
||||
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
||||
return quantized / half_width
|
||||
|
||||
def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
|
||||
half_width = self._levels // 2
|
||||
return (zhat_normalized * half_width) + half_width
|
||||
|
||||
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
|
||||
half_width = self._levels // 2
|
||||
return (zhat - half_width) / half_width
|
||||
|
||||
def codes_to_indices(self, zhat: Tensor) -> Tensor:
|
||||
"""Converts a `code` to an index in the codebook."""
|
||||
assert zhat.shape[-1] == self.codebook_dim
|
||||
zhat = self._scale_and_shift(zhat)
|
||||
return (zhat * self._basis).sum(dim=-1).to(int32)
|
||||
|
||||
def indices_to_codes(self, indices: Tensor, project_out=True) -> Tensor:
|
||||
"""Inverse of `codes_to_indices`."""
|
||||
|
||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
||||
|
||||
indices = rearrange(indices, "... -> ... 1")
|
||||
codes_non_centered = (indices // self._basis) % self._levels
|
||||
codes = self._scale_and_shift_inverse(codes_non_centered)
|
||||
|
||||
if self.keep_num_codebooks_dim:
|
||||
codes = rearrange(codes, "... c d -> ... (c d)")
|
||||
|
||||
if project_out:
|
||||
codes = self.project_out(codes)
|
||||
|
||||
if is_img_or_video:
|
||||
codes = rearrange(codes, "b ... d -> b d ...")
|
||||
|
||||
return codes
|
||||
|
||||
@autocast(enabled=False)
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
n - sequence (or flattened spatial dimensions)
|
||||
d - feature dimension
|
||||
c - number of codebook dim
|
||||
"""
|
||||
|
||||
is_img_or_video = z.ndim >= 4
|
||||
|
||||
# standardize image or video into (batch, seq, dimension)
|
||||
|
||||
if is_img_or_video:
|
||||
z = rearrange(z, "b d ... -> b ... d")
|
||||
z, ps = pack_one(z, "b * d")
|
||||
|
||||
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
||||
|
||||
z = self.project_in(z)
|
||||
|
||||
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
||||
|
||||
codes = self.quantize(z)
|
||||
indices = self.codes_to_indices(codes)
|
||||
|
||||
codes = rearrange(codes, "b n c d -> b n (c d)")
|
||||
|
||||
out = self.project_out(codes)
|
||||
|
||||
# reconstitute image or video dimensions
|
||||
|
||||
if is_img_or_video:
|
||||
out = unpack_one(out, ps, "b * d")
|
||||
out = rearrange(out, "b ... d -> b d ...")
|
||||
|
||||
indices = unpack_one(indices, ps, "b * c")
|
||||
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, "... 1 -> ...")
|
||||
|
||||
return out, indices
|
@ -0,0 +1,309 @@
|
||||
"""
|
||||
Lookup Free Quantization
|
||||
Proposed in https://arxiv.org/abs/2310.05737
|
||||
|
||||
In the simplest setup, each dimension is quantized into {-1, 1}.
|
||||
An entropy penalty is used to encourage utilization.
|
||||
"""
|
||||
|
||||
from math import log2, ceil
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from einops import rearrange, reduce, pack, unpack
|
||||
|
||||
# constants
|
||||
|
||||
Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"])
|
||||
|
||||
LossBreakdown = namedtuple("LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"])
|
||||
|
||||
# helper functions
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def default(*args):
|
||||
for arg in args:
|
||||
if exists(arg):
|
||||
return arg() if callable(arg) else arg
|
||||
return None
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
# entropy
|
||||
|
||||
|
||||
def log(t, eps=1e-5):
|
||||
return t.clamp(min=eps).log()
|
||||
|
||||
|
||||
def entropy(prob):
|
||||
return (-prob * log(prob)).sum(dim=-1)
|
||||
|
||||
|
||||
# class
|
||||
|
||||
|
||||
class LFQ(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim=None,
|
||||
codebook_size=None,
|
||||
entropy_loss_weight=0.1,
|
||||
commitment_loss_weight=0.25,
|
||||
diversity_gamma=1.0,
|
||||
straight_through_activation=nn.Identity(),
|
||||
num_codebooks=1,
|
||||
keep_num_codebooks_dim=None,
|
||||
codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer
|
||||
frac_per_sample_entropy=1.0, # make less than 1. to only use a random fraction of the probs for per sample entropy
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# some assert validations
|
||||
|
||||
assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ"
|
||||
assert (
|
||||
not exists(codebook_size) or log2(codebook_size).is_integer()
|
||||
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
|
||||
|
||||
codebook_size = default(codebook_size, lambda: 2**dim)
|
||||
codebook_dim = int(log2(codebook_size))
|
||||
|
||||
codebook_dims = codebook_dim * num_codebooks
|
||||
dim = default(dim, codebook_dims)
|
||||
|
||||
has_projections = dim != codebook_dims
|
||||
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
|
||||
self.has_projections = has_projections
|
||||
|
||||
self.dim = dim
|
||||
self.codebook_dim = codebook_dim
|
||||
self.num_codebooks = num_codebooks
|
||||
|
||||
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
||||
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
||||
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
||||
|
||||
# straight through activation
|
||||
|
||||
self.activation = straight_through_activation
|
||||
|
||||
# entropy aux loss related weights
|
||||
|
||||
assert 0 < frac_per_sample_entropy <= 1.0
|
||||
self.frac_per_sample_entropy = frac_per_sample_entropy
|
||||
|
||||
self.diversity_gamma = diversity_gamma
|
||||
self.entropy_loss_weight = entropy_loss_weight
|
||||
|
||||
# codebook scale
|
||||
|
||||
self.codebook_scale = codebook_scale
|
||||
|
||||
# commitment loss
|
||||
|
||||
self.commitment_loss_weight = commitment_loss_weight
|
||||
|
||||
# for no auxiliary loss, during inference
|
||||
|
||||
self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1))
|
||||
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
||||
|
||||
# codes
|
||||
|
||||
all_codes = torch.arange(codebook_size)
|
||||
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
|
||||
codebook = self.bits_to_codes(bits)
|
||||
|
||||
self.register_buffer("codebook", codebook, persistent=False)
|
||||
|
||||
def bits_to_codes(self, bits):
|
||||
return bits * self.codebook_scale * 2 - self.codebook_scale
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.codebook.dtype
|
||||
|
||||
def indices_to_codes(self, indices, project_out=True):
|
||||
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
||||
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, "... -> ... 1")
|
||||
|
||||
# indices to codes, which are bits of either -1 or 1
|
||||
|
||||
bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
|
||||
|
||||
codes = self.bits_to_codes(bits)
|
||||
|
||||
codes = rearrange(codes, "... c d -> ... (c d)")
|
||||
|
||||
# whether to project codes out to original dimensions
|
||||
# if the input feature dimensions were not log2(codebook size)
|
||||
|
||||
if project_out:
|
||||
codes = self.project_out(codes)
|
||||
|
||||
# rearrange codes back to original shape
|
||||
|
||||
if is_img_or_video:
|
||||
codes = rearrange(codes, "b ... d -> b d ...")
|
||||
|
||||
return codes
|
||||
|
||||
@autocast(enabled=False)
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
inv_temperature=100.0,
|
||||
return_loss_breakdown=False,
|
||||
mask=None,
|
||||
):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
n - sequence (or flattened spatial dimensions)
|
||||
d - feature dimension, which is also log2(codebook size)
|
||||
c - number of codebook dim
|
||||
"""
|
||||
|
||||
x = x.float()
|
||||
|
||||
is_img_or_video = x.ndim >= 4
|
||||
|
||||
# standardize image or video into (batch, seq, dimension)
|
||||
|
||||
if is_img_or_video:
|
||||
x = rearrange(x, "b d ... -> b ... d")
|
||||
x, ps = pack_one(x, "b * d")
|
||||
|
||||
assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}"
|
||||
|
||||
x = self.project_in(x)
|
||||
|
||||
# split out number of codebooks
|
||||
|
||||
x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks)
|
||||
|
||||
# quantize by eq 3.
|
||||
|
||||
original_input = x
|
||||
|
||||
codebook_value = torch.ones_like(x) * self.codebook_scale
|
||||
quantized = torch.where(x > 0, codebook_value, -codebook_value)
|
||||
|
||||
# use straight-through gradients (optionally with custom activation fn) if training
|
||||
|
||||
if self.training:
|
||||
x = self.activation(x)
|
||||
x = x + (quantized - x).detach()
|
||||
else:
|
||||
x = quantized
|
||||
|
||||
# calculate indices
|
||||
|
||||
indices = reduce((x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
|
||||
|
||||
# entropy aux loss
|
||||
|
||||
if self.training:
|
||||
# the same as euclidean distance up to a constant
|
||||
distance = -2 * einsum("... i d, j d -> ... i j", original_input, self.codebook)
|
||||
|
||||
prob = (-distance * inv_temperature).softmax(dim=-1)
|
||||
|
||||
# account for mask
|
||||
|
||||
if exists(mask):
|
||||
prob = prob[mask]
|
||||
else:
|
||||
prob = rearrange(prob, "b n ... -> (b n) ...")
|
||||
|
||||
# whether to only use a fraction of probs, for reducing memory
|
||||
|
||||
if self.frac_per_sample_entropy < 1.0:
|
||||
num_tokens = prob.shape[0]
|
||||
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
|
||||
rand_mask = torch.randn(num_tokens).argsort(dim=-1) < num_sampled_tokens
|
||||
per_sample_probs = prob[rand_mask]
|
||||
else:
|
||||
per_sample_probs = prob
|
||||
|
||||
# calculate per sample entropy
|
||||
|
||||
per_sample_entropy = entropy(per_sample_probs).mean()
|
||||
|
||||
# distribution over all available tokens in the batch
|
||||
|
||||
avg_prob = reduce(per_sample_probs, "... c d -> c d", "mean")
|
||||
codebook_entropy = entropy(avg_prob).mean()
|
||||
|
||||
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
|
||||
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
|
||||
|
||||
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
|
||||
else:
|
||||
# if not training, just return dummy 0
|
||||
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
|
||||
|
||||
# commit loss
|
||||
|
||||
if self.training:
|
||||
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction="none")
|
||||
|
||||
if exists(mask):
|
||||
commit_loss = commit_loss[mask]
|
||||
|
||||
commit_loss = commit_loss.mean()
|
||||
else:
|
||||
commit_loss = self.zero
|
||||
|
||||
# merge back codebook dim
|
||||
|
||||
x = rearrange(x, "b n c d -> b n (c d)")
|
||||
|
||||
# project out to feature dimension if needed
|
||||
|
||||
x = self.project_out(x)
|
||||
|
||||
# reconstitute image or video dimensions
|
||||
|
||||
if is_img_or_video:
|
||||
x = unpack_one(x, ps, "b * d")
|
||||
x = rearrange(x, "b ... d -> b d ...")
|
||||
|
||||
indices = unpack_one(indices, ps, "b * c")
|
||||
|
||||
# whether to remove single codebook dim
|
||||
|
||||
if not self.keep_num_codebooks_dim:
|
||||
indices = rearrange(indices, "... 1 -> ...")
|
||||
|
||||
# complete aux loss
|
||||
|
||||
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
||||
|
||||
ret = Return(x, indices, aux_loss)
|
||||
|
||||
if not return_loss_breakdown:
|
||||
return ret
|
||||
|
||||
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
|
453
sat/sgm/modules/autoencoding/regularizers/quantize.py
Normal file
453
sat/sgm/modules/autoencoding/regularizers/quantize.py
Normal file
@ -0,0 +1,453 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from .base import AbstractRegularizer, measure_perplexity
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractQuantizer(AbstractRegularizer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Define these in your init
|
||||
# shape (N,)
|
||||
self.used: Optional[torch.Tensor]
|
||||
self.re_embed: int
|
||||
self.unknown_index: Union[Literal["random"], int]
|
||||
|
||||
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
|
||||
assert self.used is not None, "You need to define used indices for remap"
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
|
||||
assert self.used is not None, "You need to define used indices for remap"
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
@abstractmethod
|
||||
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
|
||||
yield from self.parameters()
|
||||
|
||||
|
||||
class GumbelQuantizer(AbstractQuantizer):
|
||||
"""
|
||||
credit to @karpathy:
|
||||
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
||||
Gumbel Softmax trick quantizer
|
||||
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
||||
https://arxiv.org/abs/1611.01144
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_hiddens: int,
|
||||
embedding_dim: int,
|
||||
n_embed: int,
|
||||
straight_through: bool = True,
|
||||
kl_weight: float = 5e-4,
|
||||
temp_init: float = 1.0,
|
||||
remap: Optional[str] = None,
|
||||
unknown_index: str = "random",
|
||||
loss_key: str = "loss/vq",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.loss_key = loss_key
|
||||
self.embedding_dim = embedding_dim
|
||||
self.n_embed = n_embed
|
||||
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
|
||||
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
||||
self.embed = nn.Embedding(n_embed, embedding_dim)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
else:
|
||||
self.used = None
|
||||
self.re_embed = n_embed
|
||||
if unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
else:
|
||||
assert unknown_index == "random" or isinstance(
|
||||
unknown_index, int
|
||||
), "unknown index needs to be 'random', 'extra' or any integer"
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.remap is not None:
|
||||
logpy.info(
|
||||
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
# force hard = True when we are in eval mode, as we must quantize.
|
||||
# actually, always true seems to work
|
||||
hard = self.straight_through if self.training else True
|
||||
temp = self.temperature if temp is None else temp
|
||||
out_dict = {}
|
||||
logits = self.proj(z)
|
||||
if self.remap is not None:
|
||||
# continue only with used logits
|
||||
full_zeros = torch.zeros_like(logits)
|
||||
logits = logits[:, self.used, ...]
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
||||
if self.remap is not None:
|
||||
# go back to all entries but unused set to zero
|
||||
full_zeros[:, self.used, ...] = soft_one_hot
|
||||
soft_one_hot = full_zeros
|
||||
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
||||
out_dict[self.loss_key] = diff
|
||||
|
||||
ind = soft_one_hot.argmax(dim=1)
|
||||
out_dict["indices"] = ind
|
||||
if self.remap is not None:
|
||||
ind = self.remap_to_used(ind)
|
||||
|
||||
if return_logits:
|
||||
out_dict["logits"] = logits
|
||||
|
||||
return z_q, out_dict
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# TODO: shape not yet optional
|
||||
b, h, w, c = shape
|
||||
assert b * h * w == indices.shape[0]
|
||||
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
|
||||
if self.remap is not None:
|
||||
indices = self.unmap_to_all(indices)
|
||||
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
||||
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
|
||||
return z_q
|
||||
|
||||
|
||||
class VectorQuantizer(AbstractQuantizer):
|
||||
"""
|
||||
____________________________________________
|
||||
Discretization bottleneck part of the VQ-VAE.
|
||||
Inputs:
|
||||
- n_e : number of embeddings
|
||||
- e_dim : dimension of embedding
|
||||
- beta : commitment cost used in loss term,
|
||||
beta * ||z_e(x)-sg[e]||^2
|
||||
_____________________________________________
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_e: int,
|
||||
e_dim: int,
|
||||
beta: float = 0.25,
|
||||
remap: Optional[str] = None,
|
||||
unknown_index: str = "random",
|
||||
sane_index_shape: bool = False,
|
||||
log_perplexity: bool = False,
|
||||
embedding_weight_norm: bool = False,
|
||||
loss_key: str = "loss/vq",
|
||||
):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.loss_key = loss_key
|
||||
|
||||
if not embedding_weight_norm:
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
else:
|
||||
self.embedding = torch.nn.utils.weight_norm(nn.Embedding(self.n_e, self.e_dim), dim=1)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
else:
|
||||
self.used = None
|
||||
self.re_embed = n_e
|
||||
if unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
else:
|
||||
assert unknown_index == "random" or isinstance(
|
||||
unknown_index, int
|
||||
), "unknown index needs to be 'random', 'extra' or any integer"
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.remap is not None:
|
||||
logpy.info(
|
||||
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
|
||||
self.sane_index_shape = sane_index_shape
|
||||
self.log_perplexity = log_perplexity
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
do_reshape = z.ndim == 4
|
||||
if do_reshape:
|
||||
# # reshape z -> (batch, height, width, channel) and flatten
|
||||
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
||||
|
||||
else:
|
||||
assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
|
||||
z = z.contiguous()
|
||||
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
loss_dict = {}
|
||||
if self.log_perplexity:
|
||||
perplexity, cluster_usage = measure_perplexity(min_encoding_indices.detach(), self.n_e)
|
||||
loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
|
||||
|
||||
# compute loss for embedding
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
||||
loss_dict[self.loss_key] = loss
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
if do_reshape:
|
||||
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
if do_reshape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
else:
|
||||
min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0])
|
||||
|
||||
loss_dict["min_encoding_indices"] = min_encoding_indices
|
||||
|
||||
return z_q, loss_dict
|
||||
|
||||
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
assert shape is not None, "Need to give shape for remap"
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
indices = self.unmap_to_all(indices)
|
||||
indices = indices.reshape(-1) # flatten again
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = self.embedding(indices)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class EmbeddingEMA(nn.Module):
|
||||
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
weight = torch.randn(num_tokens, codebook_dim)
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
||||
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
||||
self.update = True
|
||||
|
||||
def forward(self, embed_id):
|
||||
return F.embedding(embed_id, self.weight)
|
||||
|
||||
def cluster_size_ema_update(self, new_cluster_size):
|
||||
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
||||
|
||||
def embed_avg_ema_update(self, new_embed_avg):
|
||||
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
||||
|
||||
def weight_update(self, num_tokens):
|
||||
n = self.cluster_size.sum()
|
||||
smoothed_cluster_size = (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
||||
# normalize embedding average with smoothed cluster size
|
||||
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
||||
self.weight.data.copy_(embed_normalized)
|
||||
|
||||
|
||||
class EMAVectorQuantizer(AbstractQuantizer):
|
||||
def __init__(
|
||||
self,
|
||||
n_embed: int,
|
||||
embedding_dim: int,
|
||||
beta: float,
|
||||
decay: float = 0.99,
|
||||
eps: float = 1e-5,
|
||||
remap: Optional[str] = None,
|
||||
unknown_index: str = "random",
|
||||
loss_key: str = "loss/vq",
|
||||
):
|
||||
super().__init__()
|
||||
self.codebook_dim = embedding_dim
|
||||
self.num_tokens = n_embed
|
||||
self.beta = beta
|
||||
self.loss_key = loss_key
|
||||
|
||||
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
else:
|
||||
self.used = None
|
||||
self.re_embed = n_embed
|
||||
if unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
else:
|
||||
assert unknown_index == "random" or isinstance(
|
||||
unknown_index, int
|
||||
), "unknown index needs to be 'random', 'extra' or any integer"
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.remap is not None:
|
||||
logpy.info(
|
||||
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
# z, 'b c h w -> b h w c'
|
||||
z = rearrange(z, "b c h w -> b h w c")
|
||||
z_flattened = z.reshape(-1, self.codebook_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (
|
||||
z_flattened.pow(2).sum(dim=1, keepdim=True)
|
||||
+ self.embedding.weight.pow(2).sum(dim=1)
|
||||
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
|
||||
) # 'n d -> d n'
|
||||
|
||||
encoding_indices = torch.argmin(d, dim=1)
|
||||
|
||||
z_q = self.embedding(encoding_indices).view(z.shape)
|
||||
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
||||
avg_probs = torch.mean(encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
||||
|
||||
if self.training and self.embedding.update:
|
||||
# EMA cluster size
|
||||
encodings_sum = encodings.sum(0)
|
||||
self.embedding.cluster_size_ema_update(encodings_sum)
|
||||
# EMA embedding average
|
||||
embed_sum = encodings.transpose(0, 1) @ z_flattened
|
||||
self.embedding.embed_avg_ema_update(embed_sum)
|
||||
# normalize embed_avg and update weight
|
||||
self.embedding.weight_update(self.num_tokens)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
# z_q, 'b h w c -> b c h w'
|
||||
z_q = rearrange(z_q, "b h w c -> b c h w")
|
||||
|
||||
out_dict = {
|
||||
self.loss_key: loss,
|
||||
"encodings": encodings,
|
||||
"encoding_indices": encoding_indices,
|
||||
"perplexity": perplexity,
|
||||
}
|
||||
|
||||
return z_q, out_dict
|
||||
|
||||
|
||||
class VectorQuantizerWithInputProjection(VectorQuantizer):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
n_codes: int,
|
||||
codebook_dim: int,
|
||||
beta: float = 1.0,
|
||||
output_dim: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(n_codes, codebook_dim, beta, **kwargs)
|
||||
self.proj_in = nn.Linear(input_dim, codebook_dim)
|
||||
self.output_dim = output_dim
|
||||
if output_dim is not None:
|
||||
self.proj_out = nn.Linear(codebook_dim, output_dim)
|
||||
else:
|
||||
self.proj_out = nn.Identity()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
||||
rearr = False
|
||||
in_shape = z.shape
|
||||
|
||||
if z.ndim > 3:
|
||||
rearr = self.output_dim is not None
|
||||
z = rearrange(z, "b c ... -> b (...) c")
|
||||
z = self.proj_in(z)
|
||||
z_q, loss_dict = super().forward(z)
|
||||
|
||||
z_q = self.proj_out(z_q)
|
||||
if rearr:
|
||||
if len(in_shape) == 4:
|
||||
z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
|
||||
elif len(in_shape) == 5:
|
||||
z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
|
||||
else:
|
||||
raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.")
|
||||
|
||||
return z_q, loss_dict
|
331
sat/sgm/modules/autoencoding/temporal_ae.py
Normal file
331
sat/sgm/modules/autoencoding/temporal_ae.py
Normal file
@ -0,0 +1,331 @@
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from sgm.modules.diffusionmodules.model import (
|
||||
XFORMERS_IS_AVAILABLE,
|
||||
AttnBlock,
|
||||
Decoder,
|
||||
MemoryEfficientAttnBlock,
|
||||
ResnetBlock,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
||||
from sgm.modules.video_attention import VideoTransformerBlock
|
||||
from sgm.util import partialclass
|
||||
|
||||
|
||||
class VideoResBlock(ResnetBlock):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
*args,
|
||||
dropout=0.0,
|
||||
video_kernel_size=3,
|
||||
alpha=0.0,
|
||||
merge_strategy="learned",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
||||
if video_kernel_size is None:
|
||||
video_kernel_size = [3, 1, 1]
|
||||
self.time_stack = ResBlock(
|
||||
channels=out_channels,
|
||||
emb_channels=0,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=False,
|
||||
skip_t_emb=True,
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, bs):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, temb, skip_video=False, timesteps=None):
|
||||
if timesteps is None:
|
||||
timesteps = self.timesteps
|
||||
|
||||
b, c, h, w = x.shape
|
||||
|
||||
x = super().forward(x, temb)
|
||||
|
||||
if not skip_video:
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = self.time_stack(x, temb)
|
||||
|
||||
alpha = self.get_alpha(bs=b // timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix
|
||||
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class AE3DConv(torch.nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||
if isinstance(video_kernel_size, Iterable):
|
||||
padding = [int(k // 2) for k in video_kernel_size]
|
||||
else:
|
||||
padding = int(video_kernel_size // 2)
|
||||
|
||||
self.time_mix_conv = torch.nn.Conv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=video_kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def forward(self, input, timesteps, skip_video=False):
|
||||
x = super().forward(input)
|
||||
if skip_video:
|
||||
return x
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
x = self.time_mix_conv(x)
|
||||
return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
|
||||
|
||||
class VideoBlock(AttnBlock):
|
||||
def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
|
||||
super().__init__(in_channels)
|
||||
# no context, single headed, as in base class
|
||||
self.time_mix_block = VideoTransformerBlock(
|
||||
dim=in_channels,
|
||||
n_heads=1,
|
||||
d_head=in_channels,
|
||||
checkpoint=False,
|
||||
ff_in=True,
|
||||
attn_mode="softmax",
|
||||
)
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.video_time_embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.in_channels, time_embed_dim),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def forward(self, x, timesteps, skip_video=False):
|
||||
if skip_video:
|
||||
return super().forward(x)
|
||||
|
||||
x_in = x
|
||||
x = self.attention(x)
|
||||
h, w = x.shape[2:]
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
|
||||
x_mix = x
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha()
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x_in + x
|
||||
|
||||
def get_alpha(
|
||||
self,
|
||||
):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
|
||||
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
|
||||
def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"):
|
||||
super().__init__(in_channels)
|
||||
# no context, single headed, as in base class
|
||||
self.time_mix_block = VideoTransformerBlock(
|
||||
dim=in_channels,
|
||||
n_heads=1,
|
||||
d_head=in_channels,
|
||||
checkpoint=False,
|
||||
ff_in=True,
|
||||
attn_mode="softmax-xformers",
|
||||
)
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.video_time_embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(self.in_channels, time_embed_dim),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def forward(self, x, timesteps, skip_time_block=False):
|
||||
if skip_time_block:
|
||||
return super().forward(x)
|
||||
|
||||
x_in = x
|
||||
x = self.attention(x)
|
||||
h, w = x.shape[2:]
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
|
||||
x_mix = x
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha()
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x_in + x
|
||||
|
||||
def get_alpha(
|
||||
self,
|
||||
):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
|
||||
def make_time_attn(
|
||||
in_channels,
|
||||
attn_type="vanilla",
|
||||
attn_kwargs=None,
|
||||
alpha: float = 0,
|
||||
merge_strategy: str = "learned",
|
||||
):
|
||||
assert attn_type in [
|
||||
"vanilla",
|
||||
"vanilla-xformers",
|
||||
], f"attn_type {attn_type} not supported for spatio-temporal attention"
|
||||
print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
|
||||
print(
|
||||
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
|
||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||
)
|
||||
attn_type = "vanilla"
|
||||
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return partialclass(
|
||||
MemoryEfficientVideoBlock,
|
||||
in_channels,
|
||||
alpha=alpha,
|
||||
merge_strategy=merge_strategy,
|
||||
)
|
||||
else:
|
||||
return NotImplementedError()
|
||||
|
||||
|
||||
class Conv2DWrapper(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
class VideoDecoder(Decoder):
|
||||
available_time_modes = ["all", "conv-only", "attn-only"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
video_kernel_size: Union[int, list] = 3,
|
||||
alpha: float = 0.0,
|
||||
merge_strategy: str = "learned",
|
||||
time_mode: str = "conv-only",
|
||||
**kwargs,
|
||||
):
|
||||
self.video_kernel_size = video_kernel_size
|
||||
self.alpha = alpha
|
||||
self.merge_strategy = merge_strategy
|
||||
self.time_mode = time_mode
|
||||
assert (
|
||||
self.time_mode in self.available_time_modes
|
||||
), f"time_mode parameter has to be in {self.available_time_modes}"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
||||
if self.time_mode == "attn-only":
|
||||
raise NotImplementedError("TODO")
|
||||
else:
|
||||
return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight
|
||||
|
||||
def _make_attn(self) -> Callable:
|
||||
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
||||
return partialclass(
|
||||
make_time_attn,
|
||||
alpha=self.alpha,
|
||||
merge_strategy=self.merge_strategy,
|
||||
)
|
||||
else:
|
||||
return super()._make_attn()
|
||||
|
||||
def _make_conv(self) -> Callable:
|
||||
if self.time_mode != "attn-only":
|
||||
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
||||
else:
|
||||
return Conv2DWrapper
|
||||
|
||||
def _make_resblock(self) -> Callable:
|
||||
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
||||
return partialclass(
|
||||
VideoResBlock,
|
||||
video_kernel_size=self.video_kernel_size,
|
||||
alpha=self.alpha,
|
||||
merge_strategy=self.merge_strategy,
|
||||
)
|
||||
else:
|
||||
return super()._make_resblock()
|
495
sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py
Normal file
495
sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py
Normal file
@ -0,0 +1,495 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
f_channels,
|
||||
zq_channels,
|
||||
norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=False,
|
||||
pad_mode="constant",
|
||||
**norm_layer_params,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
||||
if freeze_norm_layer:
|
||||
for p in self.norm_layer.parameters:
|
||||
p.requires_grad = False
|
||||
self.add_conv = add_conv
|
||||
if self.add_conv:
|
||||
self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, f, zq):
|
||||
if zq.shape[2] > 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
def Normalize3D(in_channels, zq_ch, add_conv):
|
||||
return SpatialNorm3D(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=add_conv,
|
||||
num_groups=32,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="constant",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb, zq):
|
||||
h = x
|
||||
h = self.norm1(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
h = self.norm2(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock2D(nn.Module):
|
||||
def __init__(self, in_channels, zq_ch=None, add_conv=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, zq):
|
||||
h_ = x
|
||||
h_ = self.norm(h_, zq)
|
||||
|
||||
t = h_.shape[2]
|
||||
h_ = rearrange(h_, "b c t h w -> (b t) c h w")
|
||||
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class MOVQDecoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.mid.block_2 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
|
||||
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, z, use_cp=False):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
# h = self.mid.attn_1(h, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.conv.weight
|
||||
|
||||
|
||||
class NewDecoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
post_quant_conv=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
if post_quant_conv:
|
||||
self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
else:
|
||||
self.post_quant_conv = None
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||
# block_in,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
# remove attention block
|
||||
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
|
||||
self.mid.block_2 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
|
||||
# self.conv_out = torch.nn.Conv3d(block_in,
|
||||
# out_ch,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
# h = self.mid.attn_1(h, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.conv.weight
|
535
sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py
Normal file
535
sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py
Normal file
@ -0,0 +1,535 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.typing import Union, Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
from .movq_enc_3d import CausalConv3d, Upsample3D, DownSample3D
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
f_channels,
|
||||
zq_channels,
|
||||
norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=False,
|
||||
pad_mode="constant",
|
||||
**norm_layer_params,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
||||
if freeze_norm_layer:
|
||||
for p in self.norm_layer.parameters:
|
||||
p.requires_grad = False
|
||||
self.add_conv = add_conv
|
||||
if self.add_conv:
|
||||
# self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = CausalConv3d(zq_channels, zq_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
# self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, f, zq):
|
||||
if zq.shape[2] > 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
def Normalize3D(in_channels, zq_ch, add_conv):
|
||||
return SpatialNorm3D(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=add_conv,
|
||||
num_groups=32,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="constant",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
|
||||
# self.conv1 = torch.nn.Conv3d(in_channels,
|
||||
# out_channels,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
# self.conv2 = torch.nn.Conv3d(out_channels,
|
||||
# out_channels,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
# self.conv_shortcut = torch.nn.Conv3d(in_channels,
|
||||
# out_channels,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, temb, zq):
|
||||
h = x
|
||||
h = self.norm1(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
h = self.norm2(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock2D(nn.Module):
|
||||
def __init__(self, in_channels, zq_ch=None, add_conv=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, zq):
|
||||
h_ = x
|
||||
h_ = self.norm(h_, zq)
|
||||
|
||||
t = h_.shape[2]
|
||||
h_ = rearrange(h_, "b c t h w -> (b t) c h w")
|
||||
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class MOVQDecoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||
# block_in,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
# remove attention block
|
||||
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
|
||||
self.mid.block_2 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
|
||||
# self.conv_out = torch.nn.Conv3d(block_in,
|
||||
# out_ch,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, z, use_cp=False):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
# h = self.mid.attn_1(h, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.conv.weight
|
||||
|
||||
|
||||
class NewDecoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
post_quant_conv=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
if post_quant_conv:
|
||||
self.post_quant_conv = CausalConv3d(zq_ch, z_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
else:
|
||||
self.post_quant_conv = None
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||
# block_in,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
# remove attention block
|
||||
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
|
||||
self.mid.block_2 = ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock2D(block_in, zq_ch, add_conv=add_conv))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, resamp_with_conv, compress_time=True)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
|
||||
# self.conv_out = torch.nn.Conv3d(block_in,
|
||||
# out_ch,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
# h = self.mid.attn_1(h, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.conv.weight
|
413
sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py
Normal file
413
sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py
Normal file
@ -0,0 +1,413 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.typing import Union, Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
@beartype
|
||||
def __init__(
|
||||
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
||||
|
||||
dilation = kwargs.pop("dilation", 1)
|
||||
stride = kwargs.pop("stride", 1)
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
||||
height_pad = height_kernel_size // 2
|
||||
width_pad = width_kernel_size // 2
|
||||
|
||||
self.height_pad = height_pad
|
||||
self.width_pad = width_pad
|
||||
self.time_pad = time_pad
|
||||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
||||
|
||||
stride = (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pad_mode == "constant":
|
||||
causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
x = F.pad(x, causal_padding_3d, mode="constant", value=0)
|
||||
elif self.pad_mode == "first":
|
||||
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
|
||||
x = torch.cat([pad_x, x], dim=2)
|
||||
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
x = F.pad(x, causal_padding_2d, mode="constant", value=0)
|
||||
elif self.pad_mode == "reflect":
|
||||
# reflect padding
|
||||
reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
|
||||
if reflect_x.shape[2] < self.time_pad:
|
||||
reflect_x = torch.cat(
|
||||
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2
|
||||
)
|
||||
x = torch.cat([reflect_x, x], dim=2)
|
||||
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
x = F.pad(x, causal_padding_2d, mode="constant", value=0)
|
||||
else:
|
||||
raise ValueError("Invalid pad mode")
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def Normalize3D(in_channels): # same for 3D and 2D
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, compress_time=False):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
if x.shape[2] > 1:
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
x = x.squeeze(2)
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = x[:, :, None, :, :]
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
if self.with_conv:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class DownSample3D(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
else:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant"
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize3D(in_channels)
|
||||
# self.conv1 = torch.nn.Conv3d(in_channels,
|
||||
# out_channels,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize3D(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
# self.conv2 = torch.nn.Conv3d(out_channels,
|
||||
# out_channels,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
# self.conv_shortcut = torch.nn.Conv3d(in_channels,
|
||||
# out_channels,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock2D(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize3D(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
||||
t = h_.shape[2]
|
||||
h_ = rearrange(h_, "b c t h w -> (b t) c h w")
|
||||
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
|
||||
# # original version, nan in fp16
|
||||
# w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
# w_ = w_ * (int(c)**(-0.5))
|
||||
# # implement c**-0.5 on q
|
||||
q = q * (int(c) ** (-0.5))
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Encoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
# downsampling
|
||||
# self.conv_in = torch.nn.Conv3d(in_channels,
|
||||
# self.ch,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_in = CausalConv3d(in_channels, self.ch, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock2D(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if i_level < self.temporal_compress_level:
|
||||
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
|
||||
)
|
||||
# remove attention block
|
||||
# self.mid.attn_1 = AttnBlock2D(block_in)
|
||||
self.mid.block_2 = ResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize3D(block_in)
|
||||
# self.conv_out = torch.nn.Conv3d(block_in,
|
||||
# 2*z_channels if double_z else z_channels,
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_out = CausalConv3d(
|
||||
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, pad_mode=pad_mode
|
||||
)
|
||||
|
||||
def forward(self, x, use_cp=False):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
# h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
368
sat/sgm/modules/autoencoding/vqvae/movq_modules.py
Normal file
368
sat/sgm/modules/autoencoding/vqvae/movq_modules.py
Normal file
@ -0,0 +1,368 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SpatialNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
f_channels,
|
||||
zq_channels,
|
||||
norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=False,
|
||||
**norm_layer_params,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
||||
if freeze_norm_layer:
|
||||
for p in self.norm_layer.parameters:
|
||||
p.requires_grad = False
|
||||
self.add_conv = add_conv
|
||||
if self.add_conv:
|
||||
self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, f, zq):
|
||||
f_size = f.shape[-2:]
|
||||
zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
def Normalize(in_channels, zq_ch, add_conv):
|
||||
return SpatialNorm(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=add_conv,
|
||||
num_groups=32,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb, zq):
|
||||
h = x
|
||||
h = self.norm1(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels, zq_ch=None, add_conv=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, zq):
|
||||
h_ = x
|
||||
h_ = self.norm(h_, zq)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class MOVQDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z, zq):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
h = self.mid.attn_1(h, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def forward_with_features_output(self, z, zq):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
output_features = {}
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
output_features["conv_in"] = h
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
output_features["mid_block_1"] = h
|
||||
h = self.mid.attn_1(h, zq)
|
||||
output_features["mid_attn_1"] = h
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
output_features["mid_block_2"] = h
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
output_features[f"up_{i_level}_block_{i_block}"] = h
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
output_features[f"up_{i_level}_attn_{i_block}"] = h
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
output_features[f"up_{i_level}_upsample"] = h
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
output_features["norm_out"] = h
|
||||
h = nonlinearity(h)
|
||||
output_features["nonlinearity"] = h
|
||||
h = self.conv_out(h)
|
||||
output_features["conv_out"] = h
|
||||
|
||||
return h, output_features
|
241
sat/sgm/modules/autoencoding/vqvae/quantize.py
Normal file
241
sat/sgm/modules/autoencoding/vqvae/quantize.py
Normal file
@ -0,0 +1,241 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch import einsum
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VectorQuantizer2(nn.Module):
|
||||
"""
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
||||
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
||||
"""
|
||||
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
print(
|
||||
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
else:
|
||||
self.re_embed = n_e
|
||||
|
||||
self.sane_index_shape = sane_index_shape
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
||||
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
||||
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
||||
assert return_logits == False, "Only for interface compatible with Gumbel"
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
||||
perplexity = None
|
||||
min_encodings = None
|
||||
|
||||
# compute loss for embedding
|
||||
if not self.legacy:
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
||||
else:
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
||||
|
||||
if self.remap is not None:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
||||
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
indices = indices.reshape(shape[0], -1) # add batch axis
|
||||
indices = self.unmap_to_all(indices)
|
||||
indices = indices.reshape(-1) # flatten again
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = self.embedding(indices)
|
||||
|
||||
if shape is not None:
|
||||
z_q = z_q.view(shape)
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantize(nn.Module):
|
||||
"""
|
||||
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
||||
Gumbel Softmax trick quantizer
|
||||
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
||||
https://arxiv.org/abs/1611.01144
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_hiddens,
|
||||
embedding_dim,
|
||||
n_embed,
|
||||
straight_through=True,
|
||||
kl_weight=5e-4,
|
||||
temp_init=1.0,
|
||||
use_vqinterface=True,
|
||||
remap=None,
|
||||
unknown_index="random",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.n_embed = n_embed
|
||||
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
|
||||
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
||||
self.embed = nn.Embedding(n_embed, embedding_dim)
|
||||
|
||||
self.use_vqinterface = use_vqinterface
|
||||
|
||||
self.remap = remap
|
||||
if self.remap is not None:
|
||||
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
||||
self.re_embed = self.used.shape[0]
|
||||
self.unknown_index = unknown_index # "random" or "extra" or integer
|
||||
if self.unknown_index == "extra":
|
||||
self.unknown_index = self.re_embed
|
||||
self.re_embed = self.re_embed + 1
|
||||
print(
|
||||
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
||||
f"Using {self.unknown_index} for unknown indices."
|
||||
)
|
||||
else:
|
||||
self.re_embed = n_embed
|
||||
|
||||
def remap_to_used(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
match = (inds[:, :, None] == used[None, None, ...]).long()
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
||||
def unmap_to_all(self, inds):
|
||||
ishape = inds.shape
|
||||
assert len(ishape) > 1
|
||||
inds = inds.reshape(ishape[0], -1)
|
||||
used = self.used.to(inds)
|
||||
if self.re_embed > self.used.shape[0]: # extra token
|
||||
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
||||
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
||||
return back.reshape(ishape)
|
||||
|
||||
def forward(self, z, temp=None, return_logits=False):
|
||||
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
||||
hard = self.straight_through if self.training else True
|
||||
temp = self.temperature if temp is None else temp
|
||||
|
||||
logits = self.proj(z)
|
||||
if self.remap is not None:
|
||||
# continue only with used logits
|
||||
full_zeros = torch.zeros_like(logits)
|
||||
logits = logits[:, self.used, ...]
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
||||
if self.remap is not None:
|
||||
# go back to all entries but unused set to zero
|
||||
full_zeros[:, self.used, ...] = soft_one_hot
|
||||
soft_one_hot = full_zeros
|
||||
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
||||
|
||||
ind = soft_one_hot.argmax(dim=1)
|
||||
if self.remap is not None:
|
||||
ind = self.remap_to_used(ind)
|
||||
if self.use_vqinterface:
|
||||
if return_logits:
|
||||
return z_q, diff, (None, None, ind), logits
|
||||
return z_q, diff, (None, None, ind)
|
||||
return z_q, diff, ind
|
||||
|
||||
def get_codebook_entry(self, indices, shape):
|
||||
b, h, w, c = shape
|
||||
assert b * h * w == indices.shape[0]
|
||||
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
|
||||
if self.remap is not None:
|
||||
indices = self.unmap_to_all(indices)
|
||||
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
||||
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
|
||||
return z_q
|
402
sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py
Normal file
402
sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py
Normal file
@ -0,0 +1,402 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
|
||||
# # original version, nan in fp16
|
||||
# w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
# w_ = w_ * (int(c)**(-0.5))
|
||||
# # implement c**-0.5 on q
|
||||
q = q * (int(c) ** (-0.5))
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def forward_with_features_output(self, x):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
output_features = {}
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
output_features["conv_in"] = hs[-1]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
output_features["down{}_block{}".format(i_level, i_block)] = h
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
output_features["down{}_attn{}".format(i_level, i_block)] = h
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
output_features["down{}_downsample".format(i_level)] = hs[-1]
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
output_features["mid_block_1"] = h
|
||||
h = self.mid.attn_1(h)
|
||||
output_features["mid_attn_1"] = h
|
||||
h = self.mid.block_2(h, temb)
|
||||
output_features["mid_block_2"] = h
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
output_features["norm_out"] = h
|
||||
h = nonlinearity(h)
|
||||
output_features["nonlinearity"] = h
|
||||
h = self.conv_out(h)
|
||||
output_features["conv_out"] = h
|
||||
|
||||
return h, output_features
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
897
sat/sgm/modules/cp_enc_dec.py
Normal file
897
sat/sgm/modules/cp_enc_dec.py
Normal file
@ -0,0 +1,897 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from beartype import beartype
|
||||
from beartype.typing import Union, Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
from ..util import (
|
||||
get_context_parallel_group,
|
||||
get_context_parallel_rank,
|
||||
get_context_parallel_world_size,
|
||||
get_context_parallel_group_rank,
|
||||
)
|
||||
|
||||
# try:
|
||||
from ..util import SafeConv3d as Conv3d
|
||||
# except:
|
||||
# # Degrade to normal Conv3d if SafeConv3d is not available
|
||||
# from torch.nn import Conv3d
|
||||
|
||||
_USE_CP = True
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def leaky_relu(p=0.1):
|
||||
return nn.LeakyReLU(p)
|
||||
|
||||
|
||||
def _split(input_, dim):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
if cp_world_size == 1:
|
||||
return input_
|
||||
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
dim_size = input_.size()[dim] // cp_world_size
|
||||
|
||||
input_list = torch.split(input_, dim_size, dim=dim)
|
||||
output = input_list[cp_rank]
|
||||
|
||||
if cp_rank == 0:
|
||||
output = torch.cat([inpu_first_frame_, output], dim=dim)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, dim):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
# Bypass the function if context parallel is 1
|
||||
if cp_world_size == 1:
|
||||
return input_
|
||||
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
||||
if cp_rank == 0:
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
|
||||
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
|
||||
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
||||
]
|
||||
|
||||
if cp_rank == 0:
|
||||
input_ = torch.cat([input_first_frame_, input_], dim=dim)
|
||||
|
||||
tensor_list[cp_rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
# print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _conv_split(input_, dim, kernel_size):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
# Bypass the function if context parallel is 1
|
||||
if cp_world_size == 1:
|
||||
return input_
|
||||
|
||||
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
|
||||
|
||||
if cp_rank == 0:
|
||||
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
||||
else:
|
||||
output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose(
|
||||
dim, 0
|
||||
)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _conv_gather(input_, dim, kernel_size):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
# Bypass the function if context parallel is 1
|
||||
if cp_world_size == 1:
|
||||
return input_
|
||||
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
|
||||
if cp_rank == 0:
|
||||
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
|
||||
else:
|
||||
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous()
|
||||
|
||||
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
|
||||
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
||||
]
|
||||
if cp_rank == 0:
|
||||
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
|
||||
|
||||
tensor_list[cp_rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||
|
||||
# Note: torch.cat already creates a contiguous tensor.
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _pass_from_previous_rank(input_, dim, kernel_size):
|
||||
# Bypass the function if kernel size is 1
|
||||
if kernel_size == 1:
|
||||
return input_
|
||||
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
cp_group_rank = get_context_parallel_group_rank()
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
global_rank = torch.distributed.get_rank()
|
||||
global_world_size = torch.distributed.get_world_size()
|
||||
|
||||
input_ = input_.transpose(0, dim)
|
||||
|
||||
# pass from last rank
|
||||
send_rank = global_rank + 1
|
||||
recv_rank = global_rank - 1
|
||||
if send_rank % cp_world_size == 0:
|
||||
send_rank -= cp_world_size
|
||||
if recv_rank % cp_world_size == cp_world_size - 1:
|
||||
recv_rank += cp_world_size
|
||||
|
||||
if cp_rank < cp_world_size - 1:
|
||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||
if cp_rank > 0:
|
||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
|
||||
if cp_rank == 0:
|
||||
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
|
||||
else:
|
||||
req_recv.wait()
|
||||
input_ = torch.cat([recv_buffer, input_], dim=0)
|
||||
|
||||
input_ = input_.transpose(0, dim).contiguous()
|
||||
|
||||
# print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _drop_from_previous_rank(input_, dim, kernel_size):
|
||||
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
|
||||
return input_
|
||||
|
||||
|
||||
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
ctx.kernel_size = kernel_size
|
||||
return _conv_split(input_, dim, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
|
||||
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
ctx.kernel_size = kernel_size
|
||||
return _conv_gather(input_, dim, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
|
||||
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
ctx.kernel_size = kernel_size
|
||||
return _pass_from_previous_rank(input_, dim, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
|
||||
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
|
||||
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
|
||||
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
def conv_pass_from_last_rank(input_, dim, kernel_size):
|
||||
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
class ContextParallelCausalConv3d(nn.Module):
|
||||
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
||||
|
||||
time_pad = time_kernel_size - 1
|
||||
height_pad = height_kernel_size // 2
|
||||
width_pad = width_kernel_size // 2
|
||||
|
||||
self.height_pad = height_pad
|
||||
self.width_pad = width_pad
|
||||
self.time_pad = time_pad
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.temporal_dim = 2
|
||||
|
||||
stride = (stride, stride, stride)
|
||||
dilation = (1, 1, 1)
|
||||
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input_):
|
||||
# temporal padding inside
|
||||
if _USE_CP:
|
||||
input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
|
||||
else:
|
||||
input_ = input_.transpose(0, self.temporal_dim)
|
||||
input_parallel = torch.cat([input_[:1]] * (self.time_kernel_size - 1) + [input_], dim=0)
|
||||
input_parallel = input_parallel.transpose(0, self.temporal_dim)
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
||||
output_parallel = self.conv(input_parallel)
|
||||
output = output_parallel
|
||||
return output
|
||||
|
||||
|
||||
class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
def forward(self, input_):
|
||||
if _USE_CP:
|
||||
input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
|
||||
output = super().forward(input_)
|
||||
if _USE_CP:
|
||||
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
|
||||
return output
|
||||
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
else:
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
f_channels,
|
||||
zq_channels,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=False,
|
||||
pad_mode="constant",
|
||||
gather=False,
|
||||
**norm_layer_params,
|
||||
):
|
||||
super().__init__()
|
||||
if gather:
|
||||
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
else:
|
||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
||||
if freeze_norm_layer:
|
||||
for p in self.norm_layer.parameters:
|
||||
p.requires_grad = False
|
||||
|
||||
self.add_conv = add_conv
|
||||
if add_conv:
|
||||
self.conv = ContextParallelCausalConv3d(
|
||||
chan_in=zq_channels,
|
||||
chan_out=zq_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
self.conv_y = ContextParallelCausalConv3d(
|
||||
chan_in=zq_channels,
|
||||
chan_out=f_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.conv_b = ContextParallelCausalConv3d(
|
||||
chan_in=zq_channels,
|
||||
chan_out=f_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f, zq):
|
||||
if f.shape[2] == 1 and not _USE_CP:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
elif get_context_parallel_rank() == 0:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
|
||||
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
|
||||
norm_f = self.norm_layer(f)
|
||||
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
|
||||
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
def Normalize3D(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
add_conv,
|
||||
gather=False,
|
||||
):
|
||||
return SpatialNorm3D(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
gather=gather,
|
||||
# norm_layer=nn.GroupNorm,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=add_conv,
|
||||
num_groups=32,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
)
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
with_conv,
|
||||
compress_time=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
if x.shape[2] == 1 and not _USE_CP:
|
||||
x = torch.nn.functional.interpolate(x[:, :, 0], scale_factor=2.0, mode="nearest")[:, :, None, :, :]
|
||||
elif get_context_parallel_rank() == 0:
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
if self.with_conv:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class DownSample3D(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
else:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class ContextParallelResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
gather_norm=False,
|
||||
normalization=Normalize,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = normalization(
|
||||
in_channels,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
gather=gather_norm,
|
||||
)
|
||||
|
||||
self.conv1 = ContextParallelCausalConv3d(
|
||||
chan_in=in_channels,
|
||||
chan_out=out_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = normalization(
|
||||
out_channels,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
gather=gather_norm,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = ContextParallelCausalConv3d(
|
||||
chan_in=out_channels,
|
||||
chan_out=out_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ContextParallelCausalConv3d(
|
||||
chan_in=in_channels,
|
||||
chan_out=out_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq=None):
|
||||
h = x
|
||||
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm1(h, zq)
|
||||
else:
|
||||
h = self.norm1(h)
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm2(h, zq)
|
||||
else:
|
||||
h = self.norm2(h)
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class ContextParallelEncoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
gather_norm=False,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
self.conv_in = ContextParallelCausalConv3d(
|
||||
chan_in=in_channels,
|
||||
chan_out=self.ch,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dropout=dropout,
|
||||
temb_channels=self.temb_ch,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if i_level < self.temporal_compress_level:
|
||||
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
self.mid.block_2 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, gather=gather_norm)
|
||||
|
||||
self.conv_out = ContextParallelCausalConv3d(
|
||||
chan_in=block_in,
|
||||
chan_out=2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, x, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
h = self.norm_out(h)
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class ContextParallelDecoder3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
gather_norm=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
|
||||
# log2 of temporal_compress_times
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
self.conv_in = ContextParallelCausalConv3d(
|
||||
chan_in=z_channels,
|
||||
chan_out=block_in,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
normalization=Normalize3D,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
self.mid.block_2 = ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
normalization=Normalize3D,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ContextParallelResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
normalization=Normalize3D,
|
||||
gather_norm=gather_norm,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
||||
|
||||
self.conv_out = ContextParallelCausalConv3d(
|
||||
chan_in=block_in,
|
||||
chan_out=out_ch,
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, z, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
_USE_CP = True
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.conv.weight
|
6
sat/sgm/modules/diffusionmodules/__init__.py
Normal file
6
sat/sgm/modules/diffusionmodules/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .denoiser import Denoiser
|
||||
from .discretizer import Discretization
|
||||
from .model import Decoder, Encoder, Model
|
||||
from .openaimodel import UNetModel
|
||||
from .sampling import BaseDiffusionSampler
|
||||
from .wrappers import OpenAIWrapper
|
72
sat/sgm/modules/diffusionmodules/denoiser.py
Normal file
72
sat/sgm/modules/diffusionmodules/denoiser.py
Normal file
@ -0,0 +1,72 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
|
||||
|
||||
class Denoiser(nn.Module):
|
||||
def __init__(self, weighting_config, scaling_config):
|
||||
super().__init__()
|
||||
|
||||
self.weighting = instantiate_from_config(weighting_config)
|
||||
self.scaling = instantiate_from_config(scaling_config)
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
return sigma
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
return c_noise
|
||||
|
||||
def w(self, sigma):
|
||||
return self.weighting(sigma)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
network: nn.Module,
|
||||
input: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
cond: Dict,
|
||||
**additional_model_inputs,
|
||||
) -> torch.Tensor:
|
||||
sigma = self.possibly_quantize_sigma(sigma)
|
||||
sigma_shape = sigma.shape
|
||||
sigma = append_dims(sigma, input.ndim)
|
||||
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
|
||||
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
||||
return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DiscreteDenoiser(Denoiser):
|
||||
def __init__(
|
||||
self,
|
||||
weighting_config,
|
||||
scaling_config,
|
||||
num_idx,
|
||||
discretization_config,
|
||||
do_append_zero=False,
|
||||
quantize_c_noise=True,
|
||||
flip=True,
|
||||
):
|
||||
super().__init__(weighting_config, scaling_config)
|
||||
sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
self.sigmas = sigmas
|
||||
# self.register_buffer("sigmas", sigmas)
|
||||
self.quantize_c_noise = quantize_c_noise
|
||||
|
||||
def sigma_to_idx(self, sigma):
|
||||
dists = sigma - self.sigmas.to(sigma.device)[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
return self.sigmas.to(idx.device)[idx]
|
||||
|
||||
def possibly_quantize_sigma(self, sigma):
|
||||
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
||||
|
||||
def possibly_quantize_c_noise(self, c_noise):
|
||||
if self.quantize_c_noise:
|
||||
return self.sigma_to_idx(c_noise)
|
||||
else:
|
||||
return c_noise
|
60
sat/sgm/modules/diffusionmodules/denoiser_scaling.py
Normal file
60
sat/sgm/modules/diffusionmodules/denoiser_scaling.py
Normal file
@ -0,0 +1,60 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DenoiserScaling(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class EDMScaling:
|
||||
def __init__(self, sigma_data: float = 0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_noise = 0.25 * sigma.log()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class EpsScaling:
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma.clone()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class VScaling:
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = sigma.clone()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class VScalingWithEDMcNoise(DenoiserScaling):
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
c_noise = 0.25 * sigma.log()
|
||||
return c_skip, c_out, c_in, c_noise
|
||||
|
||||
|
||||
class VideoScaling: # similar to VScaling
|
||||
def __call__(
|
||||
self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = alphas_cumprod_sqrt
|
||||
c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5)
|
||||
c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
|
||||
c_noise = additional_model_inputs["idx"].clone()
|
||||
return c_skip, c_out, c_in, c_noise
|
24
sat/sgm/modules/diffusionmodules/denoiser_weighting.py
Normal file
24
sat/sgm/modules/diffusionmodules/denoiser_weighting.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
class UnitWeighting:
|
||||
def __call__(self, sigma):
|
||||
return torch.ones_like(sigma, device=sigma.device)
|
||||
|
||||
|
||||
class EDMWeighting:
|
||||
def __init__(self, sigma_data=0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma):
|
||||
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
||||
|
||||
|
||||
class VWeighting(EDMWeighting):
|
||||
def __init__(self):
|
||||
super().__init__(sigma_data=1.0)
|
||||
|
||||
|
||||
class EpsWeighting:
|
||||
def __call__(self, sigma):
|
||||
return sigma**-2.0
|
126
sat/sgm/modules/diffusionmodules/discretizer.py
Normal file
126
sat/sgm/modules/diffusionmodules/discretizer.py
Normal file
@ -0,0 +1,126 @@
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...modules.diffusionmodules.util import make_beta_schedule
|
||||
from ...util import append_zero
|
||||
|
||||
|
||||
def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray:
|
||||
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
|
||||
|
||||
|
||||
class Discretization:
|
||||
def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
|
||||
if return_idx:
|
||||
sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx)
|
||||
else:
|
||||
sigmas = self.get_sigmas(n, device=device, return_idx=return_idx)
|
||||
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
||||
if return_idx:
|
||||
return sigmas if not flip else torch.flip(sigmas, (0,)), idx
|
||||
else:
|
||||
return sigmas if not flip else torch.flip(sigmas, (0,))
|
||||
|
||||
@abstractmethod
|
||||
def get_sigmas(self, n, device):
|
||||
pass
|
||||
|
||||
|
||||
class EDMDiscretization(Discretization):
|
||||
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.rho = rho
|
||||
|
||||
def get_sigmas(self, n, device="cpu"):
|
||||
ramp = torch.linspace(0, 1, n, device=device)
|
||||
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
||||
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
|
||||
return sigmas
|
||||
|
||||
|
||||
class LegacyDDPMDiscretization(Discretization):
|
||||
def __init__(
|
||||
self,
|
||||
linear_start=0.00085,
|
||||
linear_end=0.0120,
|
||||
num_timesteps=1000,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_timesteps = num_timesteps
|
||||
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
def get_sigmas(self, n, device="cpu"):
|
||||
if n < self.num_timesteps:
|
||||
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
||||
alphas_cumprod = self.alphas_cumprod[timesteps]
|
||||
elif n == self.num_timesteps:
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
||||
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029
|
||||
|
||||
|
||||
class ZeroSNRDDPMDiscretization(Discretization):
|
||||
def __init__(
|
||||
self,
|
||||
linear_start=0.00085,
|
||||
linear_end=0.0120,
|
||||
num_timesteps=1000,
|
||||
shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale)
|
||||
keep_start=False,
|
||||
post_shift=False,
|
||||
):
|
||||
super().__init__()
|
||||
if keep_start and not post_shift:
|
||||
linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
|
||||
self.num_timesteps = num_timesteps
|
||||
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
# SNR shift
|
||||
if not post_shift:
|
||||
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
|
||||
|
||||
self.post_shift = post_shift
|
||||
self.shift_scale = shift_scale
|
||||
|
||||
def get_sigmas(self, n, device="cpu", return_idx=False):
|
||||
if n < self.num_timesteps:
|
||||
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
||||
alphas_cumprod = self.alphas_cumprod[timesteps]
|
||||
elif n == self.num_timesteps:
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
||||
alphas_cumprod = to_torch(alphas_cumprod)
|
||||
alphas_cumprod_sqrt = alphas_cumprod.sqrt()
|
||||
alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
|
||||
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
|
||||
|
||||
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
|
||||
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
|
||||
|
||||
if self.post_shift:
|
||||
alphas_cumprod_sqrt = (
|
||||
alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
|
||||
) ** 0.5
|
||||
|
||||
if return_idx:
|
||||
return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
|
||||
else:
|
||||
return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
|
87
sat/sgm/modules/diffusionmodules/guiders.py
Normal file
87
sat/sgm/modules/diffusionmodules/guiders.py
Normal file
@ -0,0 +1,87 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ...util import append_dims, default, instantiate_from_config
|
||||
|
||||
|
||||
class Guider(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
|
||||
pass
|
||||
|
||||
|
||||
class VanillaCFG:
|
||||
"""
|
||||
implements parallelized CFG
|
||||
"""
|
||||
|
||||
def __init__(self, scale, dyn_thresh_config=None):
|
||||
self.scale = scale
|
||||
scale_schedule = lambda scale, sigma: scale # independent of step
|
||||
self.scale_schedule = partial(scale_schedule, scale)
|
||||
self.dyn_thresh = instantiate_from_config(
|
||||
default(
|
||||
dyn_thresh_config,
|
||||
{"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x, sigma, scale=None):
|
||||
x_u, x_c = x.chunk(2)
|
||||
scale_value = default(scale, self.scale_schedule(sigma))
|
||||
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
|
||||
return x_pred
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
if k in ["vector", "crossattn", "concat"]:
|
||||
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
||||
else:
|
||||
assert c[k] == uc[k]
|
||||
c_out[k] = c[k]
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
|
||||
class DynamicCFG(VanillaCFG):
|
||||
def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
|
||||
super().__init__(scale, dyn_thresh_config)
|
||||
scale_schedule = (
|
||||
lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
|
||||
)
|
||||
self.scale_schedule = partial(scale_schedule, scale)
|
||||
self.dyn_thresh = instantiate_from_config(
|
||||
default(
|
||||
dyn_thresh_config,
|
||||
{"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"},
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x, sigma, step_index, scale=None):
|
||||
x_u, x_c = x.chunk(2)
|
||||
scale_value = self.scale_schedule(sigma, step_index.item())
|
||||
x_pred = self.dyn_thresh(x_u, x_c, scale_value)
|
||||
return x_pred
|
||||
|
||||
|
||||
class IdentityGuider:
|
||||
def __call__(self, x, sigma):
|
||||
return x
|
||||
|
||||
def prepare_inputs(self, x, s, c, uc):
|
||||
c_out = dict()
|
||||
|
||||
for k in c:
|
||||
c_out[k] = c[k]
|
||||
|
||||
return x, s, c_out
|
362
sat/sgm/modules/diffusionmodules/lora.py
Normal file
362
sat/sgm/modules/diffusionmodules/lora.py
Normal file
@ -0,0 +1,362 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
||||
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
||||
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
||||
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
|
||||
class LoRACompatibleConv(nn.Conv2d):
|
||||
"""
|
||||
A convolutional layer that can be used with LoRA.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
self.scale = scale
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
dtype, device = self.weight.data.dtype, self.weight.data.device
|
||||
|
||||
w_orig = self.weight.data.float()
|
||||
w_up = self.lora_layer.up.weight.data.float()
|
||||
w_down = self.lora_layer.down.weight.data.float()
|
||||
|
||||
if self.lora_layer.network_alpha is not None:
|
||||
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
||||
|
||||
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
|
||||
fusion = fusion.reshape((w_orig.shape))
|
||||
fused_weight = w_orig + (lora_scale * fusion)
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
self.lora_layer = None
|
||||
|
||||
# offload the up and down matrices to CPU to not blow the memory
|
||||
self.w_up = w_up.cpu()
|
||||
self.w_down = w_down.cpu()
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
return
|
||||
|
||||
fused_weight = self.weight.data
|
||||
dtype, device = fused_weight.data.dtype, fused_weight.data.device
|
||||
|
||||
self.w_up = self.w_up.to(device=device).float()
|
||||
self.w_down = self.w_down.to(device).float()
|
||||
|
||||
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
|
||||
fusion = fusion.reshape((fused_weight.shape))
|
||||
unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
|
||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states, scale: float = None):
|
||||
if scale is None:
|
||||
scale = self.scale
|
||||
if self.lora_layer is None:
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
return F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
else:
|
||||
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
||||
|
||||
|
||||
class LoRACompatibleLinear(nn.Linear):
|
||||
"""
|
||||
A Linear layer that can be used with LoRA.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
self.scale = scale
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
dtype, device = self.weight.data.dtype, self.weight.data.device
|
||||
|
||||
w_orig = self.weight.data.float()
|
||||
w_up = self.lora_layer.up.weight.data.float()
|
||||
w_down = self.lora_layer.down.weight.data.float()
|
||||
|
||||
if self.lora_layer.network_alpha is not None:
|
||||
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
||||
|
||||
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
self.lora_layer = None
|
||||
|
||||
# offload the up and down matrices to CPU to not blow the memory
|
||||
self.w_up = w_up.cpu()
|
||||
self.w_down = w_down.cpu()
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
return
|
||||
|
||||
fused_weight = self.weight.data
|
||||
dtype, device = fused_weight.dtype, fused_weight.device
|
||||
|
||||
w_up = self.w_up.to(device=device).float()
|
||||
w_down = self.w_down.to(device).float()
|
||||
|
||||
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states, scale: float = None):
|
||||
if scale is None:
|
||||
scale = self.scale
|
||||
if self.lora_layer is None:
|
||||
out = super().forward(hidden_states)
|
||||
return out
|
||||
else:
|
||||
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
||||
return out
|
||||
|
||||
|
||||
def _find_children(
|
||||
model,
|
||||
search_class: List[Type[nn.Module]] = [nn.Linear],
|
||||
):
|
||||
"""
|
||||
Find all modules of a certain class (or union of classes).
|
||||
|
||||
Returns all matching modules, along with the parent of those moduless and the
|
||||
names they are referenced by.
|
||||
"""
|
||||
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
||||
for parent in model.modules():
|
||||
for name, module in parent.named_children():
|
||||
if any([isinstance(module, _class) for _class in search_class]):
|
||||
yield parent, name, module
|
||||
|
||||
|
||||
def _find_modules_v2(
|
||||
model,
|
||||
ancestor_class: Optional[Set[str]] = None,
|
||||
search_class: List[Type[nn.Module]] = [nn.Linear],
|
||||
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
||||
LoRACompatibleLinear,
|
||||
LoRACompatibleConv,
|
||||
LoRALinearLayer,
|
||||
LoRAConv2dLayer,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Find all modules of a certain class (or union of classes) that are direct or
|
||||
indirect descendants of other modules of a certain class (or union of classes).
|
||||
|
||||
Returns all matching modules, along with the parent of those moduless and the
|
||||
names they are referenced by.
|
||||
"""
|
||||
|
||||
# Get the targets we should replace all linears under
|
||||
if ancestor_class is not None:
|
||||
ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class)
|
||||
else:
|
||||
# this, incase you want to naively iterate over all modules.
|
||||
ancestors = [module for module in model.modules()]
|
||||
|
||||
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
||||
for ancestor in ancestors:
|
||||
for fullname, module in ancestor.named_modules():
|
||||
if any([isinstance(module, _class) for _class in search_class]):
|
||||
# Find the direct parent if this is a descendant, not a child, of target
|
||||
*path, name = fullname.split(".")
|
||||
parent = ancestor
|
||||
flag = False
|
||||
while path:
|
||||
try:
|
||||
parent = parent.get_submodule(path.pop(0))
|
||||
except:
|
||||
flag = True
|
||||
break
|
||||
if flag:
|
||||
continue
|
||||
# Skip this linear if it's a child of a LoraInjectedLinear
|
||||
if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]):
|
||||
continue
|
||||
# Otherwise, yield it
|
||||
yield parent, name, module
|
||||
|
||||
|
||||
_find_modules = _find_modules_v2
|
||||
|
||||
|
||||
def inject_trainable_lora_extended(
|
||||
model: nn.Module,
|
||||
target_replace_module: Set[str] = None,
|
||||
rank: int = 4,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
for _module, name, _child_module in _find_modules(
|
||||
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
|
||||
):
|
||||
if _child_module.__class__ == nn.Linear:
|
||||
weight = _child_module.weight
|
||||
bias = _child_module.bias
|
||||
lora_layer = LoRALinearLayer(
|
||||
in_features=_child_module.in_features,
|
||||
out_features=_child_module.out_features,
|
||||
rank=rank,
|
||||
)
|
||||
_tmp = (
|
||||
LoRACompatibleLinear(
|
||||
_child_module.in_features,
|
||||
_child_module.out_features,
|
||||
lora_layer=lora_layer,
|
||||
scale=scale,
|
||||
)
|
||||
.to(weight.dtype)
|
||||
.to(weight.device)
|
||||
)
|
||||
_tmp.weight = weight
|
||||
if bias is not None:
|
||||
_tmp.bias = bias
|
||||
elif _child_module.__class__ == nn.Conv2d:
|
||||
weight = _child_module.weight
|
||||
bias = _child_module.bias
|
||||
lora_layer = LoRAConv2dLayer(
|
||||
in_features=_child_module.in_channels,
|
||||
out_features=_child_module.out_channels,
|
||||
rank=rank,
|
||||
kernel_size=_child_module.kernel_size,
|
||||
stride=_child_module.stride,
|
||||
padding=_child_module.padding,
|
||||
)
|
||||
_tmp = (
|
||||
LoRACompatibleConv(
|
||||
_child_module.in_channels,
|
||||
_child_module.out_channels,
|
||||
kernel_size=_child_module.kernel_size,
|
||||
stride=_child_module.stride,
|
||||
padding=_child_module.padding,
|
||||
lora_layer=lora_layer,
|
||||
scale=scale,
|
||||
)
|
||||
.to(weight.dtype)
|
||||
.to(weight.device)
|
||||
)
|
||||
_tmp.weight = weight
|
||||
if bias is not None:
|
||||
_tmp.bias = bias
|
||||
else:
|
||||
continue
|
||||
|
||||
_module._modules[name] = _tmp
|
||||
# print('injecting lora layer to', _module, name)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def update_lora_scale(
|
||||
model: nn.Module,
|
||||
target_module: Set[str] = None,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
for _module, name, _child_module in _find_modules(
|
||||
model, target_module, search_class=[LoRACompatibleLinear, LoRACompatibleConv]
|
||||
):
|
||||
_child_module.scale = scale
|
||||
|
||||
return
|
132
sat/sgm/modules/diffusionmodules/loss.py
Normal file
132
sat/sgm/modules/diffusionmodules/loss.py
Normal file
@ -0,0 +1,132 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import ListConfig
|
||||
import math
|
||||
|
||||
from ...modules.diffusionmodules.sampling import VideoDDIMSampler, VPSDEDPMPP2MSampler
|
||||
from ...util import append_dims, instantiate_from_config
|
||||
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
|
||||
|
||||
# import rearrange
|
||||
from einops import rearrange
|
||||
import random
|
||||
from sat import mpu
|
||||
|
||||
|
||||
class StandardDiffusionLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sigma_sampler_config,
|
||||
type="l2",
|
||||
offset_noise_level=0.0,
|
||||
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert type in ["l2", "l1", "lpips"]
|
||||
|
||||
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
|
||||
|
||||
self.type = type
|
||||
self.offset_noise_level = offset_noise_level
|
||||
|
||||
if type == "lpips":
|
||||
self.lpips = LPIPS().eval()
|
||||
|
||||
if not batch2model_keys:
|
||||
batch2model_keys = []
|
||||
|
||||
if isinstance(batch2model_keys, str):
|
||||
batch2model_keys = [batch2model_keys]
|
||||
|
||||
self.batch2model_keys = set(batch2model_keys)
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
cond = conditioner(batch)
|
||||
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
|
||||
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
||||
noise = torch.randn_like(input)
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = (
|
||||
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
|
||||
)
|
||||
noise = noise.to(input.dtype)
|
||||
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
|
||||
model_output = denoiser(network, noised_input, sigmas, cond, **additional_model_inputs)
|
||||
w = append_dims(denoiser.w(sigmas), input.ndim)
|
||||
return self.get_loss(model_output, input, w)
|
||||
|
||||
def get_loss(self, model_output, target, w):
|
||||
if self.type == "l2":
|
||||
return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
|
||||
elif self.type == "l1":
|
||||
return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
|
||||
elif self.type == "lpips":
|
||||
loss = self.lpips(model_output, target).reshape(-1)
|
||||
return loss
|
||||
|
||||
|
||||
class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||
def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs):
|
||||
self.fixed_frames = fixed_frames
|
||||
self.block_scale = block_scale
|
||||
self.block_size = block_size
|
||||
self.min_snr_value = min_snr_value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
cond = conditioner(batch)
|
||||
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
|
||||
|
||||
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
|
||||
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
|
||||
idx = idx.to(input.device)
|
||||
|
||||
noise = torch.randn_like(input)
|
||||
|
||||
# broadcast noise
|
||||
mp_size = mpu.get_model_parallel_world_size()
|
||||
global_rank = torch.distributed.get_rank() // mp_size
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
|
||||
torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
|
||||
torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group())
|
||||
|
||||
additional_model_inputs["idx"] = idx
|
||||
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = (
|
||||
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
|
||||
)
|
||||
|
||||
noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims(
|
||||
(1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim
|
||||
)
|
||||
|
||||
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
||||
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
||||
|
||||
if self.min_snr_value is not None:
|
||||
w = min(w, self.min_snr_value)
|
||||
return self.get_loss(model_output, input, w)
|
||||
|
||||
def get_loss(self, model_output, target, w):
|
||||
if self.type == "l2":
|
||||
return torch.mean((w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1)
|
||||
elif self.type == "l1":
|
||||
return torch.mean((w * (model_output - target).abs()).reshape(target.shape[0], -1), 1)
|
||||
elif self.type == "lpips":
|
||||
loss = self.lpips(model_output, target).reshape(-1)
|
||||
return loss
|
||||
|
||||
|
||||
def get_3d_position_ids(frame_len, h, w):
|
||||
i = torch.arange(frame_len).view(frame_len, 1, 1).expand(frame_len, h, w)
|
||||
j = torch.arange(h).view(1, h, 1).expand(frame_len, h, w)
|
||||
k = torch.arange(w).view(1, 1, w).expand(frame_len, h, w)
|
||||
position_ids = torch.stack([i, j, k], dim=-1).reshape(-1, 3)
|
||||
return position_ids
|
683
sat/sgm/modules/diffusionmodules/model.py
Normal file
683
sat/sgm/modules/diffusionmodules/model.py
Normal file
@ -0,0 +1,683 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
|
||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
||||
h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
||||
# compute attention
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
h_ = x
|
||||
h_ = self.attention(h_)
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)
|
||||
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
h_ = x
|
||||
h_ = self.attention(h_)
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
def forward(self, x, context=None, mask=None, **unused_kwargs):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
out = super().forward(x, context=context, mask=mask)
|
||||
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
|
||||
return x + out
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in [
|
||||
"vanilla",
|
||||
"vanilla-xformers",
|
||||
"memory-efficient-cross-attn",
|
||||
"linear",
|
||||
"none",
|
||||
], f"attn_type {attn_type} unknown"
|
||||
if version.parse(torch.__version__) < version.parse("2.0.0") and attn_type != "none":
|
||||
assert XFORMERS_IS_AVAILABLE, (
|
||||
f"We do not support vanilla attention in {torch.__version__} anymore, "
|
||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
]
|
||||
)
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type="vanilla",
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
make_resblock_cls = self._make_resblock()
|
||||
make_conv_cls = self._make_conv()
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
make_resblock_cls(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn_cls(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = make_conv_cls(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def _make_attn(self) -> Callable:
|
||||
return make_attn
|
||||
|
||||
def _make_resblock(self) -> Callable:
|
||||
return ResnetBlock
|
||||
|
||||
def _make_conv(self) -> Callable:
|
||||
return torch.nn.Conv2d
|
||||
|
||||
def get_last_layer(self, **kwargs):
|
||||
return self.conv_out.weight
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
1249
sat/sgm/modules/diffusionmodules/openaimodel.py
Normal file
1249
sat/sgm/modules/diffusionmodules/openaimodel.py
Normal file
File diff suppressed because it is too large
Load Diff
763
sat/sgm/modules/diffusionmodules/sampling.py
Normal file
763
sat/sgm/modules/diffusionmodules/sampling.py
Normal file
@ -0,0 +1,763 @@
|
||||
"""
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
"""
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...modules.diffusionmodules.sampling_utils import (
|
||||
get_ancestral_step,
|
||||
linear_multistep_coeff,
|
||||
to_d,
|
||||
to_neg_log_sigma,
|
||||
to_sigma,
|
||||
)
|
||||
from ...util import append_dims, default, instantiate_from_config
|
||||
from ...util import SeededNoise
|
||||
|
||||
from .guiders import DynamicCFG
|
||||
|
||||
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
|
||||
|
||||
class BaseDiffusionSampler:
|
||||
def __init__(
|
||||
self,
|
||||
discretization_config: Union[Dict, ListConfig, OmegaConf],
|
||||
num_steps: Union[int, None] = None,
|
||||
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
|
||||
verbose: bool = False,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.num_steps = num_steps
|
||||
self.discretization = instantiate_from_config(discretization_config)
|
||||
self.guider = instantiate_from_config(
|
||||
default(
|
||||
guider_config,
|
||||
DEFAULT_GUIDER,
|
||||
)
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.device = device
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
|
||||
uc = default(uc, cond)
|
||||
|
||||
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
num_sigmas = len(sigmas)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).float()
|
||||
|
||||
return x, s_in, sigmas, num_sigmas, cond, uc
|
||||
|
||||
def denoise(self, x, denoiser, sigma, cond, uc):
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
|
||||
denoised = self.guider(denoised, sigma)
|
||||
return denoised
|
||||
|
||||
def get_sigma_gen(self, num_sigmas):
|
||||
sigma_generator = range(num_sigmas - 1)
|
||||
if self.verbose:
|
||||
print("#" * 30, " Sampling setting ", "#" * 30)
|
||||
print(f"Sampler: {self.__class__.__name__}")
|
||||
print(f"Discretization: {self.discretization.__class__.__name__}")
|
||||
print(f"Guider: {self.guider.__class__.__name__}")
|
||||
sigma_generator = tqdm(
|
||||
sigma_generator,
|
||||
total=num_sigmas,
|
||||
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
|
||||
)
|
||||
return sigma_generator
|
||||
|
||||
|
||||
class SingleStepDiffusionSampler(BaseDiffusionSampler):
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def euler_step(self, x, d, dt):
|
||||
return x + dt * d
|
||||
|
||||
|
||||
class EDMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_churn = s_churn
|
||||
self.s_tmin = s_tmin
|
||||
self.s_tmax = s_tmax
|
||||
self.s_noise = s_noise
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
|
||||
sigma_hat = sigma * (gamma + 1.0)
|
||||
if gamma > 0:
|
||||
eps = torch.randn_like(x) * self.s_noise
|
||||
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
|
||||
|
||||
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
||||
|
||||
euler_step = self.euler_step(x, d, dt)
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
gamma = (
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
|
||||
)
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
gamma,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DDIMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, s_noise=0.1, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_noise = s_noise
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
d = to_d(x, sigma, denoised)
|
||||
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
|
||||
|
||||
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
|
||||
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
self.s_noise,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AncestralSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.eta = eta
|
||||
self.s_noise = s_noise
|
||||
self.noise_sampler = lambda x: torch.randn_like(x)
|
||||
|
||||
def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
|
||||
d = to_d(x, sigma, denoised)
|
||||
dt = append_dims(sigma_down - sigma, x.ndim)
|
||||
|
||||
return self.euler_step(x, d, dt)
|
||||
|
||||
def ancestral_step(self, x, sigma, next_sigma, sigma_up):
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0,
|
||||
x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
|
||||
x,
|
||||
)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LinearMultistepSampler(BaseDiffusionSampler):
|
||||
def __init__(
|
||||
self,
|
||||
order=4,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.order = order
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
ds = []
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
sigma = s_in * sigmas[i]
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
|
||||
denoised = self.guider(denoised, sigma)
|
||||
d = to_d(x, sigma, denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > self.order:
|
||||
ds.pop(0)
|
||||
cur_order = min(i + 1, self.order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerEDMSampler(EDMSampler):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
return euler_step
|
||||
|
||||
|
||||
class HeunEDMSampler(EDMSampler):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
if torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0
|
||||
return euler_step
|
||||
else:
|
||||
denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
|
||||
d_new = to_d(euler_step, next_sigma, denoised)
|
||||
d_prime = (d + d_new) / 2.0
|
||||
|
||||
# apply correction if noise level is not 0
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
|
||||
return x
|
||||
|
||||
|
||||
class EulerAncestralSampler(AncestralSampler):
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
|
||||
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
|
||||
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DPMPP2SAncestralSampler(AncestralSampler):
|
||||
def get_variables(self, sigma, sigma_down):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
|
||||
h = t_next - t
|
||||
s = t + 0.5 * h
|
||||
return h, s, t, t_next
|
||||
|
||||
def get_mult(self, h, s, t, t_next):
|
||||
mult1 = to_sigma(s) / to_sigma(t)
|
||||
mult2 = (-0.5 * h).expm1()
|
||||
mult3 = to_sigma(t_next) / to_sigma(t)
|
||||
mult4 = (-h).expm1()
|
||||
|
||||
return mult1, mult2, mult3, mult4
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
|
||||
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
|
||||
|
||||
if torch.sum(sigma_down) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0
|
||||
x = x_euler
|
||||
else:
|
||||
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
|
||||
|
||||
x2 = mult[0] * x - mult[1] * denoised
|
||||
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
||||
x_dpmpp2s = mult[2] * x - mult[3] * denoised2
|
||||
|
||||
# apply correction if noise level is not 0
|
||||
x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
|
||||
|
||||
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
|
||||
return x
|
||||
|
||||
|
||||
class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
||||
h = t_next - t
|
||||
|
||||
if previous_sigma is not None:
|
||||
h_last = t - to_neg_log_sigma(previous_sigma)
|
||||
r = h_last / h
|
||||
return h, r, t, t_next
|
||||
else:
|
||||
return h, None, t, t_next
|
||||
|
||||
def get_mult(self, h, r, t, t_next, previous_sigma):
|
||||
mult1 = to_sigma(t_next) / to_sigma(t)
|
||||
mult2 = (-h).expm1()
|
||||
|
||||
if previous_sigma is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
mult4 = 1 / (2 * r)
|
||||
return mult1, mult2, mult3, mult4
|
||||
else:
|
||||
return mult1, mult2
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
old_denoised,
|
||||
previous_sigma,
|
||||
sigma,
|
||||
next_sigma,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0 or on the first step
|
||||
return x_standard, denoised
|
||||
else:
|
||||
denoised_d = mult[2] * denoised - mult[3] * old_denoised
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * sigmas[i - 1],
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=uc,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
||||
h = t_next - t
|
||||
|
||||
if previous_sigma is not None:
|
||||
h_last = t - to_neg_log_sigma(previous_sigma)
|
||||
r = h_last / h
|
||||
return h, r, t, t_next
|
||||
else:
|
||||
return h, None, t, t_next
|
||||
|
||||
def get_mult(self, h, r, t, t_next, previous_sigma):
|
||||
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
|
||||
mult2 = (-2 * h).expm1()
|
||||
|
||||
if previous_sigma is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
mult4 = 1 / (2 * r)
|
||||
return mult1, mult2, mult3, mult4
|
||||
else:
|
||||
return mult1, mult2
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
old_denoised,
|
||||
previous_sigma,
|
||||
sigma,
|
||||
next_sigma,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
|
||||
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0 or on the first step
|
||||
return x_standard, denoised
|
||||
else:
|
||||
denoised_d = mult[2] * denoised - mult[3] * old_denoised
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * sigmas[i - 1],
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=uc,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SdeditEDMSampler(EulerEDMSampler):
|
||||
def __init__(self, edit_ratio=0.5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.edit_ratio = edit_ratio
|
||||
|
||||
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
|
||||
randn_unit = randn.clone()
|
||||
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)
|
||||
|
||||
if num_steps is None:
|
||||
num_steps = self.num_steps
|
||||
if edit_ratio is None:
|
||||
edit_ratio = self.edit_ratio
|
||||
x = None
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
if i / num_steps < edit_ratio:
|
||||
continue
|
||||
if x is None:
|
||||
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
|
||||
|
||||
gamma = (
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
|
||||
)
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
s_in * sigmas[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
gamma,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.fixed_frames = fixed_frames
|
||||
self.sdedit = sdedit
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps,
|
||||
device=self.device,
|
||||
return_idx=True,
|
||||
do_append_zero=False,
|
||||
)
|
||||
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
|
||||
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))])
|
||||
|
||||
uc = default(uc, cond)
|
||||
|
||||
num_sigmas = len(alpha_cumprod_sqrt)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
||||
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None):
|
||||
additional_model_inputs = {}
|
||||
|
||||
if isinstance(scale, torch.Tensor) == False and scale == 1:
|
||||
additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep
|
||||
if scale_emb is not None:
|
||||
additional_model_inputs["scale_emb"] = scale_emb
|
||||
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
|
||||
else:
|
||||
additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
|
||||
).to(torch.float32)
|
||||
if isinstance(self.guider, DynamicCFG):
|
||||
denoised = self.guider(
|
||||
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
|
||||
)
|
||||
else:
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
|
||||
return denoised
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
alpha_cumprod_sqrt,
|
||||
next_alpha_cumprod_sqrt,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
):
|
||||
denoised = self.denoise(
|
||||
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
|
||||
).to(torch.float32)
|
||||
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||
|
||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * alpha_cumprod_sqrt[i],
|
||||
s_in * alpha_cumprod_sqrt[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
else:
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
|
||||
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
mult4 = 1 / (2 * r)
|
||||
return mult1, mult2, mult3, mult4
|
||||
else:
|
||||
return mult1, mult2
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
old_denoised,
|
||||
previous_alpha_cumprod_sqrt,
|
||||
alpha_cumprod_sqrt,
|
||||
next_alpha_cumprod_sqrt,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
):
|
||||
denoised = self.denoise(
|
||||
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
|
||||
).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
]
|
||||
mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0 or on the first step
|
||||
return x_standard, denoised
|
||||
else:
|
||||
denoised_d = mult[2] * denoised - mult[3] * old_denoised
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
|
||||
|
||||
x = x_advanced
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
prefix_frames = x[:, : self.fixed_frames]
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
if self.fixed_frames > 0:
|
||||
if self.sdedit:
|
||||
rd = torch.randn_like(prefix_frames)
|
||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
|
||||
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
|
||||
)
|
||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
else:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
|
||||
s_in * alpha_cumprod_sqrt[i],
|
||||
s_in * alpha_cumprod_sqrt[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
else:
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
mult4 = 1 / (2 * r)
|
||||
return mult1, mult2, mult3, mult4
|
||||
else:
|
||||
return mult1, mult2
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
old_denoised,
|
||||
previous_alpha_cumprod_sqrt,
|
||||
alpha_cumprod_sqrt,
|
||||
next_alpha_cumprod_sqrt,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
]
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised
|
||||
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0 or on the first step
|
||||
return x_standard, denoised
|
||||
else:
|
||||
denoised_d = mult[2] * denoised - mult[3] * old_denoised
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d
|
||||
|
||||
x = x_advanced
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
|
||||
s_in * alpha_cumprod_sqrt[i],
|
||||
s_in * alpha_cumprod_sqrt[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
)
|
||||
|
||||
return x
|
155
sat/sgm/modules/diffusionmodules/sampling_utils.py
Normal file
155
sat/sgm/modules/diffusionmodules/sampling_utils.py
Normal file
@ -0,0 +1,155 @@
|
||||
import torch
|
||||
from scipy import integrate
|
||||
|
||||
from ...util import append_dims
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class NoDynamicThresholding:
|
||||
def __call__(self, uncond, cond, scale):
|
||||
scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
|
||||
class StaticThresholding:
|
||||
def __call__(self, uncond, cond, scale):
|
||||
result = uncond + scale * (cond - uncond)
|
||||
result = torch.clamp(result, min=-1.0, max=1.0)
|
||||
return result
|
||||
|
||||
|
||||
def dynamic_threshold(x, p=0.95):
|
||||
N, T, C, H, W = x.shape
|
||||
x = rearrange(x, "n t c h w -> n c (t h w)")
|
||||
l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True)
|
||||
s = torch.maximum(-l, r)
|
||||
threshold_mask = (s > 1).expand(-1, -1, H * W * T)
|
||||
if threshold_mask.any():
|
||||
x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x)
|
||||
x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W)
|
||||
return x
|
||||
|
||||
|
||||
def dynamic_thresholding2(x0):
|
||||
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||
origin_dtype = x0.dtype
|
||||
x0 = x0.to(torch.float32)
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
|
||||
x0 = torch.clamp(x0, -s, s) # / s
|
||||
return x0.to(origin_dtype)
|
||||
|
||||
|
||||
def latent_dynamic_thresholding(x0):
|
||||
p = 0.9995
|
||||
origin_dtype = x0.dtype
|
||||
x0 = x0.to(torch.float32)
|
||||
s = torch.quantile(torch.abs(x0), p, dim=2)
|
||||
s = append_dims(s, x0.dim())
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0.to(origin_dtype)
|
||||
|
||||
|
||||
def dynamic_thresholding3(x0):
|
||||
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||
origin_dtype = x0.dtype
|
||||
x0 = x0.to(torch.float32)
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
|
||||
x0 = torch.clamp(x0, -s, s) # / s
|
||||
return x0.to(origin_dtype)
|
||||
|
||||
|
||||
class DynamicThresholding:
|
||||
def __call__(self, uncond, cond, scale):
|
||||
mean = uncond.mean()
|
||||
std = uncond.std()
|
||||
result = uncond + scale * (cond - uncond)
|
||||
result_mean, result_std = result.mean(), result.std()
|
||||
result = (result - result_mean) / result_std * std
|
||||
# result = dynamic_thresholding3(result)
|
||||
return result
|
||||
|
||||
|
||||
class DynamicThresholdingV1:
|
||||
def __init__(self, scale_factor):
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def __call__(self, uncond, cond, scale):
|
||||
result = uncond + scale * (cond - uncond)
|
||||
unscaled_result = result / self.scale_factor
|
||||
B, T, C, H, W = unscaled_result.shape
|
||||
flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)")
|
||||
means = flattened.mean(dim=2).unsqueeze(2)
|
||||
recentered = flattened - means
|
||||
magnitudes = recentered.abs().max()
|
||||
normalized = recentered / magnitudes
|
||||
thresholded = latent_dynamic_thresholding(normalized)
|
||||
denormalized = thresholded * magnitudes
|
||||
uncentered = denormalized + means
|
||||
unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
|
||||
scaled_result = unflattened * self.scale_factor
|
||||
return scaled_result
|
||||
|
||||
|
||||
class DynamicThresholdingV2:
|
||||
def __call__(self, uncond, cond, scale):
|
||||
B, T, C, H, W = uncond.shape
|
||||
diff = cond - uncond
|
||||
mim_target = uncond + diff * 4.0
|
||||
cfg_target = uncond + diff * 8.0
|
||||
|
||||
mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)")
|
||||
cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)")
|
||||
mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
|
||||
cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
|
||||
mim_centered = mim_flattened - mim_means
|
||||
cfg_centered = cfg_flattened - cfg_means
|
||||
|
||||
mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
|
||||
cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
|
||||
|
||||
cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref
|
||||
|
||||
result = cfg_renormalized + cfg_means
|
||||
unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
|
||||
|
||||
return unflattened
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f"Order {order} too high for step {i}")
|
||||
|
||||
def fn(tau):
|
||||
prod = 1.0
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
|
||||
if not eta:
|
||||
return sigma_to, 0.0
|
||||
sigma_up = torch.minimum(
|
||||
sigma_to,
|
||||
eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
|
||||
)
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def to_neg_log_sigma(sigma):
|
||||
return sigma.log().neg()
|
||||
|
||||
|
||||
def to_sigma(neg_log_sigma):
|
||||
return neg_log_sigma.neg().exp()
|
80
sat/sgm/modules/diffusionmodules/sigma_sampling.py
Normal file
80
sat/sgm/modules/diffusionmodules/sigma_sampling.py
Normal file
@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from sat import mpu
|
||||
|
||||
from ...util import default, instantiate_from_config
|
||||
|
||||
|
||||
class EDMSampling:
|
||||
def __init__(self, p_mean=-1.2, p_std=1.2):
|
||||
self.p_mean = p_mean
|
||||
self.p_std = p_std
|
||||
|
||||
def __call__(self, n_samples, rand=None):
|
||||
log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
|
||||
return log_sigma.exp()
|
||||
|
||||
|
||||
class DiscreteSampling:
|
||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
|
||||
self.num_idx = num_idx
|
||||
self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
world_size = mpu.get_data_parallel_world_size()
|
||||
self.uniform_sampling = uniform_sampling
|
||||
if self.uniform_sampling:
|
||||
i = 1
|
||||
while True:
|
||||
if world_size % i != 0 or num_idx % (world_size // i) != 0:
|
||||
i += 1
|
||||
else:
|
||||
self.group_num = world_size // i
|
||||
break
|
||||
|
||||
assert self.group_num > 0
|
||||
assert world_size % self.group_num == 0
|
||||
self.group_width = world_size // self.group_num # the number of rank in one group
|
||||
self.sigma_interval = self.num_idx // self.group_num
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
return self.sigmas[idx]
|
||||
|
||||
def __call__(self, n_samples, rand=None, return_idx=False):
|
||||
if self.uniform_sampling:
|
||||
rank = mpu.get_data_parallel_rank()
|
||||
group_index = rank // self.group_width
|
||||
idx = default(
|
||||
rand,
|
||||
torch.randint(
|
||||
group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
|
||||
),
|
||||
)
|
||||
else:
|
||||
idx = default(
|
||||
rand,
|
||||
torch.randint(0, self.num_idx, (n_samples,)),
|
||||
)
|
||||
if return_idx:
|
||||
return self.idx_to_sigma(idx), idx
|
||||
else:
|
||||
return self.idx_to_sigma(idx)
|
||||
|
||||
|
||||
class PartialDiscreteSampling:
|
||||
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
|
||||
self.total_num_idx = total_num_idx
|
||||
self.partial_num_idx = partial_num_idx
|
||||
self.sigmas = instantiate_from_config(discretization_config)(
|
||||
total_num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
)
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
return self.sigmas[idx]
|
||||
|
||||
def __call__(self, n_samples, rand=None):
|
||||
idx = default(
|
||||
rand,
|
||||
# torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)),
|
||||
torch.randint(0, self.partial_num_idx, (n_samples,)),
|
||||
)
|
||||
return self.idx_to_sigma(idx)
|
328
sat/sgm/modules/diffusionmodules/util.py
Normal file
328
sat/sgm/modules/diffusionmodules/util.py
Normal file
@ -0,0 +1,328 @@
|
||||
"""
|
||||
adopted from
|
||||
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
and
|
||||
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
and
|
||||
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
|
||||
thanks!
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
schedule,
|
||||
n_timestep,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
):
|
||||
if schedule == "linear":
|
||||
betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def mixed_checkpoint(func, inputs: dict, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
|
||||
borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
|
||||
it also works with non-tensor inputs
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument dictionary to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
||||
tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
||||
non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
|
||||
non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)]
|
||||
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
|
||||
return MixedCheckpointFunction.apply(
|
||||
func,
|
||||
len(tensor_inputs),
|
||||
len(non_tensor_inputs),
|
||||
tensor_keys,
|
||||
non_tensor_keys,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
return func(**inputs)
|
||||
|
||||
|
||||
class MixedCheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
run_function,
|
||||
length_tensors,
|
||||
length_non_tensors,
|
||||
tensor_keys,
|
||||
non_tensor_keys,
|
||||
*args,
|
||||
):
|
||||
ctx.end_tensors = length_tensors
|
||||
ctx.end_non_tensors = length_tensors + length_non_tensors
|
||||
ctx.gpu_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
|
||||
|
||||
ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))}
|
||||
ctx.input_non_tensors = {
|
||||
key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]))
|
||||
}
|
||||
ctx.run_function = run_function
|
||||
ctx.input_params = list(args[ctx.end_non_tensors :])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
|
||||
ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors}
|
||||
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors}
|
||||
# shallow_copies.update(additional_args)
|
||||
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
list(ctx.input_tensors.values()) + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (
|
||||
(None, None, None, None, None)
|
||||
+ input_grads[: ctx.end_tensors]
|
||||
+ (None,) * (ctx.end_non_tensors - ctx.end_tensors)
|
||||
+ input_grads[ctx.end_tensors :]
|
||||
)
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
ctx.gpu_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, "b -> b d", d=dim)
|
||||
return embedding.to(dtype)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.rearrange_pattern = rearrange_pattern
|
||||
|
||||
assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
|
||||
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||
if self.merge_strategy == "fixed":
|
||||
alpha = self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial: torch.Tensor,
|
||||
x_temporal: torch.Tensor,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator)
|
||||
x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
return x
|
41
sat/sgm/modules/diffusionmodules/wrappers.py
Normal file
41
sat/sgm/modules/diffusionmodules/wrappers.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
||||
|
||||
|
||||
class IdentityWrapper(nn.Module):
|
||||
def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
|
||||
super().__init__()
|
||||
compile = (
|
||||
torch.compile
|
||||
if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model
|
||||
else lambda x: x
|
||||
)
|
||||
self.diffusion_model = compile(diffusion_model)
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.diffusion_model(*args, **kwargs)
|
||||
|
||||
|
||||
class OpenAIWrapper(IdentityWrapper):
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor:
|
||||
for key in c:
|
||||
c[key] = c[key].to(self.dtype)
|
||||
|
||||
if x.dim() == 4:
|
||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
||||
elif x.dim() == 5:
|
||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2)
|
||||
else:
|
||||
raise ValueError("Input tensor must be 4D or 5D")
|
||||
|
||||
return self.diffusion_model(
|
||||
x,
|
||||
timesteps=t,
|
||||
context=c.get("crossattn", None),
|
||||
y=c.get("vector", None),
|
||||
**kwargs,
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user