mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
fix 0.30.1 update
This commit is contained in:
parent
bd5d36ac38
commit
43f8451893
@ -71,14 +71,7 @@ def decode_video(model_path, encoded_tensor_path, dtype, device):
|
||||
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
|
||||
encoded_frames = torch.load(encoded_tensor_path, weights_only=True).to(device).to(dtype)
|
||||
with torch.no_grad():
|
||||
decoded_frames = []
|
||||
for i in range(6): # 6 seconds
|
||||
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||
current_frames = model.decode(encoded_frames[:, :, start_frame:end_frame]).sample
|
||||
decoded_frames.append(current_frames)
|
||||
model.clear_fake_context_parallel_cache()
|
||||
|
||||
decoded_frames = torch.cat(decoded_frames, dim=2)
|
||||
decoded_frames = model.decode(encoded_frames).sample
|
||||
return decoded_frames
|
||||
|
||||
|
||||
@ -94,7 +87,7 @@ def save_video(tensor, output_path):
|
||||
frames = np.clip(frames, 0, 1) * 255
|
||||
frames = frames.astype(np.uint8)
|
||||
|
||||
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
|
||||
writer = imageio.get_writer(output_path + "/output.mp4", fps=8)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user