diff --git a/train/run.py b/train/run.py index efc1939..6a0f15e 100644 --- a/train/run.py +++ b/train/run.py @@ -20,7 +20,6 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: - if trainer.fast_dev_run: raise Exception( "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode." @@ -40,7 +39,6 @@ def get_wandb_logger(trainer: Trainer) -> WandbLogger: class WatchModel(Callback): - def __init__(self, log: str = "gradients", log_freq: int = 100): self.log = log self.log_freq = log_freq @@ -52,7 +50,6 @@ def on_train_start(self, trainer, pl_module): class UploadCheckpointsAsArtifact(Callback): - def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): self.ckpt_dir = ckpt_dir self.upload_best_only = upload_best_only