mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-11 09:39:18 +08:00
update
This commit is contained in:
parent
300fc75c49
commit
e7835cd79c
@ -323,10 +323,8 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
if isinstance(c[k], torch.Tensor):
|
||||
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
||||
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
if self.noised_image_input:
|
||||
print("Adding noise to first frame")
|
||||
image = x[:, :, 0:1]
|
||||
image = self.add_noise_to_first_frame(image)
|
||||
image = self.encode_first_stage(image, batch)
|
||||
@ -344,6 +342,8 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
log["samples"] = samples
|
||||
else:
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if only_log_video_latents:
|
||||
latents = 1.0 / self.scale_factor * samples
|
||||
log["latents"] = latents
|
||||
|
@ -1,8 +1,8 @@
|
||||
#! /bin/bash
|
||||
|
||||
echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
echo "RUN on $(hostname), CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
||||
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_5b_i2v_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
Loading…
x
Reference in New Issue
Block a user