minor attempts

This commit is contained in:
Stephan Auerhahn 2022-07-25 20:47:48 -07:00
parent 74125c433b
commit ce39dbe208
2 changed files with 7 additions and 6 deletions

View File

@ -27,6 +27,8 @@ from SwissArmyTransformer.resources import auto_create
from models.cogvideo_cache_model import CogVideoCacheModel from models.cogvideo_cache_model import CogVideoCacheModel
from coglm_strategy import CoglmStrategy from coglm_strategy import CoglmStrategy
sys.path.append('./Image-Local-Attention')
def get_masks_and_position_ids_stage1(data, textlen, framelen): def get_masks_and_position_ids_stage1(data, textlen, framelen):
# Extract batch size and sequence length. # Extract batch size and sequence length.
@ -415,8 +417,7 @@ class InferenceModel_Interpolate(CogVideoCacheModel):
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S') logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S')
subprocess.call("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True) subprocess.call("python setup.py develop", cwd="/src/Image-Local-Attention", shell=True)
sys.path.append('./Image-Local-Attention')
os.environ["SAT_HOME"] = "/sharefs/cogview-new" os.environ["SAT_HOME"] = "/sharefs/cogview-new"
args = get_args([ args = get_args([
"--batch-size", "1", "--batch-size", "1",
@ -482,7 +483,7 @@ class Predictor(BasePredictor):
both_stages: bool = Input( both_stages: bool = Input(
description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True description="Run both stages (uncheck to run more quickly and output only a few frames)", default=True
), ),
use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True), use_guidance: bool = Input(description="Use stage 1 guidance (recommended)", default=True)
) -> typing.Iterator[Path]: ) -> typing.Iterator[Path]:
if translate: if translate:
prompt = self.translator.translate(prompt.strip()) prompt = self.translator.translate(prompt.strip())
@ -632,7 +633,7 @@ class Predictor(BasePredictor):
for clip_i in range(len(imgs)): for clip_i in range(len(imgs)):
# os.makedirs(output_dir_full_paths[clip_i], exist_ok=True) # os.makedirs(output_dir_full_paths[clip_i], exist_ok=True)
my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False) my_save_multiple_images(imgs[clip_i], outputdir, subdir=f"frames/{clip_i}", debug=False)
os.system(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25") subprocess.call(f"gifmaker -i '{outputdir}'/frames/'{clip_i}'/0*.jpg -o '{outputdir}/{clip_i}.gif' -d 0.25", shell=True)
torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt')) torch.save(save_tokens, os.path.join(outputdir, 'frame_tokens.pt'))
logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time)) logging.info("CogVideo Stage1 completed. Taken time {:.2f}\n".format(time.time() - process_start_time))
@ -763,7 +764,7 @@ class Predictor(BasePredictor):
for sample_i in range(sample_num): for sample_i in range(sample_num):
my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) my_save_multiple_images(decoded_sr_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125") subprocess.call(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125", shell=True)
logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime)) logging.info("Direct super-resolution completed. Taken time {:.2f}\n".format(time.time() - dsr_starttime))
else: else:
@ -782,7 +783,7 @@ class Predictor(BasePredictor):
for sample_i in range(sample_num): for sample_i in range(sample_num):
my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False) my_save_multiple_images(decoded_videos[sample_i], outputdir,subdir=f"frames/{sample_i+sample_num*gpu_rank}", debug=False)
os.system(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125") subprocess.call(f"gifmaker -i '{outputdir}'/frames/'{sample_i+sample_num*gpu_rank}'/0*.jpg -o '{outputdir}/{sample_i+sample_num*gpu_rank}.gif' -d 0.125", shell=True)
#imgs = [] #imgs = []