diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8130b04 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +wandb/ +*.sh +*.txt + diff --git a/net_train.py b/net_train.py index 598f484..1619da5 100755 --- a/net_train.py +++ b/net_train.py @@ -20,6 +20,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import loggers +from pytorch_lightning.profiler import SimpleProfiler + from simnet.lib.net import common from simnet.lib import datapoint @@ -69,6 +71,8 @@ def draw_detections( model = PanopticModel(hparams, epochs, train_ds, EvalMethod()) model_checkpoint = ModelCheckpoint(filepath=hparams.output, save_top_k=-1, period=1, mode='max') wandb_logger = loggers.WandbLogger(name=hparams.wandb_name, project='CenterSnap') + + profiler = SimpleProfiler() if hparams.finetune_real: trainer = pl.Trainer( @@ -94,6 +98,7 @@ def draw_detections( default_save_path=hparams.output, use_amp=False, print_nan_grads=True, + profiler=profiler ) trainer.fit(model) diff --git a/simnet/lib/net/dataset.py b/simnet/lib/net/dataset.py index 1984bfd..7a5ddc3 100755 --- a/simnet/lib/net/dataset.py +++ b/simnet/lib/net/dataset.py @@ -49,7 +49,7 @@ def __init__(self, dataset_uri, hparams, preprocess_image_func=None, datapoint_d super().__init__() if datapoint_dataset is None: datapoint_dataset = datapoint.make_dataset(dataset_uri) - self.datapoint_handles = datapoint_dataset.list() + self.datapoint_handles = datapoint_dataset.list()[:200] # No need to shuffle, already shufled based on random uids self.hparams = hparams if preprocess_image_func is None: