mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
stable version
This commit is contained in:
parent
58d66c8a08
commit
250a0bce45
@ -160,6 +160,7 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
||||
query, query_reference = query.chunk(2)
|
||||
key, key_reference = key.chunk(2)
|
||||
value, value_reference = value.chunk(2)
|
||||
batch_size = batch_size // 2
|
||||
|
||||
hidden_states, encoder_hidden_states = self.calculate_attention(
|
||||
query=query,
|
||||
@ -295,6 +296,8 @@ def sample(
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
if reference_latents is not None:
|
||||
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
|
||||
@ -305,13 +308,14 @@ def sample(
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
||||
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
|
||||
extra_step_kwargs = {}
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
spatial_scaling_factor = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size
|
||||
image_rotary_emb = (
|
||||
pipeline._prepare_rotary_positional_embeddings(
|
||||
height=latents.size(3) * spatial_scaling_factor,
|
||||
width=latents.size(4) * spatial_scaling_factor,
|
||||
height=latents.size(3) * pipeline.vae_scale_factor_spatial,
|
||||
width=latents.size(4) * pipeline.vae_scale_factor_spatial,
|
||||
num_frames=latents.size(1),
|
||||
device=device,
|
||||
)
|
||||
@ -332,7 +336,7 @@ def sample(
|
||||
if reference_latents is not None:
|
||||
reference = reference_latents[i]
|
||||
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
||||
latent_model_input = torch.cat([latent_model_input, reference])
|
||||
latent_model_input = torch.cat([latent_model_input, reference], dim=0)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
@ -349,6 +353,9 @@ def sample(
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if reference_latents is not None: # Recover the original batch size
|
||||
noise_pred, _ = noise_pred.chunk(2)
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
pipeline._guidance_scale = 1 + guidance_scale * (
|
||||
@ -410,20 +417,20 @@ def ddim_inversion(
|
||||
prompt="",
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device=device).manual_seed(seed=seed),
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
with OverrideAttnProcessors(transformer=pipeline.transformer):
|
||||
recon_latents = sample(
|
||||
pipeline=pipeline,
|
||||
latents=torch.randn_like(video_latents),
|
||||
scheduler=inverse_scheduler,
|
||||
scheduler=pipeline.scheduler,
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device=device).manual_seed(seed=seed),
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
reference_latents=reversed(inverse_latents),
|
||||
)
|
||||
filename = os.path.splitext(os.path.basename(video_path))[0]
|
||||
filename, _ = os.path.splitext(os.path.basename(video_path))
|
||||
inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4")
|
||||
recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4")
|
||||
export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)
|
||||
|
Loading…
x
Reference in New Issue
Block a user