From 24b2053596d71fe4350e73625997b346e94066f6 Mon Sep 17 00:00:00 2001 From: glide-the Date: Sat, 12 Oct 2024 12:28:52 +0800 Subject: [PATCH] Add new command line arguments for LoRA weights and prompt --- tools/load_cogvideox_lora.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tools/load_cogvideox_lora.py b/tools/load_cogvideox_lora.py index 449320a..1b12975 100644 --- a/tools/load_cogvideox_lora.py +++ b/tools/load_cogvideox_lora.py @@ -57,6 +57,23 @@ def get_args(): The formula for lora_scale is: lora_r / alpha. """, ) + parser.add_argument( + "--lora_alpha", + type=int, + default=1, + help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256. + This part is used to calculate the value for lora_scale, which is by default divided by the alpha value, + used for stable learning and to prevent underflow. In the SAT training framework, + alpha is set to 1 by default. The higher the rank, the better the expressive capability, + but it requires more memory and training time. Increasing this number blindly isn't always better. + The formula for lora_scale is: lora_r / alpha. + """, + ) + parser.add_argument( + "--prompt", + type=str, + help="prompt", + ) parser.add_argument( "--output_dir", type=str, @@ -69,17 +86,18 @@ def get_args(): if __name__ == "__main__": args = get_args() pipe = CogVideoXPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device) - pipe.load_lora_weights(args.lora_weights_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") - pipe.fuse_lora(lora_scale=1/128) + pipe.load_lora_weights(args.lora_weights_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora") + # pipe.fuse_lora(lora_scale=args.lora_alpha/args.lora_r, ['transformer']) + lora_scaling=args.lora_alpha/args.lora_r + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") os.makedirs(args.output_dir, exist_ok=True) - prompt="""In the heart of a bustling city, a young woman with long, flowing brown hair and a radiant smile stands out. She's donned in a cozy white beanie adorned with playful animal ears, adding a touch of whimsy to her appearance. Her eyes sparkle with joy as she looks directly into the camera, her expression inviting and warm. The background is a blur of activity, with indistinct figures moving about, suggesting a lively public space. The lighting is soft and diffused, casting a gentle glow on her face and highlighting her features. The overall mood is cheerful and vibrant, capturing a moment of happiness in the midst of urban life. - """ + latents = pipe( - prompt=prompt, + prompt=args.prompt, num_videos_per_prompt=1, num_inference_steps=50, num_frames=49,