mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
stable version
This commit is contained in:
parent
58d66c8a08
commit
250a0bce45
@ -160,6 +160,7 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
|||||||
query, query_reference = query.chunk(2)
|
query, query_reference = query.chunk(2)
|
||||||
key, key_reference = key.chunk(2)
|
key, key_reference = key.chunk(2)
|
||||||
value, value_reference = value.chunk(2)
|
value, value_reference = value.chunk(2)
|
||||||
|
batch_size = batch_size // 2
|
||||||
|
|
||||||
hidden_states, encoder_hidden_states = self.calculate_attention(
|
hidden_states, encoder_hidden_states = self.calculate_attention(
|
||||||
query=query,
|
query=query,
|
||||||
@ -295,6 +296,8 @@ def sample(
|
|||||||
)
|
)
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
if reference_latents is not None:
|
||||||
|
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
|
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
|
||||||
@ -305,13 +308,14 @@ def sample(
|
|||||||
|
|
||||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
|
||||||
# 7. Create rotary embeds if required
|
# 7. Create rotary embeds if required
|
||||||
spatial_scaling_factor = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size
|
|
||||||
image_rotary_emb = (
|
image_rotary_emb = (
|
||||||
pipeline._prepare_rotary_positional_embeddings(
|
pipeline._prepare_rotary_positional_embeddings(
|
||||||
height=latents.size(3) * spatial_scaling_factor,
|
height=latents.size(3) * pipeline.vae_scale_factor_spatial,
|
||||||
width=latents.size(4) * spatial_scaling_factor,
|
width=latents.size(4) * pipeline.vae_scale_factor_spatial,
|
||||||
num_frames=latents.size(1),
|
num_frames=latents.size(1),
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@ -332,7 +336,7 @@ def sample(
|
|||||||
if reference_latents is not None:
|
if reference_latents is not None:
|
||||||
reference = reference_latents[i]
|
reference = reference_latents[i]
|
||||||
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
||||||
latent_model_input = torch.cat([latent_model_input, reference])
|
latent_model_input = torch.cat([latent_model_input, reference], dim=0)
|
||||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
@ -349,6 +353,9 @@ def sample(
|
|||||||
)[0]
|
)[0]
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
if reference_latents is not None: # Recover the original batch size
|
||||||
|
noise_pred, _ = noise_pred.chunk(2)
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if use_dynamic_cfg:
|
if use_dynamic_cfg:
|
||||||
pipeline._guidance_scale = 1 + guidance_scale * (
|
pipeline._guidance_scale = 1 + guidance_scale * (
|
||||||
@ -410,20 +417,20 @@ def ddim_inversion(
|
|||||||
prompt="",
|
prompt="",
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator(device=device).manual_seed(seed=seed),
|
generator=torch.Generator(device=device).manual_seed(seed),
|
||||||
)
|
)
|
||||||
with OverrideAttnProcessors(transformer=pipeline.transformer):
|
with OverrideAttnProcessors(transformer=pipeline.transformer):
|
||||||
recon_latents = sample(
|
recon_latents = sample(
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
latents=torch.randn_like(video_latents),
|
latents=torch.randn_like(video_latents),
|
||||||
scheduler=inverse_scheduler,
|
scheduler=pipeline.scheduler,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
generator=torch.Generator(device=device).manual_seed(seed=seed),
|
generator=torch.Generator(device=device).manual_seed(seed),
|
||||||
reference_latents=reversed(inverse_latents),
|
reference_latents=reversed(inverse_latents),
|
||||||
)
|
)
|
||||||
filename = os.path.splitext(os.path.basename(video_path))[0]
|
filename, _ = os.path.splitext(os.path.basename(video_path))
|
||||||
inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4")
|
inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4")
|
||||||
recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4")
|
recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4")
|
||||||
export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)
|
export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user