diff --git a/.gitignore b/.gitignore index 0f05e72d..ee73da37 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ build/ # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties` # should NOT be excluded as they contain compiler settings and other important # information for Eclipse / Flash Builder. +playground/ diff --git a/src/flag_gems/ops/flip.py b/src/flag_gems/ops/flip.py index 56a66124..38f152ae 100644 --- a/src/flag_gems/ops/flip.py +++ b/src/flag_gems/ops/flip.py @@ -4,11 +4,12 @@ import triton from ..utils import pointwise_dynamic +from ..utils.tensor_wrapper import StridedBuffer @pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")]) @triton.jit -def flip_func(x, **kwargs): +def copy_func(x): return x @@ -29,10 +30,17 @@ def flip(A: torch.Tensor, dims) -> torch.Tensor: n = 0 offset = 0 for i in range(len(flip_dims_b)): - if flip_dims_b[i] and A.size()[i] > 1 and A.stride()[i] != 0: + if flip_dims_b[i] and A.size(i) > 1 and A.stride(i) != 0: offset += strides[i] * (A.shape[i] - 1) strides[i] = -strides[i] n += 1 if n == 0 or A.numel() <= 1: return A.clone() - return flip_func(A, out0_offset=offset, out0_strides=strides) + out = torch.empty_like(A) + # a flipped view of A + flipped_A = StridedBuffer(A, strides=strides, offset=offset) + + # TODO: flip op can have a custom task simplification method, but we skip it now and just use A's rank. + overload = copy_func.instantiate(A.ndim) + overload(flipped_A, out0=out) + return out diff --git a/src/flag_gems/utils/code_utils.py b/src/flag_gems/utils/code_utils.py index 6f71fa14..b318b940 100644 --- a/src/flag_gems/utils/code_utils.py +++ b/src/flag_gems/utils/code_utils.py @@ -102,6 +102,9 @@ def writelines(self, lines): for line in lines: self.writeline(line) + def writemultiline(self, s): + self.writelines(s.splitlines()) + def indent(self, offset=1): @contextlib.contextmanager def ctx(): diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index c91db26c..287045f0 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -1,17 +1,22 @@ import importlib import os -import threading -from typing import Any, Callable, List, Mapping, Optional, Tuple +from dataclasses import dataclass +from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple import torch -import torch._prims_common as utils import triton -from triton import language as tl from triton.runtime.jit import JITFunction from flag_gems.utils.code_cache import cache_dir -from flag_gems.utils.code_utils import IndentedBuffer, NameSpace -from flag_gems.utils.shape_utils import broadcast_shapes +from flag_gems.utils.code_utils import IndentedBuffer +from flag_gems.utils.shape_utils import ( + all_c_contiguous, + all_the_same_shape, + all_the_same_stride, + broadcasted_stride, +) +from flag_gems.utils.tensor_wrapper import StridedBuffer +from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion # ------------------ Operation Description --------------------------- @@ -33,7 +38,21 @@ def _check_sized_list(container, size): assert len(container) == size -class OPDesc: +def _tuple_content(strings: Sequence[str]) -> str: + # comma separated list + if len(strings) == 0: + return "" + if len(strings) == 1: + return f"{strings[0]}," + else: + return ", ".join(strings) + + +def _cs(strings: Iterable[str]) -> str: + return ", ".join(strings) + + +class FunctionSchema: _num_inputs: int _is_tensor: List[bool] _dtypes: List[Optional[type]] @@ -51,19 +70,21 @@ def __init__( is_tensor: Optional[List[bool]] = None, dtypes: Optional[List[Optional[type]]] = None, num_outputs: Optional[int] = None, - promotion_methods: Optional[List[Tuple[int, ...]]] = None, + promotion_methods=None, ): if is_tensor is not None: _check_typed_list(is_tensor, bool) if dtypes is not None: _check_typed_list(dtypes, (type, type(None))) + if promotion_methods is None: raise ValueError( "No type promotion method provided! You must provide type promotion method for each output!" ) else: - self._promotion_methods = promotion_methods - + self._promotion_methods = self.canonicalize_promotion_methods( + promotion_methods + ) if num_inputs is not None: self._num_inputs = num_inputs if is_tensor is not None: @@ -102,6 +123,7 @@ def __init__( self._num_outputs = num_outputs _check_sized_list(promotion_methods, num_outputs) else: + self._num_outputs = 1 self._num_outputs = len(promotion_methods) assert self._num_inputs >= 1 @@ -109,6 +131,17 @@ def __init__( self._num_input_tensors = sum(self._is_tensor) self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors + self._input_id = self._compute_input_id() + + @staticmethod + def canonicalize_promotion_methods(promotion_methods): + canonicalized = [] + for item in promotion_methods: + *arg_indices, method = item + canonicalized.append( + (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) + ) + return canonicalized def num_inputs(self): # num of arguments, outputs not included @@ -123,6 +156,9 @@ def is_tensor(self, arg_id: int) -> bool: def input_type(self, arg_id) -> Optional[type]: return self._dtypes[arg_id] + def output_type(self, i): + return self._promotion_methods[i] + def num_input_tensors(self) -> int: return self._num_input_tensors @@ -132,28 +168,11 @@ def num_output_tensors(self) -> int: def num_non_tensor_args(self) -> int: return self._num_non_tensor_inputs - def type_promotion_methods(self) -> List[Tuple[int, ...]]: - return self._promotion_methods - - def _match_enum_by_string( - self, input_str: str - ) -> utils.ELEMENTWISE_TYPE_PROMOTION_KIND: - for kind in utils.ELEMENTWISE_TYPE_PROMOTION_KIND: - if input_str.lower() == kind.name.lower(): - return kind - raise ValueError(f"No matching enum member found for input: {input_str}") - - def ith_type_promotion_args(self, i) -> List[int]: - return self._promotion_methods[i][:-1] - - def ith_type_promotion_kind(self, i) -> utils.ELEMENTWISE_TYPE_PROMOTION_KIND: - return self._match_enum_by_string(self._promotion_methods[i][-1]) - - def signature(self, outputs_in_arg: bool = False): + def signature(self, outputs_in_arg: bool = False) -> str: input_types = [] for is_tensor, dtype in zip(self._is_tensor, self._dtypes): if is_tensor: - input_types.append("Tensor") + input_types.append("StridedBuffer") else: if dtype is None: input_types.append("scalar") @@ -161,598 +180,493 @@ def signature(self, outputs_in_arg: bool = False): input_types.append(_type_name(dtype)) output_types = [] - for _ in range(self.num_outputs()): - output_types.append("Tensor") + if outputs_in_arg: + for i in range(self.num_outputs()): + output_types.append(f"StridedBuffer(a{1}!)") input_types.extend(output_types) - sig = f'Pointwise: ({", ".join(input_types)}) -> ({", ".join(output_types)})' + else: + for _ in range(self.num_outputs()): + output_types.append("StridedBuffer") + sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' return sig - def __str__(self) -> str: - return self.signature(outputs_in_arg=False) - - -# --------------------------- pointwise wrapper genration ----------------------------------- -def parameter_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str: - """Generate parameter declaration with type annotation for wrapper function. - Example: in0: torch.Tensor, val0: float, out0: torch.Tensor - """ - parameters: List[str] = [] - - input_tensor_index = 0 - non_tensor_index = 0 - for i in range(op_desc.num_inputs()): - if op_desc._is_tensor[i]: - parameters.append(f"in{input_tensor_index}: torch.Tensor") - input_tensor_index += 1 - else: - if op_desc.input_type(i) is not None: - parameters.append( - f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))}" - ) + def _compute_input_id(self): + input_tensor_index = 0 + non_tensor_index = 0 + mapping: List[int] = [] + for i in range(self.num_inputs()): + if self.is_tensor(i): + mapping.append(input_tensor_index) + input_tensor_index += 1 else: - parameters.append(f"val{non_tensor_index}") - non_tensor_index += 1 - - if include_outputs: - output_tensor_index = 0 - for i in range(op_desc.num_outputs()): - parameters.append(f"out{output_tensor_index}: torch.Tensor") - output_tensor_index += 1 + mapping.append(non_tensor_index) + non_tensor_index += 1 + return mapping - parameters.append("**kwargs") + def input_index(self, idx): + return self._input_id[idx] - return ", ".join(parameters) + def __str__(self) -> str: + return self.signature(outputs_in_arg=False) -def ith_parameter_for_type_promotion(op_desc: OPDesc, ith: int) -> str: - """Generate parameter reference for i-th type promotion rule - Example: in0, val0, out0 - """ - parameters: List[str] = [] +@dataclass +class CodeGenConfig: + max_tile_size: int + max_grid_size: Tuple[int, int, int] + max_num_warps_per_cta: int - input_tensor_index = 0 - non_tensor_index = 0 - for i in range(op_desc.num_inputs()): - if i not in op_desc.ith_type_promotion_args(ith): - if op_desc._is_tensor[i]: - input_tensor_index += 1 - else: - non_tensor_index += 1 - continue - if op_desc._is_tensor[i]: - parameters.append(f"in{input_tensor_index}") - input_tensor_index += 1 - else: - parameters.append(f"val{non_tensor_index}") - non_tensor_index += 1 + prefer_block_pointer: bool + # TODO: add 1d tile back + prefer_1d_tile: bool - return ", ".join(parameters) +class KernelGenerator: + def __init__( + self, + function_schema: FunctionSchema, + scalar_fn: triton.JITFunction, + rank: int, + name: str, + ): + self.fx = function_schema + self.fn = scalar_fn + self.ndim = rank + self.name = name + + self.fn_name = scalar_fn.__name__ + self.fn_module = scalar_fn.__module__ + + def gen_import_function(self, code: IndentedBuffer): + code.writeline(f'"""Quoted source of {self.fn_name}:') + code.writemultiline(self.fn.src) + code.writeline('"""') + code.newline() -def parameter_ref_for_wrapper( - op_desc: OPDesc, - include_outputs: bool = False, - include_offset: bool = False, - include_kwargs: bool = False, -) -> str: - """Generate parameter reference for wrapper function. - Example: in0, val0, out0, out0_offset - """ - parameters: List[str] = [] - - input_tensor_index = 0 - non_tensor_index = 0 - for i in range(op_desc.num_inputs()): - if op_desc._is_tensor[i]: - parameters.append(f"in{input_tensor_index}") - input_tensor_index += 1 + def gen_decorators(self, code): + code.writeline("@libentry()") + num_non_tensor_args = self.fx.num_non_tensor_args() + if num_non_tensor_args > 0: + # we do not specialize non tensor args since they are passed into the inlined function + # which means that their values may not deserve specialization + non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] + code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") else: - parameters.append(f"val{non_tensor_index}") - non_tensor_index += 1 - - if include_outputs: - output_tensor_index = 0 - for i in range(op_desc.num_outputs()): - parameters.append(f"out{output_tensor_index}") - if include_offset: - parameters.append(f"out{output_tensor_index}_offset") - output_tensor_index += 1 + code.writeline("@triton.jit") - if include_kwargs: - parameters.append("**kwargs") + def input_name(self, i): + is_tensor = self.fx.is_tensor(i) + name = "in" if is_tensor else "val" + index = self.fx.input_index(i) + return f"{name}{index}" - return ", ".join(parameters) + def output_name(self, i): + return f"out{i}" + def gen_signature(self, code): + code.writeline(f"def {self.name}(") + with code.indent(): + input_tensor_index = 0 + non_tensor_index = 0 + output_tensor_index = 0 + + schema = self.fx + # signature: inputs ptrs & non tensor inputs + for i in range(schema.num_inputs()): + if schema.is_tensor(i): + code.writeline( + f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" + ) + input_tensor_index += 1 + else: + if schema.input_type(i) is not None: + code.writeline( + f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," + ) + else: + code.writeline(f"val{non_tensor_index},") + non_tensor_index += 1 -def output_ref_for_wrapper(op_desc: OPDesc) -> str: - """Generate output variable refernece for wrapper function. - Example: out0, out1 - """ - parameters: List[str] = [f"out{i}" for i in range(op_desc.num_outputs())] - return ", ".join(parameters) - - -def docstring_for_functional_wrapper(op_desc: OPDesc): - doc = f'"""Generated wrapper function with {str(op_desc)}"""' - return doc - - -def docstring_for_destination_passing_wrapper(op_desc: OPDesc): - doc = f'"""Generated wrapper function with {op_desc.signature(outputs_in_arg=True)}"""' - return doc - - -def generate_imports(code: IndentedBuffer) -> IndentedBuffer: - code.writeline("import math") - code.writeline("import torch") - code.writeline("import triton") - code.writeline("from triton import language as tl") - code.newline() - code.writeline("from flag_gems.utils.shape_utils import (") - code.writeline(" broadcast_shapes,") - code.writeline(" broadcasted_stride,") - code.writeline(" c_contiguous_stride,") - code.writeline(" volume,") - code.writeline(" Stride,") - code.writeline(")") - code.writeline("from flag_gems.utils.libentry import libentry") - code.writeline("from flag_gems.utils.type_utils import type_promotion") - code.writeline("import torch._prims_common as utils") - code.newline() - code.newline() - return code - - -def generate_functional_pointwise_wrapper( - op_desc: OPDesc, - wrapper_name: str, - destination_passing_func_name: str, - code: IndentedBuffer, -) -> IndentedBuffer: - # wrapper signature - parameters: str = parameter_for_wrapper(op_desc, include_outputs=False) - wrapper_signature: str = f"def {wrapper_name}({parameters}):" - code.writeline(wrapper_signature) - - with code.indent(): - # docstring - wrapper_docstring = docstring_for_functional_wrapper(op_desc) - code.writeline(wrapper_docstring) - - shapes_str = ", ".join( - f"in{i}.shape" for i in range(op_desc.num_input_tensors()) - ) - code.writeline(f"shape = broadcast_shapes([{shapes_str}])") + # signature: output ptrs + for i in range(schema.num_outputs()): + code.writeline( + f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" + ) + output_tensor_index += 1 + + # signature: strides, for each tensor arguments + ndim = self.ndim + if ndim > 0: + # strides for inputs + for i in range(schema.num_input_tensors()): + stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) + code.writeline(f"{stride_args}, # strides for in{i}") + + # strides for outputs + for i in range(schema.num_output_tensors()): + stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) + code.writeline(f"{stride_args}, # strides for out{i}") + + # task space, used to reconstruct multi index + task_space_args = _cs(f"s{i}: int" for i in range(ndim)) + code.writeline(f"{task_space_args}, # task_space") + + # number of tasks, used to compute mask + code.writeline("num_tasks: int,") + + # tile size & tiles_per_cta, gsl style + if ndim > 0: + code.writeline("tiles_per_cta: int,") + tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) + code.writeline(f"{tile_sizes},") + code.writeline("one_tile_per_cta: tl.constexpr,") + code.writeline("):") + + def gen_num_tiles(self, code): + # tile-grid size + ndim = self.ndim + for i in range(ndim): + if i < ndim: + code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") + + def gen_body_for_0d(self, code): + schema = self.fx + inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] + outputs_to_scalar_fn = [ + self.output_name(i) for i in range(schema.num_output_tensors()) + ] + inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) + outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) - # output allocation - num_output_tensor_index = 0 - for i in range(op_desc.num_outputs()): - type_promotion_args = ith_parameter_for_type_promotion(op_desc, i) - k_type_promotion = op_desc.ith_type_promotion_kind(i) + code.writeline("# loads") + for i in range(schema.num_input_tensors()): code.writeline( - ( - f"out{num_output_tensor_index} = " - f"torch.empty(shape, dtype=type_promotion" - f"({type_promotion_args}, type_promotion=utils.{k_type_promotion})[1], " - f"device=in0.device)" - ) + f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " + "# workaround the bug on bool, we should use the pointer's dtype)" ) - num_output_tensor_index += 1 + code.newline() - # call destination_passing_func - output_names: str = output_ref_for_wrapper(op_desc) - call_str = ( - f"{output_names} = {destination_passing_func_name}" - f"({parameter_ref_for_wrapper(op_desc, include_outputs=True, include_offset=False, include_kwargs=True)})" + code.writeline("# compute") + code.writeline( + f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" ) - code.writeline(call_str) + code.newline() - return_str = f"return {output_names}" - code.writeline(return_str) + code.writeline("# stores") + for i in range(schema.num_output_tensors()): + code.writeline( + f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" + ) code.newline() + return code + + # nd tile 1d grid kernel with block pointer + def gen_body_one_tile_per_cta_bptr(self, code): + ndim = self.ndim + schema = self.fx + + # block pointer for each operand + shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) + offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) + tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) + order = _tuple_content(tuple(str(i) for i in reversed(range(ndim)))) + + # reconstruct pid multi index + code.writeline( + "# pid multi index recontruction: we use c ordering, right axes changes fastest" + ) + for i in reversed(range(ndim)): + if i > 0: + code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") + code.writeline(f"tile_id //= num_tiles{i}") + else: + code.writeline(f"tile_id{i} = tile_id") code.newline() - return code - - -def generate_destination_passing_pointwise_wrapper( - op_desc: OPDesc, - rank: int, - wrapper_name: str, - kernel_name: str, - code: IndentedBuffer, -) -> IndentedBuffer: - # wrapper signature - parameters: str = parameter_for_wrapper(op_desc, include_outputs=True) - wrapper_signature: str = f"def {wrapper_name}({parameters}):" - code.writeline(wrapper_signature) - - with code.indent(): - # docstring - wrapper_docstring = docstring_for_destination_passing_wrapper(op_desc) - code.writeline(wrapper_docstring) - - if rank > 0: - code.writeline("shape = out0.shape") - code.writeline("num_tasks = volume(shape)") - if rank > 0: - code.writeline("tile_size = min(512, triton.next_power_of_2(num_tasks))") - code.writeline("num_warps = 4") - code.writeline("num_ctas = min(65535, triton.cdiv(num_tasks, tile_size))") + # cta_offsets + code.writeline("# tile offsets") + for i in range(ndim): + code.writeline(f"offset{i} = tile_id{i} * tile_size{i}") + + # loads + code.writeline("# loads") + for i in range(schema.num_input_tensors()): + strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) code.writeline( - "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)" + f"in{i}_bptr = tl.make_block_ptr(" + f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" + ) + code.writeline( + f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " + "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)" ) - else: - code.writeline("num_warps = 1") - code.writeline("num_ctas = 1") - code.writeline("grid = (num_ctas, 1, 1)") code.newline() - # input strides for each input tensor w.r.t. the task index space - if rank > 0: - code.writeline("# strides of each tensor argument w.r.t the task space") - for i in range(op_desc.num_input_tensors()): - code.writeline( - f"in{i}_strides = broadcasted_stride(in{i}.shape, in{i}.stride(), shape)" - ) - for i in range(op_desc.num_output_tensors()): - code.writeline(f"if 'out{i}_offset' in kwargs:") - with code.indent(): - code.writeline(f"out{i}_offset = kwargs['out{i}_offset']") - code.writeline("else:") - with code.indent(): - code.writeline(f"out{i}_offset = 0") - - code.writeline(f"if 'out{i}_strides' in kwargs:") - with code.indent(): - code.writeline(f"out{i}_strides = kwargs['out{i}_strides']") - code.writeline("else:") - with code.indent(): - code.writeline(f"out{i}_strides = out{i}.stride()") - else: - for i in range(op_desc.num_output_tensors()): - code.writeline(f"out{i}_offset = 0") + # compute + # TODO: sepearate this part + inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] + outputs_to_scalar_fn = [ + self.output_name(i) for i in range(schema.num_output_tensors()) + ] + inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) + outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) + + code.writeline("# compute") + code.writeline( + f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" + ) code.newline() - # grid - code.writeline("# kernel launch") + # stores + code.writeline( + "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype" + ) + for i in range(schema.num_output_tensors()): + strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim))) + code.writeline( + f"out{i}_bptr = tl.make_block_ptr(" + f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" + ) + code.writeline( + f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" + ) - # launch kernel - code.writeline("with torch.cuda.device(in0.device.index):") + def gen_body_gsl_bptr(self, code): + code.writeline("num_ctas = tl.num_programs(0)") + code.writeline("for j in range(0, tiles_per_cta):") with code.indent(): - kernel_launch: str = f"{kernel_name}[grid](" - code.writeline(kernel_launch) + code.writeline("tile_id = pid + j * num_ctas") + self.gen_body_one_tile_per_cta_bptr(code) - with code.indent(): - code.writeline( - "{},".format( - parameter_ref_for_wrapper( - op_desc, - include_outputs=True, - include_offset=True, - include_kwargs=False, - ) - ) - ) + # nd tile 1d grid kernel with block of pointers + def gen_body_one_tile_per_cta(self, code): + pass - if rank > 0: - for i in range(op_desc.num_input_tensors()): - s = ", ".join(f"in{i}_strides[{j}]" for j in range(rank)) - code.writeline(f"{s}, # stride for in{i}") + def gen_body_gsl(self, code): + pass - for i in range(op_desc.num_output_tensors()): - s = ", ".join(f"out{i}_strides[{j}]" for j in range(rank)) - code.writeline(f"{s}, # stride for out{i}") + def codegen_nd_tile(self, code): + """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" + self.gen_import_function(code) + self.gen_decorators(code) + self.gen_signature(code) - shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank)) - code.writeline(f"{shape_args}, # task indexing space") - code.writeline("num_tasks, # num tasks") - code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") - code.writeline("tile_size=tile_size,") - code.writeline("one_tile_per_cta=tiles_per_cta==1,") - code.writeline("num_warps=num_warps,") - code.writeline(")") + # function body for rank-0 + if self.ndim == 0: + with code.indent(): + self.gen_body_for_0d(code) + return code - # return - code.writeline(f"return {output_ref_for_wrapper(op_desc)}") - code.newline() + with code.indent(): + code.writeline("pid = tl.program_id(0)") + self.gen_num_tiles(code) + # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute + code.writeline("if one_tile_per_cta: # monolitic kernel style") + with code.indent(): + code.writeline("tile_id = pid") + self.gen_body_one_tile_per_cta_bptr(code) + # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ + code.writeline("else: # grid-stride-loop style kernel") + with code.indent(): + self.gen_body_gsl_bptr(code) code.newline() - return code - - -def generate_pointwise_kernel( - op_desc: OPDesc, - scalar_fn: JITFunction, - rank: int, - kernel_name: str, - code: IndentedBuffer, -) -> IndentedBuffer: - # make the inlined function visible in the context - fn_name = scalar_fn.__name__ - code.writeline(f"from {scalar_fn.__module__} import {fn_name}") - code.writeline(f"inlined_f = {fn_name}._scalar_fn") - code.newline() - - # the decorators - code.writeline("@libentry()") - if op_desc.num_non_tensor_args() > 0: - # we do not specialize non tensor args since they are passed into the inlined function - # which means that their values may not deserve specialization - non_specialize_arg_names = [ - f"val{i}" for i in range(op_desc.num_non_tensor_args()) - ] - code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") - else: - code.writeline("@triton.jit") + return code - # signature - code.writeline(f"def {kernel_name}(") - function_ns = NameSpace() - with code.indent(): - input_tensor_index = 0 - non_tensor_index = 0 - output_tensor_index = 0 - # signature: inputs ptrs & non tensor inputs - for i in range(op_desc.num_inputs()): - if op_desc.is_tensor(i): - code.writeline( - f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" - ) - function_ns.create_name(f"in{input_tensor_index}_ptr") - input_tensor_index += 1 + def codegen_1d_tile(self, code): + """Generate kernel 1d tile & 1d grid with gsl support.""" + pass + + +class WrapperGenerator: + def __init__( + self, + function_schema: FunctionSchema, + jit_fn_name: str, + ndim: int, + name: str, + config: CodeGenConfig, + ): + self.fx = function_schema + self.jit_fn_name = jit_fn_name + self.ndim = ndim + self.name = name + self.config = config + + def input_name(self, i): + is_tensor = self.fx.is_tensor(i) + name = "in" if is_tensor else "val" + index = self.fx.input_index(i) + return f"{name}{index}" + + def output_name(self, i): + return f"out{i}" + + def gen_signature(self, code: IndentedBuffer): + # TODO: check if triton handles constexprs transitively + schema = self.fx + params: List[str] = [] + for i in range(schema.num_inputs()): + if schema.is_tensor(i): + params.append(f"{self.input_name(i)}: StridedBuffer") else: - if op_desc.input_type(i) is not None: - code.writeline( - f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))}," - ) + arg_type = schema.input_type(i) + if arg_type is not None: + params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") else: - code.writeline(f"val{non_tensor_index},") - function_ns.create_name(f"val{non_tensor_index}") - non_tensor_index += 1 - - # signature: output ptrs - for i in range(op_desc.num_outputs()): + params.append(f"{self.input_name(i)}") + # NOTE: [the wrapper's signature and rules for passing parameters ] + # input params: must be passed by position, since the names are renamed to + # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd + # So we enforce that these parameters must be passed by position. + # maybe we can fix it later + # output parameters: must be passed by keywore, since the scalar function + # do not have output parameters(think of it as some scalar function, output + # parameter does not make sense in this case.) They are added to allow destination + # passing style API. Output parameter is convenient in cases where we want + # to use some pre-defiend outputs(especially when they are some views of other + # tensors). We emphasize that these parameters are added in-addition, we enforce + # that they be passed by keyword. After all, out0, out1, ... does not mismatch + # names form the scalar function, since it does not have output parameters. + params.append("/") + params.append("*") # output params must be passed by keyword + for i in range(schema.num_output_tensors()): + params.append(f"{self.output_name(i)}: StridedBuffer=None") + code.writeline(f"def {self.name}({_cs(params)}):") + + def gen_docstring(self, code: IndentedBuffer): + schema = self.fx + doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' + code.writeline(doc) + + def gen_task_partition(self, code: IndentedBuffer): + code.writeline("# task partitioning") + ndim = self.ndim + if ndim == 0: + code.writeline("num_warps = 1") + code.writeline("num_ctas = 1") + else: + code.writeline("shape = out0.shape") + code.writeline("num_tasks = out0.numel()") + max_tile_size = self.config.max_tile_size code.writeline( - f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" + f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" ) - code.writeline(f"out{output_tensor_index}_offset: int,") - function_ns.create_name(f"out{output_tensor_index}_ptr") - function_ns.create_name(f"out{output_tensor_index}_offset") - output_tensor_index += 1 - - # signature: strides, for each tensor arguments - # only add this arguments when rank > 0 - if rank > 0: - # strides for inputs - for i in range(op_desc.num_input_tensors()): - for j in range(rank): - function_ns.create_name(f"in{i}_stride{j}") - stride_args = ", ".join(f"in{i}_stride{j}: int" for j in range(rank)) - code.writeline(f"{stride_args}, # strides for in{i}") - - # strides for outputs - for i in range(op_desc.num_output_tensors()): - for j in range(rank): - function_ns.create_name(f"out{i}_stride{j}") - stride_args = ", ".join(f"out{i}_stride{j}: int" for j in range(rank)) - code.writeline(f"{stride_args}, # strides for out{i}") - - # task space, used to reconstruct multi index - task_space_args = ", ".join(f"s{i}: int" for i in range(rank)) - for i in range(rank): - function_ns.create_name(f"s{i}") - code.writeline(f"{task_space_args}, # task_space") - - # number of tasks, used to compute mask - code.writeline("num_tasks: int,") - function_ns.create_name("num_tasks") - - # tile size & tiles_per_cta, gsl style - if rank > 0: - code.writeline("tiles_per_cta,") - function_ns.create_name("tiles_per_cta") - - code.writeline("tile_size: tl.constexpr,") - function_ns.create_name("tile_size") - - code.writeline("one_tile_per_cta: tl.constexpr,") - function_ns.create_name("one_tile_per_cta") - code.writeline("):") - - # input & output names - inputs_to_scalar_fn = [] - input_tensor_index = 0 - non_tensor_index = 0 - for i in range(op_desc.num_inputs()): - if op_desc.is_tensor(i): - inputs_to_scalar_fn.append(f"in{input_tensor_index}") - input_tensor_index += 1 - else: - inputs_to_scalar_fn.append(f"val{non_tensor_index}") - non_tensor_index += 1 - inputs_to_scalar_fn: str = ", ".join(inputs_to_scalar_fn) + code.writeline("tile_size = math.prod(tile_sizes)") + code.writeline( + "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" + ) + max_grid_size0 = self.config.max_grid_size[0] + code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") + + code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") + code.writeline("num_warps = heuristics_for_num_warps(tile_size)") + code.writeline("one_tile_per_cta = tiles_per_cta==1") + code.writeline("grid = (num_ctas, 1, 1)") + + def gen_kernel_launch(self, code: IndentedBuffer): + schema = self.fx + ndim = self.ndim - outputs_to_scalar_fn = [f"out{i}" for i in range(op_desc.num_outputs())] - outputs_to_scalar_fn: str = ", ".join(outputs_to_scalar_fn) + code.writeline("# kernel launch") + for i in range(schema.num_input_tensors()): + code.writeline(f"in{i}_strides = in{i}.stride()") + for i in range(schema.num_output_tensors()): + code.writeline(f"out{i}_strides = out{i}.stride()") - # function body for rank-0 - if rank == 0: + code.writeline("with torch.cuda.device(in0.device):") with code.indent(): - code.writeline("# loads") - for i in range(op_desc.num_input_tensors()): - ptrs_expr: str = f"in{i}_ptr" - load_stmt: str = f"in{i} = tl.load({ptrs_expr})" - function_ns.create_name(f"in{i}") # add to the namespace - code.writeline(load_stmt) - code.newline() - - code.writeline("# compute") - code.writeline(f"{outputs_to_scalar_fn} = inlined_f({inputs_to_scalar_fn})") - code.newline() - - code.writeline("# stores") - for i in range(op_desc.num_output_tensors()): - ptrs_expr: str = f"out{i}_ptr + out{i}_offset" - store_stmt: str = f"tl.store({ptrs_expr}, out{i})" - code.writeline(store_stmt) - code.newline() - return code + code.writeline(f"{self.jit_fn_name}[grid](") + with code.indent(): + params = [] + # NOTE: WRAP + for i in range(schema.num_inputs()): + if schema.is_tensor(i): + params.append(f"{self.input_name(i)}") + else: + params.append(self.input_name(i)) + for i in range(schema.num_output_tensors()): + params.append(f"{self.output_name(i)}") - with code.indent(): - # get pid - code.writeline("# task id & masking") - pid_stmt = "pid = tl.program_id(0)" - code.writeline(pid_stmt) - function_ns.create_name("pid") + code.writeline(f"{_cs(params)},") - code.writeline("num_ctas = tl.num_programs(0)") - function_ns.create_name("num_ctas") + if ndim > 0: + for i in range(schema.num_input_tensors()): + s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) + code.writeline(f"{s}, # stride for in{i}") - # get tid (a.k.a task id) - tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)" - code.writeline(tid_stmt) - function_ns.create_name("init_tid") + for i in range(schema.num_output_tensors()): + s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) + code.writeline(f"{s}, # stride for out{i}") - # one-tile-per-cta, monolithic kernel style - code.writeline("if one_tile_per_cta: # monolitic kernel style") - with code.indent(): - tid_stmt = "tid = init_tid" - code.writeline(tid_stmt) - function_ns.create_name("tid") - - # only apply masking when rank > 0 - # since we only load a value instead of a block of values when the rank is 0 - mask_stmt: str = "mask = tid < num_tasks" - code.writeline(mask_stmt) - function_ns.create_name("mask") - code.newline() - - # reconstruct multi index - code.writeline("# multi index recontruction") - for i in reversed(range(rank)): - if i > 0: - code.writeline(f"i{i} = tid % s{i}") - code.writeline(f"tid //= s{i}") - else: - code.writeline(f"i{i} = tid") - function_ns.create_name(f"{i}") - code.newline() - - # loads - code.writeline("# loads") - for i in range(op_desc.num_input_tensors()): - ptrs_expr: str = " + ".join( - f"i{j} * in{i}_stride{j}" for j in range(rank) - ) - ptrs_expr: str = f"in{i}_ptr + {ptrs_expr}" - load_stmt: str = f"in{i} = tl.load({ptrs_expr}, mask=mask)" - function_ns.create_name(f"in{i}") # add to the namespace - code.writeline(load_stmt) - code.newline() - - # compute - code.writeline("# compute") - code.writeline(f"{outputs_to_scalar_fn} = inlined_f({inputs_to_scalar_fn})") - code.newline() - - # stores - code.writeline("# stores") - for i in range(op_desc.num_output_tensors()): - ptrs_expr: str = " + ".join( - f"i{j} * out{i}_stride{j}" for j in range(rank) - ) - ptrs_expr: str = f"out{i}_ptr + out{i}_offset + {ptrs_expr}" - store_stmt: str = f"tl.store({ptrs_expr}, out{i}, mask=mask)" - code.writeline(store_stmt) + shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) + code.writeline(f"{shape_args}, # task indexing space") + code.writeline("num_tasks, # num tasks") + code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") + for i in range(ndim): + code.writeline(f"tile_size{i}=tile_sizes[{i}],") + code.writeline("one_tile_per_cta=one_tile_per_cta,") + code.writeline("num_warps=num_warps,") + code.writeline(")") + + def gen_return(self, code: IndentedBuffer): + return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) + code.writeline(f"return {return_exprs}") + + def codegen_nd_tile(self, code): + self.gen_signature(code) - # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ - code.writeline("else: # grid-stride-loop style kernel") with code.indent(): - code.writeline("for j in range(0, tiles_per_cta):") - function_ns.create_name("j") - with code.indent(): - tid_stmt = "tid = init_tid + j * tile_size * num_ctas" - code.writeline(tid_stmt) - function_ns.create_name("tid") - - # only apply masking when rank > 0 - # since we only load a value instead of a block of values when the rank is 0 - mask_stmt: str = "mask = tid < num_tasks" - code.writeline(mask_stmt) - function_ns.create_name("mask") - code.newline() - - # reconstruct multi index - code.writeline("# multi index recontruction") - for i in reversed(range(rank)): - if i > 0: - code.writeline(f"i{i} = tid % s{i}") - code.writeline(f"tid //= s{i}") - else: - code.writeline(f"i{i} = tid") - function_ns.create_name(f"{i}") - code.newline() - - # loads - code.writeline("# loads") - for i in range(op_desc.num_input_tensors()): - ptrs_expr: str = " + ".join( - f"i{j} * in{i}_stride{j}" for j in range(rank) - ) - ptrs_expr: str = f"in{i}_ptr + {ptrs_expr}" - load_stmt: str = f"in{i} = tl.load({ptrs_expr}, mask=mask)" - function_ns.create_name(f"in{i}") # add to the namespace - code.writeline(load_stmt) - code.newline() - - # compute - code.writeline("# compute") - code.writeline( - f"{outputs_to_scalar_fn} = inlined_f({inputs_to_scalar_fn})" - ) - code.newline() + self.gen_docstring(code) + self.gen_task_partition(code) + self.gen_kernel_launch(code) + self.gen_return(code) + code.newline() + return code - # stores - code.writeline("# stores") - for i in range(op_desc.num_output_tensors()): - ptrs_expr: str = " + ".join( - f"i{j} * out{i}_stride{j}" for j in range(rank) - ) - ptrs_expr: str = f"out{i}_ptr + out{i}_offset + {ptrs_expr}" - store_stmt: str = f"tl.store({ptrs_expr}, out{i}, mask=mask)" - code.writeline(store_stmt) - code.newline() - return code - - -def generate_code( - op_desc: OPDesc, - scalar_fn: JITFunction, - inputs: Tuple[Any], - wrapper_name: str, - destination_passing_func_name: str, - kernel_name: str, - code: IndentedBuffer, -) -> IndentedBuffer: - assert ( - len(inputs) == op_desc.num_inputs() - ), "the number of inputs does not match {str(op_desc)}" - input_tensor_ids = [i for i in range(op_desc.num_inputs()) if op_desc.is_tensor(i)] - tensor_shapes = [inputs[i].shape for i in input_tensor_ids] - shape = broadcast_shapes(tensor_shapes) - rank = len(shape) - - # the only runtime determined factor is the rank of the task space - code = generate_imports(code) - code = generate_functional_pointwise_wrapper( - op_desc, wrapper_name, destination_passing_func_name, code - ) - code = generate_destination_passing_pointwise_wrapper( - op_desc, rank, destination_passing_func_name, kernel_name, code - ) - code = generate_pointwise_kernel(op_desc, scalar_fn, rank, kernel_name, code) - return code + +class ModuleGenerator: + def __init__( + self, + function_schema: FunctionSchema, + scalar_fn: triton.JITFunction, + ndim: int, + jit_fn_name: str, + wrapper_name: str, + config: CodeGenConfig, + ): + self.wrapper_gen = WrapperGenerator( + function_schema, jit_fn_name, ndim, wrapper_name, config + ) + self.kernel_gen = KernelGenerator(function_schema, scalar_fn, ndim, jit_fn_name) + + @staticmethod + def generate_imports(code: IndentedBuffer) -> IndentedBuffer: + code.writeline("import math") + code.writeline("from typing import Optional") + code.writeline("import torch") + code.writeline("import triton") + code.writeline("from triton import language as tl") + code.writeline( + "from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes" + ) + code.newline() + code.writeline("from flag_gems.utils.shape_utils import (") + code.writeline(" heuristics_for_tile_size,") + code.writeline(" heuristics_for_num_warps,") + code.writeline(")") + code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") + code.writeline("from flag_gems.utils.libentry import libentry") + code.newline() + code.newline() + return code + + def codegen(self, code: IndentedBuffer): + # the only runtime determined factor is the rank of the task space + code = self.generate_imports(code) + code = self.wrapper_gen.codegen_nd_tile(code) + code = self.kernel_gen.codegen_nd_tile(code) + return code class PointwiseDynamicFunction: @@ -761,64 +675,190 @@ class PointwiseDynamicFunction: The generated code are written out to the cache directory (defaults to ~/.flaggems). """ - def __init__(self, op_desc: OPDesc, scalar_fn: JITFunction): - self._op_desc = op_desc + def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction): + self.fx = op_desc assert isinstance(scalar_fn, JITFunction) self._scalar_fn = scalar_fn self._scalar_fn_cache_key = scalar_fn.cache_key self.pid = os.getpid() - self.lock = threading.Lock() # instantiated & cached overloads - self.overloads: Mapping[str, Callable] = {} + self.overloads: Mapping[int, Callable] = {} def __call__(self, *args, **kwargs): - # note: kwargs should not be used in JITFunction directly - key = f"{self.arg_key(*args)}" - cache = self.overloads - lock = self.lock - - while key not in cache: - # generate file & import it - with lock: - if key in cache: + # inputs must be passed by position, outputs must be passed by keyword + ndim, args, kwargs = self.prepare_args(*args, **kwargs) + overload = self.instantiate(ndim) + out = overload(*args, **kwargs) + return self.unwrap(out) + + @staticmethod + def use_fast_path(tensors): + return all_the_same_shape(tensors) and ( + all_c_contiguous(tensors) + or ( + all_the_same_stride(tensors) + and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) + ) + ) + + def prepare_args(self, *args, **kwargs): + # output allocation(when needed) + # task simplification & task-rank infernece & input-output reinterpretation + schema = self.fx + outputs_that_need_allocation: List[int] = [] + out_tensors = [] + for i in range(schema.num_output_tensors()): + k = f"out{i}" + if k in kwargs: + out_tensors.append(kwargs[k]) + else: + outputs_that_need_allocation.append(i) + # input arguments must be passed by position + in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] + + # output dtype promotions + outputs_dtypes_for_allocation = [] + for i in outputs_that_need_allocation: + *arg_indices, method = schema._promotion_methods[i] + promote_args = (args[j] for j in arg_indices) + _, dtype = type_promotion(*promote_args, type_promotion=method) + outputs_dtypes_for_allocation.append(dtype) + + tensors = out_tensors + in_tensors + if self.use_fast_path(tensors): # dimension collapse & use physical ordering + allocated_outputs = [ + torch.empty_like(tensors[0], dtype=dtype) + for dtype in outputs_dtypes_for_allocation + ] + task_shape = (tensors[0].numel(),) + strides = (1,) + ndim = 1 + args = tuple( + ( + StridedBuffer(item, task_shape, strides) + if schema.is_tensor(i) + else item + ) + for i, item in enumerate(args) + ) + kwargs = { + k: StridedBuffer(item, task_shape, strides) + for k, item in kwargs.items() + } + for seq_id, output_id in enumerate(outputs_that_need_allocation): + kwargs[f"out{output_id}"] = StridedBuffer( + allocated_outputs[seq_id], task_shape, strides + ) + else: + # a simple strategy: all the undefined tensors will follow the first + # tensor that is not broadcated, no attempts to simplify task, no reordering, + # no dimenion collapsing + shapes = tuple(item.shape for item in tensors) + task_shape = torch.broadcast_shapes(*shapes) + ndim = len(task_shape) + for item in tensors: + if item.shape == task_shape: + allocated_outputs = [ + torch.empty_like(item, dtype=dtype) + for dtype in outputs_dtypes_for_allocation + ] break - code = IndentedBuffer() - code = generate_code( - self._op_desc, - self._scalar_fn, - args, - "_wrapper", - "_wrapper_out", - "_jit_function", - code, + else: # nobreak + device = tensors[0].device + allocated_outputs = [ + torch.empty(task_shape, dtype=dtype, device=device) + for dtype in outputs_dtypes_for_allocation + ] + args = tuple( + ( + StridedBuffer( + item, + task_shape, + broadcasted_stride(item.shape, item.stride(), task_shape), + ) + if schema.is_tensor(i) + else item ) - - file_name = f"pointwise_dynamic_{self._scalar_fn_cache_key}_rank_{key}_pid_{self.pid}.py" - - with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: - f.write(code.getvalue()) - - # load - spec = importlib.util.spec_from_file_location( - f"_gen_module_{self._scalar_fn_cache_key}_rank_{key}_pid_{self.pid}", - f.name, + for i, item in enumerate(args) + ) + kwargs = { + k: StridedBuffer( + item, + task_shape, + broadcasted_stride(item.shape, item.stride(), task_shape), ) - m = importlib.util.module_from_spec(spec) - # do not expose it to sys.modules - # sys.modules["_add_module"] = m - spec.loader.exec_module(m) - overload = getattr(m, "_wrapper") - cache[key] = overload - - overload = self.overloads[key] - return overload(*args, **kwargs) - - def arg_key(self, *args): - tensors = [item for item in args if torch.is_tensor(item)] - max_rank = max(item.ndim for item in tensors) - return max_rank + for k, item in kwargs.items() + } + for seq_id, output_id in enumerate(outputs_that_need_allocation): + item = allocated_outputs[seq_id] + kwargs[f"out{output_id}"] = StridedBuffer( + item, + task_shape, + broadcasted_stride(item.shape, item.stride(), task_shape), + ) + return (ndim, args, kwargs) + + def unwrap(self, tensors): + if self.fx.num_output_tensors() == 1: + return tensors.unwrap() + return tuple(item.unwrap() for item in tensors) + + def instantiate(self, ndim): + # NOTE: manually instantiated overload does not have `prepare_args` as + # preprocessing, so you have to manually allocate output and make sure that + # the inputs & ouputs actually fits the manually instantiated overload + if ndim in self.overloads: + return self.overloads[ndim] + + key = str(ndim) + code = IndentedBuffer() + config = CodeGenConfig(8192, (65536, 65536, 65536), 32, True, False) + module_gen = ModuleGenerator( + self.fx, + self._scalar_fn, + ndim, + "_kernel", + "_wrapper", + config, + ) + module_gen.codegen(code) + + # NOTE: [why write the generated code to a file] + # triton uses inpsect to get the source of the jitted function, which requires + # that the source code can be found by inspect + # We write it into a file, since inspect cannot find the source of functions dynamically + # created via exec string. We can help inspect to find the source by hacking linecache + # library, but we find generating a module simpler, since we can generating 2 functions + # the kernel and the wrapper, and the wrapper calls the kernel. + file_name = f"pointwise_dynamic_{self._scalar_fn_cache_key}_rank_{key}_pid_{self.pid}.py" + with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: + f.write(code.getvalue()) + + # load + spec = importlib.util.spec_from_file_location( + f"_gen_module_{self._scalar_fn_cache_key}_rank_{key}_pid_{self.pid}", + f.name, + ) + m = importlib.util.module_from_spec(spec) + # do not expose it to sys.modules + # sys.modules["_add_module"] = m + + # NOTE: [why not import the scalar function] + # we do not re-import the scalar function, although the generated kernel **calls** it + # Since a function's __name__ may be changed, from the module where it is defined import its + # __name__ is not same; Also the same may be rebind to something else, importing via name + # cannot guarantee that scalar function is imported. + # So we copy the scalar function and its __globals__ to the generated module to do this + # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime + spec.loader.exec_module(m) + m.__dict__.update(self._scalar_fn.__globals__) + m.__dict__[self._scalar_fn.__name__] = self._scalar_fn + + overload = getattr(m, "_wrapper") + self.overloads[key] = overload + return overload def pointwise_dynamic( @@ -834,7 +874,7 @@ def decorator(fn): nonlocal num_inputs if (num_inputs is None) and (is_tensor is None) and (dtypes is None): num_inputs = len(fn.arg_names) - op_desc = OPDesc( + op_desc = FunctionSchema( num_inputs=num_inputs, is_tensor=is_tensor, dtypes=dtypes, @@ -846,108 +886,3 @@ def decorator(fn): if f is not None: return decorator(f) return decorator - - -if __name__ == "__main__": - - @pointwise_dynamic( - is_tensor=[True, False, True], - dtypes=[None, float, None], - promotion_methods=[(0, 1, 2, "DEFAULT")], - ) - @triton.jit - def saxpy(x, alpha, y): - return x * alpha + y - - x = torch.randn((3, 4), device="cuda") - y = torch.randn((4,), device="cuda") - out1 = saxpy(x, 2.0, y) - out2 = x * 2.0 + y - print(out1) - print(out2) - torch.testing.assert_close(out1, out2) - print() - - @pointwise_dynamic( - is_tensor=[True, False, True], promotion_methods=[(0, 1, 2, "DEFAULT")] - ) - @triton.jit - def saxpy(x, alpha, y): - return x * alpha + y - - out1 = saxpy(x, 2.0, y) - out2 = x * 2.0 + y - print(out1) - print(out2) - torch.testing.assert_close(out1, out2) - print() - - @pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) - @triton.jit - def ge(x, y): - return x > y - - out1 = ge(x, y) - out2 = x > y - print(out1) - print(out2) - torch.testing.assert_close(out1, out2) - print() - - @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) - @triton.jit - def ordinary(x, y): - return tl.sin(x) + tl.cos(y) - - out1 = ordinary(x, y) - out2 = torch.sin(x) + torch.cos(y) - print(out1) - print(out2) - torch.testing.assert_close(out1, out2) - print() - - @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) - @triton.jit - def ordinary2(x, y): - return tl.sin(x) + tl.cos(y) - - out1 = ordinary2(x, y) - out2 = torch.sin(x) + torch.cos(y) - print(out1) - print(out2) - torch.testing.assert_close(out1, out2) - print() - - @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) - @triton.jit - def ordinary2(x, y): - return tl.sin(x) + tl.cos(y) - - x = torch.tensor(1.0, device="cuda") - y = torch.tensor(2.0, device="cuda") - out1 = ordinary2(x, y) - out2 = torch.sin(x) + torch.cos(y) - print(out1) - print(out2) - torch.testing.assert_close(out1, out2) - print() - - @pointwise_dynamic( - is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")] - ) - @triton.jit - def eq(x, y): - return x.to(tl.float32) == y.to( - tl.float32 - ) # ensures that y is not used for specialization - - x = torch.arange(10, device="cuda") - y = 1 - # by default value 1 is treated as constexpr even thought it is not marked as constexpr - # do_not_specialize avoids this - out1 = eq(x, y) - out2 = x == y - print(out1) - print(out2) - torch.testing.assert_close(out1, out2) - print() diff --git a/src/flag_gems/utils/shape_utils.py b/src/flag_gems/utils/shape_utils.py index f98ff89a..553a1fd6 100644 --- a/src/flag_gems/utils/shape_utils.py +++ b/src/flag_gems/utils/shape_utils.py @@ -1,8 +1,9 @@ import functools import operator -from typing import Iterable, Tuple +from typing import Iterable, Sequence, Tuple import torch +import triton Shape = Tuple[int] Stride = Tuple[int] @@ -125,11 +126,67 @@ def c_contiguous_stride(shape: Shape) -> Stride: s = 1 for size in reversed(shape): strides.append(s) - s *= size - + s *= min(size, 1) # treat size 0 as size 1 return tuple(reversed(strides)) +def f_contiguous_stride(shape: Shape) -> Stride: + strides = [] + s = 1 + for size in shape: + strides.append(s) + s *= min(size, 1) # treat size 0 as size 1 + return tuple(strides) + + +def ordered_stride(shape: Shape, order: Perm) -> Stride: + strides = [0] * len(shape) + s = 1 + for i in order: + strides[i] = s + s *= min(shape[i], 1) # treat size 0 as size 1 + return tuple(strides) + + +def all_the_same_shape(tensors: Sequence[torch.Tensor]) -> bool: + if len(tensors) == 0: + return True + shape = tensors[0].shape + return all(item.shape == shape for item in tensors[1:]) + + +def all_the_same_stride(tensors: Sequence[torch.Tensor]) -> bool: + if len(tensors) == 0: + return True + stride = tensors[0].stride() + return all(item.stride() == stride for item in tensors[1:]) + + +def all_c_contiguous(tensors: Sequence[torch.Tensor]) -> bool: + if len(tensors) == 0: + return True + return all(tensor.is_contiguous() for tensor in tensors) + + +def heuristics_for_tile_size(max_tile_size, *sizes): + tile_sizes = [] + for size in sizes: + tile_size = min(max_tile_size, triton.next_power_of_2(size)) + tile_sizes.append(tile_size) + max_tile_size = max(1, max_tile_size // tile_size) + return tuple(tile_sizes) + + +# This should be part of CodeGenConfig +def heuristics_for_num_warps(tile_size): + if tile_size < 2048: + return 4 + elif tile_size < 4096: + return 8 + else: + return 16 + + def dim_compress(inp, dims): if isinstance(dims, int): dims = [dims] diff --git a/src/flag_gems/utils/tensor_wrapper.py b/src/flag_gems/utils/tensor_wrapper.py new file mode 100644 index 00000000..ada9ca36 --- /dev/null +++ b/src/flag_gems/utils/tensor_wrapper.py @@ -0,0 +1,79 @@ +import math + +import torch + + +class TypedPtr: + """This is a minimal requirement for a type to be treated as a tensor in triton jit + function. Basically it is ia typed pointer, withou knowning the device, size, shape, + strides, etc. + """ + + def __init__(self, ptr: int, dtype: torch.dtype): + self.ptr = ptr + self.dtype = dtype + + def data_ptr(self) -> int: + return self.ptr + + @classmethod + def from_tensor(cls, tensor: torch.Tensor, offset: int = 0): + return cls(tensor.data_ptr() + tensor.element_size() * offset, tensor.dtype) + + @classmethod + def reinterpret_tensor(cls, tensor: torch.Tensor, dtype: torch.dtype, offset=0): + return cls(tensor.data_ptr() + dtype.itemsize * offset, dtype) + + +class StridedBuffer: + """A drop in replacement of torch.Tensor that can be used in wrapper generated by + PointwiseDynamicFunction. It allows us to use a different shape, stride, data + pointer as the base tensor. + + It is a kind of reinterpretation of the base tensor. We make this class since we + cannot get a Tensor view with negative strides via torch APIs, while we need this + to implement flip op. + + Although generated code can accept torch.Tensor & StridedBuffer, but StridedBuffer + may not have all the methods as torch.Tensors do. We add some attributes & methods + with the same name as torch.Tensor, which are used in the generated code. But we + may not cover all the methods, add one if what you need is missing here. + + And can also be used in triton kernels since it also has dtype & data_ptr(). + """ + + def __init__( + self, base: torch.Tensor, shape=None, strides=None, dtype=None, offset=0 + ): + self._base = base + self.dtype = dtype or base.dtype + if offset == 0: + self._data_ptr = self._base.data_ptr() + else: + offset = self.dtype.itemsize * offset + self._data_ptr = self._base.data_ptr() + offset + self.shape = shape if shape is not None else self._base.shape + self._strides = strides if strides is not None else self._base.stride() + self.device = self._base.device + self.ndim = len(self.shape) + + def stride(self): + return self._strides + + def size(self): + return self.shape + + def element_size(self): + return self.dtype.itemsize + + def numel(self): + return math.prod(self.shape) + + def dim(self): + return self.ndim + + def unwrap(self): + return self._base + + def data_ptr(self): + return self._data_ptr diff --git a/src/flag_gems/utils/type_utils.py b/src/flag_gems/utils/type_utils.py index 996cfb21..91a86ddd 100644 --- a/src/flag_gems/utils/type_utils.py +++ b/src/flag_gems/utils/type_utils.py @@ -1,8 +1,8 @@ -import torch._prims_common as utils +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes -def type_promotion(*args, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND): - computation_dtype, result_dtype = utils.elementwise_dtypes( +def type_promotion(*args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND): + computation_dtype, result_dtype = elementwise_dtypes( *args, type_promotion_kind=type_promotion, ) diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 9b33b9d7..014a0d08 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -30,6 +30,7 @@ ALL_FLOAT_DTYPES = [torch.float16, torch.float32, torch.float64, torch.bfloat16] INT_DTYPES = [torch.int16, torch.int32] ALL_INT_DTYPES = [torch.int16, torch.int32, torch.int64] +BOOL_TYPES = [torch.bool] SCALARS = [0.001, -0.999, 100.001, -111.999] DIM_LIST = [0, 1] diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index f7944e92..da6e705b 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -1,4 +1,5 @@ import logging +import random import pytest import torch @@ -8,6 +9,7 @@ from .accuracy_utils import ( ALL_FLOAT_DTYPES, ALL_INT_DTYPES, + BOOL_TYPES, FLOAT_DTYPES, INT_DTYPES, POINTWISE_SHAPES, @@ -67,14 +69,18 @@ def test_accuracy_add_scalar_tensor(shape, scalar, alpha, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseand(shape, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) + if dtype in BOOL_TYPES: + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + else: + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -86,12 +92,16 @@ def test_accuracy_bitwiseand(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseand_scalar(shape, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - inp2 = 0x00FF + if dtype in BOOL_TYPES: + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp2 = bool(random.randint(0, 2)) + else: + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = 0x00FF ref_inp1 = to_reference(inp1) ref_out = torch.bitwise_and(ref_inp1, inp2) @@ -102,12 +112,16 @@ def test_accuracy_bitwiseand_scalar(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseand_scalar_tensor(shape, dtype): - inp1 = 0x00FF - inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) + if dtype in BOOL_TYPES: + inp1 = bool(random.randint(0, 2)) + inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + else: + inp1 = 0x00FF + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) ref_inp2 = to_reference(inp2) ref_out = torch.bitwise_and(inp1, ref_inp2) @@ -118,14 +132,18 @@ def test_accuracy_bitwiseand_scalar_tensor(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseor(shape, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) + if dtype in BOOL_TYPES: + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + else: + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -137,12 +155,16 @@ def test_accuracy_bitwiseor(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseor_scalar(shape, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - inp2 = 0x00FF + if dtype in BOOL_TYPES: + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp2 = bool(random.randint(0, 2)) + else: + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = 0x00FF ref_inp1 = to_reference(inp1) ref_out = torch.bitwise_or(ref_inp1, inp2) @@ -153,12 +175,16 @@ def test_accuracy_bitwiseor_scalar(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseor_scalar_tensor(shape, dtype): - inp1 = 0x00FF - inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) + if dtype in BOOL_TYPES: + inp1 = bool(random.randint(0, 2)) + inp2 = torch.randint(0, 2, size=shape, dtype=torch.bool, device="cuda") + else: + inp1 = 0x00FF + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) ref_inp2 = to_reference(inp2) ref_out = torch.bitwise_or(inp1, ref_inp2) diff --git a/tests/test_pointwise_dynamic.py b/tests/test_pointwise_dynamic.py new file mode 100644 index 00000000..da6062f4 --- /dev/null +++ b/tests/test_pointwise_dynamic.py @@ -0,0 +1,337 @@ +import pytest +import torch +import triton + +from flag_gems.utils.pointwise_dynamic import FunctionSchema, pointwise_dynamic + + +def test_function_schema_with_non_tensor_input(): + schema = FunctionSchema( + is_tensor=[True, False, True], + dtypes=[None, float, None], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + assert schema.num_input_tensors() == 2 + assert schema.num_output_tensors() == 1 + assert schema.num_inputs() == 3 + assert schema.num_non_tensor_args() == 1 + assert schema.input_index(0) == 0 # the first input is the first input tensor + assert schema.input_index(1) == 0 # the second input is the first non tensor input + assert schema.input_index(2) == 1 # the third input is the second input tensor + + +def test_function_schema_mismatch_input_num1(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[True, False, True], + dtypes=[None], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_mismatch_input_num2(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[True, False, True], + num_inputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_mismatch_input_num3(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + num_inputs=2, + dtypes=[None, None, None], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_missing_output_dtype_promotion_rules(): + with pytest.raises(ValueError): + schema = FunctionSchema( + num_inputs=2, + dtypes=[None, None, None], + ) + _ = schema + + +def test_function_schema_mismatch_output_num(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + num_inputs=1, + num_outputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_missing_input_info(): + with pytest.raises(ValueError): + schema = FunctionSchema( + num_outputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_no_tensor_inputs1(): + # no tensor input is okay with FunctionSchema + schema = FunctionSchema( + is_tensor=[False, False, False], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_no_tensor_inputs2(): + schema = FunctionSchema( + num_inputs=3, + is_tensor=[False, False, False], + promotion_methods=[(0, 1, 2, "DEFAULT")], + ) + _ = schema + + +def test_function_schema_no_outputs1(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[False, False, False], + promotion_methods=[], + ) + _ = schema + + +def test_function_schema_no_outputs2(): + with pytest.raises(AssertionError): + schema = FunctionSchema( + is_tensor=[False, False, False], + num_outputs=0, + promotion_methods=[], + ) + _ = schema + + +def test_function_schema_illegal_dtypes(): + with pytest.raises(AssertionError): + schema = FunctionSchema(dtypes=[0, False, "a"]) + _ = schema + + +def test_function_schema_multiple_outputs(): + schema = FunctionSchema( + num_inputs=3, + num_outputs=2, + promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, "ALWAYS_BOOL")], + ) + _ = schema + + +def test_dynamic_function_without_non_tensor_args(): + @pointwise_dynamic(num_inputs=2, promotion_methods=[(0, 1, "DEFAULT")]) + @triton.jit + def add(x, y): + return x + y + + SIZE = 2 + for ndim in range(10): + shape = [SIZE] * ndim + x = torch.randn(shape, device="cuda") + y = torch.randn_like(x) + out = add(x, y) + torch.testing.assert_close(out, x + y) + + +def test_dynamic_function_with_non_tensor_args(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 2 + for ndim in range(10): + shape = [SIZE] * ndim + x = torch.randn(shape, device="cuda") + y = torch.randn_like(x) + alpha = 2.0 + out = axpy(x, y, alpha) + torch.testing.assert_close(out, alpha * x + y) + + +def test_dynamic_function_with_multiple_outputs(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + num_outputs=2, + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + ) + @triton.jit + def multiple_out(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 2 + for ndim in range(10): + shape = [SIZE] * ndim + x = torch.randn(shape, device="cuda") + y = torch.randn_like(x) + alpha = 2.0 + out0, out1 = multiple_out(x, y, alpha) + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +def test_dynamic_function_with_broadcasting(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 10 + x = torch.randn([SIZE, 1, SIZE], device="cuda") + y = torch.randn([1, SIZE, 1], device="cuda") + alpha = 2.0 + out = axpy(x, y, alpha) + torch.testing.assert_close(out, alpha * x + y) + + +def test_dynamic_function_with_broadcasting2(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 10 + x = torch.randn([SIZE, 1, SIZE], device="cuda") + y = torch.randn([], device="cuda") + alpha = 2.0 + out = axpy(x, y, alpha) + torch.testing.assert_close(out, alpha * x + y) + + +def test_dynamic_function_with_predefined_out(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT")], + ) + @triton.jit + def axpy(x, y, alpha): + return alpha * x + y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device="cuda") + y = torch.randn([], device="cuda") + alpha = 2.0 + o = torch.empty([SIZE, SIZE, SIZE], device="cuda") + out = axpy(x, y, alpha, out0=o) + torch.testing.assert_close(out, alpha * x + y) + + +def test_dynamic_function_with_some_predefined_out1(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device="cuda") + y = torch.randn([], device="cuda") + alpha = 2.0 + o = torch.empty([SIZE, SIZE, SIZE], device="cuda") + out0, out1 = axpyaxmy(x, y, alpha, out0=o) + assert out0 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +def test_dynamic_function_with_some_predefined_out2(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device="cuda") + y = torch.randn([], device="cuda") + alpha = 2.0 + o = torch.empty([SIZE, SIZE, SIZE], device="cuda") + out0, out1 = axpyaxmy(x, y, alpha, out1=o) + assert out1 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) + + +def test_dynamic_function_with_bool_input_and_output(): + @pointwise_dynamic( + num_inputs=1, is_tensor=[True], promotion_methods=[(0, "DEFAULT")] + ) + @triton.jit + def invert(x): + return ~x + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device="cuda") > 0 + notx = invert(x) + + torch.testing.assert_close(notx, ~x) + + +def test_dynamic_function_manual_instantiation(): + @pointwise_dynamic( + num_inputs=1, is_tensor=[True], promotion_methods=[(0, "DEFAULT")] + ) + @triton.jit + def invert(x): + return ~x + + SIZE = 10 + x = torch.randn([SIZE, SIZE, SIZE], device="cuda") > 0 + o = torch.empty_like(x) + # manually instantiated overload does not handle output allocation + # since it is kind of low level + notx = invert.instantiate(3)(x, out0=o) + torch.testing.assert_close(notx, ~x) + + +def test_dynamic_function_with_nd_buffer(): + @pointwise_dynamic( + num_inputs=3, + is_tensor=[True, True, False], + promotion_methods=[(0, 1, "DEFAULT"), (0, 1, "DEFAULT")], + ) + @triton.jit + def axpyaxmy(x, y, alpha): + return alpha * x + y, alpha * x - y + + M, N, K = 40, 60, 80 + x = torch.randn([M, N, K], device="cuda")[::2, ::2, ::2] + y = torch.randn([N // 2, K // 2, M // 2], device="cuda").permute(2, 0, 1) + alpha = 2.0 + o = torch.empty([M // 2, N // 2, K // 2], device="cuda") + out0, out1 = axpyaxmy(x, y, alpha, out0=o) + assert out0 is o + torch.testing.assert_close(out0, alpha * x + y) + torch.testing.assert_close(out1, alpha * x - y) diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index f3b553e0..b0359f7e 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -6,6 +6,7 @@ from .accuracy_utils import ( ALL_FLOAT_DTYPES, ALL_INT_DTYPES, + BOOL_TYPES, DIM_POINTWISE_SHAPES, DIMS, FLOAT_DTYPES, @@ -31,11 +32,14 @@ def test_accuracy_abs(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", INT_DTYPES) +@pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwisenot(shape, dtype): - inp = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) + if dtype in BOOL_TYPES: + inp = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + else: + inp = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) ref_inp = to_reference(inp) ref_out = torch.bitwise_not(ref_inp)