mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
update testing
This commit is contained in:
parent
b538be32be
commit
adf65ff0eb
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,4 +1,3 @@
|
||||
output/
|
||||
*__pycache__/
|
||||
samples*/
|
||||
runs/
|
||||
@ -6,4 +5,5 @@ checkpoints/
|
||||
master_ip
|
||||
logs/
|
||||
*.DS_Store
|
||||
.idea
|
||||
.idea
|
||||
output*
|
@ -58,7 +58,7 @@ def generate_video(
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.vae.enable_tiling()
|
||||
|
@ -147,7 +147,9 @@ def pad_last_frame(tensor, num_frames):
|
||||
# T, H, W, C
|
||||
if len(tensor) < num_frames:
|
||||
pad_length = num_frames - len(tensor)
|
||||
pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device)
|
||||
# Use the last frame to pad instead of zero
|
||||
last_frame = tensor[-1]
|
||||
pad_tensor = last_frame.unsqueeze(0).expand(pad_length, *tensor.shape[1:])
|
||||
padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
|
||||
return padded_tensor
|
||||
else:
|
||||
|
@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
||||
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_2b.yaml configs/inference.yaml --seed $RANDOM"
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
Loading…
x
Reference in New Issue
Block a user