mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 18:43:38 +08:00
Merge ebc9d39c02a50dcbc03b21f132e0c12bb39c1909 into 7a1af7154511e0ce4e4be8d62faa8c5e5a3532d2
This commit is contained in:
commit
21f51ec232
129
finetune/models/cogvideox_t2v/kd_trainer.py
Normal file
129
finetune/models/cogvideox_t2v/kd_trainer.py
Normal file
@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torchvision
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
from typing_extensions import override
|
||||
from torchvision.models.detection import FasterRCNN
|
||||
from torchvision.models.detection.rpn import AnchorGenerator
|
||||
from torchvision.models import vgg16
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
|
||||
class VGG16BackboneWrapper(nn.Module):
|
||||
def __init__(self, vgg_features):
|
||||
super(VGG16BackboneWrapper, self).__init__()
|
||||
self.features = vgg_features
|
||||
self.out_channels = 512
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
return OrderedDict([("0", x)])
|
||||
|
||||
class CogVideoXT2VKdTrainer(CogVideoXT2VLoraTrainer):
|
||||
# Remove vae from the unload list to make it available in compute_loss
|
||||
UNLOAD_LIST = ["text_encoder"]
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.teacher_model = self.load_teacher_model()
|
||||
|
||||
def load_teacher_model(self):
|
||||
teacher_model_path = self.args.teacher_model_path if hasattr(self.args, 'teacher_model_path') else None
|
||||
if not teacher_model_path:
|
||||
print("Warning: teacher_model_path is not provided. Knowledge distillation will be skipped.")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create a VGG16-based Faster R-CNN model
|
||||
# 1. VGG16 backbone
|
||||
vgg_features = vgg16(weights=None).features
|
||||
# The original VGG16 model in torchvision has a maxpool layer at the end of features.
|
||||
# Faster R-CNN with VGG backbone in many implementations does not use this last maxpool.
|
||||
# Let's remove it to be closer to the original Caffe implementation.
|
||||
backbone_features = vgg_features[:-1]
|
||||
backbone = VGG16BackboneWrapper(backbone_features)
|
||||
|
||||
# 2. RPN
|
||||
anchor_generator = AnchorGenerator(sizes=((128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),))
|
||||
|
||||
# 3. RoI heads
|
||||
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2)
|
||||
|
||||
# The user's model has fc6 and fc7 layers, which corresponds to a TwoMLPHead.
|
||||
# VGG16's output from the backbone is 512 * 7 * 7 = 25088
|
||||
box_head = TwoMLPHead(in_channels=25088, representation_size=4096)
|
||||
|
||||
num_classes = self.args.teacher_model_num_classes if hasattr(self.args, 'teacher_model_num_classes') else 91 # COCO default
|
||||
box_predictor = FastRCNNPredictor(in_channels=4096, num_classes=num_classes)
|
||||
|
||||
|
||||
# 4. Faster R-CNN model
|
||||
model = FasterRCNN(backbone,
|
||||
rpn_anchor_generator=anchor_generator,
|
||||
box_roi_pool=roi_pooler,
|
||||
box_head=box_head,
|
||||
box_predictor=box_predictor,
|
||||
num_classes=num_classes)
|
||||
|
||||
# Load the pre-trained weights from the converted file
|
||||
print(f"Loading teacher model from: {teacher_model_path}")
|
||||
state_dict = torch.load(teacher_model_path)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.to(self.accelerator.device)
|
||||
print("Teacher model loaded successfully.")
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"Error loading teacher model: {e}")
|
||||
return None
|
||||
|
||||
@override
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
# Get the original diffusion loss
|
||||
diffusion_loss = super().compute_loss(batch)
|
||||
|
||||
if self.teacher_model is None:
|
||||
return diffusion_loss
|
||||
|
||||
latents = batch["encoded_videos"]
|
||||
|
||||
# Decode the latents to get video frames
|
||||
video_frames = self.components.vae.decode(latents / self.components.vae.config.scaling_factor).sample
|
||||
|
||||
# The output of the VAE is in the range [-1, 1]. We need to normalize it to [0, 1] for the teacher model.
|
||||
video_frames = (video_frames + 1) / 2
|
||||
|
||||
video_frames = video_frames.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# Calculate the knowledge distillation loss
|
||||
kd_loss = 0
|
||||
num_frames_processed = 0
|
||||
for i in range(video_frames.shape[0]): # For each video in the batch
|
||||
frames = [frame for frame in video_frames[i]] # list of frames for the i-th video
|
||||
if not frames:
|
||||
continue
|
||||
|
||||
num_frames_processed += len(frames)
|
||||
teacher_output = self.teacher_model(frames)
|
||||
|
||||
for output in teacher_output:
|
||||
if len(output['boxes']) == 0:
|
||||
kd_loss += 1
|
||||
|
||||
if num_frames_processed > 0:
|
||||
kd_loss /= num_frames_processed
|
||||
else:
|
||||
kd_loss = 0
|
||||
|
||||
# Combine the losses
|
||||
kd_loss_weight = self.args.kd_loss_weight if hasattr(self.args, 'kd_loss_weight') else 0.1
|
||||
# Make kd_loss a tensor
|
||||
kd_loss_tensor = torch.tensor(kd_loss, device=self.accelerator.device, dtype=diffusion_loss.dtype)
|
||||
total_loss = diffusion_loss + kd_loss_weight * kd_loss_tensor
|
||||
|
||||
self.accelerator.log({"kd_loss": kd_loss, "diffusion_loss": diffusion_loss.item(), "total_loss": total_loss.item()})
|
||||
|
||||
return total_loss
|
||||
|
||||
register("cogvideox-t2v", "kd", CogVideoXT2VKdTrainer)
|
||||
@ -6,7 +6,7 @@ from finetune.trainer import Trainer
|
||||
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
|
||||
|
||||
|
||||
def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
|
||||
def register(model_name: str, training_type: Literal["lora", "sft", "kd"], trainer_cls: Trainer):
|
||||
"""Register a model and its associated functions for a specific training type.
|
||||
|
||||
Args:
|
||||
|
||||
@ -12,7 +12,12 @@ class Args(BaseModel):
|
||||
model_path: Path
|
||||
model_name: str
|
||||
model_type: Literal["i2v", "t2v"]
|
||||
training_type: Literal["lora", "sft"] = "lora"
|
||||
training_type: Literal["lora", "sft", "kd"] = "lora"
|
||||
|
||||
########## KD ##########
|
||||
teacher_model_path: Path | None = None
|
||||
teacher_model_num_classes: int | None = None
|
||||
kd_loss_weight: float = 0.1
|
||||
|
||||
########## Output ##########
|
||||
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now()))
|
||||
@ -239,6 +244,11 @@ class Args(BaseModel):
|
||||
# Validation
|
||||
parser.add_argument("--do_validation", type=lambda x: x.lower() == 'true', default=False)
|
||||
parser.add_argument("--validation_steps", type=int, default=None)
|
||||
|
||||
# KD parameters
|
||||
parser.add_argument("--teacher_model_path", type=str, default=None)
|
||||
parser.add_argument("--teacher_model_num_classes", type=int, default=None)
|
||||
parser.add_argument("--kd_loss_weight", type=float, default=0.1)
|
||||
parser.add_argument("--validation_dir", type=str, default=None)
|
||||
parser.add_argument("--validation_prompts", type=str, default=None)
|
||||
parser.add_argument("--validation_images", type=str, default=None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user