GithubHelp home page GithubHelp logo

saharshlaud / nbeats_time_series Goto Github PK

View Code? Open in Web Editor NEW
0.0 2.0 0.0 163 KB

This repo provides an implementation of the N-BEATS algorithm introduced in https://arxiv.org/abs/1905.10437 and enables reproducing the experimental results presented in the paper using a simple Jupyter Notebook.

License: GNU General Public License v3.0

Python 12.19% Jupyter Notebook 87.81%
deep-learning machine-learning pytorch time-series timeseries

nbeats_time_series's Introduction

NBeats_Time_Series

N-BEATS: Neural basis expansion analysis for interpretable time series forecasting

N-BEATS is a type of neural network that was first described in a paper published in the 2019 ICLR conference by Oreshkin et al. The authors reported that N-BEATS outperformed the M4 forecast competition winner by 3%. The M4 winner was a hybrid between a recurrent neural network and Holt-Winters exponential smoothing — whereas N-BEATS implements a “pure” deep neural architecture.

nbeats pipeline

In this repo, we will use nbeats_forecast which is an end to end library for univariate time series forecasting using N-BEATS. This library uses nbeats-pytorch as base and simplifies the task of forecasting using N-BEATS by providing a interface similar to scikit-learn and keras.

Requires: Python >=3.6

Installation

$ pip install nbeats_forecast

Import

from nbeats_forecast import NBeats

Input

numpy array of size nx1

Output

Forecasted values as numpy array of size mx1

Mandatory Parameters for the model:

  • data
  • period_to_forecast

A basic model with only mandatory parameters can be used to get forecasted values as shown below:

import pandas as pd
from nbeats_forecast import NBeats

data = pd.read_csv('data.csv')   
data = data.values        #univariate time series data of shape nx1 (numpy array)

model = NBeats(data=data, period_to_forecast=12)
model.fit()
forecast = model.predict()

The other optional parameters for the object of the NBeats model (as described in the paper) can be tweaked for better performance. If these parameters are not passed, default values as mentioned in the table below are considered.

Parameter Type Default Value Description
backcast_length integer 3* period_to_forecast Explained in the paper
path string ' ' path to save intermediate training checkpoint
checkpoint_file_name string 'nbeats-training-checkpoint.th' name for checkpoint file ending in format .th
mode string 'cpu' Any of the torch.device modes
batch_size integer len(data)/15 size of batch
thetas_dims list of integers [7, 8] Explained in the paper
nb_blocks_per_stack integer 3 Explained in the paper
share_weights_in_stack boolean False Explained in the paper
train_percent float(below 1) 0.8 Percentage of data to be used for training
save_checkpoint boolean False save intermediate checkpoint files
hidden_layer_units integer 128 hissen layer units
stack list of integers [1,1] adding stacks in the model as per the paper passed in list as integer. Mapping is as follows -- 1: GENERIC_BLOCK, 2: TREND_BLOCK , 3: SEASONALITY_BLOCK

Repository Structure

Model

PyTorch implementation of N-BEATS can be found in models/nbeats.py

Implementation

The notebooks directory contains a notebook with univariate time series analysis using the one of the time series in the provided dataset. The notebook consists of steps such as loading time series data, time series visualization, implementing the NBeats model and training it with 75% of the data. It also includes test on the remaining 25% data and evaluation metrics on the prediction along with a forecast plots as well.

Iteration

The iteration folder consists of a simple code that iterates over all the 50 datasets in the directory and produces NBeats model for each indivdual time series along with its prediction plot.

How to use the repo

In order to use this basic repo for implementing NBeats model you can follow these basic steps:

  1. Fork or download this repository
  2. Install nbeats_forecast library on your system.
  3. Open the notebooks/nbeats.ipynb file and modify the file path as per your dataset.
  4. Run the notebook to get the NBeats model results for your dataset.
  5. Open the iteration/nbeats_iter.py file and modify the dataset path according to your directory.
  6. Run the iteration file to create NBeats models for all the time series in your directory.

What works and what doesn't

The given Machine Learning Challenge was truly inspiring and challenged my concepts and knowledge in the field of Machine Learning. Out of the given tasks, I was able to implement the NBeats model using Pytorch implementation and apply it for the enitre 50 time series data files.

The second task regarding REvin normalization was a little beyond my current knowledge base and although I was able to understand how Revin works by normalizing and then denormalizing the time series data to improve performance, I was not able to implement the same in my NBeats pipeline.

All the tasks were performed to the best of my abilities and potential and I hope to higly scale my knowledge and capabilities in the field of Machine Learning as a part of Greendeck ML team.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.