-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathveas_pilot_forecast_cv_rnn.yaml
44 lines (38 loc) · 1.47 KB
/
veas_pilot_forecast_cv_rnn.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# @package _global_
# example hyperparameter optimization of some experiment with Optuna:
# python train.py -m hparams_search=mnist_optuna experiment=example
defaults:
- veas_pilot_cv_base
- override /experiment: veas_pilot_forecast
- override /model: veas_pilot_rnn_forecast
use_dtw: null
model:
loss_fn:
_target_: ${eval:'"src.metrics.soft_dtw.SoftDTWMetric" if ${use_dtw} else "torch.nn.MSELoss"'}
log_hyperparameters_custom:
- "use_dtw"
#trainer:
# accelerator: ${eval:'"cpu" if ${use_dtw} else "auto"'}
# here we define Optuna hyperparameter search
# it optimizes for value returned from function with @hydra.main decorator
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
hydra:
sweeper:
study_name: veas_pilot-forecast-cv-rnn
# define hyperparameter search space
params:
++model.hidden_dim: range(1, 200)
++model.dropout: range(0,0.6,0.1)
++model.n_rnn_layers: range(1,5)
++model.input_chunk_length: range(1, 288)
use_dtw: choice(True, False)
++model.optimizer_kwargs.lr: interval(1e-5,1e-2)
use_inputs.oxygen: choice(True, False)
use_inputs.ammonium: choice(True, False)
use_inputs.filterpressure_1: choice(True, False)
use_inputs.filterpressure_8: choice(True, False)
use_inputs.turb: choice(True, False)
use_inputs.temp: choice(True, False)
use_inputs.methanol: choice(True, False)
use_inputs.orto-p: choice(True, False)
use_inputs.tunnelwater: choice(True, False)