update testing

This commit is contained in:
zR 2024-08-26 17:56:00 +08:00
parent b538be32be
commit adf65ff0eb
4 changed files with 7 additions and 5 deletions

4
.gitignore vendored
View File

@ -1,4 +1,3 @@
output/
*__pycache__/
samples*/
runs/
@ -6,4 +5,5 @@ checkpoints/
master_ip
logs/
*.DS_Store
.idea
.idea
output*

View File

@ -58,7 +58,7 @@ def generate_video(
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
torch_dtype=torch.bfloat16,
torch_dtype=dtype,
)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()

View File

@ -147,7 +147,9 @@ def pad_last_frame(tensor, num_frames):
# T, H, W, C
if len(tensor) < num_frames:
pad_length = num_frames - len(tensor)
pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device)
# Use the last frame to pad instead of zero
last_frame = tensor[-1]
pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:])
padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
return padded_tensor
else:

View File

@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_video.py --base configs/cogvideox_2b.yaml configs/inference.yaml --seed $RANDOM"
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
echo ${run_cmd}
eval ${run_cmd}