GithubHelp home page GithubHelp logo

wenjiedu / saits Goto Github PK

View Code? Open in Web Editor NEW
256.0 5.0 47.0 601 KB

The official PyTorch implementation of the paper "SAITS: Self-Attention-based Imputation for Time Series". A fast and state-of-the-art (SOTA) deep-learning neural network model for efficient time-series imputation (impute multivariate incomplete time series containing NaN missing data/values with machine learning). https://arxiv.org/abs/2202.08516

Home Page: https://doi.org/10.1016/j.eswa.2023.119619

License: MIT License

Python 96.96% Shell 3.04%
time-series imputation-model missing-values self-attention partially-observed-data partially-observed-time-series partially-observed interpolation time-series-imputation incomplete-data

saits's Introduction

SAITS Title

powered by PyTorch

Tip

🎉 [Updates in Feb 2024] Our survey paper Deep Learning for Multivariate Time Series Imputation: A Survey has been released on arXiv. The code is open source in the GitHub repo Awesome_Imputation. We comprehensively review the literature of the state-of-the-art deep-learning imputation methods for time series, provide a taxonomy for them, and discuss the challenges and future directions in this field.

🔥 [Updates in Apr 2024] We applied SAITS embedding and training strategies to Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer in PyPOTS to enable them applicable to the time-series imputation task.

‼️Kind reminder: This document can help you solve many common questions, please read it before you run the code.

The official code repository is for the paper SAITS: Self-Attention-based Imputation for Time Series (preprint on arXiv is here), which has been accepted by the journal Expert Systems with Applications (ESWA) [2022 IF 8.665, CiteScore 12.2, JCR-Q1, CAS-Q1, CCF-C]. You may never have heard of ESWA, while it was ranked 1st in Google Scholar under the top publications of Artificial Intelligence in 2016 (info source), and is still the top 1 AI journal according to Google Scholar metrics (here is the current ranking list FYI).

SAITS is the first work applying pure self-attention without any recursive design in the algorithm for general time series imputation. Basically you can take it as a validated framework for time series imputation. More generally, you can use it for sequence imputation. Besides, the code here is open source under the MIT license. Therefore, you're welcome to modify the SAITS code for your own research purpose and domain applications. Of course, it probably needs a bit of modification in the model structure or loss functions for specific scenarios or data input. And this is an incomplete list of scientific research referencing SAITS in their papers.

🤗 Please cite SAITS in your publications if it helps with your work. Please star🌟 this repo to help others notice SAITS if you think it is useful. It really means a lot to our open-source research. Thank you! BTW, you may also like PyPOTS for easily modeling your partially-observed time-series datasets.

Important

📣 Attention please:

SAITS now is available in PyPOTS, a Python toolbox for data mining on POTS (Partially-Observed Time Series). An example of training SAITS for imputing dataset PhysioNet-2012 is shown below. With PyPOTS, easy peasy! 😉

👉 Click here to see the example 👀
# pip install pypots>=0.4
import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar
from pypots.data import load_specific_dataset
from pypots.imputation import SAITS
from pypots.utils.metrics import calc_mae

# Data preprocessing. Tedious, but PyPOTS can help.
data = load_specific_dataset('physionet_2012')  # PyPOTS will automatically download and extract it.
X = data['X']
num_samples = len(X['RecordID'].unique())
X = X.drop(['RecordID', 'Time'], axis = 1)
X = StandardScaler().fit_transform(X.to_numpy())
X = X.reshape(num_samples, 48, -1)
X_ori = X  # keep X_ori for validation
X = mcar(X, 0.1)  # randomly hold out 10% observed values as ground truth
dataset = {"X": X}  # X for model input
print(X.shape)  # (11988, 48, 37), 11988 samples and each sample has 48 time steps, 37 features

# Model training. This is PyPOTS showtime.
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_ffn=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10)
# Here I use the whole dataset as the training set because ground truth is not visible to the model, you can also split it into train/val/test sets
saits.fit(dataset)
imputation = saits.impute(dataset)  # impute the originally-missing values and artificially-missing values
indicating_mask = np.isnan(X) ^ np.isnan(X_ori)  # indicating mask for imputation error calculation
mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)


