GithubHelp home page GithubHelp logo

remigenet / tkan Goto Github PK

View Code? Open in Web Editor NEW
126.0 1.0 19.0 57.18 MB

TKAN: Temporal Kolmogorov-Arnold Networks

License: Other

Python 100.00%
kolmogorov-arnold-networks temporal temporal-networks tensorflow tensorflow2 timeseries-forecasting tkan jax keras keras3

tkan's Introduction

TKAN: Temporal Kolmogorov-Arnold Networks

TKAN (Temporal Kolmogorov-Arnold Networks) is a neural network architecture designed to enhance multi-horizon time series forecasting. This Keras implementation integrates TKAN as a layer within sequential models, facilitating the use of advanced neural network techniques in practical applications. The implementation is tested to be compatatible with Tensorflow, Jax and Torch. From testing jax is the best backend in terms of performance with it, while torch is very slow (probably not well optimized for it). It is the original implementation of the paper The KAN part implementation has been inspired from efficient_kan, and is available here and works similarly to it, thus not exactly like the original implementation.

In case of performance consideration, the best setup tested used jax docker image followed by installing jax using pip install "jax[cuda12]", this is what is used in the example section where you can compare the TKAN vs LSTM vs GRU time and performance. I also discourage using as is the example for torch, it seems that currently when running test using torch backend with keras is much slower than torch directly, even for GRU or LSTM.

TKAN representation

Installation

Install TKAN directly from PyPI:

pip install tkan

Dependencies are managed using pyproject.toml.

Usage

TKAN can be used within TensorFlow models to handle complex sequential patterns in data. It's implementation reproduce architecture of RNN in tensorflow with Cell class and Layer that inherits from RNN in order to provide a perfect integrations with tensorflow. Here is an example that demonstrates how to use TKAN in a sequential model:

import keras
from tkan import TKAN


# Example model using TKAN with B-spline activations
model = keras.Sequential([
      keras.layers.InputLayer(input_shape=X_train_seq.shape[1:]),
      TKAN(100, sub_kan_configs=[{'spline_order': 3, 'grid_size': 10}, {'spline_order': 1, 'grid_size': 5}, {'spline_order': 4, 'grid_size': 6}, ], return_sequences=True, use_bias=True), #Define the params of the KANLinear as dict as here
      TKAN(100, sub_kan_configs=[1, 2, 3, 3, 4], return_sequences=True, use_bias=True), #Use float or int to specify only the exponent of the spline
      TKAN(100, sub_kan_configs=['relu', 'relu', 'relu', 'relu', 'relu'], return_sequences=True, use_bias=True), #Or use string to specify the standard tensorflow activation using Dense in sublayers instead of KANLinear
      TKAN(100, sub_kan_configs=[None for _ in range(3)], return_sequences=False, use_bias=True), # Or put None for default activation
      keras.layers.Dense(y_train_seq.shape[1]),
])

You can find a more complete example with comparison with other models in the example folder.

Please cite our work if you use this repo:

@article{genet2024tkan,
  title={Tkan: Temporal kolmogorov-arnold networks},
  author={Genet, Remi and Inzirillo, Hugo},
  journal={arXiv preprint arXiv:2405.07344},
  year={2024}
}

Shield: CC BY-NC-SA 4.0

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

CC BY-NC-SA 4.0

tkan's People

Contributors

escherba avatar remigenet avatar remigenetaplo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

tkan's Issues

MinMaxScaler

Just fyi - your example calls fit_transform on the train and test data with separate objects. It should call fit only on train and just transform test using the scaler object used to fit/transform train.

Slow training

I am able to create a model but the model training is incredibly slow compared to a LSTM model with a similar number of parameters (800ms/step vs 7ms per step).
I am wondering if others have experienced similar issues?
I am using a RTX3090 GPU and Tensorflow 2.9 with Python 3.10.

Thanks.

GPU

Does This not support GPU?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.