Skip to content

Commit

Permalink
Add tests for compare op
Browse files Browse the repository at this point in the history
* Equal
* Not equal
* greater than
* greater or equal to
* less than
* less or equal to
  • Loading branch information
mmanzoorTT committed Nov 29, 2024
1 parent b00c49d commit 5d067a9
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions tests/TTIR/test_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
#

import pytest
import jax
import jax.numpy as jnp
import numpy

from infrastructure import verify_module

# Note: TTNN does not support boolean data type, so bfloat16 is used instead.
# Hence the output of comparison operation is bflaot16. JAX can not perform any
# computation due to mismatch in output data type (in testing infrastructure).
# The following tests explicitly convert data type of comparison operation
# output for the verification purposes.


def test_equal():
def module_equal(a, b):
c = a == b
return jax.lax.convert_element_type(c, jnp.float32)

verify_module(module_equal, [(64, 64), (64, 64)], dtype=jnp.bfloat16)


def test_notEqual():
def module_notEqual(a, b):
c = a != b
return jax.lax.convert_element_type(c, jnp.float32)

verify_module(module_notEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16)


def test_greaterThan():
def module_greaterThan(a, b):
c = a > b
return jax.lax.convert_element_type(c, jnp.float32)

verify_module(module_greaterThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16)


def test_greaterEqual():
def module_greaterEqual(a, b):
c = a >= b
return jax.lax.convert_element_type(c, jnp.float32)

verify_module(module_greaterEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16)


def test_lessThan():
def module_lessThan(a, b):
c = a < b
return jax.lax.convert_element_type(c, jnp.float32)

verify_module(module_lessThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16)


def test_lessEqual():
def module_lessEqual(a, b):
c = a <= b
return jax.lax.convert_element_type(c, jnp.float32)

verify_module(module_lessEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16)

0 comments on commit 5d067a9

Please sign in to comment.