fix 0.30.1 update

This commit is contained in:
zR 2024-08-30 14:28:47 +08:00
parent bd5d36ac38
commit 43f8451893

View File

@ -71,14 +71,7 @@ def decode_video(model_path, encoded_tensor_path, dtype, device):
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
encoded_frames = torch.load(encoded_tensor_path, weights_only=True).to(device).to(dtype)
with torch.no_grad():
decoded_frames = []
for i in range(6): # 6 seconds
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = model.decode(encoded_frames[:, :, start_frame:end_frame]).sample
decoded_frames.append(current_frames)
model.clear_fake_context_parallel_cache()
decoded_frames = torch.cat(decoded_frames, dim=2)
decoded_frames = model.decode(encoded_frames).sample
return decoded_frames
@ -94,7 +87,7 @@ def save_video(tensor, output_path):
frames = np.clip(frames, 0, 1) * 255
frames = frames.astype(np.uint8)
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
writer = imageio.get_writer(output_path + "/output.mp4", fps=8)
for frame in frames:
writer.append_data(frame)
writer.close()