diff --git a/tests/TTIR/test_compare.py b/tests/TTIR/test_compare.py new file mode 100644 index 00000000..6c563490 --- /dev/null +++ b/tests/TTIR/test_compare.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# + +import pytest +import jax +import jax.numpy as jnp +import numpy + +from infrastructure import verify_module + +# Note: TTNN does not support boolean data type, so bfloat16 is used instead. +# Hence the output of comparison operation is bflaot16. JAX can not perform any +# computation due to mismatch in output data type (in testing infrastructure). +# The following tests explicitly convert data type of comparison operation +# output for the verification purposes. + + +def test_equal(): + def module_equal(a, b): + c = a == b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_equal, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_notEqual(): + def module_notEqual(a, b): + c = a != b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_notEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_greaterThan(): + def module_greaterThan(a, b): + c = a > b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_greaterThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_greaterEqual(): + def module_greaterEqual(a, b): + c = a >= b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_greaterEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_lessThan(): + def module_lessThan(a, b): + c = a < b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_lessThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_lessEqual(): + def module_lessEqual(a, b): + c = a <= b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_lessEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16)