diff --git a/nbs/models.gru.ipynb b/nbs/models.gru.ipynb index 7cb14f21c..7f0608a5f 100644 --- a/nbs/models.gru.ipynb +++ b/nbs/models.gru.ipynb @@ -1,5 +1,14 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1" + ] + }, { "cell_type": "code", "execution_count": null, @@ -70,6 +79,7 @@ "outputs": [], "source": [ "#| export\n", + "import warnings\n", "from typing import Optional\n", "\n", "import torch\n", @@ -91,7 +101,7 @@ " \"\"\" GRU\n", "\n", " Multi Layer Recurrent Network with Gated Units (GRU), and\n", - " MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained \n", + " MLP decoder. The network has non-linear activation functions, it is trained \n", " using ADAM stochastic gradient descent. The network accepts static, historic \n", " and future exogenous data, flattens the inputs.\n", "\n", @@ -101,7 +111,7 @@ " `inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.
\n", " `encoder_n_layers`: int=2, number of layers for the GRU.
\n", " `encoder_hidden_size`: int=200, units for the GRU's hidden state size.
\n", - " `encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.
\n", + " `encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.
\n", " `encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.
\n", " `encoder_dropout`: float=0., dropout regularization applied to GRU outputs.
\n", " `context_size`: int=10, size of context vector for each timestamp on the forecasting window.
\n", @@ -143,7 +153,7 @@ " inference_input_size: int = -1,\n", " encoder_n_layers: int = 2,\n", " encoder_hidden_size: int = 200,\n", - " encoder_activation: str = 'tanh',\n", + " encoder_activation: Optional[str] = None,\n", " encoder_bias: bool = True,\n", " encoder_dropout: float = 0.,\n", " context_size: int = 10,\n", @@ -199,6 +209,14 @@ " **trainer_kwargs\n", " )\n", "\n", + " if encoder_activation is not None:\n", + " warnings.warn(\n", + " \"The 'encoder_activation' argument is deprecated and will be removed in \"\n", + " \"future versions. The activation function in GRU is frozen in PyTorch and \"\n", + " \"it cannot be modified.\",\n", + " DeprecationWarning,\n", + " )\n", + "\n", " # RNN\n", " self.encoder_n_layers = encoder_n_layers\n", " self.encoder_hidden_size = encoder_hidden_size\n", @@ -322,7 +340,7 @@ "import matplotlib.pyplot as plt\n", "\n", "from neuralforecast import NeuralForecast\n", - "from neuralforecast.models import GRU\n", + "# from neuralforecast.models import GRU\n", "from neuralforecast.losses.pytorch import DistributionLoss\n", "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n", "\n", diff --git a/neuralforecast/models/gru.py b/neuralforecast/models/gru.py index 900eac162..9a6d92325 100644 --- a/neuralforecast/models/gru.py +++ b/neuralforecast/models/gru.py @@ -3,7 +3,8 @@ # %% auto 0 __all__ = ['GRU'] -# %% ../../nbs/models.gru.ipynb 6 +# %% ../../nbs/models.gru.ipynb 7 +import warnings from typing import Optional import torch @@ -13,12 +14,12 @@ from ..common._base_recurrent import BaseRecurrent from ..common._modules import MLP -# %% ../../nbs/models.gru.ipynb 7 +# %% ../../nbs/models.gru.ipynb 8 class GRU(BaseRecurrent): """GRU Multi Layer Recurrent Network with Gated Units (GRU), and - MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained + MLP decoder. The network has non-linear activation functions, it is trained using ADAM stochastic gradient descent. The network accepts static, historic and future exogenous data, flattens the inputs. @@ -28,7 +29,7 @@ class GRU(BaseRecurrent): `inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.
`encoder_n_layers`: int=2, number of layers for the GRU.
`encoder_hidden_size`: int=200, units for the GRU's hidden state size.
- `encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.
+ `encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.
`encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.
`encoder_dropout`: float=0., dropout regularization applied to GRU outputs.
`context_size`: int=10, size of context vector for each timestamp on the forecasting window.
@@ -72,7 +73,7 @@ def __init__( inference_input_size: int = -1, encoder_n_layers: int = 2, encoder_hidden_size: int = 200, - encoder_activation: str = "tanh", + encoder_activation: Optional[str] = None, encoder_bias: bool = True, encoder_dropout: float = 0.0, context_size: int = 10, @@ -129,6 +130,14 @@ def __init__( **trainer_kwargs ) + if encoder_activation is not None: + warnings.warn( + "The 'encoder_activation' argument is deprecated and will be removed in " + "future versions. The activation function in GRU is frozen in PyTorch and " + "it cannot be modified.", + DeprecationWarning, + ) + # RNN self.encoder_n_layers = encoder_n_layers self.encoder_hidden_size = encoder_hidden_size