Introduction

In this post we introduce nixtlats: a library of state-of-the-art deep learning models for time series forecasting written in pytorch, focused on usability and replicability. In this first post we introduce the Exponential Smoothing with Recurrent Neural Networks (ESRNN) model, winner of the M4 time series competition and show that our implementation achieves similar performance to the original submission. Throughout the post we describe the pipeline for training the models using nixtlats and also a brief description of the datasets included in the library. This work is inspired by an earlier implementation of ESRNN in pytorch.

Install nixtlats

The nixtlats library is available in pip so to install it just use:

!pip install nixtlats

GPU usage

To use GPU in colab: Edit > Notebook settings or Runtime>Change runtime type and select GPU as Hardware accelerator.

Import libraries

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch as t
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import seed_everything

from nixtlats.data.datasets.m4 import M4, M4Info, M4Evaluation
from nixtlats.data.tsdataset import TimeSeriesDataset
from nixtlats.data.tsloader import TimeSeriesLoader
from nixtlats.models.esrnn.esrnn import ESRNN
seed_everything(117982, workers=True)
Global seed set to 117982
117982

Import M4 Yearly data

The nixtlats library provides functions to download and manipulate M4 data. The M4.load method returns train and test sets. The library also provides a wide variety of datasets, see the documentation.

group = M4Info['Yearly']
Y_df, _, S_df = M4.load(directory='data', group=group.name)
Y_df_test = Y_df.groupby('unique_id').tail(group.horizon).copy()
Y_df_train = Y_df.drop(Y_df_test.index)

To avoid leakage, set test values as zero.

Y_df_test.loc[:, 'y'] = 0

nixtlats requires a dummy test set to make forecasts.

Y_df_full = pd.concat([Y_df_train, Y_df_test]).sort_values(['unique_id', 'ds'], ignore_index=True)

Define Time Series Datasets

The pipeline for model training follows the logic of deep learning practices. In the first instance a Dataset must be instantiated. The TimeSeriesDataset class allows to return the complete series in each iteration, this is useful for recurrent models such as ESRNN. To be instantiated, the class receives the target series Y_df as a pandas dataframe with columns unique_id, ds and y. Additionally, temporary exogenous variables X_df and static variables S_df can be included. In this case we only use static variables as in the original model.

train_ts_dataset = TimeSeriesDataset(Y_df=Y_df_train, S_df=S_df,
                                     input_size=4,
                                     output_size=group.horizon)
test_ts_dataset = TimeSeriesDataset(Y_df=Y_df_full, S_df=S_df,
                                    input_size=4,
                                    output_size=group.horizon)

Define Time Series Loaders

Once the training and test TimeSeriesDatasets are defined, the loaders must be instantiated through TimeSeriesLoader which is a wrapper of the pytorch's DataLoader, so it can receive all the original arguments.

train_ts_loader = TimeSeriesLoader(dataset=train_ts_dataset,
                                   batch_size=16,
                                   shuffle=False)
test_ts_loader = TimeSeriesLoader(dataset=test_ts_dataset,
                                  batch_size=1024,
                                  eq_batch_size=False,
                                  shuffle=False)

Define ESRNN model

The instantiation of the model is carried out as follows. For an overview of all parameters see the documentation.

model = ESRNN(n_series=group.n_ts,
              n_x=0, n_s=1,
              sample_freq=1,
              input_size=4,
              output_size=group.horizon,
              learning_rate=0.0025,
              lr_scheduler_step_size=6,
              lr_decay=0.08,
              per_series_lr_multip=0.8,
              gradient_clipping_threshold=20,
              rnn_weight_decay=0,
              level_variability_penalty=50,
              testing_percentile=50,
              training_percentile=50,
              cell_type='GRU',
              state_hsize=30,
              dilations=[[1, 2], [2, 6]],
              add_nl_layer=True,
              loss='SMYL',
              val_loss='SMAPE',
              seasonality=[])

Train model

The nixtlats library is designed to operate with pytorch-lightning so that all its functionalities can be used. Add gpus = 1 for GPU training using colab.

trainer = pl.Trainer(max_epochs=15,
                     progress_bar_refresh_rate=10, 
                     deterministic=True)
trainer.fit(model, train_ts_loader)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/home/federicogarza2/anaconda3/envs/nixtla/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: you defined a validation_step but have no val_dataloader. Skipping val loop
  warnings.warn(*args, **kwargs)

  | Name  | Type   | Params
---------------------------------
0 | model | _ESRNN | 44.2 K
---------------------------------
44.2 K    Trainable params
0         Non-trainable params
44.2 K    Total params
0.177     Total estimated model params size (MB)
Global seed set to 117982
/home/federicogarza2/anaconda3/envs/nixtla/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)

Make forecasts

The ESRNN model returns for each test batch three elements: y, the test values y_hat and the mask mask.

outputs = trainer.predict(model, test_ts_loader)
/home/federicogarza2/anaconda3/envs/nixtla/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: The dataloader, predict dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
_, y_hat, mask = zip(*outputs)

The ESRNN model returns forecasts for all windows each step_size. We need the last window only.

y_hat = t.cat([y_hat_[:, -1] for y_hat_ in y_hat]).cpu().numpy()

Evaluate results

An M4 performance evaluator function is included within nixtlats to facilitate the reproduction of the results.

M4Evaluation.evaluate('data', 'Yearly', y_hat)
SMAPE MASE OWA
Yearly 13.348044 2.974721 0.782632

Original results

In addition, the results of the original submissions can be evaluated as follows.

esrnn_url = 'https://github.com/Nixtla/m4-forecasts/raw/master/forecasts/submission-118.zip'
M4Evaluation.evaluate('data', 'Yearly', esrnn_url)
SMAPE MASE OWA
Yearly 13.175716 2.979908 0.778012

Conclusion

In this post we present nixtlats, a time series forecasting library using deep learning. As can be seen for the case of Yearly in M4 the results are similar to those obtained by the original implementation. These results are encouraging because the ensembles of different models used by the original implementation were not considered.

How to contribute

The full code is publicly available at github. To contribute you can fork this repository and make a PR with your improvements. You can also create issues if you have problems running the model.