batch size

This commit is contained in:
Stephan Auerhahn 2022-07-26 02:49:23 -07:00
parent 2d43d2ff70
commit 0e90671806

View File

@ -439,6 +439,7 @@ class Predictor(BasePredictor):
args.coglm_temperature2 = 0.89 args.coglm_temperature2 = 0.89
args.generate_frame_num = 5 args.generate_frame_num = 5
args.stage1_max_inference_batch_size = -1 args.stage1_max_inference_batch_size = -1
args.max_inference_batch_size = 8
args.top_k = 12 args.top_k = 12
args.use_guidance_stage1 = True args.use_guidance_stage1 = True
args.use_guidance_stage2 = False args.use_guidance_stage2 = False
@ -500,8 +501,11 @@ class Predictor(BasePredictor):
yield Path(file) yield Path(file)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return return
@torch.no_grad()
def run(self): def run(self):
torch.manual_seed(self.args.seed)
random.seed(self.args.seed)
invalid_slices = [slice(tokenizer.num_image_tokens, None)] invalid_slices = [slice(tokenizer.num_image_tokens, None)]
strategy_cogview2 = CoglmStrategy(invalid_slices, strategy_cogview2 = CoglmStrategy(invalid_slices,
temperature=1.0, top_k=16) temperature=1.0, top_k=16)
@ -509,8 +513,6 @@ class Predictor(BasePredictor):
temperature=self.args.temperature, top_k=self.args.top_k, temperature=self.args.temperature, top_k=self.args.top_k,
temperature2=self.args.coglm_temperature2) temperature2=self.args.coglm_temperature2)
torch.manual_seed(self.args.seed)
random.seed(self.args.seed)
workdir = tempfile.mkdtemp() workdir = tempfile.mkdtemp()
os.makedirs(f"{workdir}/output/stage1", exist_ok=True) os.makedirs(f"{workdir}/output/stage1", exist_ok=True)
os.makedirs(f"{workdir}/output/stage2", exist_ok=True) os.makedirs(f"{workdir}/output/stage2", exist_ok=True)