Skip to content

Commit

Permalink
Add Max operator test plan [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
vobojevicTT committed Jan 22, 2025
1 parent 2cf6021 commit 25a0755
Showing 1 changed file with 173 additions and 0 deletions.
173 changes: 173 additions & 0 deletions forge/test/operators/pytorch/reduce/test_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import torch

from torch import nn

from typing import List, Dict
from loguru import logger

from forge.op_repo import TensorShape
from forge.verify.config import VerifyConfig

from forge.verify.value_checkers import AllCloseValueChecker

from test.operators.utils import InputSourceFlags, VerifyUtils
from test.operators.utils import InputSource
from test.operators.utils import TestVector
from test.operators.utils import TestPlan
from test.operators.utils.compat import TestDevice
from test.operators.utils import TestCollection
from test.operators.utils import TestCollectionCommon
from test.operators.utils import ValueRanges

from test.operators.pytorch.eltwise_unary import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass


class ModelFromAnotherOpMax(nn.Module):
def __init__(self, operator, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_from_another_op"
self.operator = operator
self.kwargs = kwargs

def forward(self, x):
xx = torch.add(x, x)
return self.operator(xx, **self.kwargs)[0]


class ModelDirectMax(nn.Module):
def __init__(self, operator, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_from_host"
self.operator = operator
self.kwargs = kwargs

def forward(self, x):
return self.operator(x, **self.kwargs)[0]


class ModelConstEvalPassMax(nn.Module):
def __init__(self, operator, shape: TensorShape, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_const_eval_pass"
self.operator = operator
self.kwargs = kwargs
self.c = (torch.rand(shape, requires_grad=False) - 0.5).detach()

def forward(self, x):
cc = self.operator(self.c, **self.kwargs)[0]
xx = self.operator(x, **self.kwargs)[0]
return torch.add(xx, cc)



class TestVerification:

MODEL_TYPES = {
InputSource.FROM_ANOTHER_OP: ModelFromAnotherOp,
InputSource.FROM_HOST: ModelDirect,
InputSource.FROM_DRAM_QUEUE: ModelDirect,
InputSource.CONST_EVAL_PASS: ModelConstEvalPass,
}

MODEL_TYPES_MAX_SPECIFIC = {
InputSource.FROM_ANOTHER_OP: ModelFromAnotherOpMax,
InputSource.FROM_HOST: ModelDirectMax,
InputSource.FROM_DRAM_QUEUE: ModelDirectMax,
InputSource.CONST_EVAL_PASS: ModelConstEvalPassMax,
}

@classmethod
def verify(
cls,
test_device: TestDevice,
test_vector: TestVector,
input_params: List[Dict] = [],
warm_reset: bool = False,
):

input_source_flag: InputSourceFlags = None
if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,):
input_source_flag = InputSourceFlags.FROM_DRAM

operator = getattr(torch, test_vector.operator)
kwargs = test_vector.kwargs if test_vector.kwargs else {}

if not kwargs:
model_type = cls.MODEL_TYPES[test_vector.input_source]
else:
model_type = cls.MODEL_TYPES_MAX_SPECIFIC[test_vector.input_source]

pytorch_model = (
model_type(operator, test_vector.input_shape, kwargs)
if test_vector.input_source in (InputSource.CONST_EVAL_PASS,)
else model_type(operator, kwargs)
)

input_shapes = tuple([test_vector.input_shape])

logger.trace(f"***input_shapes: {input_shapes}")

VerifyUtils.verify(
model=pytorch_model,
test_device=test_device,
input_shapes=input_shapes,
input_params=input_params,
input_source_flag=input_source_flag,
dev_data_format=test_vector.dev_data_format,
math_fidelity=test_vector.math_fidelity,
warm_reset=warm_reset,
value_range=ValueRanges.SMALL,
deprecated_verification=False,
verify_config=VerifyConfig(value_checker=AllCloseValueChecker()),
)


class TestParamsData:

__test__ = False

test_plan: TestPlan = None

operator = ["max"]

@classmethod
def generate_kwargs(cls, test_vector: TestVector):

dim = len(test_vector.input_shape)
dims = list(range(0, dim))

for i in dims:
for ch in [True, False]:
yield {"dim": i, "keepdim": ch}


TestParamsData.test_plan = TestPlan(
verify=lambda test_device, test_vector: TestVerification.verify(
test_device,
test_vector,
),
collections=[
# torch.max(input)
TestCollection(
operators=TestParamsData.operator,
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
),
# torch.max(input, dim=..., keepdim=...)
TestCollection(
operators=TestParamsData.operator,
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
),
],
failing_rules=[],
)


def get_test_plans() -> List[TestPlan]:
return [TestParamsData.test_plan]

0 comments on commit 25a0755

Please sign in to comment.