Skip to content

Commit

Permalink
[Operator] Add tile op
Browse files Browse the repository at this point in the history
  • Loading branch information
zfu82 committed Aug 14, 2024
1 parent 88df9bb commit fa7554a
Show file tree
Hide file tree
Showing 6 changed files with 503 additions and 0 deletions.
16 changes: 16 additions & 0 deletions benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,19 @@ def masked_fill_args(dtype, batch, size):
sizes=SIZES,
)
bench.run()


def test_perf_tile():
def tile_kwargs(dtype, batch, size):
return {"dims": [2, 4]}

bench = Benchmark(
op_name="tile",
torch_op=torch.tile,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=tile_kwargs,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def enable(lib=aten_lib):
lib.impl("isclose", isclose, "CUDA")
lib.impl("allclose", allclose, "CUDA")
lib.impl("flip", flip, "CUDA")
lib.impl("tile", tile, "CUDA")
lib.impl("masked_fill", masked_fill, "CUDA")


Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .sub import sub
from .sum import sum, sum_dim
from .tanh import tanh
from .tile import tile
from .topk import topk
from .triu import triu
from .uniform import uniform_
Expand Down Expand Up @@ -168,6 +169,7 @@
"softmax",
"sub",
"tanh",
"tile",
"triu",
"topk",
"max",
Expand Down
Loading

0 comments on commit fa7554a

Please sign in to comment.