stable version

This commit is contained in:
LittleNyima 2025-02-20 05:03:15 +00:00
parent 58d66c8a08
commit 250a0bce45

View File

@ -160,6 +160,7 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
query, query_reference = query.chunk(2) query, query_reference = query.chunk(2)
key, key_reference = key.chunk(2) key, key_reference = key.chunk(2)
value, value_reference = value.chunk(2) value, value_reference = value.chunk(2)
batch_size = batch_size // 2
hidden_states, encoder_hidden_states = self.calculate_attention( hidden_states, encoder_hidden_states = self.calculate_attention(
query=query, query=query,
@ -295,6 +296,8 @@ def sample(
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
if reference_latents is not None:
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device) timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
@ -305,13 +308,14 @@ def sample(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
extra_step_kwargs = {}
# 7. Create rotary embeds if required # 7. Create rotary embeds if required
spatial_scaling_factor = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size
image_rotary_emb = ( image_rotary_emb = (
pipeline._prepare_rotary_positional_embeddings( pipeline._prepare_rotary_positional_embeddings(
height=latents.size(3) * spatial_scaling_factor, height=latents.size(3) * pipeline.vae_scale_factor_spatial,
width=latents.size(4) * spatial_scaling_factor, width=latents.size(4) * pipeline.vae_scale_factor_spatial,
num_frames=latents.size(1), num_frames=latents.size(1),
device=device, device=device,
) )
@ -332,7 +336,7 @@ def sample(
if reference_latents is not None: if reference_latents is not None:
reference = reference_latents[i] reference = reference_latents[i]
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
latent_model_input = torch.cat([latent_model_input, reference]) latent_model_input = torch.cat([latent_model_input, reference], dim=0)
latent_model_input = scheduler.scale_model_input(latent_model_input, t) latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@ -349,6 +353,9 @@ def sample(
)[0] )[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
if reference_latents is not None: # Recover the original batch size
noise_pred, _ = noise_pred.chunk(2)
# perform guidance # perform guidance
if use_dynamic_cfg: if use_dynamic_cfg:
pipeline._guidance_scale = 1 + guidance_scale * ( pipeline._guidance_scale = 1 + guidance_scale * (
@ -410,20 +417,20 @@ def ddim_inversion(
prompt="", prompt="",
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed=seed), generator=torch.Generator(device=device).manual_seed(seed),
) )
with OverrideAttnProcessors(transformer=pipeline.transformer): with OverrideAttnProcessors(transformer=pipeline.transformer):
recon_latents = sample( recon_latents = sample(
pipeline=pipeline, pipeline=pipeline,
latents=torch.randn_like(video_latents), latents=torch.randn_like(video_latents),
scheduler=inverse_scheduler, scheduler=pipeline.scheduler,
prompt=prompt, prompt=prompt,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
generator=torch.Generator(device=device).manual_seed(seed=seed), generator=torch.Generator(device=device).manual_seed(seed),
reference_latents=reversed(inverse_latents), reference_latents=reversed(inverse_latents),
) )
filename = os.path.splitext(os.path.basename(video_path))[0] filename, _ = os.path.splitext(os.path.basename(video_path))
inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4") inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4")
recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4") recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4")
export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps) export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)