Multimodal tasks in CLiMB Adding new tasks to CLiMB consists of three main components:
- Data Processing: TaskDataset class and
build_task_dataloader()
method indata/visionlanguage_datasets/<task>_dataset.py
- Creating the task trainer: TaskTrainer class in
train/visionlanguage_tasks/<task>_trainer.py
- Adding task configuration and parameters to
configs/task_configs.py
This figure shows how the various classes and methods for adding a new task are structured, with arrows indicating which methods call which other methods and instantiate other classes.
In data/visionlanguage_datasets/<task>_dataset.py
, you have to create:
- a
TaskDataset
class, task_batch_collate()
method, andbuild_task_dataloader()
method
i. TaskDataset
: This is a torch.utils.data.Dataset class. The __init__()
method can take in whatever arguments you want, but should contain at least two arguments:
data_dir
: A string pointing to the location of the task's data filessplit
: A sting indicating whether this is a train or val split of the dataset- Optional:
images_dataset
, an instance of a Dataset fromdata/image_datasets/
which uses a common image dataset. For instance, SNLI-VE uses images from Flickr30k, so its Dataset takes animages_dataset
object belong to the Flickr30KImagesDataset dataset, which handles the image processing for all tasks that use Flickr30K images.
The TaskDataset
class must have a __getitem__()
method that returns the text and image inputs (the latter may be retrieved from images_dataset.get_image_data()
if applicable), and output label, in the form of a dictionary.
Example can be seen here.
An example of a TaskDataset
class for SNLI-VE, the SnliVEDataset class, can be seen here.
ii. task_batch_collate()
: This is a method that collates inputs from different instances in a batch into batched inputs and outputs in the form of a dictionary.
This includes padding of text and image tensors if needed.
The method's arguments are:
batch
: List of dictionaries, each dictionary corresponding to a batch instance, returned fromTaskDataset.__getitem__()
visual_input_type
: a string, indicating the image format that the model takes as input (in the case of ViLT,visual_input_type='pil-image'
. This string is retrieved from the model configuration file. Image collation can be done using theimage_collate()
method indata/image_collation.py
, using thevisual_input_type
argument.
The method returns a dictionary, each item corresponding to a batched input or output item.
An example of a task_batch_collate
method for SNLI-VE, called snli_ve_batch_collate
, can be seen here.
iii. build_task_dataloader()
: This is a method that returns a torch.utils.data.Dataloader, using an instance of TaskDataset
and the task_batch_collate()
method
An example of a build_task_dataloader
method for SNLI-VE, called build_snli_ve_dataloader
, can be seen here.
In train/visionlanguage_tasks/<task>_trainer.py
, you have to create a TaskTrainer
class. The Trainer should ideally be designed to be as model-agnostic as possible.
The TaskTrainer takes args
from the users, task configurations from configs/task_configs.py
and model configurations from configs/model_configs.py
TaskTrainer
has the following methods:
__init__()
: creates the train and validation dataloaders (using thebuild_task_dataloader()
method fromdata/visionlanguage_datasets/<task>_dataset.py
) and sets training hyperparameters.create_optimizer()
: Creates optimizer for trainingtrain()
: creates optimizer fromcreate_optimizer()
, and does training by calling thetrain_step()
method (which in turn callsforward_pass()
.eval_forgetting()
: Called after model has already been trained on this task in the upstream CL phase. Loads a model checkpoint from a future task, and evaluates on this task usingeval()
.eval()
: Evaluation over the validation dataloader. This method is called from thetrain()
method after every epoch, and can be used to evaluate forgetting fromeval_forgetting()
.
An example of a TaskTrainer
for SNLI-VE, called SNLIVETrainer
, can be seen here.
In configs/task_configs.py
, you need to create a <task>_config
dictionary, containing the following keys:
task_name
: Name of the task (for display purposes)data_dir
: The directory within the overall data directory containing this task's dataimage_source
(if applicable): If task dataset uses animages_dataset
such as MS-COCO or Flickr30K that is shared between tasks- Learning hyperparameters:
num_epochs
,lr
,weight_decay
,adam_epsilon
,warmup_ratio
task_trainer
: TheTaskTrainer
class for this task (defined aboe)random_baseline_score
: Performance of a random baseline on this task (as a percentage). Used for computing upstream knowledge transfer scores.
An example of a task_config
for SNLI-VE, called snli_ve_config
, can be seen here.
The task_config
is then stored in the task_configs
dictionary, with the task_key
being the key.
For instance, for SNLI-VE, task_key = 'snli-ve'
.