diff --git a/docs/README.md b/docs/README.md
index 76db2e4b..32685a95 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -70,8 +70,9 @@
CPU practice guide |
GPU practice guide |
- C++ API support |
- OpenXLA |
+ C++ API support |
+ OpenXLA |
+ Keras 3 |
@@ -132,3 +133,6 @@
* OpenXLA
IntelĀ® Extension for TensorFlow\* adopts a uniform Device API PJRT as the supported device plugin mechanism to implement Intel GPU backend for OpenXLA support on TensorFlow frontend.
+
+* Keras 3
+ Keras 3 with TensorFlow comes with a significant enhancement - the Just-In-Time (JIT) compilation is enabled by default. This feature leverages the XLA (Accelerated Linear Algebra) compiler to optimize TensorFlow computations. See Keras 3 to avoid possible performance issues and error.
\ No newline at end of file
diff --git a/docs/guide/images/keras3.png b/docs/guide/images/keras3.png
new file mode 100644
index 00000000..0823e484
Binary files /dev/null and b/docs/guide/images/keras3.png differ
diff --git a/docs/guide/keras3_support.md b/docs/guide/keras3_support.md
new file mode 100644
index 00000000..2988cdbf
--- /dev/null
+++ b/docs/guide/keras3_support.md
@@ -0,0 +1,49 @@
+# Keras 3 Overview
+
+[Keras](https://keras.io/about/) is a deep learning API written in Python and capable of running on top of either JAX, TensorFlow, or PyTorch. Both JAX and TensorFlow backend compiles the model by XLA and delivers the best training and prediction performance on GPU. But results vary from model to model, as non XLA TensorFlow is occasionaly faster on GPU. The following image show how ITEX works with XLA, Keras 3 TensorFlow backend and legacy Keras.
+
+
+
+
+
+
+## Use Case with different performance
+There are serval use cases that can lead to diffent performance.
+
+* Default
+Users use Keras 3 and the model supports jit, the model will runs into XLA.
+If user script does not contains keras related code and does not enables XLA in tensorflow. There will be performance regression. Set environment variable `ITEX_DISABLE_XLA=1` to avoid regression. After ITEX XLA disabled, users can choose wether to use NPD (default) or stream excutor for better performance by environment variable `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE`.
+
+* Legacy Keras
+To continue using Keras 2.0, do the following.
+1. Install `tf-keras` via `pip install tf-keras`
+2. To switch `tf.keras` to use Keras 2 (`tf-keras`), set the environment variable `TF_USE_LEGACY_KERAS=1` directly or in your python program with `import os;os.environ["TF_USE_LEGACY_KERAS"]="1"`. Please note that this will set it for all packages in your Python runtime program
+3. Change the keras import: replace `import keras` with `import tf_keras as keras`. Update any `from keras import ` to `from tf_keras`.
+
+Users can choose wether to use NPD (default) or stream excutor for better performance by environment variable `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE`.
+
+* Keras 3 with jit_compile disabled
+Users can disable jit_compile by `model.jit_compile=False` or `model.compile(..., jit_compile=False)`. The use of itex ops override can also lead to disabling jit_compile. In this case, `ITEX_DISABLE_XLA=1` must be set.
+
+* Enable XLA through TensorFlow.
+Users can enable XLA through TensorFlow by add environment variable `TF_XLA_FLAGS="--tf_xla_auto_jit=1"`. Use `tf_xla_auto_jit=1` for auto clustering TF ops into XLA, `tf_xla_auto_jit=2` for compiling all into XLA. Users should set `model.jit_compile=False` if keras model is used. If ITEX custom ops is used or `ITEX_OPS_OVERRIDE` is set, users should use `tf_xla_auto_jit=1` to avoid error.
+
+
+
+
+
+## Situations leads to warning or Error
+We list all invalid cases here. Keras version equals to 0 means model script does not use Keras.
+
+Note that in any cases, `import keras` first before `import tensorflow` will cause an error due to circular import in ITEX.
+
+| OPS_OVERRIDE | TF_AUTO_JIT_FLAG | Keras version | NPD | Jit Compile | Warning | Error | Solution |
+|--------------|------------------|---------------|-----|-------------|---------|-------|----------|
+| Any | 0 | 0 | 0 | NA | | PluggableDevice cannot work with latest Keras. | `ITEX_DISABLE_XLA=1` |
+| Any | 0 | 0 | 1 | NA | Perf Regression | | `ITEX_DISABLE_XLA=1` |
+| Any | Any | 2 | Any | 1 | | | Unkown behavior, not supported. Use `TF_AUTO_JIT_FLAG="--tf_xla_auto_jit=1"` or `2` to enable XLA |
+| Any | 0 | 3 | 0 | Any | | Cannot close NPD when keras 3 | `ITEX_DISABLE_XLA=1` |
+| Any | 0 | 3 | 1 | 0 | | perf regression | `ITEX_DISABLE_XLA=1` |
+| Any | 1 | Any | 0 | Any | | Cannot close NPD | `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1` |
+| Any | 2 | Any | 0 | Any | | Cannot close NPD | `ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1` |
+| 1 | 2 | Any | 1 | Any | custom op not supported by XLA | | `ITEX_OPS_OVERRIDE=0` |
diff --git a/itex/core/kernels/xpu_kernel.cc b/itex/core/kernels/xpu_kernel.cc
index 9b40093d..2dfca0cc 100644
--- a/itex/core/kernels/xpu_kernel.cc
+++ b/itex/core/kernels/xpu_kernel.cc
@@ -85,9 +85,30 @@ void TF_InitKernel() {
bool ops_override = false;
ITEX_CHECK_OK(
itex::ReadBoolFromEnvVar("ITEX_OPS_OVERRIDE", false, &ops_override));
+ // clang-format off
if (ops_override) {
- PyRun_SimpleString("import intel_extension_for_tensorflow as itex;\n");
- PyRun_SimpleString("itex.experimental_ops_override();\n");
+ PyRun_SimpleString(
+ "try:\n"
+ " import os;\n"
+ " if os.environ.get('TF_USE_LEGACY_KERAS', None) in ('true', 'True', '1'):\n" // NOLINT(whitespace/line_length)
+ " from intel_extension_for_tensorflow.python.experimental_ops_override import experimental_ops_override;\n" // NOLINT(whitespace/line_length)
+ " else:\n"
+ " from intel_extension_for_tensorflow.python.experimental_ops_override_k3 import experimental_ops_override;\n" // NOLINT(whitespace/line_length)
+ " from intel_extension_for_tensorflow.python.override_keras3 import override_keras3;\n" // NOLINT(whitespace/line_length)
+ " experimental_ops_override();\n"
+ " override_keras3();\n"
+ "except BaseException:\n"
+ " print('please import ITEX or tensorflow berfore keras')\n"
+ " quit()\n");
+ } else {
+ PyRun_SimpleString(
+ "try:\n"
+ " from intel_extension_for_tensorflow.python.override_keras3 import override_keras3;\n" // NOLINT(whitespace/line_length)
+ " override_keras3();\n"
+ "except BaseException:\n"
+ " print('please import ITEX or tensorflow berfore keras')\n"
+ " quit()\n");
}
+ // clang-format on
#endif // CC_BUILD
}
diff --git a/itex/python/base_init.py b/itex/python/base_init.py
index 8ea93ac3..383236d6 100644
--- a/itex/python/base_init.py
+++ b/itex/python/base_init.py
@@ -33,3 +33,7 @@
if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"):
from intel_extension_for_tensorflow.python.experimental_ops_override import experimental_ops_override
+else:
+ from intel_extension_for_tensorflow.python.experimental_ops_override_k3 import experimental_ops_override
+
+from intel_extension_for_tensorflow.python.override_keras3 import override_keras3
diff --git a/itex/python/experimental_ops_override_k3.py b/itex/python/experimental_ops_override_k3.py
new file mode 100644
index 00000000..b41cf57b
--- /dev/null
+++ b/itex/python/experimental_ops_override_k3.py
@@ -0,0 +1,268 @@
+# Copyright (c) 2023 Intel Corporation
+#
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ITEX optimization for some TensorFlow API."""
+import logging
+import os
+import types
+import tensorflow as tf
+
+
+from keras import ops
+
+
+from intel_extension_for_tensorflow.python.ops.layer_norm_k3 import _layer_norm
+
+format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+logging.basicConfig(level=logging.INFO, format=format_str)
+logger = logging.getLogger(__name__)
+
+
+def copy_func(f, name=None):
+ '''
+ return a function with same code, globals, defaults, closure, and
+ name (or provide a new name)
+ '''
+ fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__,
+ f.__defaults__, f.__closure__)
+ # in case f was given attrs (note this dict is a shallow copy):
+ fn.__dict__.update(f.__dict__)
+ return fn
+
+
+def _can_use_onednn_layer_norm(self, ndims):
+ """Return false if Itex layernorm implementation cannot be used.
+
+ Check if the axis is contiguous and can be collapsed into the last axis.
+ The self.axis is assumed to have no duplicates.
+ """
+ self._data_format = "NHWC" # pylint: disable=protected-access
+ self._is_one_axis_len = None # pylint: disable=protected-access
+ can_use_onednn_layer_norm = True
+ axis = sorted(self.axis)
+ if axis[-1] != ndims - 1 or ndims < 2 or ndims > 4 or axis[-1] - axis[0] != len(axis) - 1: # pylint: disable=line-too-long
+ can_use_onednn_layer_norm = False
+
+ if can_use_onednn_layer_norm and (axis[-1] == 3 or self.axis[-1] == -1):
+ self.data_format = 'NHWC'
+
+ if len(axis) == 1:
+ self._is_one_axis_len = True # pylint: disable=protected-access
+ else:
+ self._is_one_axis_len = False # pylint: disable=protected-access
+
+ if self.dtype == 'float64':
+ raise ValueError(
+ 'Itex Layernorm only support float32, bfloat16 and float16.') # pylint: disable=line-too-long
+
+ return can_use_onednn_layer_norm
+
+
+def experimental_ops_override():
+ '''
+ using itex api in some tf and keras functions.
+ '''
+ try:
+ from pkg_resources import packaging # pylint: disable=import-outside-toplevel
+ version = packaging.version.parse
+ if version(tf.__version__) < version("2.16.1"):
+ return
+
+ from keras.src import backend # pylint: disable=import-outside-toplevel
+ from keras.src.utils import tf_utils # pylint: disable=import-outside-toplevel
+
+ import keras
+ tf_ln_call = copy_func(keras.layers.LayerNormalization.call)
+ tf_gn_call = copy_func(keras.layers.GroupNormalization.call)
+ tf_gn_build = copy_func(keras.layers.GroupNormalization.build)
+
+ except BaseException: # pylint: disable=broad-except
+ return
+
+ def itex_layer_norm_build(self, input_shape):
+ self.supports_jit = False
+ if self.compute_dtype == "float16" or self.compute_dtype == "bfloat16": # pylint: disable=no-else-return
+ self._param_dtype = "float32"
+ else:
+ self._param_dtype = self.dtype or dtypes.float32
+ ndims = len(input_shape)
+ if ndims is None:
+ raise ValueError(
+ 'Input shape %s has undefined rank.' % input_shape)
+ if isinstance(self.axis, list):
+ shape = tuple([input_shape[dim] for dim in self.axis])
+ else:
+ shape = (input_shape[self.axis],)
+ self.axis = [self.axis]
+ for idx, x in enumerate(self.axis):
+ if x < 0:
+ self.axis[idx] = ndims + x
+ param_shape = [input_shape[dim] for dim in self.axis]
+ if self.scale or self.rms_scaling:
+ self.gamma = self.add_weight(
+ name="gamma",
+ shape=shape,
+ initializer=self.gamma_initializer,
+ regularizer=self.gamma_regularizer,
+ constraint=self.gamma_constraint,
+ trainable=True,
+ dtype=self._param_dtype,
+ )
+ else:
+ self.gamma = None
+ self._gamma_const = ops.ones(
+ dtype=self._param_dtype, shape=param_shape)
+
+ if self.center and not self.rms_scaling:
+ self.beta = self.add_weight(
+ name="beta",
+ shape=shape,
+ initializer=self.beta_initializer,
+ regularizer=self.beta_regularizer,
+ constraint=self.beta_constraint,
+ trainable=True,
+ dtype=self._param_dtype,
+ )
+ else:
+ self.beta = None
+ self._beta_const = ops.zeros(
+ dtype=self._param_dtype, shape=param_shape)
+ self._use_layernorm = _can_use_onednn_layer_norm(self, ndims)
+ self.built = True
+
+ def _layer_norm_inference_or_training(self, inputs, gamma, beta, training):
+ """Returns the output of layer norm."""
+ def _layer_norm_training():
+ return _layer_norm(
+ inputs,
+ scale=gamma,
+ offset=beta,
+ epsilon=self.epsilon,
+ is_training=True,
+ data_format=self._data_format)
+
+ def _layer_norm_inference():
+ return _layer_norm(
+ inputs,
+ scale=gamma,
+ offset=beta,
+ epsilon=self.epsilon,
+ is_training=False,
+ data_format=self._data_format)
+
+ output, _, _ = tf.__internal__.smart_cond.smart_cond(
+ training, _layer_norm_training, _layer_norm_inference)
+ return output
+
+ def itex_layer_norm_call(self, inputs, training=None):
+ if not self._use_layernorm: # pylint: disable=protected-access
+ return tf_ln_call(self, inputs) # pylint: disable=not-callable
+ if self.rms_scaling: # pylint: disable=protected-access
+ return tf_ln_call(self, inputs) # pylint: disable=not-callable
+ if training is None:
+ is_training = True
+ if isinstance(training, int):
+ is_training = bool(training)
+ if not self.trainable:
+ # When the layer is not trainable, it overrides the value passed from
+ # model.
+ is_training = False
+ # Compute the axes along which to reduce the mean / variance
+ inputs = ops.cast(inputs, self.compute_dtype)
+ # Compute the axes along which to reduce the mean / variance
+ input_shape = inputs.shape
+ ndims = len(input_shape)
+
+ # Broadcasting only necessary for norm when the axis is not just
+ # the last dimension
+ broadcast_shape = [1] * ndims
+ for dim in self.axis:
+ broadcast_shape[dim] = input_shape[dim]
+
+ def _broadcast(v):
+ if (
+ v is not None
+ and len(v.shape) != ndims
+ and self.axis != [ndims - 1]
+ ):
+ return ops.reshape(v, broadcast_shape)
+ return v
+
+ input_dtype = inputs.dtype
+ if input_dtype in (tf.float16, tf.bfloat16) and self.dtype == "float32" and not self._use_layernorm:
+ # If mixed precision is used, cast inputs to float32 so that
+ # this is at least as numerically stable as the fused version.
+ inputs = ops.cast(inputs, "float32")
+
+ beta = self.beta if self.beta is not None else self._beta_const
+ gamma = self.gamma if self.gamma is not None else self._gamma_const
+ if self._is_one_axis_len:
+ outputs = _layer_norm_inference_or_training(self, inputs, gamma, beta,
+ is_training)
+ return outputs
+ else:
+ # Collapse dims before self.axis, and dims in self.axis
+ pre_dim, in_dim = (1, 1)
+ axis = sorted(self.axis)
+ tensor_shape = inputs.shape
+ for dim in range(0, ndims):
+ dim_tensor = tensor_shape[dim]
+ if dim < axis[0]:
+ pre_dim = pre_dim * dim_tensor
+ else:
+ assert dim in axis
+ in_dim = in_dim * dim_tensor
+
+ squeezed_shape = [1, pre_dim, in_dim]
+ inputs = ops.reshape(inputs, squeezed_shape)
+
+ # self.gamma and self.beta have the wrong shape for layer_norm, so
+ # we cannot pass them as the scale and offset parameters. Therefore, we
+ # create two constant tensors in correct shapes for layer_norm and
+ # later construct a separate calculation on the scale and offset.
+ scale = ops.ones([in_dim], dtype="float32")
+ offset = ops.zeros([in_dim], dtype="float32")
+
+ # Compute layer normalization.
+ outputs = _layer_norm_inference_or_training(self, inputs, scale,
+ offset, is_training)
+ outputs = ops.reshape(outputs, tensor_shape)
+ scale, offset = _broadcast(
+ self.gamma), _broadcast(self.beta)
+
+ if scale is not None:
+ outputs = outputs * ops.cast(scale, outputs.dtype)
+ if offset is not None:
+ outputs = outputs + ops.cast(offset, outputs.dtype)
+ return outputs
+
+ try:
+ keras.layers.LayerNormalization.call = itex_layer_norm_call
+ keras.layers.LayerNormalization.build = itex_layer_norm_build
+ logger.info("itex experimental ops override is enabled.")
+ except BaseException: # pylint: disable=broad-except
+ logger.error("Cannot override itex ops.")
+ try:
+ import keras # pylint: disable=import-outside-toplevel
+ keras.src.layers.normalization.layer_normalization.LayerNormalization.call = itex_layer_norm_call
+ keras.src.layers.normalization.layer_normalization.LayerNormalization.build = itex_layer_norm_build
+ except BaseException: # pylint: disable=broad-except
+ logger.warning(
+ "itex experimental ops override: Keras is not installed.") # pylint: disable=line-too-long
diff --git a/itex/python/ops/group_norm_k3.py b/itex/python/ops/group_norm_k3.py
index 0d3f97df..e0b7f3db 100644
--- a/itex/python/ops/group_norm_k3.py
+++ b/itex/python/ops/group_norm_k3.py
@@ -238,7 +238,7 @@ def itex_group_norm_call(self, inputs):
if dim > 1:
in_dim = in_dim * dim_tensor
- squeezed_shape = [tensor_shape[0], tensor_shape[1], in_dim]
+ squeezed_shape = [-1, tensor_shape[1], in_dim]
inputs = ops.reshape(reshaped_inputs, squeezed_shape)
# self.gamma and self.beta have the wrong shape for layer_norm, so
@@ -247,19 +247,21 @@ def itex_group_norm_call(self, inputs):
# later construct a separate calculation on the scale and offset.
scale = ops.ones([in_dim], dtype="float32")
offset = ops.zeros([in_dim], dtype="float32")
-
outputs, _, _ = _layer_norm(inputs,
scale=scale,
offset=offset,
epsilon=self.epsilon,
is_training=True)
- outputs = ops.reshape(outputs, tensor_shape)
+ out_tensor_shape = list(tensor_shape)
+ out_tensor_shape[0] = -1 if out_tensor_shape[0] is None else out_tensor_shape[0]
+ outputs = ops.reshape(outputs, out_tensor_shape)
if axis != 1:
perm_back_shape = list(range(0, group_ndims))
perm_back_shape.pop(1)
perm_back_shape.insert(axis, 1)
outputs = ops.transpose(outputs, perm_back_shape)
-
+ input_shape = list(input_shape)
+ input_shape[0] = -1 if input_shape[0] is None else input_shape[0]
outputs = ops.reshape(outputs, input_shape)
if self.scale:
diff --git a/itex/python/ops/layer_norm_k3.py b/itex/python/ops/layer_norm_k3.py
index 2038146c..9c553277 100644
--- a/itex/python/ops/layer_norm_k3.py
+++ b/itex/python/ops/layer_norm_k3.py
@@ -355,7 +355,7 @@ def _broadcast(v):
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = ops.rsqrt(variance + self.epsilon)
- outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)
+ outputs = inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
else:
if self._use_layernorm:
beta = self.beta if self.beta is not None else self._beta_const
diff --git a/itex/python/override_keras3.py b/itex/python/override_keras3.py
new file mode 100644
index 00000000..436d6c3d
--- /dev/null
+++ b/itex/python/override_keras3.py
@@ -0,0 +1,98 @@
+# Copyright (c) 2023 Intel Corporation
+#
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""check keras model if it supports jit compile."""
+import keras
+import os
+import types
+
+XLA_AUTO_CLUSTER = False
+if "--tf_xla_auto_jit=1" in os.environ.get("TF_XLA_FLAGS", "").replace(" ", ""):
+ XLA_AUTO_CLUSTER = True
+
+
+def copy_func(f, name=None):
+ '''
+ return a function with same code, globals, defaults, closure, and
+ name (or provide a new name)
+ '''
+ fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__,
+ f.__defaults__, f.__closure__)
+ # in case f was given attrs (note this dict is a shallow copy):
+ fn.__dict__.update(f.__dict__)
+ return fn
+
+
+keras_model_compile = copy_func(keras.src.trainers.trainer.Trainer.compile)
+keras_model_predict = copy_func(
+ keras.src.backend.tensorflow.trainer.TensorFlowTrainer.predict)
+
+
+def itex_model_compile(self,
+ optimizer="rmsprop",
+ loss=None,
+ loss_weights=None,
+ metrics=None,
+ weighted_metrics=None,
+ run_eagerly=False,
+ steps_per_execution=1,
+ jit_compile="auto",
+ auto_scale_loss=True,
+ ):
+ keras_model_compile(self, # pylint: disable=not-callable
+ optimizer=optimizer,
+ loss=loss,
+ loss_weights=loss_weights,
+ metrics=metrics,
+ weighted_metrics=weighted_metrics,
+ run_eagerly=run_eagerly,
+ steps_per_execution=steps_per_execution,
+ jit_compile=jit_compile,
+ auto_scale_loss=auto_scale_loss,)
+ if ((not self.jit_compile) and os.environ.get("ITEX_DISABLE_XLA", "0") in ("false", "0") and (not XLA_AUTO_CLUSTER)):
+ print("This keras model does not support jit compile, please use legacy keras or set ITEX_DISABLE_XLA=1")
+ quit()
+
+
+def itex_predict(
+ self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
+):
+ if ((not self.jit_compile) and os.environ.get("ITEX_DISABLE_XLA", "0") in ("false", "0") and (not XLA_AUTO_CLUSTER)):
+ print("This keras model is not jit compiled, please compile it or use legacy keras or set ITEX_DISABLE_XLA=1")
+ quit()
+ return keras_model_predict(self, x, batch_size, verbose, steps, callbacks) # pylint: disable=not-callable
+
+
+def override_keras3():
+ '''
+ override model_supports_jit
+ '''
+ if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"):
+ return
+ try:
+ from pkg_resources import packaging # pylint: disable=import-outside-toplevel
+ version = packaging.version.parse
+ if version(keras.__version__) >= version("3.0.0"):
+ keras.src.trainers.trainer.Trainer.compile = itex_model_compile
+ keras.src.backend.tensorflow.trainer.TensorFlowTrainer.predict = itex_predict
+
+ except BaseException: # pylint: disable=broad-except
+ logger.warning(
+ "itex.override_keras3 failed") # pylint: disable=line-too-long
diff --git a/test/sanity/nn/group_normalization_k3_test.py b/test/sanity/nn/group_normalization_k3_test.py
index 122cfab6..d3b9dbe3 100644
--- a/test/sanity/nn/group_normalization_k3_test.py
+++ b/test/sanity/nn/group_normalization_k3_test.py
@@ -14,6 +14,7 @@
# =============================================================================
import os
os.environ['TF_USE_LEGACY_KERAS']='0'
+os.environ['ITEX_DISABLE_XLA']='1'
import tensorflow.compat.v2 as tf
import numpy as np
diff --git a/test/sanity/nn/rms_normalization_test_k3.py b/test/sanity/nn/rms_normalization_test_k3.py
index 0502f93f..04f04102 100644
--- a/test/sanity/nn/rms_normalization_test_k3.py
+++ b/test/sanity/nn/rms_normalization_test_k3.py
@@ -14,6 +14,7 @@
# =============================================================================
import os
os.environ['TF_USE_LEGACY_KERAS']='0'
+os.environ['ITEX_DISABLE_XLA']='1'
import tensorflow as tf
import numpy as np
diff --git a/third_party/build_option/sycl/runtime/itex_gpu_runtime.h b/third_party/build_option/sycl/runtime/itex_gpu_runtime.h
index d984ac7e..f9be30fa 100644
--- a/third_party/build_option/sycl/runtime/itex_gpu_runtime.h
+++ b/third_party/build_option/sycl/runtime/itex_gpu_runtime.h
@@ -134,10 +134,11 @@ class ITEXNpdConfig {
ITEX_LOG(FATAL) << "PluggableDevice cannot enable XLA! Please export "
"ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1 to enable XLA.";
}
- if (!isLegacyKeras_) {
+ if (!isLegacyKeras_ && !isXLADisabled_) {
ITEX_LOG(FATAL)
<< "PluggableDevice cannot work with latest Keras. Please export "
- "TF_USE_LEGACY_KERAS=1 to use legacy keras.";
+ "TF_USE_LEGACY_KERAS=1 to use legacy keras or "
+ "ITEX_DISABLE_XLA=1 with keras jit compile disabled.";
}
}
}
@@ -151,15 +152,30 @@ class ITEXNpdConfig {
ReadVariableFromEnv("TF_USE_LEGACY_KERAS", &isLegacyKeras_);
// Check whether to use PJRT Buffer cache mechanism.
ReadVariableFromEnv("ITEX_CACHE_PJRT_BUFFER", &isPJRTBufferCached_);
+ // For jit_compile=False in keras
+ ReadVariableFromEnv("ITEX_DISABLE_XLA", &isXLADisabled_);
// Determine whether enable XLA auto JIT
isXlaAutoJitEnabled_ = static_cast(TF_GetXlaAutoJitEnabled());
- isXlaAutoJitEnabled_ |= isLegacyKeras_ ? 0 : 1;
+ if (!isXlaAutoJitEnabled_ && (!isLegacyKeras_ || !isXLADisabled_)) {
+ ITEX_LOG(WARNING) << "Set TF_USE_LEGACY_KERAS=0 or ITEX_DISABLE_XLA=1 if "
+ "your script does not use keras";
+ }
+ // Keras 3.0.
+ if (!isLegacyKeras_) {
+ // XLA auto_jit is off and XLA is disabled in ITEX.
+ if (isXLADisabled_ && !isXlaAutoJitEnabled_) {
+ isXlaAutoJitEnabled_ = false;
+ } else {
+ isXlaAutoJitEnabled_ = true;
+ }
+ }
CheckNPDConfig();
// Preparations for enabling XLA auto JIT
if (isXlaAutoJitEnabled_) {
- ITEX_VLOG(1) << "ITEX XLA auto_jit is enabled!";
+ ITEX_LOG(WARNING) << "ITEX XLA auto_jit is enabled! There will be "
+ "performance drop if you are not using XLA.";
setenv("ITEX_REMAPPER", "0", 0);
setenv("ITEX_LAYOUT_OPT", "0", 0);
setenv("ITEX_ENABLE_MULTIPLE_STREAM", "1", 0);
@@ -172,6 +188,7 @@ class ITEXNpdConfig {
bool isNextPluggableDeviceEnabled_ = true;
bool isLegacyKeras_ = false;
+ bool isXLADisabled_ = false;
bool isXlaAutoJitEnabled_ = false;
bool isPJRTBufferCached_ = true;
};