diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 48104e570..fb0352567 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4620,6 +4620,47 @@ def scaled_dot_product_attention(self, inputs, input_types): attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) return attn_weight + + + def nan_to_num(self, inputs, input_types): + + data = inputs[0] + nan_value = inputs[1] + posinf = inputs[2] + neginf = inputs[3] + + dtype = input_types[0] + + assert dtype == "float32", f"Expected dtype to be float32, but got {dtype}. Support for {dtype} is not added yet." + + dtype_max = np.finfo(dtype).max + dtype_min = np.finfo(dtype).min + + nan_tensor = tvm.relay.const(nan_value if nan_value is not None else 0.0, dtype) + posinf_tensor = tvm.relay.const(posinf if posinf is not None else dtype_max, dtype) + neginf_tensor = tvm.relay.const(neginf if neginf is not None else dtype_min, dtype) + + result = tvm.relay.where(tvm.relay.isnan(data), nan_tensor, data) + result = tvm.relay.where(tvm.relay.equal(data, tvm.relay.const(np.inf, dtype)), posinf_tensor, result) + result = tvm.relay.where(tvm.relay.equal(data, tvm.relay.const(-np.inf, dtype)), neginf_tensor, result) + + return result + + def atan2(self, inputs, input_types): + + data_1 = inputs[1] + data_2 = inputs[0] + + ratio = tvm.relay.divide(data_2, data_1) + atan_res = tvm.relay.atan(ratio) + + pi = tvm.relay.const(np.pi, "float32") + zero = tvm.relay.const(0.0, "float32") + + correction = tvm.relay.where(tvm.relay.less(data_1, zero), tvm.relay.where(tvm.relay.greater_equal(data_2, zero), pi, -pi), zero) + + result = tvm.relay.add(atan_res, correction) + return result # Operator mappings def create_convert_map(self): @@ -4920,6 +4961,8 @@ def create_convert_map(self): "aten::linalg_vector_norm": self.linalg_vector_norm, "aten::scaled_dot_product_attention": self.scaled_dot_product_attention, "aten::lift_fresh": self.identity, + "aten::nan_to_num":self.nan_to_num, + "aten::atan2":self.atan2, } def update_convert_map(self, custom_map):