mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-06-01 09:04:08 +08:00
working, hopefully
This commit is contained in:
parent
a9f705b286
commit
7fa88d90b5
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "Image-Local-Attention"]
|
||||
path = Image-Local-Attention
|
||||
url = https://github.com/Sleepychord/Image-Local-Attention
|
||||
1
Image-Local-Attention
Submodule
1
Image-Local-Attention
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 43fee310cb1c6f64fb0ed77404ba3b01fa586026
|
||||
31
cog.yaml
Normal file
31
cog.yaml
Normal file
@ -0,0 +1,31 @@
|
||||
# Configuration for Cog ⚙️
|
||||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
||||
|
||||
build:
|
||||
gpu: true
|
||||
|
||||
system_packages:
|
||||
- "ffmpeg"
|
||||
- "libsm6"
|
||||
- "libxext6"
|
||||
- "libglib2.0-0"
|
||||
|
||||
python_version: "3.8"
|
||||
python_packages:
|
||||
- "deep-translator==1.8.3"
|
||||
- "gifmaker==1.5"
|
||||
- "icetk==0.0.4"
|
||||
- "opencv-python==4.6.0.66"
|
||||
- "SwissArmyTransformer==0.2.9"
|
||||
- "torch==1.9.0"
|
||||
- "torchvision==0.10.0"
|
||||
|
||||
run:
|
||||
- "mkdir -p /sharefs/cogview-new; cd /sharefs/cogview-new; wget https://models.nmb.ai/cogvideo/cogvideo-stage1.tar.gz -O - | tar xz"
|
||||
- "cd /sharefs/cogview-new; wget https://models.nmb.ai/cogvideo/cogvideo-stage2.tar.gz -O - | tar xz"
|
||||
- "cd /sharefs/cogview-new; wget https://models.nmb.ai/cogview2/cogview2-dsr.tar.gz -O - | tar xz"
|
||||
- "mkdir -p /root/.icetk_models; wget -O /root/.icetk_models/ice_text.model https://models.nmb.ai/cogvideo/ice_text.model"
|
||||
- "mkdir -p /root/.tcetk_models; wget -O /root/.icetk_models/ice_image.pt https://models.nmb.ai/cogvideo/ice_image.pt"
|
||||
|
||||
predict: "predict.py:Predictor"
|
||||
image: "r8.im/nightmareai/cogvideo"
|
||||
85
predict.py
Normal file
85
predict.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
from random import randint
|
||||
import subprocess
|
||||
import tempfile
|
||||
import glob
|
||||
import typing
|
||||
from deep_translator import GoogleTranslator
|
||||
from cog import BasePredictor, Input, Path
|
||||
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
subprocess.call("python setup.py install", cwd="/src/Image-Local-Attention", shell=True)
|
||||
self.translator = GoogleTranslator(source="en", target="zh-CN")
|
||||
|
||||
def predict(
|
||||
self,
|
||||
prompt: str = Input(description="Prompt"),
|
||||
seed: int = Input(description="Seed (leave empty to use a random seed)", default=None, le=(2**32 - 1), ge=0),
|
||||
translate: bool = Input(
|
||||
description="Translate prompt from English to Simplified Chinese (required if not entering Chinese text)",
|
||||
default=True,
|
||||
),
|
||||
# both_stages: bool = Input(
|
||||
# description="Run both stages (uncheck to run only stage 1 for quicker results)", default=True
|
||||
# ),
|
||||
use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True),
|
||||
) -> typing.List[Path]:
|
||||
if translate:
|
||||
prompt = self.translator.translate(prompt)
|
||||
workdir = tempfile.mkdtemp()
|
||||
os.makedirs(f"{workdir}/output")
|
||||
with open(f"{workdir}/input.txt", "w") as f:
|
||||
f.write(prompt)
|
||||
if seed is None:
|
||||
seed = randint(0, 2**32)
|
||||
args = [
|
||||
"python",
|
||||
"cogvideo_pipeline.py",
|
||||
"--input-source",
|
||||
f"{workdir}/input.txt",
|
||||
"--output-path",
|
||||
f"{workdir}/output",
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--parallel-size",
|
||||
"1",
|
||||
"--guidance-alpha",
|
||||
"3.0",
|
||||
"--generate-frame-num",
|
||||
"4",
|
||||
"--tokenizer-type",
|
||||
"fake",
|
||||
"--mode",
|
||||
"inference",
|
||||
"--distributed-backend",
|
||||
"nccl",
|
||||
"--fp16",
|
||||
"--model-parallel-size",
|
||||
"1",
|
||||
"--temperature",
|
||||
"1.05",
|
||||
"--coglm-temperature",
|
||||
"0.89",
|
||||
"--top_k",
|
||||
"12",
|
||||
"--sandwich-ln",
|
||||
"--seed",
|
||||
str(seed),
|
||||
"--num-workers",
|
||||
"0",
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--max-inference-batch-size",
|
||||
"8",
|
||||
"--both-stages",
|
||||
]
|
||||
if use_guidance:
|
||||
args.append("--use-guidance-stage1")
|
||||
print(args)
|
||||
os.environ["SAT_HOME"] = "/sharefs/cogview-new"
|
||||
if subprocess.check_output(args, shell=False, cwd="/src"):
|
||||
output = glob.glob(f"{workdir}/output/**/*.gif")
|
||||
for f in output:
|
||||
yield Path(f)
|
||||
Loading…
x
Reference in New Issue
Block a user