diff --git a/tools/parallel_inference/parallel_inference_xdit.py b/tools/parallel_inference/parallel_inference_xdit.py index 03bc268..d0b8029 100644 --- a/tools/parallel_inference/parallel_inference_xdit.py +++ b/tools/parallel_inference/parallel_inference_xdit.py @@ -78,8 +78,9 @@ def main(): num_frames=input_config.num_frames, prompt=input_config.prompt, num_inference_steps=input_config.num_inference_steps, - generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + generator=torch.Generator().manual_seed(input_config.seed), guidance_scale=6, + use_dynamic_cfg=True, ).frames[0] end_time = time.time()