Skip to content

Commit

Permalink
Fix typo in functional model input flattening (keras-team#402)
Browse files Browse the repository at this point in the history
* Fix typo in functional model input flattening

* Add unit test
  • Loading branch information
ianstenbit authored Jun 26, 2023
1 parent c0fd9bc commit 9d923de
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras_core/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _flatten_to_reference_inputs(self, inputs, allow_extra_keys=True):
if isinstance(inputs, dict):
ref_inputs = self._inputs_struct
if not nest.is_nested(ref_inputs):
ref_inputs = [self._nested_inputs]
ref_inputs = [self._inputs_struct]
if isinstance(ref_inputs, dict):
# In the case that the graph is constructed with dict input
# tensors, We will use the original dict key to map with the
Expand Down
18 changes: 18 additions & 0 deletions keras_core/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ def test_basic_flow_dict_io(self):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

def test_named_input_dict_io(self):
input_a = Input(shape=(3,), batch_size=2, name="a")
x = layers.Dense(5)(input_a)
outputs = layers.Dense(4)(x)

model = Functional(input_a, outputs)

# Eager call
in_val = {"a": np.random.random((2, 3))}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# Symbolic call
input_a_2 = Input(shape=(3,), batch_size=2)
in_val = {"a": input_a_2}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

def test_layer_getters(self):
# Test mixing ops and layers
input_a = Input(shape=(3,), batch_size=2, name="input_a")
Expand Down

0 comments on commit 9d923de

Please sign in to comment.