From 5d067a9c380f7f19618751e08e2b8f3a989ff108 Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Fri, 29 Nov 2024 15:01:03 +0000 Subject: [PATCH] Add tests for compare op * Equal * Not equal * greater than * greater or equal to * less than * less or equal to --- tests/TTIR/test_compare.py | 65 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/TTIR/test_compare.py diff --git a/tests/TTIR/test_compare.py b/tests/TTIR/test_compare.py new file mode 100644 index 00000000..6c563490 --- /dev/null +++ b/tests/TTIR/test_compare.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# + +import pytest +import jax +import jax.numpy as jnp +import numpy + +from infrastructure import verify_module + +# Note: TTNN does not support boolean data type, so bfloat16 is used instead. +# Hence the output of comparison operation is bflaot16. JAX can not perform any +# computation due to mismatch in output data type (in testing infrastructure). +# The following tests explicitly convert data type of comparison operation +# output for the verification purposes. + + +def test_equal(): + def module_equal(a, b): + c = a == b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_equal, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_notEqual(): + def module_notEqual(a, b): + c = a != b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_notEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_greaterThan(): + def module_greaterThan(a, b): + c = a > b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_greaterThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_greaterEqual(): + def module_greaterEqual(a, b): + c = a >= b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_greaterEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_lessThan(): + def module_lessThan(a, b): + c = a < b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_lessThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16) + + +def test_lessEqual(): + def module_lessEqual(a, b): + c = a <= b + return jax.lax.convert_element_type(c, jnp.float32) + + verify_module(module_lessEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16)