stable version

This commit is contained in:
LittleNyima 2025-02-20 05:03:15 +00:00
parent 58d66c8a08
commit 250a0bce45

View File

@ -160,6 +160,7 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
query, query_reference = query.chunk(2)
key, key_reference = key.chunk(2)
value, value_reference = value.chunk(2)
batch_size = batch_size // 2
hidden_states, encoder_hidden_states = self.calculate_attention(
query=query,
@ -295,6 +296,8 @@ def sample(
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
if reference_latents is not None:
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
@ -305,13 +308,14 @@ def sample(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
extra_step_kwargs = {}
# 7. Create rotary embeds if required
spatial_scaling_factor = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size
image_rotary_emb = (
pipeline._prepare_rotary_positional_embeddings(
height=latents.size(3) * spatial_scaling_factor,
width=latents.size(4) * spatial_scaling_factor,
height=latents.size(3) * pipeline.vae_scale_factor_spatial,
width=latents.size(4) * pipeline.vae_scale_factor_spatial,
num_frames=latents.size(1),
device=device,
)
@ -332,7 +336,7 @@ def sample(
if reference_latents is not None:
reference = reference_latents[i]
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
latent_model_input = torch.cat([latent_model_input, reference])
latent_model_input = torch.cat([latent_model_input, reference], dim=0)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@ -349,6 +353,9 @@ def sample(
)[0]
noise_pred = noise_pred.float()
if reference_latents is not None: # Recover the original batch size
noise_pred, _ = noise_pred.chunk(2)
# perform guidance
if use_dynamic_cfg:
pipeline._guidance_scale = 1 + guidance_scale * (
@ -410,20 +417,20 @@ def ddim_inversion(
prompt="",
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed=seed),
generator=torch.Generator(device=device).manual_seed(seed),
)
with OverrideAttnProcessors(transformer=pipeline.transformer):
recon_latents = sample(
pipeline=pipeline,
latents=torch.randn_like(video_latents),
scheduler=inverse_scheduler,
scheduler=pipeline.scheduler,
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed=seed),
generator=torch.Generator(device=device).manual_seed(seed),
reference_latents=reversed(inverse_latents),
)
filename = os.path.splitext(os.path.basename(video_path))[0]
filename, _ = os.path.splitext(os.path.basename(video_path))
inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4")
recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4")
export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)