CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting
Official PyTorch code repository for the CoST paper. Required dependencies can be installed by:
pip install -r requirements.txt
The datasets can be obtained and put into datasets/
folder in the following way:
- 3 ETT datasets should be placed at
datasets/ETTh1.csv
,datasets/ETTh2.csv
anddatasets/ETTm1.csv
. - Electricity dataset placed at
datasets/LD2011_2014.txt
and runelectricity.py
. - Weather dataset (link from Informer repository) placed at
datasets/WTH.csv
- M5 dataset place
calendar.csv
,sales_train_validation.csv
,sales_train_evaluation.csv
,sales_test_validation.csv
andsales_test_evaluation.csv
atdatasets/
and run m5.py.
To train and evaluate CoST on a dataset, run the following command:
python train.py <dataset_name> <run_name> --archive <archive> --batch-size <batch_size> --repr-dims <repr_dims> --gpu <gpu> --eval
The detailed descriptions about the arguments are as following:
Parameter name | Description of parameter |
---|---|
dataset_name | The dataset name |
run_name | The folder name used to save model, output and evaluation metrics. This can be set to any word |
archive | The archive name that the dataset belongs to. This can be set to forecast_csv or forecast_csv_univar |
batch_size | The batch size (defaults to 8) |
repr_dims | The representation dimensions (defaults to 320) |
gpu | The gpu no. used for training and inference (defaults to 0) |
eval | Whether to perform evaluation after training |
kernels | Kernel sizes for mixture of AR experts module |
alpha | Weight for loss function |
(For descriptions of more arguments, run python train.py -h
.)
After training and evaluation, the trained encoder, output and evaluation metrics can be found in training/<DatasetName>/<RunName>_<Date>_<Time>/
.
Scripts: The scripts for reproduction are provided in scripts/
folder.
Q: ValueError: Found array with dim 4. StandardScaler expected <= 2.
A: Please install the appropriate package requirements as found in requirements.txt
, in particular, scikit_learn==0.24.1
.
The implementation of CoST relies on resources from the following codebases and repositories, we thank the original authors for open-sourcing their work.
Please consider citing if you find this code useful to your research.
@inproceedings{ woo2022cost, title={Co{ST}: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting}, author={Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi}, booktitle={International Conference on Learning Representations}, year={2022}, url={https://openreview.net/forum?id=PilZY3omXV2} }