mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 10:32:09 +08:00
batch size
This commit is contained in:
parent
2d43d2ff70
commit
0e90671806
@ -439,6 +439,7 @@ class Predictor(BasePredictor):
|
||||
args.coglm_temperature2 = 0.89
|
||||
args.generate_frame_num = 5
|
||||
args.stage1_max_inference_batch_size = -1
|
||||
args.max_inference_batch_size = 8
|
||||
args.top_k = 12
|
||||
args.use_guidance_stage1 = True
|
||||
args.use_guidance_stage2 = False
|
||||
@ -500,8 +501,11 @@ class Predictor(BasePredictor):
|
||||
yield Path(file)
|
||||
torch.cuda.empty_cache()
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self):
|
||||
torch.manual_seed(self.args.seed)
|
||||
random.seed(self.args.seed)
|
||||
|
||||
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
||||
strategy_cogview2 = CoglmStrategy(invalid_slices,
|
||||
temperature=1.0, top_k=16)
|
||||
@ -509,8 +513,6 @@ class Predictor(BasePredictor):
|
||||
temperature=self.args.temperature, top_k=self.args.top_k,
|
||||
temperature2=self.args.coglm_temperature2)
|
||||
|
||||
torch.manual_seed(self.args.seed)
|
||||
random.seed(self.args.seed)
|
||||
workdir = tempfile.mkdtemp()
|
||||
os.makedirs(f"{workdir}/output/stage1", exist_ok=True)
|
||||
os.makedirs(f"{workdir}/output/stage2", exist_ok=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user