From e7835cd79cfad3d0c8cc574e438701c8cac1f374 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 14 Sep 2024 18:29:38 +0800 Subject: [PATCH] update --- sat/diffusion_video.py | 6 +++--- sat/finetune_multi_gpus.sh | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 830f9b6..8329e9d 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -323,10 +323,8 @@ class SATVideoDiffusionEngine(nn.Module): if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) - samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w - samples = samples.permute(0, 2, 1, 3, 4).contiguous() + if self.noised_image_input: - print("Adding noise to first frame") image = x[:, :, 0:1] image = self.add_noise_to_first_frame(image) image = self.encode_first_stage(image, batch) @@ -344,6 +342,8 @@ class SATVideoDiffusionEngine(nn.Module): samples = samples.permute(0, 2, 1, 3, 4).contiguous() log["samples"] = samples else: + samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w + samples = samples.permute(0, 2, 1, 3, 4).contiguous() if only_log_video_latents: latents = 1.0 / self.scale_factor * samples log["latents"] = latents diff --git a/sat/finetune_multi_gpus.sh b/sat/finetune_multi_gpus.sh index ef56701..a9a8ad2 100644 --- a/sat/finetune_multi_gpus.sh +++ b/sat/finetune_multi_gpus.sh @@ -1,8 +1,8 @@ #! /bin/bash -echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" -run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" +run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_5b_i2v_lora.yaml configs/sft.yaml --seed $RANDOM" echo ${run_cmd} eval ${run_cmd}