batch size

This commit is contained in:
Stephan Auerhahn 2022-07-26 02:49:23 -07:00
parent 2d43d2ff70
commit 0e90671806

View File

@ -439,6 +439,7 @@ class Predictor(BasePredictor):
args.coglm_temperature2 = 0.89
args.generate_frame_num = 5
args.stage1_max_inference_batch_size = -1
args.max_inference_batch_size = 8
args.top_k = 12
args.use_guidance_stage1 = True
args.use_guidance_stage2 = False
@ -500,8 +501,11 @@ class Predictor(BasePredictor):
yield Path(file)
torch.cuda.empty_cache()
return
@torch.no_grad()
def run(self):
torch.manual_seed(self.args.seed)
random.seed(self.args.seed)
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
strategy_cogview2 = CoglmStrategy(invalid_slices,
temperature=1.0, top_k=16)
@ -509,8 +513,6 @@ class Predictor(BasePredictor):
temperature=self.args.temperature, top_k=self.args.top_k,
temperature2=self.args.coglm_temperature2)
torch.manual_seed(self.args.seed)
random.seed(self.args.seed)
workdir = tempfile.mkdtemp()
os.makedirs(f"{workdir}/output/stage1", exist_ok=True)
os.makedirs(f"{workdir}/output/stage2", exist_ok=True)
@ -562,8 +564,8 @@ class Predictor(BasePredictor):
)
logging.info("[First Frame]Taken time {:.2f}\n".format(time.time() - start_time))
output_tokens_1st = torch.cat(output_list_1st, dim=0)
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
given_tokens = output_tokens_1st[:, text_len_1st+1:text_len_1st+401].unsqueeze(1) # given_tokens.shape: [bs, frame_num, 400]
# generate subsequent frames:
total_frames = self.generate_frame_num
enc_duration = tokenizer.encode(str(float(duration))+"")