This is the official repository for the NeurIPS 2023 paper Im-Promptu: In-Context Composition from Image Prompts. Website: jha-lab.github.io/impromptu/.
-
[Im-Promptu: In-Context Composition from Image Prompts]
-
[Table of Contents]
The following shell script creates an anaconda environment called "impromptu" and installs all the required packages.
source env_setup.sh
The datasets can be downloaded from the following link. The datasets should be placed in the ./datasets
directory. Detailed information about the benchmarks can be found in the ./benchmarks/README.md
file.
Solving analogies by simple transformation over the pixel space
$\hat{D} = C+ (B-A)$
A command instance to run the pixel transformation model
python3 learners/pixel.py --dataset shapes3d --batch_size 64 --data_path ./datasets/shapes3d/train.h5 --logs_dir ./logs_dir/ --phase val
The various arguments that can be passed to the script are:
--dataset = name of the dataset (options: shapes3d, clevr, bitmoji)
--batch_size = batch size for training
--data_path = path to the training data
--logs_dir = path to the directory where the logs will be stored
--phase = split of the dataset to evaluate on
Monolithic vector representation to solve visual analogies. Architecture laid out in ./learners/monolithic.py
Training instance of a monolithic learner is given below:
cd train_scripts
python train_monolithic.py --epochs 100 --dataset shapes3d --data_path ../datasets/shapes3d/train.h5 --image_size 64 --seed 0 --d_model 192 --logs_dir ../logs_dir/
Hyperparameters can be tweaked as follows
--epochs = Training epochs
--dataset = Name of the dataset to spawn Dataset from ./utils/create_dataset.py
--d_model = Latent vector dimension
--image_size = Input image size
--lr_main = Peak learning rate
--lr_warmup_steps = Learning rate warmup steps for linear schedule
--data_path = Path to the dataset
--log_path = path to the directory where the logs will be stored
Patch abstractions to solve visual analogies. Architecture laid out in ./learners/patch_network.py
cd train_scripts
python3 train_patch_network.py --batch_size 16 --dataset shapes3d --img_channels 3 --epochs 150 --data_path ./datasets/shapes3d/train.h5 --vocab_size 512 --image_size 64 --num_enc_heads 4 --num_enc_blocks 4 --num_dec_blocks 4 --num_heads 4 --seed 3
Additional hyperparameters are as follows
--vocab_size = Size of dVAE vocabulary
--num_dec_block = Number of decoder blocks
--num_enc_block = Number of context encoder blocks
--num_heads = Number of attention heads in the decoder
--num_enc_heads = Number of attention heads in the context encoder
Solving analogies by learning object-centric representations. Architecture laid out in ./learners/object_centric_learner.py
cd train_scripts/
python train.py --img_channels 3 --dataset shapes3d --batch_size 32 --epochs 150 --data_path ./datasets/shapes3d/train.h5 --vocab_size 512 --image_size 64 --num_iterations 3 --num_slots 3 --num_enc_heads 4 --num_enc_blocks 4 --num_dec_heads 4 --num_heads 4 --slate_encoder_path ./logs_dir_pretrain/SLATE/best_encoder.pt --lr_warmup_steps 15000 --seed 0 --log_path ./logs_dir/
--num_slots = Number of object slots per image
--num_iterations = Number of iterations for slot attention
--slate_encoder_path = Path to the pre-trained slate encoder
Architecture laid out in ./learners/sequential_prompter.py
cd train_scripts/
python train_prompt_.py --img_channels 3 --epochs 150 --data_path ./datasets/shapes3d/train.h5 --vocab_size 512 --image_size 64 --num_iterations 3 --num_slots 3 --slate_encoder_path ./logs_dir_pretrain/shapes3d_SLATE/best_encoder.pt --seed 0
Cite our work using the following bitex entry:
@misc{dedhia2023impromptu,
title={Im-Promptu: In-Context Composition from Image Prompts},
author={Bhishma Dedhia and Michael Chang and Jake C. Snell and Thomas L. Griffiths and Niraj K. Jha},
year={2023},
eprint={2305.17262},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
The Clear BSD License Copyright (c) 2023, Bhishma Dedhia and Jha Lab. All rights reserved.
See License file for more details.