-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Equal * Not equal * greater than * greater or equal to * less than * less or equal to
- Loading branch information
1 parent
b00c49d
commit 5d067a9
Showing
1 changed file
with
65 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
import pytest | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy | ||
|
||
from infrastructure import verify_module | ||
|
||
# Note: TTNN does not support boolean data type, so bfloat16 is used instead. | ||
# Hence the output of comparison operation is bflaot16. JAX can not perform any | ||
# computation due to mismatch in output data type (in testing infrastructure). | ||
# The following tests explicitly convert data type of comparison operation | ||
# output for the verification purposes. | ||
|
||
|
||
def test_equal(): | ||
def module_equal(a, b): | ||
c = a == b | ||
return jax.lax.convert_element_type(c, jnp.float32) | ||
|
||
verify_module(module_equal, [(64, 64), (64, 64)], dtype=jnp.bfloat16) | ||
|
||
|
||
def test_notEqual(): | ||
def module_notEqual(a, b): | ||
c = a != b | ||
return jax.lax.convert_element_type(c, jnp.float32) | ||
|
||
verify_module(module_notEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16) | ||
|
||
|
||
def test_greaterThan(): | ||
def module_greaterThan(a, b): | ||
c = a > b | ||
return jax.lax.convert_element_type(c, jnp.float32) | ||
|
||
verify_module(module_greaterThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16) | ||
|
||
|
||
def test_greaterEqual(): | ||
def module_greaterEqual(a, b): | ||
c = a >= b | ||
return jax.lax.convert_element_type(c, jnp.float32) | ||
|
||
verify_module(module_greaterEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16) | ||
|
||
|
||
def test_lessThan(): | ||
def module_lessThan(a, b): | ||
c = a < b | ||
return jax.lax.convert_element_type(c, jnp.float32) | ||
|
||
verify_module(module_lessThan, [(64, 64), (64, 64)], dtype=jnp.bfloat16) | ||
|
||
|
||
def test_lessEqual(): | ||
def module_lessEqual(a, b): | ||
c = a <= b | ||
return jax.lax.convert_element_type(c, jnp.float32) | ||
|
||
verify_module(module_lessEqual, [(64, 64), (64, 64)], dtype=jnp.bfloat16) |