☕️ Welcome to the universe of PyPOTS. Enjoy it and have fun!

❖ Motivation and Performance

⦿ Motivation: SAITS is developed primarily to help overcome the drawbacks (slow speed, memory constraints, and compounding error) of RNN-based imputation models and to obtain the state-of-the-art (SOTA) imputation accuracy on partially-observed time series.

⦿ Performance: SAITS outperforms BRITS by 12% ∼ 38% in MAE (mean absolute error) and achieves 2.0 ∼ 2.6 times faster training speed. Furthermore, SAITS outperforms Transformer (trained by our joint-optimization approach) by 2% ∼ 19% in MAE with a more efficient model structure (to obtain comparable performance, SAITS needs only 15% ∼ 30% parameters of Transformer). Compared to another SOTA self-attention imputation model NRTSI, SAITS achieves 7% ∼ 39% smaller mean squared error (above 20% in nine out of sixteen cases), meanwhile, needs much fewer parameters and less imputation time in practice. Please refer to our full paper for more details about SAITS' performance.

❖ Brief Graphical Illustration of Our Methodology

Here we only show the two main components of our method: the joint-optimization training approach and SAITS structure. For the detailed description and explanation, please read our full paper Paper_SAITS.pdf in this repo or on arXiv.

Training approach

Fig. 1: Training approach

SAITS architecture

Fig. 2: SAITS structure

❖ Citing SAITS

If you find SAITS is helpful to your work, please cite our paper as below, ⭐️star this repository, and recommend it to others who you think may need it. 🤗 Thank you!

