GithubHelp home page GithubHelp logo

Comments (13)

fchollet avatar fchollet commented on May 18, 2024

What are you trying to do here? Pickle a custom class that includes a Keras model?

from keras.

CAW9 avatar CAW9 commented on May 18, 2024

Not exactly. I am trying to enable pickle dumps for a KerasClassifier wrapper. Previously, it was proposed in #4274 that this should be done with the following code:

class KerasClassifier(tf.keras.wrappers.scikit_learn.KerasClassifier):
"""
TensorFlow Keras API neural network classifier.

Workaround the tf.keras.wrappers.scikit_learn.KerasClassifier
serialization
issue using BytesIO and HDF5 in order to enable pickle dumps.

Adapted from:
https://github.com/keras-team/keras/issues/4274#issuecomment-519226139
"""

def __getstate__(self):
    state = self.__dict__
    if "model" in state:
        model = state["model"]
        model_hdf5_bio = io.BytesIO()
        with h5py.File(model_hdf5_bio, 'w') as file:
            model.save(file)
            # tf.keras.models.save_model(model, file, save_format="h5")
        state["model"] = model_hdf5_bio
        state_copy = copy.deepcopy(state)
        state["model"] = model
        return state_copy
    else:
        return state

def __setstate__(self, state):
    if "model" in state:
        model_hdf5_bio = state["model"]
        with h5py.File(model_hdf5_bio, 'r') as file:
            state["model"] = tf.keras.models.load_model(file)
    self.__dict__ = state

Since then, tf.keras.wrappers.scikit_learn.KerasClassifier has been deprecated and essentially replaced by scikeras.

The reason for the custom getstate and setstate are because these workarounds were historically needed (and recommended in the linked issue) to serialize the KerasClassifier wrapper object. Migrating to scikeras was not a problem, but now that python 3.12, tensorflow 2.16.1, and keras 3.x are being used, the previously suggested mechanisms used to serialize a KerasClassifier have broken down.

Thank you again!

from keras.

fchollet avatar fchollet commented on May 18, 2024

Did you try just using pickle with no workaround? It might well work with Keras 3.3.3.

from keras.

CAW9 avatar CAW9 commented on May 18, 2024

The KerasClassifier object (with no workarounds attempted) does serialize just fine before fitting:

KerasClassifier(
model=<function build_tf_estimator.< locals >.build_tf_model at 0x144b60220>
build_fn=None
warm_start=False
random_state=None
optimizer=rmsprop
loss=None
metrics=None
batch_size=200
validation_batch_size=None
verbose=2
callbacks=None
validation_split=0.0
shuffle=True
run_eagerly=False
epochs=30
class_weight=None
)

p = dill.dumps(est)
print(dill.loads(p))


But after fitting I get a large cascade of errors:

est.fit(get_df_values(X), y, **kwargs)
p = dill.dumps(est)
print(dill.loads(p))

File "/Users/cgladue/downloads/py312/automl_model.py", line 202, in fit_on_all_data
print(dill.loads(p))
^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/dill/_dill.py", line 303, in loads
return load(file, ignore, **kwds)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/dill/_dill.py", line 289, in load
return Unpickler(file, ignore=ignore, **kwds).load()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/dill/_dill.py", line 444, in load
obj = StockUnpickler.load(self)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/scikeras/_saving_utils.py", line 15, in unpack_keras_model
return load_model(b, compile=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/keras/src/saving/saving_lib.py", line 141, in load_model
return _load_model_from_fileobj(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/keras/src/saving/saving_lib.py", line 170, in _load_model_from_fileobj
model = deserialize_keras_object(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/keras/src/saving/serialization_lib.py", line 720, in deserialize_keras_object
raise TypeError(
TypeError: <class 'keras.src.models.sequential.Sequential'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config() are explicitly deserialized in the model's from_config() method.

Along with several other errors of the same form, like

Exception encountered: <class 'keras.src.layers.core.dense.Dense'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config() are explicitly deserialized in the model's from_config() method.

I will look into this get_config and from_config suggestion, but I do not yet understand why fitting the model would make it unserializable.

from keras.

SuryanarayanaY avatar SuryanarayanaY commented on May 18, 2024

Hi @CAW9 ,

If the model constructor has Keras layers then you need to implement get_config and from_config methods explicitly.

from keras.

CAW9 avatar CAW9 commented on May 18, 2024

I've reproduced the error using only keras, and no scikeras dependency ( https://colab.research.google.com/drive/1ps1jt8WMINt0mOqvnGkE2sNdTqHLNkW1?usp=sharing ). This time, it fails to serialize before fitting.

To clarify, you are suggesting that I cannot define and use a simple keras model like this:

def build_tf_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(
10, input_dim=X_shape[1], activation=tf.nn.relu),
tf.keras.layers.Dense(y_nunique, activation=tf.nn.softmax)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.2, amsgrad=True, name='Adam'),
loss='sparse_categorical_crossentropy',
)
return model

I need to instead implement a custom wrapper for a keras model in a class and in that class define custom get_config and from_config?

from keras.

CAW9 avatar CAW9 commented on May 18, 2024

If in that example you change keras from 3.3.3 to 2.15.0, the serialization does not fail.

from keras.

SuryanarayanaY avatar SuryanarayanaY commented on May 18, 2024

I have reproduced the behaviour with tf-nightly(Keras3). Wigth TF2.15 it works fine. Attached gist here.

from keras.

SuryanarayanaY avatar SuryanarayanaY commented on May 18, 2024

The code is failing at keras.activations.get(identifier) step. But when I pass the activation tf.nn.relu directly its not failing.

identifier = tf.nn.relu
config = keras.activations.get(identifier)
print(type(config))
callable(config)

# Output
3.3.3.dev2024050303
<class 'function'>
True

from keras.

CAW9 avatar CAW9 commented on May 18, 2024

I was able to fix my code by changing:

activation=tf.nn.relu to activation="relu"
and
activation=tf.nn.softmax to activation="softmax"

This is a satisfactory workaround for me.

If you would like to close the issue, I support you. If you feel that there is still a bug you need to address, feel free to leave it open.

Thank you for your help on this!

from keras.

SuryanarayanaY avatar SuryanarayanaY commented on May 18, 2024

Identifier as a string it's working.IMO this is still a bug when identifier is either a dict or a function.

from keras.

fchollet avatar fchollet commented on May 18, 2024

Passing TF objects directly (e.g. tf.nn.softmax) does not play well with serialization. Make sure to pass Keras objects -- could be "softmax" or keras.ops.softmax.

When passing an external object, the object should be passed via the custom_objects dict at deserialization time (e.g. in load_model or deserialize_keras_object. However, because TF is nuts, the name of your objects aren't what you expect (e.g. tf.nn.softmax is named softmax_v2) so you have to take that into account (e.g. pass custom_objects={"softmax_v2": tf.nn.softmax}).

from keras.

google-ml-butler avatar google-ml-butler commented on May 18, 2024

Are you satisfied with the resolution of your issue?
Yes
No

from keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.