Comments (13)
What are you trying to do here? Pickle a custom class that includes a Keras model?
from keras.
Not exactly. I am trying to enable pickle dumps for a KerasClassifier wrapper. Previously, it was proposed in #4274 that this should be done with the following code:
class KerasClassifier(tf.keras.wrappers.scikit_learn.KerasClassifier):
"""
TensorFlow Keras API neural network classifier.
Workaround the tf.keras.wrappers.scikit_learn.KerasClassifier
serialization
issue using BytesIO and HDF5 in order to enable pickle dumps.
Adapted from:
https://github.com/keras-team/keras/issues/4274#issuecomment-519226139
"""
def __getstate__(self):
state = self.__dict__
if "model" in state:
model = state["model"]
model_hdf5_bio = io.BytesIO()
with h5py.File(model_hdf5_bio, 'w') as file:
model.save(file)
# tf.keras.models.save_model(model, file, save_format="h5")
state["model"] = model_hdf5_bio
state_copy = copy.deepcopy(state)
state["model"] = model
return state_copy
else:
return state
def __setstate__(self, state):
if "model" in state:
model_hdf5_bio = state["model"]
with h5py.File(model_hdf5_bio, 'r') as file:
state["model"] = tf.keras.models.load_model(file)
self.__dict__ = state
Since then, tf.keras.wrappers.scikit_learn.KerasClassifier has been deprecated and essentially replaced by scikeras.
The reason for the custom getstate and setstate are because these workarounds were historically needed (and recommended in the linked issue) to serialize the KerasClassifier wrapper object. Migrating to scikeras was not a problem, but now that python 3.12, tensorflow 2.16.1, and keras 3.x are being used, the previously suggested mechanisms used to serialize a KerasClassifier have broken down.
Thank you again!
from keras.
Did you try just using pickle
with no workaround? It might well work with Keras 3.3.3.
from keras.
The KerasClassifier object (with no workarounds attempted) does serialize just fine before fitting:
KerasClassifier(
model=<function build_tf_estimator.< locals >.build_tf_model at 0x144b60220>
build_fn=None
warm_start=False
random_state=None
optimizer=rmsprop
loss=None
metrics=None
batch_size=200
validation_batch_size=None
verbose=2
callbacks=None
validation_split=0.0
shuffle=True
run_eagerly=False
epochs=30
class_weight=None
)
p = dill.dumps(est)
print(dill.loads(p))
But after fitting I get a large cascade of errors:
est.fit(get_df_values(X), y, **kwargs)
p = dill.dumps(est)
print(dill.loads(p))
File "/Users/cgladue/downloads/py312/automl_model.py", line 202, in fit_on_all_data
print(dill.loads(p))
^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/dill/_dill.py", line 303, in loads
return load(file, ignore, **kwds)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/dill/_dill.py", line 289, in load
return Unpickler(file, ignore=ignore, **kwds).load()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/dill/_dill.py", line 444, in load
obj = StockUnpickler.load(self)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/scikeras/_saving_utils.py", line 15, in unpack_keras_model
return load_model(b, compile=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/keras/src/saving/saving_lib.py", line 141, in load_model
return _load_model_from_fileobj(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/keras/src/saving/saving_lib.py", line 170, in _load_model_from_fileobj
model = deserialize_keras_object(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cgladue/downloads/py312/keras/src/saving/serialization_lib.py", line 720, in deserialize_keras_object
raise TypeError(
TypeError: <class 'keras.src.models.sequential.Sequential'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config()
are explicitly deserialized in the model's from_config()
method.
Along with several other errors of the same form, like
Exception encountered: <class 'keras.src.layers.core.dense.Dense'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config()
are explicitly deserialized in the model's from_config()
method.
I will look into this get_config and from_config suggestion, but I do not yet understand why fitting the model would make it unserializable.
from keras.
Hi @CAW9 ,
If the model constructor has Keras layers then you need to implement get_config
and from_config
methods explicitly.
from keras.
I've reproduced the error using only keras, and no scikeras dependency ( https://colab.research.google.com/drive/1ps1jt8WMINt0mOqvnGkE2sNdTqHLNkW1?usp=sharing ). This time, it fails to serialize before fitting.
To clarify, you are suggesting that I cannot define and use a simple keras model like this:
def build_tf_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(
10, input_dim=X_shape[1], activation=tf.nn.relu),
tf.keras.layers.Dense(y_nunique, activation=tf.nn.softmax)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.2, amsgrad=True, name='Adam'),
loss='sparse_categorical_crossentropy',
)
return model
I need to instead implement a custom wrapper for a keras model in a class and in that class define custom get_config and from_config?
from keras.
If in that example you change keras from 3.3.3 to 2.15.0, the serialization does not fail.
from keras.
I have reproduced the behaviour with tf-nightly(Keras3). Wigth TF2.15 it works fine. Attached gist here.
from keras.
The code is failing at keras.activations.get(identifier)
step. But when I pass the activation tf.nn.relu
directly its not failing.
identifier = tf.nn.relu
config = keras.activations.get(identifier)
print(type(config))
callable(config)
# Output
3.3.3.dev2024050303
<class 'function'>
True
from keras.
I was able to fix my code by changing:
activation=tf.nn.relu to activation="relu"
and
activation=tf.nn.softmax to activation="softmax"
This is a satisfactory workaround for me.
If you would like to close the issue, I support you. If you feel that there is still a bug you need to address, feel free to leave it open.
Thank you for your help on this!
from keras.
Identifier as a string it's working.IMO this is still a bug when identifier is either a dict or a function.
from keras.
Passing TF objects directly (e.g. tf.nn.softmax
) does not play well with serialization. Make sure to pass Keras objects -- could be "softmax"
or keras.ops.softmax
.
When passing an external object, the object should be passed via the custom_objects
dict at deserialization time (e.g. in load_model
or deserialize_keras_object
. However, because TF is nuts, the name of your objects aren't what you expect (e.g. tf.nn.softmax
is named softmax_v2
) so you have to take that into account (e.g. pass custom_objects={"softmax_v2": tf.nn.softmax}
).
from keras.
Are you satisfied with the resolution of your issue?
Yes
No
from keras.
Related Issues (20)
- TypeError: Could not locate class 'adam'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()` HOT 1
- Shape error for some use cases of `binary_crossentropy`. HOT 3
- ops.copy makes object (arrays only?) immutable HOT 1
- Regression: Cannot export model with bilinear UpSampling2D and dynamic inputs HOT 1
- TextVectorization does not convert Cyrillic characters to lowercase HOT 2
- TypeError when loading Keras 2 legacy models with nested outputs in Keras 3 HOT 2
- Have the plan to support fp8 inference? HOT 6
- JAX array conversion failure in Keras model prediction HOT 3
- On JAX, Keras replaces any exception inside `call` method of `keras.Model` subclass with misleading error HOT 2
- To Keras community: What interpretations do you have for these curves? HOT 3
- No module named 'keras.src.engine' HOT 7
- Feature request: keras.ops.linalg.lstsq HOT 4
- Example Doubt HOT 3
- More Customisation in utils.ProgBar HOT 6
- Progress bar crash when empty dataset HOT 1
- Multihead Attention Seed Specification HOT 1
- Unable to make two instances of the MobileNetV3 within the same model HOT 2
- NumPy 2.0 support HOT 3
- Add backend-agnostic worker-process data loading HOT 8
- Keras does not save weights properly HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras.