mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat: add warning for fp16 mixed precision training
This commit is contained in:
parent
36427274d6
commit
96e511b413
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Tuple
|
||||
|
||||
@ -157,6 +158,15 @@ class Args(BaseModel):
|
||||
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
||||
raise e
|
||||
|
||||
@field_validator("mixed_precision")
|
||||
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str:
|
||||
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower():
|
||||
logging.warning(
|
||||
"All CogVideoX models except cogvideox-2b were trained with bfloat16. "
|
||||
"Using fp16 precision may lead to training instability."
|
||||
)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def parse_args(cls):
|
||||
"""Parse command line arguments and return Args instance"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user