mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Merge remote-tracking branch 'upstream/main' into dev
This commit is contained in:
commit
cfaca91cde
@ -1246,11 +1246,11 @@ def main(args):
|
|||||||
|
|
||||||
use_deepspeed_optimizer = (
|
use_deepspeed_optimizer = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none"
|
and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() != "none"
|
||||||
)
|
)
|
||||||
use_deepspeed_scheduler = (
|
use_deepspeed_scheduler = (
|
||||||
accelerator.state.deepspeed_plugin is not None
|
accelerator.state.deepspeed_plugin is not None
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none"
|
and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() != "none"
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||||
@ -1283,7 +1283,7 @@ def main(args):
|
|||||||
|
|
||||||
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
|
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
|
||||||
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
|
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
|
||||||
noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
|
noisy_image = image + torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]
|
||||||
image_latent_dist = vae.encode(noisy_image).latent_dist
|
image_latent_dist = vae.encode(noisy_image).latent_dist
|
||||||
|
|
||||||
return latent_dist, image_latent_dist
|
return latent_dist, image_latent_dist
|
||||||
|
Loading…
x
Reference in New Issue
Block a user