mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-02 10:32:09 +08:00
batch size
This commit is contained in:
parent
2d43d2ff70
commit
0e90671806
12
predict.py
12
predict.py
@ -439,6 +439,7 @@ class Predictor(BasePredictor):
|
|||||||
args.coglm_temperature2 = 0.89
|
args.coglm_temperature2 = 0.89
|
||||||
args.generate_frame_num = 5
|
args.generate_frame_num = 5
|
||||||
args.stage1_max_inference_batch_size = -1
|
args.stage1_max_inference_batch_size = -1
|
||||||
|
args.max_inference_batch_size = 8
|
||||||
args.top_k = 12
|
args.top_k = 12
|
||||||
args.use_guidance_stage1 = True
|
args.use_guidance_stage1 = True
|
||||||
args.use_guidance_stage2 = False
|
args.use_guidance_stage2 = False
|
||||||
@ -500,8 +501,11 @@ class Predictor(BasePredictor):
|
|||||||
yield Path(file)
|
yield Path(file)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return
|
return
|
||||||
|
@torch.no_grad()
|
||||||
def run(self):
|
def run(self):
|
||||||
|
torch.manual_seed(self.args.seed)
|
||||||
|
random.seed(self.args.seed)
|
||||||
|
|
||||||
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
|
||||||
strategy_cogview2 = CoglmStrategy(invalid_slices,
|
strategy_cogview2 = CoglmStrategy(invalid_slices,
|
||||||
temperature=1.0, top_k=16)
|
temperature=1.0, top_k=16)
|
||||||
@ -509,8 +513,6 @@ class Predictor(BasePredictor):
|
|||||||
temperature=self.args.temperature, top_k=self.args.top_k,
|
temperature=self.args.temperature, top_k=self.args.top_k,
|
||||||
temperature2=self.args.coglm_temperature2)
|
temperature2=self.args.coglm_temperature2)
|
||||||
|
|
||||||
torch.manual_seed(self.args.seed)
|
|
||||||
random.seed(self.args.seed)
|
|
||||||
workdir = tempfile.mkdtemp()
|
workdir = tempfile.mkdtemp()
|
||||||
os.makedirs(f"{workdir}/output/stage1", exist_ok=True)
|
os.makedirs(f"{workdir}/output/stage1", exist_ok=True)
|
||||||
os.makedirs(f"{workdir}/output/stage2", exist_ok=True)
|
os.makedirs(f"{workdir}/output/stage2", exist_ok=True)
|
||||||
@ -562,8 +564,8 @@ class Predictor(BasePredictor):
|
|||||||
)
|
)
|
||||||
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
|
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
|
||||||
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
||||||
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
|
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
|
||||||
|
|
||||||
# generate subsequent frames:
|
# generate subsequent frames:
|
||||||
total_frames = self.generate_frame_num
|
total_frames = self.generate_frame_num
|
||||||
enc_duration = tokenizer.encode(str(float(duration))+"秒")
|
enc_duration = tokenizer.encode(str(float(duration))+"秒")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user