From 0e90671806e6fbfaa8e803b057285ed96ccdd19e Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Tue, 26 Jul 2022 02:49:23 -0700 Subject: [PATCH] batch size --- predict.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/predict.py b/predict.py index 36b192a..9da9bba 100644 --- a/predict.py +++ b/predict.py @@ -439,6 +439,7 @@ class Predictor(BasePredictor): args.coglm_temperature2 = 0.89 args.generate_frame_num = 5 args.stage1_max_inference_batch_size = -1 + args.max_inference_batch_size = 8 args.top_k = 12 args.use_guidance_stage1 = True args.use_guidance_stage2 = False @@ -500,8 +501,11 @@ class Predictor(BasePredictor): yield Path(file) torch.cuda.empty_cache() return - + @torch.no_grad() def run(self): + torch.manual_seed(self.args.seed) + random.seed(self.args.seed) + invalid_slices = [slice(tokenizer.num_image_tokens, None)] strategy_cogview2 = CoglmStrategy(invalid_slices, temperature=1.0, top_k=16) @@ -509,8 +513,6 @@ class Predictor(BasePredictor): temperature=self.args.temperature, top_k=self.args.top_k, temperature2=self.args.coglm_temperature2) - torch.manual_seed(self.args.seed) - random.seed(self.args.seed) workdir = tempfile.mkdtemp() os.makedirs(f"{workdir}/output/stage1", exist_ok=True) os.makedirs(f"{workdir}/output/stage2", exist_ok=True) @@ -562,8 +564,8 @@ class Predictor(BasePredictor): ) logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time)) output_tokens_1st = torch.cat(output_list_1st, dim=0) - given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400] - + given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400] + # generate subsequent frames: total_frames = self.generate_frame_num enc_duration = tokenizer.encode(str(float(duration))+"秒")