Skip to content

Commit

Permalink
implement evaluate() for torch backend (keras-team#255)
Browse files Browse the repository at this point in the history
Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
haifeng-jin and haifeng-jin authored Jun 3, 2023
1 parent d02ed05 commit 4de3449
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
3 changes: 2 additions & 1 deletion keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def _initialize(self, value):
self._value.requires_grad_(self.trainable)

def _direct_assign(self, value):
self._value.copy_(value)
with torch.no_grad():
self._value.copy_(value)

def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def empty(shape, dtype="float32"):

def equal(x1, x2):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
return torch.equal(x1, x2)
return torch.eq(x1, x2)


def exp(x):
Expand Down
62 changes: 61 additions & 1 deletion keras_core/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ def predict(
):
pass

def test_step(self, data):
data = data[0]
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
if self._call_has_training_arg():
y_pred = self(x, training=False)
else:
y_pred = self(x)
loss = self.compute_loss(
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
)
self._loss_tracker.update_state(loss)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

def evaluate(
self,
x=None,
Expand All @@ -204,4 +217,51 @@ def evaluate(
return_dict=False,
**kwargs,
):
pass
# TODO: respect compiled trainable state
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
if kwargs:
raise ValueError(f"Arguments not recognized: {kwargs}")

if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches of input/target data.
epoch_iterator = EpochIterator(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)

# Switch the torch Module back to testing mode.
self.eval()

callbacks.on_test_begin()
logs = None
self.reset_metrics()
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)
with torch.no_grad():
logs = self.test_step(data)
callbacks.on_test_batch_end(step, logs)
logs = self._pythonify_logs(self.get_metrics_result())
callbacks.on_test_end(logs)

if return_dict:
return logs
return self._flatten_metrics_in_order(logs)

0 comments on commit 4de3449

Please sign in to comment.