GithubHelp home page GithubHelp logo

chebykan's Introduction

chebykan's People

Contributors

iiisak avatar k-h-ismail avatar synodicmonth avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

chebykan's Issues

Poor generalization

I tried using ChebyKAN to train signal waveforms, but it showed poor generalization. What may be the reason??
image
is train data
image
is test data.

ChebyKAN Having troubles on solving dynamical systems

Hi, ChebyKAN is indeed simple, elegant and powerful, I believe it can do more.

So I implemented it on solving some models with dynamical systems, economic models to be precise, where the dynamical systems or equations quite similar to PDEs.

The main problem I encountered is that ChebyKAN is more prone to be "stuck", preventing training going any further. Here are two illustrations with KAN structure

valuefunction_KAN  =  KAN(width=[2,5,5,1], grid=5, k=3, grid_eps=1.0, noise_scale_base=0.25)

and ChebyKAN structure

class ChebyKAN(nn.Module):
    def __init__(self):
        super(ChebyKAN, self).__init__()
        self.chebykan1 = ChebyKANLayer(2, 8, 8)
        self.chebykan2 = ChebyKANLayer(8, 16, 5)
        self.chebykan3 = ChebyKANLayer(16, 1, 5)

    def forward(self, x):
        x = self.chebykan1(x)
        x = self.chebykan2(x)
        x = self.chebykan3(x)
        return x

valuefunction_cheb  =  ChebyKAN()

Results on the first fig are trained by LBFGS, and by Adam with learning rate 1e-2 on the second fig.

aa3dee3b0229880fe6b2e1e9b2daa04
5556b5a73690251bc5af3c1a4524213

I have tested it multiple times, with different input, output dimensions and degree range from 4 to 12, the issue remains.

No Dropout found in any experiment

As the title says, I've not found any regularization method as a dropout in your current implementation and experiments. Is there a specific reason for not doing that? Perhaps the learned activation function is already very complex and shouldn't be able to overfit compared to traditional perceptrons?

Also the experiments done in the following repository neither use dropout: https://github.com/1ssb/torchkan

ChebyKAN layer is equivalent to custom activation + nn.Linear

Hi, very interesting idea, kudos!

I believe the proposed layer is equivalent to the following combination (I fix degree to be 4 for simplicity):

from ChebyKANLayer import ChebyKANLayer

class ChebyActivation(nn.Module):
    def __init__(self, degree):
        assert degree == 4
        super().__init__()

    def forward(self, x):
        x = torch.tanh(x)

        x = torch.cat(
            [
                torch.ones_like(x),
                x,
                2 * x**2 - 1,
                4 * x**3 - 3 * x,
                8 * x**4 - 8 * x**2 + 1,
            ],
            dim=1,
        )
        return x

input_dim = 128
output_dim = 256
for _ in range(100):
    variant_0 = ChebyKANLayer(input_dim, output_dim, 4)
    variant_1 = nn.Sequential(
        ChebyActivation(4),
        nn.Linear(input_dim * 5, output_dim, bias=False),
    )
    # ensure same weights
    variant_1[1].weight.data.copy_(variant_0.cheby_coeffs.permute(1, 2, 0).flatten(1))
    for _ in range(100):
        x = torch.randn(1234, input_dim)
        res1 = variant_0(x)
        res2 = variant_1(x)

        assert (
            res1 - res2
        ).abs().max() < 1e-6, "Found inconsistency between implementations!"

print("Two implementations are equivalent!")

This makes it a variant of LAN network (see app. B2 in KAN paper), which is nice, but it's a double-edged sword.

On one side, with this rewrite you can train it pretty efficiently (by checkpointing ChebyActivation function and using optimized cuda Linear kernel).

On the other side, modern networks like LLAMA3 already use Gated Linear Unit activations, which should give roughly equivalent representation (I'm not 100% sure on this point tho).

Do you think it's correct reasoning or maybe I'm missing smth?

Thanks in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.