update testing

This commit is contained in:
zR 2024-08-26 17:56:00 +08:00
parent b538be32be
commit adf65ff0eb
4 changed files with 7 additions and 5 deletions

4
.gitignore vendored
View File

@ -1,4 +1,3 @@
output/
*__pycache__/ *__pycache__/
samples*/ samples*/
runs/ runs/
@ -6,4 +5,5 @@ checkpoints/
master_ip master_ip
logs/ logs/
*.DS_Store *.DS_Store
.idea .idea
output*

View File

@ -58,7 +58,7 @@ def generate_video(
text_encoder=text_encoder, text_encoder=text_encoder,
transformer=transformer, transformer=transformer,
vae=vae, vae=vae,
torch_dtype=torch.bfloat16, torch_dtype=dtype,
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling() pipe.vae.enable_tiling()

View File

@ -147,7 +147,9 @@ def pad_last_frame(tensor, num_frames):
# T, H, W, C # T, H, W, C
if len(tensor) < num_frames: if len(tensor) < num_frames:
pad_length = num_frames - len(tensor) pad_length = num_frames - len(tensor)
pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device) # Use the last frame to pad instead of zero
last_frame = tensor[-1]
pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:])
padded_tensor = torch.cat([tensor, pad_tensor], dim=0) padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
return padded_tensor return padded_tensor
else: else:

View File

@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_video.py --base configs/cogvideox_2b.yaml configs/inference.yaml --seed $RANDOM" run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
echo ${run_cmd} echo ${run_cmd}
eval ${run_cmd} eval ${run_cmd}