GithubHelp home page GithubHelp logo

fzenke / spytorch Goto Github PK

View Code? Open in Web Editor NEW
280.0 280.0 76.0 6.58 MB

Tutorial for surrogate gradient learning in spiking neural networks

Jupyter Notebook 99.26% Makefile 0.02% TeX 0.34% Gnuplot 0.04% Python 0.34%

spytorch's People

Contributors

fzenke avatar mshalvagal avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

spytorch's Issues

Issue in running Tutorial-4

When I am running the following piece of code in Tutorial-4:

loss_hist = train(x_train, y_train, lr=2e-4, nb_epochs=nb_epochs)

I am getting the following error:
pic3

Can you please suggest me how to resolve this issue?

maybe simplification

I don't understand why the 'rst' variable exists. It seems to always be == 'out'. Changing to rst = out yields same results...

def spike_fn(x):
    out = torch.zeros_like(x)
    out[x > 0] = 1.0
    return out
...
# Here we loop over time
for t in range(nb_steps):
    mthr = mem-1.0
    out = spike_fn(mthr) 
    rst = torch.zeros_like(mem)
    c = (mthr > 0)
    rst[c] = torch.ones_like(mem)[c] 

Explicit labels mapping missing

I couldn't find any explicit mapper between label numbers and their content. While printing the content of the extras, the first 10 positions (0-9) are occupied by English words, then the German words follow. I therefore deduce that the labels 0-9 correspond to the English partition of the dataset. Could you please confirm?

Thank you in advance.

propagation delay

Hi zenke,
I have a question about the snn model. If I feed a spike image to a snn with L layers at time step n, the output of the last layer will be affected by the input at time step n + L - 1. In deep networks, the delay should be considered, because it will increase the whole time steps.
Screen Shot 2021-12-15 at 4 50 45 PM

v1 update

Hi, thanks for the tutorial!
I noticed that in the tutorial 4, the recurrent weights (v1) are not updating.
Do you have any suggestion?

Compute recurrent contribution from spikes

Hey Friedemann,

thank you for the very comprehensive tutorial! I have a question on the way the recurrence is computed in tutorial 4. If I understand the equation for the dynamics of the current correctly, the recurrence should be computed with the spiking neuron state:

mthr = mem-1.0
out = spike_fn(mthr)
h1 = h1_from_input[:,t] + torch.einsum("ab,bc->ac", (out, v1))

Instead in tutorial 4, a separate hidden state is kept, that ignores the spike function:

h1 = h1_from_input[:,t] + torch.einsum("ab,bc->ac", (h1, v1))

Is this done deliberately? Judging from simulating a few epochs, the two versions seem to perform similarly.

Thank you,

Simon

resetting with "out" instead of "rst"?

  • This is a comment, not an issue *

Hi Friedemann,
First of thanks a lot for these great tutorials, I've enjoyed a lot playing with them, and I've learned a lot :-)
One question: in the run_snn function, why do you bother constructing the "rst" tensor? Why don't you subtract the "out" tensor, which also contains the output spikes? I've tried, and it seems to work.
Just curious.
Best,

Tim

Software/Machine description available?

Hey Friedemann,

thanks for making the examples available, they look very helpful.
However, to make them fully reproducible I think that some additional information regarding the "technical dependencies" is needed.

In particular, the list of used software packages (incl. version and build variant information) plus some specification about the machine hardware (CPU arch, GPUs).

Preferably, the former could be expressed as a recipe for constructing a container (Dockerfile, or for better HPC-compatibility, a
Singularity recipe), maybe even using an explicitly versioning package manager like spack.

Cheers,
Eric

Spike times shifted

I have the impression that the spike recordings are shifted one time step in all tutorials. Could you maybe check if this is indeed the case?

From my understanding, time step 0 is recorded twice for the spikes, once during initialisation

  mem = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
  spk_rec = [mem]

and once within the simulation of time step 0:

  for t in range(nb_steps):
      mthr = mem-1.0
      out = spike_fn(mthr)
      ...
      spk_rec.append(out)

As a result the indeces appear shifted when comparing

print(torch.nonzero((mem_rec-1.0) > 0.0))
print(torch.nonzero(spk_rec))

Thanks,
Simon

Dataset never decompressed

Hello,

I belive I ran into a possible issue here.
Due to line 37 the evaluation in line 38 will always be false if one hasnt already got the uncompressed dataset.

def get_and_gunzip(origin, filename, md5hash=None, cache_dir=None, cache_subdir=None):
gz_file_path = get_file(filename, origin, md5_hash=md5hash, cache_dir=cache_dir, cache_subdir=cache_subdir)
hdf5_file_path = gz_file_path
if not os.path.isfile(hdf5_file_path) or os.path.getctime(gz_file_path) > os.path.getctime(hdf5_file_path):
print("Decompressing %s"%gz_file_path)
with gzip.open(gz_file_path, 'r') as f_in, open(hdf5_file_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
return hdf5_file_path

If I change line 37 to:
hdf5_file_path = gz_file_path[:-3]
This works for me.

Best,
Aaron

Problem in SpyTorchTutorial2

Hello,

It was a very nice and interesting tutorial, thank you for preparing it...

tutorial1 haven't any problem, but in tutorial 2, some dtype problems occurred... after their fixation, training process was very slow on GTX 980 (I've run on this config some very deep model)... could you please explain your config, and also training time and response time?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.