google-labs-jules[bot] ebc9d39c02 feat: Add knowledge distillation for logo generation with VGG teacher
This commit introduces a knowledge distillation module to enhance logo generation in the CogVideoX-2B text-to-video model.

The key changes include:

- A new `KDTrainer` class that inherits from `CogVideoXT2VLoraTrainer`. This trainer loads a teacher model and computes a knowledge distillation loss to guide the student model.
- The teacher model loading logic has been updated to support a VGG16-based Faster R-CNN model, to be compatible with user-provided weights. This includes a custom construction of the Faster R-CNN model with a VGG16 backbone and appropriate RoI heads.
- The `kd` training type is now supported, allowing users to select it from the command line.
- New command-line arguments (`teacher_model_path`, `teacher_model_num_classes`, `kd_loss_weight`) have been added to configure the knowledge distillation process.
- A new configuration file (`cogvideox_2b_kd.yaml`) is provided as an example for running a `kd` training session.
2025-08-21 10:13:06 +00:00

130 lines
5.3 KiB
Python

import torch
import torchvision
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
from ..utils import register
from typing_extensions import override
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models import vgg16
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
import torch.nn as nn
from collections import OrderedDict
class VGG16BackboneWrapper(nn.Module):
def __init__(self, vgg_features):
super(VGG16BackboneWrapper, self).__init__()
self.features = vgg_features
self.out_channels = 512
def forward(self, x):
x = self.features(x)
return OrderedDict([("0", x)])
class CogVideoXT2VKdTrainer(CogVideoXT2VLoraTrainer):
# Remove vae from the unload list to make it available in compute_loss
UNLOAD_LIST = ["text_encoder"]
def __init__(self, args):
super().__init__(args)
self.teacher_model = self.load_teacher_model()
def load_teacher_model(self):
teacher_model_path = self.args.teacher_model_path if hasattr(self.args, 'teacher_model_path') else None
if not teacher_model_path:
print("Warning: teacher_model_path is not provided. Knowledge distillation will be skipped.")
return None
try:
# Create a VGG16-based Faster R-CNN model
# 1. VGG16 backbone
vgg_features = vgg16(weights=None).features
# The original VGG16 model in torchvision has a maxpool layer at the end of features.
# Faster R-CNN with VGG backbone in many implementations does not use this last maxpool.
# Let's remove it to be closer to the original Caffe implementation.
backbone_features = vgg_features[:-1]
backbone = VGG16BackboneWrapper(backbone_features)
# 2. RPN
anchor_generator = AnchorGenerator(sizes=((128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),))
# 3. RoI heads
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2)
# The user's model has fc6 and fc7 layers, which corresponds to a TwoMLPHead.
# VGG16's output from the backbone is 512 * 7 * 7 = 25088
box_head = TwoMLPHead(in_channels=25088, representation_size=4096)
num_classes = self.args.teacher_model_num_classes if hasattr(self.args, 'teacher_model_num_classes') else 91 # COCO default
box_predictor = FastRCNNPredictor(in_channels=4096, num_classes=num_classes)
# 4. Faster R-CNN model
model = FasterRCNN(backbone,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
box_head=box_head,
box_predictor=box_predictor,
num_classes=num_classes)
# Load the pre-trained weights from the converted file
print(f"Loading teacher model from: {teacher_model_path}")
state_dict = torch.load(teacher_model_path)
model.load_state_dict(state_dict)
model.eval()
model.to(self.accelerator.device)
print("Teacher model loaded successfully.")
return model
except Exception as e:
print(f"Error loading teacher model: {e}")
return None
@override
def compute_loss(self, batch) -> torch.Tensor:
# Get the original diffusion loss
diffusion_loss = super().compute_loss(batch)
if self.teacher_model is None:
return diffusion_loss
latents = batch["encoded_videos"]
# Decode the latents to get video frames
video_frames = self.components.vae.decode(latents / self.components.vae.config.scaling_factor).sample
# The output of the VAE is in the range [-1, 1]. We need to normalize it to [0, 1] for the teacher model.
video_frames = (video_frames + 1) / 2
video_frames = video_frames.permute(0, 2, 1, 3, 4)
# Calculate the knowledge distillation loss
kd_loss = 0
num_frames_processed = 0
for i in range(video_frames.shape[0]): # For each video in the batch
frames = [frame for frame in video_frames[i]] # list of frames for the i-th video
if not frames:
continue
num_frames_processed += len(frames)
teacher_output = self.teacher_model(frames)
for output in teacher_output:
if len(output['boxes']) == 0:
kd_loss += 1
if num_frames_processed > 0:
kd_loss /= num_frames_processed
else:
kd_loss = 0
# Combine the losses
kd_loss_weight = self.args.kd_loss_weight if hasattr(self.args, 'kd_loss_weight') else 0.1
# Make kd_loss a tensor
kd_loss_tensor = torch.tensor(kd_loss, device=self.accelerator.device, dtype=diffusion_loss.dtype)
total_loss = diffusion_loss + kd_loss_weight * kd_loss_tensor
self.accelerator.log({"kd_loss": kd_loss, "diffusion_loss": diffusion_loss.item(), "total_loss": total_loss.item()})
return total_loss
register("cogvideox-t2v", "kd", CogVideoXT2VKdTrainer)