@article{du2023saits,
title = {{SAITS: Self-Attention-based Imputation for Time Series}},
journal = {Expert Systems with Applications},
volume = {219},
pages = {119619},
year = {2023},
issn = {0957-4174},
doi = {10.1016/j.eswa.2023.119619},
url = {https://arxiv.org/abs/2202.08516},
author = {Wenjie Du and David Cote and Yan Liu},
}

or

Wenjie Du, David Cote, and Yan Liu. SAITS: Self-Attention-based Imputation for Time Series. Expert Systems with Applications, 219:119619, 2023.

❖ Repository Structure

The implementation of SAITS is in dir modeling. We give configurations of our models in dir configs, provide the dataset links and preprocessing scripts in dir dataset_generating_scripts. Dir NNI_tuning contains the hyper-parameter searching configurations.

❖ Development Environment

All dependencies of our development environment are listed in file conda_env_dependencies.yml. You can quickly create a usable python environment with an anaconda command conda env create -f conda_env_dependencies.yml.

❖ Datasets

For datasets downloading and generating, please check out the scripts in dir dataset_generating_scripts.

❖ Quick Run

Generate the dataset you need first. To do so, please check out the generating scripts in dir dataset_generating_scripts.

After data generation, train and test your model, for example,

# create a dir to save logs and results
mkdir NIPS_results

# train a model
nohup python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    > NIPS_results/PhysioNet2012_SAITS_best.out &

# during training, you can run the blow command to read the training log
less NIPS_results/PhysioNet2012_SAITS_best.out

# after training, pick the best model and modify the path of the model for testing in the config file, then run the below command to test the model
python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    --test_mode

❗️Note that paths of datasets and saving dirs may be different on personal computers, please check them in the configuration files.

❖ Acknowledgments

Thanks to Ciena, Mitacs, and NSERC (Natural Sciences and Engineering Research Council of Canada) for funding support. Thanks to Ciena for providing computing resources. Thanks to all our reviewers for helping improve the quality of this paper. And thank you all for your attention to this work.

✨Stars/forks/issues/PRs are all welcome!

👏 Click to View Stargazers and Forkers:

Stargazers repo roster for @WenjieDu/SAITS

Forkers repo roster for @WenjieDu/SAITS

❖ Last but Not Least

If you have any additional questions or have interests in collaboration, please take a look at my GitHub profile and feel free to contact me 😃.

saits's People

Contributors

wenjiedu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

saits's Issues

Custom dataset

Thank you for your wonderful work and I would like to know whether I can use this model or train a model from scratch to imputate my time series?

Question about MAE

Hi, Wenjie

def masked_mae_cal(inputs, target, mask):
    """ calculate Mean Absolute Error"""
    return torch.sum(torch.abs(inputs - target) * mask) / (torch.sum(mask) + 1e-9)

I have a little doubt about the calculation of MAE.
I found you normalizes the dataset with standard scaling, it means the target and input are standard normalized. So why not calculate MAE after inverse the scaling to them?

window truncate function

def window_truncate(feature_vectors, seq_len):
    """ Generate time series samples, truncating windows from time-series data with a given sequence length.
    Parameters
    ----------
    feature_vectors: time series data, len(shape)=2, [total_length, feature_num]
    seq_len: sequence length
    """
    start_indices = np.asarray(range(feature_vectors.shape[0] // seq_len)) * seq_len
    sample_collector = []
    for idx in start_indices:
        sample_collector.append(feature_vectors[idx: idx + seq_len])

    return np.asarray(sample_collector).astype('float32')

Wenjie,

I have some questions if you do not mind to clarify

  1. In the implementation, is the training data generated by diving into the time series based on the sequence length?
  2. What is the advantage of such training data configuration over using the sliding window approach, e.g., generates the training set with one-time step lag [t-n, t-n+1, ... t], [t-n+1, t-n+2, ... t+1], [t-n+2, t-n+3, ... t+2]. Is not the sliding window approach would generate more datasets for training?
  3. I am not quite familiar with transformer architecture. In a typical RNN based imputation method, there are the concepts of sequence length (i.e., length of historical or future data for input) and prediction horizon (i.e., how far in the future or in the past the model try to impute). For the SAITS, what would be the equivalent concepts or does such a concept of the prediction horizon exist?
  4. I understand from your paper that the sequence length is fixed between different models for comparison purposes. How does the sequence length affect the accuracy of the imputation? What would you recommend to determine the appropriate sequence length for the problem at hand?
  5. An unrelated question, Is your PyPOTS currently working with the Air Quality dataset?

Thanks in advance,
Haochen

Question about output of the first DMSA

Hello, I want to ask you about the saits.py part of the modeling in your code, I only used the first DMSA module, I also entered the X and Miss Mask in your way, but after going through the encoder layer, data becomes all Nan, what is the reason for this situation.
Looking forward to your reply

Using CSV files versus h5 data

Hello Wenjie,

Thank you for releasing the code, I had couple of questions. I am trying to run the code using Air Quality dataset in Google Colab. These are some of my doubts:

  1. !CUDA_VISIBLE_DEVICES=2 python run_models.py --config_path configs/AirQuality_SAITS_best.ini
    Running this gives me the following error message.
    OSError: Unable to open file (unable to open file: name = 'dataset_generating_scripts/RawData/AirQuality/PRSA_Data_20130301-20170228/datasets.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

All the dataset is in .CSV format.
1a) Is there an option to use the default .csv data?
1b) How do I convert .csv to h5 format?

  1. Where should we change the file path of the dataset for training purpose
    As in file configs/AirQuality_SAITS_best.ini ?

Please do let me know, thanks.

Niharika

Configs of ETTm1

Hello,
Could you share the configuration settings on the ETTm1 dataset?
Thanks!

Question about loss

Thank you for your work and please understand that it is not a direct question about the code. Is there any reason the loss function does not include the classification error term? Some models that perform reconstruction and imputation are include classification error in the loss function. Have you ever trained models in this way? If so, please let me know what the results were like.

Final error calculation

Hello Wenjie,

I have a doubt regarding the calculation of the final error metrics on the test data.

Suppose my sample data looks like this:

date          A       B
timestamp1    3       5
timestamp2    4       7
timestamp3    6       8
timestamp4    8       10

After introducing 50% missingness :

date          A       B
timestamp1    Nan     5
timestamp2    Nan     7
timestamp3    6       8
timestamp4    Nan    Nan

After imputation :

date          A       B
timestamp1    2       5
timestamp2    4       7
timestamp3    6       8
timestamp4    6       5
  1. The MAE, RMSE, and MRE are calculated only on the imputed values or on the whole dataset?
  2. Can you explain the MAE, RMSE, and MRE formulas/ equations used.

Thank you, Please let me know

Regards
Niharika Joshi

pd.concat

FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.

Test data

Hi,

After certain modification and inclusion of code snippets I was able train, validate and get the mae for test data.
I want to obtain the de-normalized value after the imputation happens in test data, both predicted and actual. Can you help?

Question about temporal dependencies and feature correlations captured by DMSA

你好,关于文章当中的自注意力我有问题想请教您。维度为N×N的自注意力矩阵Q·Kt,表示的是长度为N的一种维度之间的注意力关系,而您文章中提到的“Such a mechanism makes DMSA able to capture the temporal dependencies and feature correlations between time steps in the high dimensional space with only one attention operation”,DMSA的一个注意力矩阵能一次性同时捕获到两种维度之间的注意力,想问一次注意力操作捕获到两种类型的注意力是怎么做到的。

Calculation of the loss function

Thank you for excellent work!

The imputation loss of MIT is not covered the complement feature vector in the code.

Secondly, the paper also talks about taking the raw data X without artificially-masking as input to the MIT formula, and I found in the corresponding code that you used the manual masked X^.

Is there something I don't understand. I look forward to your resolution of my doubts!

Loss_MIT wrong?

I saw that the loss of MIT computaion in core.py was
'MIT_loss = self.customized_loss_func(
X_tilde_3, inputs["X_ori"], inputs["indicating_mask"]
)'
,which computed the MAE between M~3 and X_ori and differed to the paper.

Some questions about multivariate time series Imputation.

Thank you for your work,I recently read your paper SAITS: Self-Attention-based Imputation for Time Series. I am also doing the work related to multivariate time series Imputation. I have some questions, and I hope to communicate with you.
1.I recently used your method to run the data set I used. My data processing approach is first divided into training set and test set, and then build time sequence, first use the train set to train, and then use the test set test (but I know data Imputation algorithm is unsupervised algorithm and did not use the true information of the missing data, there are some people who divided the test set training set, while there are also some people didn't,There are some differences in the results of your algorithm between these two data set partitioning methods,May I ask how do you view the partitioning of data sets?)
2. May I ask whether your algorithm will have overfitting, because the loss of back propagation is the MAE of unmissing items, not the MAE of the whole data set. I feel that with the increase of training times, it will gradually tend to be overfitting
3. Now the stopping condition of the algorithm is to reach the specified epoch. The epoch of different data sets need to be detected,If we divide the test set and the training set, can we quit the training by judging that the missing item data MAE of the training set reaches the minimum.
Thank you very much

Question about hyperparameter optimization

Hello, I would like to ask you about the hyperparameter optimization for the model. In your file NNI_tuning/SAITS/SAITS_searching_config.yml, you described the settings for the hyperparameters and the training command, which also includes a JSON file. However, when I tried to run the command for hyperparameter optimization on SAITS, I encountered an error: "No option 'mit' in section: 'training'". I supplemented the missing parameters and ran it again, but I only obtained the parameters set in the SAITS_basic_config.ini file. Could you please advise me on how to iterate through the parameters in the JSON file to obtain the optimal parameters?

How to comprehend the NNI finetunning?

First, the file SAITS_basic_config.ini under NNI_tuning folder miss 2 args: "MIT" & "ORT", which influence the script python ../../run_models.py --config_path SAITS_basic_config.ini --param_searching_mode running. You may add this two args in the .ini file and also check for other .ini files if you have time.
Second, i am wondering how to check the help of nni for tuning the parameters. To be more specific, which parameters did nni change? When and how much did the parameters change? Only parameters listed in SAITS_searching_space.json file will be changed?

Thanks for your attention.

训练步长可以动态调整吗?

你好,根据给出的example.py中的saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_inner=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10),可以看到n_steps被设置为48,因为example中给定的数据集中每个RecordID都有48个样本。
但我的数据集中每个RecordID对应的样本数是不固定的,比如1个,7个,甚至216个,这样的话我把n_steps参数设置为最大的RecordID对应数目,比如216,这会是可行的吗?或者有没有其它方案。十分感谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.