diff --git a/.gitignore b/.gitignore index 6ff20ae..695a4e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -output/ *__pycache__/ samples*/ runs/ @@ -6,4 +5,5 @@ checkpoints/ master_ip logs/ *.DS_Store -.idea \ No newline at end of file +.idea +output* \ No newline at end of file diff --git a/inference/cli_demo_quantization.py b/inference/cli_demo_quantization.py index ed00158..7fb6b1a 100644 --- a/inference/cli_demo_quantization.py +++ b/inference/cli_demo_quantization.py @@ -58,7 +58,7 @@ def generate_video( text_encoder=text_encoder, transformer=transformer, vae=vae, - torch_dtype=torch.bfloat16, + torch_dtype=dtype, ) pipe.enable_model_cpu_offload() pipe.vae.enable_tiling() diff --git a/sat/data_video.py b/sat/data_video.py index d16667f..04a11ca 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -147,7 +147,9 @@ def pad_last_frame(tensor, num_frames): # T, H, W, C if len(tensor) < num_frames: pad_length = num_frames - len(tensor) - pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device) + # Use the last frame to pad instead of zero + last_frame = tensor[-1] + pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:]) padded_tensor = torch.cat([tensor, pad_tensor], dim=0) return padded_tensor else: diff --git a/sat/inference.sh b/sat/inference.sh index 11c50a6..c798fa5 100755 --- a/sat/inference.sh +++ b/sat/inference.sh @@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" -run_cmd="$environs python sample_video.py --base configs/cogvideox_2b.yaml configs/inference.yaml --seed $RANDOM" +run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM" echo ${run_cmd} eval ${run_cmd}