diff --git a/sat/README.md b/sat/README.md index 2293dd9..e70b488 100644 --- a/sat/README.md +++ b/sat/README.md @@ -11,13 +11,13 @@ This code is the framework used by the team to train the model. It has few comme ## Inference Model -1. Ensure that you have correctly installed the dependencies required by this folder. +### 1. Ensure that you have correctly installed the dependencies required by this folder. ```shell pip install -r requirements.txt ``` -2. Download the model weights +### 2. Download the model weights First, go to the SAT mirror to download the dependencies. @@ -44,9 +44,12 @@ Then unzip, the model structure should look like this: └── 3d-vae.pt ``` -Next, clone the T5 model, which is not used for training and fine-tuning, but must be used. +Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) -``` +Next, clone the T5 model, which is not used for training and fine-tuning, but must be used. +> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well. + +```shell git clone https://huggingface.co/THUDM/CogVideoX-2b.git mkdir t5-v1_1-xxl mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl @@ -68,6 +71,229 @@ loading it into Deepspeed in Finetune. 0 directories, 8 files ``` +Here is the English translation of the provided text: + +### 3. Modify the file in `configs/cogvideox_2b.yaml`. + +```yaml +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + + transformer_args: + checkpoint_activations: True ## using gradient checkpointing + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "{absolute_path/to/your/t5-v1_1-xxl}/t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "{absolute_path/to/your/t5-v1_1-xxl}/CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 +``` + +### 4. Modify the file in `configs/inference.yaml`. + +```yaml +args: + latent_channels: 16 + mode: inference + load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + + batch_size: 1 + input_type: txt # You can choose txt for pure text input, or change to cli for command line input + input_file: configs/test.txt # Pure text file, which can be edited + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + fp16: True # For CogVideoX-2B +# bf16: True # For CogVideoX-5B + output_dir: outputs/ + force_inference: True +``` + ++ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt. ++ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the OPENAI_API_KEY as your environmental variable. ++ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput. + +```yaml +input_type: cli +``` + +This allows input from the command line as prompts. + +Change `output_dir` if you wish to modify the address of the output video + +```yaml +output_dir: outputs/ +``` + +It is saved by default in the `.outputs/` folder. + +### 5. Run the inference code to perform inference. + +```shell +bash inference.sh +``` + +## Fine-tuning the Model + +### Preparing the Dataset + +The dataset format should be as follows: + +``` +. +├── labels +│   ├── 1.txt +│   ├── 2.txt +│   ├── ... +└── videos + ├── 1.mp4 + ├── 2.mp4 + ├── ... +``` + Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels should be matched one-to-one. Generally, a single video should not be associated with multiple labels. diff --git a/sat/README_ja.md b/sat/README_ja.md index 5c0e852..3867cfa 100644 --- a/sat/README_ja.md +++ b/sat/README_ja.md @@ -11,13 +11,13 @@ ## 推論モデル -1. このフォルダに必要な依存関係が正しくインストールされていることを確認してください。 +### 1. このフォルダに必要な依存関係が正しくインストールされていることを確認してください。 ```shell pip install -r requirements.txt ``` -2. モデルウェイトをダウンロードします +### 2. モデルウェイトをダウンロードします まず、SAT ミラーにアクセスして依存関係をダウンロードします。 @@ -44,10 +44,18 @@ unzip transformer.zip └── 3d-vae.pt ``` -次に、T5 モデルをクローンします。これはトレーニングやファインチューニングには使用されませんが、使用する必要があります。 +モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs`のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)をご参照ください。 +```shell +git lfs install ``` -git clone https://huggingface.co/THUDM/CogVideoX-2b.git + +次に、T5 モデルをクローンします。これはトレーニングやファインチューニングには使用されませんが、使用する必要があります。 +> モデルを複製する際には、[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)のモデルファイルの場所もご使用いただけます。 + +```shell +git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ハギングフェイス(huggingface.org)からモデルをダウンロードいただきます +# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeからモデルをダウンロードいただきます mkdir t5-v1_1-xxl mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl ``` @@ -67,28 +75,182 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl 0 directories, 8 files ``` -3. `configs/cogvideox_2b_infer.yaml` ファイルを変更します。 +### 3. `configs/cogvideox_2b.yaml` ファイルを変更します。 ```yaml -load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer モデルパス +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + log_keys: + - txt -conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: false - input_key: txt - ucg_rate: 0.1 - target: sgm.modules.encoders.modules.FrozenT5Embedder + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization params: - model_dir: "google/t5-v1_1-xxl" ## T5 モデルパス - max_length: 226 + shift_scale: 3.0 -first_stage_config: - target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper - params: - cp_size: 1 - ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE モデルパス + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + + transformer_args: + checkpoint_activations: True ## グラデーション チェックポイントを使用する + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "{absolute_path/to/your/t5-v1_1-xxl}/t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "{absolute_path/to/your/t5-v1_1-xxl}/CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 +``` +### 4. `configs/inference.yaml` ファイルを変更します。 + +```yaml +args: + latent_channels: 16 + mode: inference + load: "{absolute_path/to/your}/transformer" # CogVideoX-2b-sat/transformerフォルダの絶対パス + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + + batch_size: 1 + input_type: txt #TXTのテキストファイルを入力として選択されたり、CLIコマンドラインを入力として変更されたりいただけます + input_file: configs/test.txt #テキストファイルのパスで、これに対して編集がさせていただけます + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + fp16: True # For CogVideoX-2B +# bf16: True # For CogVideoX-5B + output_dir: outputs/ + force_inference: True ``` + 複数のプロンプトを保存するために txt を使用する場合は、`configs/test.txt` @@ -110,7 +272,7 @@ output_dir: outputs/ デフォルトでは `.outputs/` フォルダに保存されます。 -4. 推論コードを実行して推論を開始します。 +### 5. 推論コードを実行して推論を開始します。 ```shell bash inference.sh diff --git a/sat/README_zh.md b/sat/README_zh.md index 807b133..6fc1c16 100644 --- a/sat/README_zh.md +++ b/sat/README_zh.md @@ -10,13 +10,13 @@ ## 推理模型 -1. 确保你已经正确安装本文件夹中的要求的依赖 +### 1. 确保你已经正确安装本文件夹中的要求的依赖 ```shell pip install -r requirements.txt ``` -2. 下载模型权重 +### 2. 下载模型权重 首先,前往 SAT 镜像下载依赖。 @@ -43,10 +43,17 @@ unzip transformer.zip └── 3d-vae.pt ``` -接着,克隆 T5 模型,该模型不用做训练和微调,但是必须使用。 - +由于模型的权重档案较大,建议使用`git lfs`。`git lfs`安装参见[这里](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) +```shell +git lfs install ``` -git clone https://huggingface.co/THUDM/CogVideoX-2b.git + +接着,克隆 T5 模型,该模型不用做训练和微调,但是必须使用。 +> 克隆模型的时候也可以使用[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)上的模型文件位置。 + +```shell +git clone https://huggingface.co/THUDM/CogVideoX-2b.git #从huggingface下载模型 +# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #从modelscope下载模型 mkdir t5-v1_1-xxl mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl ``` @@ -66,29 +73,183 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl 0 directories, 8 files ``` -3. 修改`configs/cogvideox_2b_infer.yaml`中的文件。 +### 3. 修改`configs/cogvideox_2b.yaml`中的文件。 ```yaml -load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer 模型路径 +model: + scale_factor: 1.15258426 + disable_first_stage_autocast: true + log_keys: + - txt -conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: false - input_key: txt - ucg_rate: 0.1 - target: sgm.modules.encoders.modules.FrozenT5Embedder + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization params: - model_dir: "google/t5-v1_1-xxl" ## T5 模型路径 - max_length: 226 + shift_scale: 3.0 -first_stage_config: - target: sgm.models.autoencoder.VideoAutoencoderInferenceWrapper - params: - cp_size: 1 - ckpt_path: "{your_CogVideoX-2b-sat_path}/vae/3d-vae.pt" ## VAE 模型路径 + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 30 + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 1920 + adm_in_channels: 256 + num_attention_heads: 30 + transformer_args: + checkpoint_activations: True ## using gradient checkpointing + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Basic3DPositionEmbeddingMixin + params: + text_length: 226 + height_interpolation: 1.875 + width_interpolation: 1.875 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "{absolute_path/to/your/t5-v1_1-xxl}/t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl权重文件夹的绝对路径 + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "{absolute_path/to/your/t5-v1_1-xxl}/CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.pt文件夹的绝对路径 + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 3.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 +``` + +### 4. 修改`configs/inference.yaml`中的文件。 + +```yaml +args: + latent_channels: 16 + mode: inference + load: "{absolute_path/to/your}/transformer" # CogVideoX-2b-sat/transformer文件夹的绝对路径 + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + + batch_size: 1 + input_type: txt #可以选择txt纯文字档作为输入,或者改成cli命令行作为输入 + input_file: configs/test.txt #纯文字档,可以对此做编辑 + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + fp16: True # For CogVideoX-2B +# bf16: True # For CogVideoX-5B + output_dir: outputs/ + force_inference: True ``` + 如果使用 txt 保存多个提示词,请参考`configs/test.txt` @@ -109,7 +270,7 @@ output_dir: outputs/ 默认保存在`.outputs/`文件夹下。 -4. 运行推理代码,即可推理 +### 5. 运行推理代码, 即可推理 ```shell bash inference.sh diff --git a/sat/data_video.py b/sat/data_video.py index 3783340..d16667f 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -145,9 +145,10 @@ def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): def pad_last_frame(tensor, num_frames): # T, H, W, C - if tensor.shape[0] < num_frames: - last_frame = tensor[-int(num_frames - tensor.shape[1]) :] - padded_tensor = torch.cat([tensor, last_frame], dim=0) + if len(tensor) < num_frames: + pad_length = num_frames - len(tensor) + pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device) + padded_tensor = torch.cat([tensor, pad_tensor], dim=0) return padded_tensor else: return tensor[:num_frames] @@ -378,8 +379,9 @@ class SFTDataset(Dataset): num_frames = max_num_frames start = int(skip_frms_num) end = int(start + num_frames / fps * actual_fps) - indices = np.arange(start, end, (end - start) / num_frames).astype(int) - temp_frms = vr.get_batch(np.arange(start, end)) + end_safty = min(int(start + num_frames / fps * actual_fps), int(ori_vlen)) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end_safty)) assert temp_frms is not None tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] @@ -388,7 +390,7 @@ class SFTDataset(Dataset): num_frames = max_num_frames start = int(skip_frms_num) end = int(ori_vlen - skip_frms_num) - indices = np.arange(start, end, (end - start) / num_frames).astype(int) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None tensor_frms = ( @@ -417,7 +419,7 @@ class SFTDataset(Dataset): ) tensor_frms = pad_last_frame( - tensor_frms, num_frames + tensor_frms, max_num_frames ) # the len of indices may be less than num_frames, due to round error tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")