From adf65ff0eb662586e35b097be71e7f8c52dcedb6 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 26 Aug 2024 17:56:00 +0800 Subject: [PATCH] update testing --- .gitignore | 4 ++-- inference/cli_demo_quantization.py | 2 +- sat/data_video.py | 4 +++- sat/inference.sh | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 6ff20ae..695a4e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -output/ *__pycache__/ samples*/ runs/ @@ -6,4 +5,5 @@ checkpoints/ master_ip logs/ *.DS_Store -.idea \ No newline at end of file +.idea +output* \ No newline at end of file diff --git a/inference/cli_demo_quantization.py b/inference/cli_demo_quantization.py index ed00158..7fb6b1a 100644 --- a/inference/cli_demo_quantization.py +++ b/inference/cli_demo_quantization.py @@ -58,7 +58,7 @@ def generate_video( text_encoder=text_encoder, transformer=transformer, vae=vae, - torch_dtype=torch.bfloat16, + torch_dtype=dtype, ) pipe.enable_model_cpu_offload() pipe.vae.enable_tiling() diff --git a/sat/data_video.py b/sat/data_video.py index d16667f..04a11ca 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -147,7 +147,9 @@ def pad_last_frame(tensor, num_frames): # T, H, W, C if len(tensor) < num_frames: pad_length = num_frames - len(tensor) - pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device) + # Use the last frame to pad instead of zero + last_frame = tensor[-1] + pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:]) padded_tensor = torch.cat([tensor, pad_tensor], dim=0) return padded_tensor else: diff --git a/sat/inference.sh b/sat/inference.sh index 11c50a6..c798fa5 100755 --- a/sat/inference.sh +++ b/sat/inference.sh @@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" -run_cmd="$environs python sample_video.py --base configs/cogvideox_2b.yaml configs/inference.yaml --seed $RANDOM" +run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM" echo ${run_cmd} eval ${run_cmd}