GithubHelp home page GithubHelp logo

snnlibpy's Introduction

WrappedBindsNET

update

これはBindsNETと呼ばれるPyTorchベースのSpiking Neural Networksフレームワークをさらに使いやすくしよう, というコンセプトのもと作成中.
この小さなライブラリは,大体snnlib.pyに詰められているので,各種定数などはかなり弄りやすいかと思います.
もちろん,main.pyから直接クラス変数は変更できます.
完全に個人利用ですが,使いたい人がいればご自由にどうぞ
(結構頻繁に小さな(大したことない)アップデートをしています.)

作者の修士課程修了に伴い,大きなアップデートは今後おそらくありませんが,これを拡張して利用することは歓迎いたします.

I am making a tiny and user friendly library of Spiking Neural Networks with BindsNET.
All functions are packed to only snnlib.py, so you can use easily.
This library is used by private myself, but if you want to use it, feel free to use.

未完成につきバグがまだある可能性があります.(Maybe, there are bugs because this is incompletely.)

実行保証環境 (Environment)

以下の環境において問題なく実行可能なことを確認しています.

  • OS.........MacOS 10.15 or Ubuntu 16.04 LTS
  • Python.....3.6.* or 3.7.* (, or later)
  • BindsNET...0.2.7 (not worked on < 0.2.7)
  • PyTorch....1.10 (GPU: torch... 1.3.0+cu92, torchvision... 0.4.1+cu92)

Example

  • Sample code
from wbn import Spiking


if __name__ == '__main__':

    # Build SNNs and decide the number of input neurons and the simulation time.
    snn = Spiking(input_l=784, obs_time=300, dt=0.5)
    snn.IMAGE_DIR += 'diehl/'

    # Add a layer and give the num of neurons and the neuron model.
    snn.add_layer(n=100,
                  node=snn.DIEHL_COOK,          # or snn.DIEHL_COOK
                  w=snn.W_SIMPLE_RAND,   # initialize weights
                  rule=snn.SIMPLE_STDP,  # learning rule
                  nu=(1e-4, 1e-2),       # learning rate
                  )

    # Add an inhibitory layer
    snn.add_inhibit_layer(inh_w=-128)

    # Load dataset
    snn.load_MNIST()

    # Check your network architecture
    snn.print_model()

    # If you use a small network, your network computation by GPU may be more slowly than CPU.
    # So you can change directly whether using GPU or not as below.
    # snn.gpu = False

    # Gpu is available?? If available, make it use.
    snn.to_gpu()

    # Plot weight maps before training
    snn.plot(plt_type='wmps', prefix='0', f_shape=(10, 10))

    # Make my network run
    for i in range(3):
        snn.run()

        snn.plot(plt_type='wmps', prefix='{}'.format(i+1), f_shape=(10, 10))  # plot maps

    # Plot test accuracy transition
    snn.plot(plt_type='history', prefix='result')

    # Plot weight maps after training
    snn.plot(plt_type='wmps', prefix='result', f_shape=(10, 10))

    # Plot output spike trains after training
    snn.plot(plt_type='sp', range=10)

    print(snn.history)

or very simply,

from wbn import DiehlCook_unsupervised_model  # packed sample simulation code

DiehlCook_unsupervised_model()

is ok (actually this function is my backup data, so it's good for you to use this when you check whether it works properly).

  • Generated image samples
    • A weight map before training pre_training

    • A weight map after STDP training with 1,0000 MNIST data pre_training

BindsNET references

【docs】
Welcome to BindsNET’s documentation! — bindsnet 0.2.5 documentation

【Github】
Hananel-Hazan/bindsnet: Simulation of spiking neural networks (SNNs) using PyTorch.

【Paper】
BindsNET: A Machine Learning-Oriented Spiking Neural Networks Library in Python

snnlibpy's People

Contributors

hiroshiaraki avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

snnlibpy's Issues

Plotting Spikes,weights,Performance

Hi,
I just started to follow your repository.You have done very well.
So, I wanted to plot my input spikes,weight and model performance but after training my plot didn't show.
I am confused how to solve this issue Can you tell me, how can I fix this issue?
My code is like this:
for epoch in range( n_epochs ):
if epoch % progress_interval == 0:
print( "Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start) )
start = t()

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True)

for step, batch in enumerate( train_dataloader ):
    # Get next input sample.

    inputs = {"X": batch["encoded_image"].view( time, 1, 1, 28, 28 )}
    if gpu:
        inputs = {k: v.cuda() for k, v in inputs.items()}
    label = batch["label"]

    # Run the network on the input.
    network.run( inputs=inputs, time=time, input_time_dim=1 )

    # Optionally plot various simulation information.
    if plot:
        image = batch["image"].view( 28, 28 )

        inpt = inputs["X"].view( time, 784 ).sum( 0 ).view( 28, 28 )
        weights1 = conv_conn.w
        _spikes = {
            "X": spikes["X"].get( "s" ).view( time, -1 ),
            "Y": spikes["Y"].get( "s" ).view( time, -1 ),
        }
        _voltages = {"Y": voltages["Y"].get( "v" ).view( time, -1 )}

        inpt_axes, inpt_ims = plot_input(
            image, inpt, label=label, axes=inpt_axes, ims=inpt_ims
        )
        spike_ims, spike_axes = plot_spikes( _spikes, ims=spike_ims, axes=spike_axes )
        weights1_im = plot_conv2d_weights( weights1, im=weights1_im )
        voltage_ims, voltage_axes = plot_voltages(
            _voltages, ims=voltage_ims, axes=voltage_axes
        )

        #plt.ioff()
        #plt.show()

    plt.ioff()
    # plt.pause( 1 )
    plt.show()

    network.reset_state_variables()  # Reset state variables.

print( "Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start) )
print( "Training complete.\n" )

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.