diff --git a/inference/ddim_inversion.py b/inference/ddim_inversion.py index fc19b29..309c374 100644 --- a/inference/ddim_inversion.py +++ b/inference/ddim_inversion.py @@ -160,6 +160,7 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0): query, query_reference = query.chunk(2) key, key_reference = key.chunk(2) value, value_reference = value.chunk(2) + batch_size = batch_size // 2 hidden_states, encoder_hidden_states = self.calculate_attention( query=query, @@ -295,6 +296,8 @@ def sample( ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + if reference_latents is not None: + prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device) @@ -305,13 +308,14 @@ def sample( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) + if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs + extra_step_kwargs = {} # 7. Create rotary embeds if required - spatial_scaling_factor = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size image_rotary_emb = ( pipeline._prepare_rotary_positional_embeddings( - height=latents.size(3) * spatial_scaling_factor, - width=latents.size(4) * spatial_scaling_factor, + height=latents.size(3) * pipeline.vae_scale_factor_spatial, + width=latents.size(4) * pipeline.vae_scale_factor_spatial, num_frames=latents.size(1), device=device, ) @@ -332,7 +336,7 @@ def sample( if reference_latents is not None: reference = reference_latents[i] reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference - latent_model_input = torch.cat([latent_model_input, reference]) + latent_model_input = torch.cat([latent_model_input, reference], dim=0) latent_model_input = scheduler.scale_model_input(latent_model_input, t) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -349,6 +353,9 @@ def sample( )[0] noise_pred = noise_pred.float() + if reference_latents is not None: # Recover the original batch size + noise_pred, _ = noise_pred.chunk(2) + # perform guidance if use_dynamic_cfg: pipeline._guidance_scale = 1 + guidance_scale * ( @@ -410,20 +417,20 @@ def ddim_inversion( prompt="", num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, - generator=torch.Generator(device=device).manual_seed(seed=seed), + generator=torch.Generator(device=device).manual_seed(seed), ) with OverrideAttnProcessors(transformer=pipeline.transformer): recon_latents = sample( pipeline=pipeline, latents=torch.randn_like(video_latents), - scheduler=inverse_scheduler, + scheduler=pipeline.scheduler, prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, - generator=torch.Generator(device=device).manual_seed(seed=seed), + generator=torch.Generator(device=device).manual_seed(seed), reference_latents=reversed(inverse_latents), ) - filename = os.path.splitext(os.path.basename(video_path))[0] + filename, _ = os.path.splitext(os.path.basename(video_path)) inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4") recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4") export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)