Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_model_from_hdf5 doesn't load axis param for Concatenate layer #20963

Closed
gkoundry opened this issue Feb 26, 2025 · 2 comments · Fixed by #20973
Closed

load_model_from_hdf5 doesn't load axis param for Concatenate layer #20963

gkoundry opened this issue Feb 26, 2025 · 2 comments · Fixed by #20973
Assignees
Labels

Comments

@gkoundry
Copy link

When loading a keras model from an hdf5 file the axis parameter is dropped from the model config.
The code at https://github.com/keras-team/keras/blob/6c3dd68c1b2e7783d15244279959e89fe88ee346/ pops the param from the config and only puts it back if it's a list which AFAIK it never is.

Script to reproduce the issue:

from tensorflow.keras.layers import Input, Dense, Concatenate
from tensorflow.keras.models import Model
from keras.src.legacy.saving import legacy_h5_format

input1 = Input(shape=(1,4), name='input1')
input2 = Input(shape=(1,4), name='input2')
input3 = Input(shape=(2,4), name='input3')
concat1 = Concatenate(axis=1)([input1, input2])
concat2 = Concatenate(axis=-1)([concat1, input3])
output = Dense(1, activation='sigmoid')(concat2)

model = Model(inputs=[input1, input2], outputs=output)

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

legacy_h5_format.save_model_to_hdf5(model, "cat_test.h5")
legacy_h5_format.load_model_from_hdf5("cat_test.h5")

Running this script results in load_model_from_hdf5 failing with an error ValueError: A Concatenate layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 1, 8), (None, 2, 4)] concatenate_1

@mehtamansi29
Copy link
Collaborator

Hi @gkoundry -

Thanks for reporting the issue. Here if you ran your reproducible code with latest keras(3.8.0) and instead of saving model .h5 try to use .keras extension like this. Code works fine without error.

keras.saving.save_model(model, "cat_test.keras")
model1 = keras.saving.load_model("cat_test.keras")

Attached gist for the reference here.

@gkoundry
Copy link
Author

I know that there is a better way to do this but I had some older HDF5 models that I was trying to load which is why I ran into this. It does seem to be a bug and was difficult to track down so I figured I would just report it anyways just to save someone else a headache if they ran into it. However I understand if you decide it's not worth it to maintain old legacy code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants