GithubHelp home page GithubHelp logo

Load keras with custom metrics about snn_toolbox HOT 6 CLOSED

RKCZ avatar RKCZ commented on August 21, 2024
Load keras with custom metrics

from snn_toolbox.

Comments (6)

rbodo avatar rbodo commented on August 21, 2024

Currently, the toolbox is set up for classification, and compiles each model with two metrics: top-1 accuracy, and top-k accuracy, where k can be specified by the user and equals 1 by default.

It would be a welcome extension to make this behavior more flexible (allowing custom metrics) and support for instance object detection.

For now, a custom metric is possible with the following hacks:

  1. In file snntoolbox/parsing/utils.py, import your custom metric:

    from x.y import my_metric

  2. In the custom_objects dict within the function get_custom_activations_dict, add your metric to the dict:

    'my_metric': my_metric

  3. In build_parsed_model, replace the top-k metric by yours:

    self.parsed_model.compile(
        'sgd', 'categorical_crossentropy',
        ['accuracy', keras.metrics.top_k_categorical_accuracy])

becomes

    self.parsed_model.compile(
        'sgd', 'categorical_crossentropy', 
        self.input_model.metrics)

(Here I'm assuming your model has been trained with metrics=['accuracy', 'my_metric'].)

  1. Finally, in snntoolbox/simulation/utils.py, replace
    top5score_moving += sum(in_top_k(output_b_l_t[:, :, -1], truth_b, self.top_k))
    top5acc_moving = top5score_moving / num_samples_seen

by

top5acc_moving = keras.backend.get_value(my_metric(
    keras.backend.constant(truth_d), 
    keras.backend.constant(guesses_d)))

(Here I'm assuming you implemented your metric as keras/tf function. You can avoid the tf.constant conversion by using a python-version of your metric here.)

Now, when you run the toolbox, you will be seeing output like this:

Evaluating parsed model on 100 samples...
Top-1 accuracy: 100.00%  (<== accuracy)
Top-5 accuracy: 100.00%  (<== my_metric)

The first value is still the top-1 accuracy as before, but the second line now reports your metric. (Change the print function to whatever label you like.)

When simulating the SNN you will get:

Current accuracy of batch:
0.00%_10.00%_10.00%_0.00%_0.00%_10.00%_20.00%_40.00%_60.00%_60.00%_70.00%_90.00%_90.00%_100.00% (<== accuracy)
Moving accuracy of SNN (top-1, top-1): 100.00% (<== accuracy), 100.00% (<== my_metric).
Moving accuracy of ANN (top-1, top-1): 100.00% (<== accuracy), 100.00% (<== my_metric).

This recipe has been used for instance to support the "precision" metric. Again, a more sustainable implementation of this would be very welcome.

from snn_toolbox.

RKCZ avatar RKCZ commented on August 21, 2024

Thank you for detailed instructions and I am sorry it took me so much time to respond.
I followed the directions but I cannot figure it out.
The metric I am trying to use is keras.metrics.AUC(). The issue is that I don't know what value should be added into the custom_objects dict.
I compile the original model with following command:

model.compile(
      optimizer=keras.optimizers.Adam(), loss=keras.losses.BinaryCrossentropy(),
      metrics=['binary_accuracy', keras.metrics.AUC(name='auc')])

and I tried to add mapping 'auc': keras.metrics.AUC() which resulted in exception ValueError: Unknown metric function: {'class_name': 'AUC', 'config': {'name': ...
I tried to change it to 'auc': keras.metrics.AUC().update_status() and then I tried to create new function to wrap the metric:

def auc(y_true, y_pred):
  auc = keras.metrics.AUC()
  auc.update_state(y_true, y_pred)
  return auc.result().numpy()

but I always got similar error. Do you know how to use keras.metrics instances?

from snn_toolbox.

rbodo avatar rbodo commented on August 21, 2024

I think the custom_objects mapping might be case-sensitive; try

'AUC': keras.metrics.AUC()

Also, it shouldn't make a difference to have name='auc' in the constructor - but to be safe I'd just leave that out at first.

By the way, using 'binary_accuracy' instead of 'accuracy' will result in unexpected behavior when testing the SNN (the toolbox assumes 'accuracy').

from snn_toolbox.

RKCZ avatar RKCZ commented on August 21, 2024

Thank you for pointing out that 'accuracy' must be specified instead of 'binary_accuracy'.
There is still the same error even when I change the key to upper case.

from snn_toolbox.

rbodo avatar rbodo commented on August 21, 2024

Don't know what it could be, will try to take a look later this week.

from snn_toolbox.

rbodo avatar rbodo commented on August 21, 2024

I can get the model to compile using

'AUC': AUC

in the custom_objects dict (i.e. pass the class, not an instance of AUC).

from snn_toolbox.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.