GithubHelp home page GithubHelp logo

kanrl's Introduction

Kolmogorov-Arnold Q-Network (KAQN) - KAN applied to Reinforcement learning, initial experiments

This small project test the novel architecture Kolmogorov-Arnold Networks (KAN) in the reinforcement learning paradigm to the CartPole problem.

KANs are promising alternatives of Multi-Layer Perceptrons (MLPs). KANs have strong mathematical foundations just like MLPs: MLPs are based on the universal approximation theorem, while KANs are based on Kolmogorov-Arnold representation theorem. KANs and MLPs are dual: KANs have activation functions on edges, while MLPs have activation functions on nodes. This simple change makes KANs better (sometimes much better!) than MLPs in terms of both model accuracy and interpretability.

mlp_kan_compare

For more information about this novel architecture please visit:

Experimentation

The implementation of Kolmogorov-Arnold Q-Network (KAQN) offers a promising avenue in reinforcement learning. In this project, we replace the Multi-Layer Perceptron (MLP) component of Deep Q-Networks (DQN) with the Kolmogorov-Arnold Network. Furthermore, we employ the Double Deep Q-Network (DDQN) update rule to enhance stability and learning efficiency.

The following plot compare DDQN implementation with KAN (width=8) and the classical MLP (width=32) on the CartPole-v1 environment for 500 episodes on 32 seeds (with 50 warm-ups episodes).

Epsisode length evolution during training on CartPole-v1

The following plot displays the interpretable policy learned by KAQN during a successful training session.

Interpretable policy for CartPole

  • Observation: KAQN exhibits unstable learning and struggles to solve CartPole-v1 across multiple seeds with the current hyperparameters (refer to config.yaml).
  • Next Steps: Further investigation is warranted to select more suitable hyperparameters. It's possible that KAQN encounters challenges with the non-stationary nature of value function approximation. Consider exploring alternative configurations or adapting KAQN for policy learning.
  • Performance Comparison: It's noteworthy that KAQN operates notably slower than DQN, with over a 10x difference in speed, despite having fewer parameters. This applies to both inference and training phases.
  • Interpretable Policy: The learned policy with KANs is more interpretable than MLP, I'm currently working on extraction on interpretable policy...

KAN for RL interpretability

In a web application, we showcase a method to make a trained RL policy interpretable using KAN. The process involves transferring the knowledge from a pre-trained RL policy to a KAN. The process involve:

  • Train the KAN using observations from trajectories generated by a pre-trained RL policy, the KAN learns to map observations to corresponding actions.
  • Apply symbolic regression algorithms to the KAN's learned mapping.
  • Extract an interpretable policy expressed in symbolic form.

To launch the app (after cloning the repo), run :

cd interpretable
python app.py

There is also a live-demo on Hugging Face : https://huggingface.co/spaces/riiswa/RL-Interpretable-Policy-via-Kolmogorov-Arnold-Network

demo

Contributing

I welcome the community to enhance this project. There are plenty of opportunities to contribute, like hyperparameters search, benchmark with classic DQN, implementation of others algorithm (REINFORCE, A2C, etc...) and additional environment support. Feel free to submit pull requests with your contributions or open issues to discuss ideas and improvements. Together, we can explore the full potential of KAN and advance the field of reinforcement learning ❤️.

kanrl's People

Contributors

riiswa avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.