diff --git a/requirements.txt b/requirements.txt index cb1a66eaf..2b978a635 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ termcolor >= 1.1.0 trimesh >= 2.37.22 # Required by trimesh. networkx +ai-edge-litert >= 1.0.1 \ No newline at end of file diff --git a/tensorflow_graphics/util/test_case.py b/tensorflow_graphics/util/test_case.py index f1ac35545..d28c7db0f 100644 --- a/tensorflow_graphics/util/test_case.py +++ b/tensorflow_graphics/util/test_case.py @@ -31,6 +31,10 @@ import tensorflow as tf from tensorflow_graphics.util import tfg_flags +# pylint: disable=g-direct-tensorflow-import +from ai-edge-litert import interpreter as tfl_interpreter +# pylint: enable=g-direct-tensorflow-import + FLAGS = flags.FLAGS @@ -364,7 +368,7 @@ def assert_tf_lite_convertible(self, sess, in_tensors, out_tensors) tflite_model = converter.convert() # Load TFLite model and allocate tensors. - interpreter = tf.lite.Interpreter(model_content=tflite_model) + interpreter = tfl_interpreter.Interpreter(model_content=tflite_model) interpreter.allocate_tensors() # If no test inputs provided then randomly generate inputs. if test_inputs is None: