Replicating ESRNN results with nixtlats (I): M4 Yearly dataset
[nikstla] (noun, nahuatl) Period of time.
- Introduction
- Install nixtlats
- GPU usage
- Import libraries
- Import M4 Yearly data
- Define Time Series Datasets
- Define Time Series Loaders
- Define ESRNN model
- Train model
- Make forecasts
- Evaluate results
- Original results
- Conclusion
- How to contribute
In this post we introduce nixtlats
: a library of state-of-the-art deep learning models for time series forecasting written in pytorch, focused on usability and replicability. In this first post we introduce the Exponential Smoothing with Recurrent Neural Networks (ESRNN) model, winner of the M4 time series competition and show that our implementation achieves similar performance to the original submission. Throughout the post we describe the pipeline for training the models using nixtlats
and also a brief description of the datasets included in the library. This work is inspired by an earlier implementation of ESRNN in pytorch.
The nixtlats
library is available in pip so to install it just use:
!pip install nixtlats
To use GPU in colab: Edit > Notebook settings or Runtime>Change runtime type and select GPU as Hardware accelerator
.
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch as t
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import seed_everything
from nixtlats.data.datasets.m4 import M4, M4Info, M4Evaluation
from nixtlats.data.tsdataset import TimeSeriesDataset
from nixtlats.data.tsloader import TimeSeriesLoader
from nixtlats.models.esrnn.esrnn import ESRNN
seed_everything(117982, workers=True)
The nixtlats
library provides functions to download and manipulate M4 data. The M4.load
method returns train and test sets. The library also provides a wide variety of datasets, see the documentation.
group = M4Info['Yearly']
Y_df, _, S_df = M4.load(directory='data', group=group.name)
Y_df_test = Y_df.groupby('unique_id').tail(group.horizon).copy()
Y_df_train = Y_df.drop(Y_df_test.index)
To avoid leakage, set test values as zero.
Y_df_test.loc[:, 'y'] = 0
nixtlats
requires a dummy test set to make forecasts.
Y_df_full = pd.concat([Y_df_train, Y_df_test]).sort_values(['unique_id', 'ds'], ignore_index=True)
The pipeline for model training follows the logic of deep learning practices. In the first instance a Dataset
must be instantiated. The TimeSeriesDataset
class allows to return the complete series in each iteration, this is useful for recurrent models such as ESRNN. To be instantiated, the class receives the target series Y_df
as a pandas dataframe with columns unique_id
, ds
and y
. Additionally, temporary exogenous variables X_df
and static variables S_df
can be included. In this case we only use static variables as in the original model.
train_ts_dataset = TimeSeriesDataset(Y_df=Y_df_train, S_df=S_df,
input_size=4,
output_size=group.horizon)
test_ts_dataset = TimeSeriesDataset(Y_df=Y_df_full, S_df=S_df,
input_size=4,
output_size=group.horizon)
Once the training and test TimeSeriesDatasets
are defined, the loaders must be instantiated through TimeSeriesLoader
which is a wrapper of the pytorch's DataLoader
, so it can receive all the original arguments.
train_ts_loader = TimeSeriesLoader(dataset=train_ts_dataset,
batch_size=16,
shuffle=False)
test_ts_loader = TimeSeriesLoader(dataset=test_ts_dataset,
batch_size=1024,
eq_batch_size=False,
shuffle=False)
The instantiation of the model is carried out as follows. For an overview of all parameters see the documentation.
model = ESRNN(n_series=group.n_ts,
n_x=0, n_s=1,
sample_freq=1,
input_size=4,
output_size=group.horizon,
learning_rate=0.0025,
lr_scheduler_step_size=6,
lr_decay=0.08,
per_series_lr_multip=0.8,
gradient_clipping_threshold=20,
rnn_weight_decay=0,
level_variability_penalty=50,
testing_percentile=50,
training_percentile=50,
cell_type='GRU',
state_hsize=30,
dilations=[[1, 2], [2, 6]],
add_nl_layer=True,
loss='SMYL',
val_loss='SMAPE',
seasonality=[])
The nixtlats
library is designed to operate with pytorch-lightning
so that all its functionalities can be used. Add gpus = 1
for GPU training using colab.
trainer = pl.Trainer(max_epochs=15,
progress_bar_refresh_rate=10,
deterministic=True)
trainer.fit(model, train_ts_loader)
The ESRNN model returns for each test batch three elements: y
, the test values y_hat
and the mask mask
.
outputs = trainer.predict(model, test_ts_loader)
_, y_hat, mask = zip(*outputs)
The ESRNN model returns forecasts for all windows each step_size
. We need the last window only.
y_hat = t.cat([y_hat_[:, -1] for y_hat_ in y_hat]).cpu().numpy()
An M4 performance evaluator function is included within nixtlats
to facilitate the reproduction of the results.
M4Evaluation.evaluate('data', 'Yearly', y_hat)
In addition, the results of the original submissions can be evaluated as follows.
esrnn_url = 'https://github.com/Nixtla/m4-forecasts/raw/master/forecasts/submission-118.zip'
M4Evaluation.evaluate('data', 'Yearly', esrnn_url)
In this post we present nixtlats
, a time series forecasting library using deep learning. As can be seen for the case of Yearly
in M4
the results are similar to those obtained by the original implementation. These results are encouraging because the ensembles of different models used by the original implementation were not considered.
The full code is publicly available at github. To contribute you can fork this repository and make a PR with your improvements. You can also create issues if you have problems running the model.