From a93ced624bdee7498c03a5b2a3140decb8c455e8 Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Mon, 27 May 2024 09:55:55 +0800 Subject: [PATCH 01/16] [Codegen] refactor codegen for poinwise (#29) * refactor codegen for poinwise 1. support 0d task space; 2. support specifying output dtypes; 3. support non-tensor arguments to the scalar function; 4. add more type annotation in the generated code. --- src/flag_gems/utils/pointwise_dynamic.py | 723 +++++++++++++++++------ src/flag_gems/utils/pointwise_static.py | 303 ---------- 2 files changed, 544 insertions(+), 482 deletions(-) delete mode 100644 src/flag_gems/utils/pointwise_static.py diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 55ef2a9d..6726c30b 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -1,6 +1,5 @@ -from itertools import chain import importlib -from typing import List, Callable, Mapping +from typing import Tuple, List, Callable, Mapping, Optional, Any import torch import triton @@ -13,125 +12,353 @@ from flag_gems.utils.code_utils import IndentedBuffer, NameSpace -def generate_pointwise_wrapper( - inputs: List[torch.Tensor], - num_outputs: int, +# ------------------ Operation Description --------------------------- +def _type_name(type) -> str: + "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" + if type in (bool, int, float, str): + return type.__name__ + if isinstance(type, torch.dtype): + return str(type) + return str(type) + +def _check_typed_list(container, type): + for item in container: + assert isinstance(item, type) + +def _check_sized_list(container, size): + assert len(container) == size + +class OPDesc: + _num_inputs: int + _is_tensor: List[bool] + _dtypes: List[Optional[type]] + + _num_input_tensors: int + _num_non_tensor_inputs: int + + _num_outputs: int + _output_dtypes: List[torch.dtype] + + def __init__( + self, + *, + num_inputs: Optional[int] = None, + is_tensor: Optional[List[bool]] = None, + dtypes: Optional[List[Optional[type]]] = None, + num_outputs: Optional[int] = None, + output_dtypes: Optional[List[torch.dtype]] = None, + ): + if is_tensor is not None: + _check_typed_list(is_tensor, bool) + if dtypes is not None: + _check_typed_list(dtypes, (type, type(None))) + + if num_inputs is not None: + self._num_inputs = num_inputs + if is_tensor is not None: + _check_sized_list(is_tensor, num_inputs) + self._is_tensor = is_tensor + else: + self._is_tensor = [True] * num_inputs + + if dtypes is not None: + _check_sized_list(dtypes, num_inputs) + self._dtypes = dtypes + else: + self._dtypes = [None] * num_inputs + elif is_tensor is not None: + self._num_inputs = len(is_tensor) + self._is_tensor = is_tensor + if dtypes is not None: + _check_sized_list(dtypes, self._num_inputs) + self._dtypes = dtypes + else: + self._dtypes = [None] * self._num_inputs + elif dtypes is not None: + self._num_inputs = len(dtypes) + self._dtypes = dtypes + if is_tensor is not None: + _check_sized_list(is_tensor, self._num_inputs) + self._is_tensor = is_tensor + else: + self._is_tensor = [item is None for item in dtypes] + else: + raise ValueError("Cannot make OPDesc when none of (num_inputs, is_tensor, dtypes) is specified.") + + if output_dtypes is not None: + _check_typed_list(output_dtypes, torch.dtype) + + if num_outputs is not None: + self._num_outputs = num_outputs + if output_dtypes is not None: + _check_sized_list(output_dtypes, num_outputs) + self._output_dtypes = output_dtypes + else: + self._output_dtypes = [None] * num_inputs # infer from the 1st input + elif output_dtypes is not None: + self._num_outputs = len(output_dtypes) + self._output_dtypes = output_dtypes + else: + self._num_outputs = 1 + self._output_dtypes = [None] + + assert self._num_inputs >= 1 + assert self._num_outputs >= 1 + + self._num_input_tensors = sum(self._is_tensor) + self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors + + + def num_inputs(self): + # num of arguments, outputs not included + return self._num_inputs + + def num_outputs(self): + return self._num_outputs + + def is_tensor(self, arg_id: int) -> bool: + return self._is_tensor[arg_id] + + def input_type(self, arg_id) -> Optional[type]: + return self._dtypes[arg_id] + + def output_dtype(self, output_id) -> torch.dtype: + return self._output_dtypes[output_id] + + def num_input_tensors(self) -> int: + return self._num_input_tensors + + def num_output_tensors(self) -> int: + return self._num_outputs + + def num_non_tensor_args(self) -> int: + return self._num_non_tensor_inputs + + def signature(self, outputs_in_arg: bool = False): + input_types = [] + for is_tensor, dtype in zip(self._is_tensor, self._dtypes): + if is_tensor: + input_types.append("Tensor") + else: + if dtype is None: + input_types.append("scalar") + else: + input_types.append(_type_name(dtype)) + + output_types = [] + for dtype in self._output_dtypes: + if dtype is None: + output_types.append("Tensor") + else: + output_types.append(f"Tensor[{_type_name(dtype)}]") + if outputs_in_arg: + input_types.extend(output_types) + sig = f'Pointwise: ({", ".join(input_types)}) -> ({", ".join(output_types)})' + return sig + + def __str__(self) -> str: + return self.signature(outputs_in_arg=False) + +# --------------------------- pointwise wrapper genration ----------------------------------- +def parameter_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str: + """Generate parameter declaration with type annotation for wrapper function. + Example: in0: torch.Tensor, val0: float, out0: torch.Tensor + """ + parameters: List[str] = [] + + input_tensor_index = 0 + non_tensor_index = 0 + for i in range(op_desc.num_inputs()): + if op_desc._is_tensor[i]: + parameters.append(f"in{input_tensor_index}: torch.Tensor") + input_tensor_index += 1 + else: + if op_desc.input_type(i) is not None: + parameters.append(f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))}") + else: + parameters.append(f"val{non_tensor_index}") + non_tensor_index += 1 + + if include_outputs: + output_tensor_index = 0 + for i in range(op_desc.num_outputs()): + parameters.append(f"out{output_tensor_index}: torch.Tensor") + output_tensor_index += 1 + + return ", ".join(parameters) + +def parameter_ref_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str: + """Generate parameter reference for wrapper function. + Example: in0, val0, out0 + """ + parameters: List[str] = [] + + input_tensor_index = 0 + non_tensor_index = 0 + for i in range(op_desc.num_inputs()): + if op_desc._is_tensor[i]: + parameters.append(f"in{input_tensor_index}") + input_tensor_index += 1 + else: + parameters.append(f"val{non_tensor_index}") + non_tensor_index += 1 + + if include_outputs: + output_tensor_index = 0 + for i in range(op_desc.num_outputs()): + parameters.append(f"out{output_tensor_index}") + output_tensor_index += 1 + + return ", ".join(parameters) + +def output_ref_for_wrapper(op_desc: OPDesc) -> str: + """Generate output variable refernece for wrapper function. + Example: out0, out1 + """ + parameters: List[str] = [f"out{i}" for i in range(op_desc.num_outputs())] + return ", ".join(parameters) + +def docstring_for_functional_wrapper(op_desc: OPDesc): + doc = f'"""Generated wrapper function with {str(op_desc)}"""' + return doc + +def docstring_for_destination_passing_wrapper(op_desc: OPDesc): + doc = f'"""Generated wrapper function with {op_desc.signature(outputs_in_arg=True)}"""' + return doc + +def generate_imports(code: IndentedBuffer) -> IndentedBuffer: + code.writeline("import math") + code.writeline("import torch") + code.writeline("import triton") + code.writeline("from triton import language as tl") + code.newline() + code.writeline( + "from flag_gems.utils.shape_utils import broadcast_shapes, broadcasted_stride, c_contiguous_stride, volume, Stride" + ) + code.writeline("from flag_gems.utils.libentry import libentry") + code.newline() + return code + +def generate_functional_pointwise_wrapper( + op_desc: OPDesc, + wrapper_name: str, + destination_passing_func_name: str, + code: IndentedBuffer +) -> IndentedBuffer: + # wrapper signature + parameters: str = parameter_for_wrapper(op_desc, include_outputs=False) + wrapper_signature: str = f"def {wrapper_name}({parameters}):" + code.writeline(wrapper_signature) + + with code.indent(): + # docstring + wrapper_docstring = docstring_for_functional_wrapper(op_desc) + code.writeline(wrapper_docstring) + + shapes_str = ", ".join(f"in{i}.shape" for i in range(op_desc.num_input_tensors())) + code.writeline(f"shape = broadcast_shapes([{shapes_str}])") + + # output allocation + num_output_tensor_index = 0 + for i in range(op_desc.num_outputs()): + if op_desc.input_type(i) is None: + code.writeline(f"out{num_output_tensor_index} = torch.empty(shape, dtype=in0.dtype, device=in0.device)") + else: + code.writeline(f"out{num_output_tensor_index} = torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)") + num_output_tensor_index += 1 + + # call destination_passing_func + output_names: str = output_ref_for_wrapper(op_desc) + call_str = f"{output_names} = {destination_passing_func_name}({parameter_ref_for_wrapper(op_desc, include_outputs=True)})" + code.writeline(call_str) + + return_str = f"return {output_names}" + code.writeline(return_str) + code.newline() + return code + +def generate_destination_passing_pointwise_wrapper( + op_desc: OPDesc, + rank: int, wrapper_name: str, kernel_name: str, - scalar_fn: Callable, - code: IndentedBuffer, + code: IndentedBuffer ) -> IndentedBuffer: - """Generate code to call kernel for static shape. - Shape & stride computations are parts of the generated code. - """ - # number of inputs - num_inputs = len(inputs) - - # compute task index space from input shapes - tensor_shapes = tuple( - item.shape - for item in chain( - inputs, - ) - ) - shape = broadcast_shapes(tensor_shapes) - rank = len(shape) + # wrapper signature + parameters: str = parameter_for_wrapper(op_desc, include_outputs=True) + wrapper_signature: str = f"def {wrapper_name}({parameters}):" + code.writeline(wrapper_signature) # task partitioning, 1d task indexing tile_size = 512 num_warps = 4 - - # wrapper signature - input_parameters: List[str] = [f"in{i}: torch.Tensor" for i in range(num_inputs)] - arguments: str = ", ".join(input_parameters) - wrapper_signature: str = f"def {wrapper_name}({arguments}):" - code.writeline(wrapper_signature) + if rank == 0: # special case with rank-0, only 1 element to compute + tile_size = 32 + num_warps = 1 with code.indent(): # docstring - wrapper_docstring: str = f'"""Generated pointwise kernel with {num_inputs} input tensors and {num_outputs} output tensors."""' + wrapper_docstring = docstring_for_destination_passing_wrapper(op_desc) code.writeline(wrapper_docstring) - # ----- output allocation ----- - # NOTE: the layout of the output depends on - # 1. the first input, if it has no internal overlapping and has the same shape as the output, the output follows its layout - # 2. otherwise, the output is C-contiguous - shapes_str = ", ".join(f"in{i}.shape" for i in range(num_inputs)) + shapes_str = ", ".join(f"in{i}.shape" for i in range(op_desc.num_input_tensors())) code.writeline(f"shape = broadcast_shapes([{shapes_str}])") - - code.writeline("if shape == in0.shape:") - with code.indent(): - for i in range(num_outputs): - allocate_output: str = f"out{i}: torch.Tensor = torch.empty_like(in0)" - code.writeline(allocate_output) - code.writeline("else:") - with code.indent(): - for i in range(num_outputs): - allocate_output: str = f"out{i}: torch.Tensor = torch.empty(shape, dtype=in0.dtype, device=in0.device)" - code.writeline(allocate_output) + code.writeline(f"num_tasks = volume(shape)") + code.newline() # input strides for each input tensor w.r.t. the task index space - inputs: str = ",".join(f"in{i}" for i in range(num_inputs)) - code.writeline( - f"input_strides = tuple(broadcasted_stride(item.shape, item.stride(), shape) for item in ({inputs},))" - ) - # code.writeline(f"print(input_strides)") - # outputs are all c-contiguous, not the best actually - code.writeline( - f"output_strides = tuple(out0.stride() for _ in range({num_outputs}))" - ) - # code.writeline(f"print(output_strides)") + if rank > 0: + code.writeline("# strides of each tensor argument w.r.t the task space") + for i in range(op_desc.num_input_tensors()): + code.writeline(f"in{i}_strides = broadcasted_stride(in{i}.shape, in{i}.stride(), shape)") + + for i in range(op_desc.num_output_tensors()): + code.writeline(f"out{i}_strides = out{i}.stride()") + + code.newline() # grid - code.writeline("num_tasks = volume(shape)") + code.writeline("# kernel launch") grid_stmt: str = f"grid = triton.cdiv(num_tasks, {tile_size}), 1, 1" code.writeline(grid_stmt) # launch kernel kernel_launch: str = f"{kernel_name}[grid](" code.writeline(kernel_launch) - with code.indent(): - # input tensors - input_args: str = ", ".join(f"in{i}" for i in range(num_inputs)) - code.writeline(f"{input_args}, # input tensors") - # output tensors - output_args: str = ", ".join(f"out{i}" for i in range(num_outputs)) - code.writeline(f"{output_args}, # output tensors") - for i in range(num_inputs): - s = ", ".join(f"input_strides[{i}][{j}]" for j in range(rank)) - code.writeline(f"{s}, # stride for in{i}") + with code.indent(): + code.writeline(f"{parameter_ref_for_wrapper(op_desc, include_outputs=True)},") + + if rank > 0: + for i in range(op_desc.num_input_tensors()): + s = ", ".join(f"in{i}_strides[{j}]" for j in range(rank)) + code.writeline(f"{s}, # stride for in{i}") - for i in range(num_outputs): - s = ", ".join(f"output_strides[{i}][{j}]" for j in range(rank)) - code.writeline(f"{s}, # stride for out{i}") + for i in range(op_desc.num_output_tensors()): + s = ", ".join(f"out{i}_strides[{j}]" for j in range(rank)) + code.writeline(f"{s}, # stride for out{i}") - shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank)) - code.writeline(f"{shape_args}, # task indexing space") + shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank)) + if rank > 0: + code.writeline(f"{shape_args}, # task indexing space") + code.writeline(f"num_tasks, # num tasks") code.writeline(f"tile_size={tile_size},") code.writeline(f"num_warps={num_warps},") code.writeline(")") # return - code.writeline(f"return {output_args}") + code.writeline(f"return {output_ref_for_wrapper(op_desc)}") code.newline() - - # generate triton kernel - code = generate_pointwise_kernel( - num_inputs, num_outputs, rank, kernel_name, scalar_fn, code - ) return code - def generate_pointwise_kernel( - num_inputs: int, - num_outputs: int, + op_desc: OPDesc, + scalar_fn: JITFunction, rank: int, kernel_name: str, - scalar_fn: JITFunction, - code: IndentedBuffer, + code: IndentedBuffer ) -> IndentedBuffer: code.writeline("@libentry()") code.writeline("@triton.jit") @@ -140,35 +367,60 @@ def generate_pointwise_kernel( function_ns = NameSpace() # signature with code.indent(): - input_parameters = [f"in{i}_ptr" for i in range(num_inputs)] - output_parameters = [f"out{i}_ptr" for i in range(num_outputs)] - ptr_arguments = ", ".join(chain(input_parameters, output_parameters)) - code.writeline(f"{ptr_arguments},") - for arg_name in ptr_arguments: - function_ns.create_name(arg_name) - - for i in range(num_inputs): - for j in range(rank): - function_ns.create_name(f"stride_in{i}{j}") - stride_args = ", ".join(f"stride_in{i}{j}: int" for j in range(rank)) - code.writeline(f"{stride_args}, # strides for in{i}") - - for i in range(num_outputs): - for j in range(rank): - function_ns.create_name(f"stride_out{i}{j}") - stride_args = ", ".join(f"stride_out{i}{j}: int" for j in range(rank)) - code.writeline(f"{stride_args}, # strides for out{i}") - - task_space_args = ", ".join(f"s{i}: int" for i in range(rank)) - for i in range(rank): - function_ns.create_name(f"s{i}") - code.writeline(f"{task_space_args}, # task_space") + input_tensor_index = 0 + non_tensor_index = 0 + output_tensor_index = 0 + # inputs ptrs & non tensor inputs + for i in range(op_desc.num_inputs()): + if op_desc.is_tensor(i): + code.writeline(f"in{input_tensor_index}_ptr: tl.pointer_type,") + function_ns.create_name(f"in{input_tensor_index}_ptr") + input_tensor_index += 1 + else: + if op_desc.input_type(i) is not None: + code.writeline(f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))},") + else: + code.writeline(f"val{non_tensor_index},") + function_ns.create_name(f"val{non_tensor_index}") + non_tensor_index += 1 + + # output ptrs + for i in range(op_desc.num_outputs()): + code.writeline(f"out{output_tensor_index}_ptr: tl.pointer_type,") + function_ns.create_name(f"out{output_tensor_index}_ptr") + output_tensor_index += 1 + + + if rank > 0: + # strides for inputs + for i in range(op_desc.num_input_tensors()): + for j in range(rank): + function_ns.create_name(f"in{i}_stride{j}") + stride_args = ", ".join(f"in{i}_stride{j}: int" for j in range(rank)) + code.writeline(f"{stride_args}, # strides for in{i}") + + # strides for outputs + for i in range(op_desc.num_output_tensors()): + for j in range(rank): + function_ns.create_name(f"out{i}_stride{j}") + stride_args = ", ".join(f"out{i}_stride{j}: int" for j in range(rank)) + code.writeline(f"{stride_args}, # strides for out{i}") + + # task space, used to reconstruct multi index + task_space_args = ", ".join(f"s{i}: int" for i in range(rank)) + for i in range(rank): + function_ns.create_name(f"s{i}") + code.writeline(f"{task_space_args}, # task_space") + + # number of tasks, used to compute mask + code.writeline(f"num_tasks: int,") + function_ns.create_name("num_tasks") code.writeline("tile_size: tl.constexpr,") function_ns.create_name("tile_size") - code.writeline("):") + # function body with code.indent(): # get pid code.writeline("# task id & masking") @@ -176,145 +428,258 @@ def generate_pointwise_kernel( code.writeline(pid_stmt) function_ns.create_name("pid") - # tile size + # get tid (a.k.a task id) tid_stmt = "tid = pid * tile_size + tl.arange(0, tile_size)" code.writeline(tid_stmt) function_ns.create_name("tid") - # masking - volume_expr: str = " * ".join(f"s{i}" for i in range(rank)) - num_task_stmt: str = f"num_tasks = {volume_expr}" - code.writeline(num_task_stmt) - function_ns.create_name("num_tasks") - - mask_stmt: str = "mask = tid < num_tasks" - code.writeline(mask_stmt) - function_ns.create_name("mask") - code.newline() + if rank > 0: + # only apply masking when rank > 0 + # since we only load a value instead of a block of values when the rank is 0 + mask_stmt: str = "mask = tid < num_tasks" + code.writeline(mask_stmt) + function_ns.create_name("mask") + code.newline() # reconstruct multi index - code.writeline("# multi index recontruction") - for i in reversed(range(rank)): - code.writeline(f"i{i} = tid % s{i}") - function_ns.create_name(f"{i}") - if i > 0: - code.writeline(f"tid //= s{i}") - code.newline() + if rank > 0: + code.writeline("# multi index recontruction") + for i in reversed(range(rank)): + code.writeline(f"i{i} = tid % s{i}") + function_ns.create_name(f"{i}") + if i > 0: + code.writeline(f"tid //= s{i}") + code.newline() # loads code.writeline("# loads") - for i in range(num_inputs): - ptrs_expr: str = " + ".join(f"i{j} * stride_in{i}{j}" for j in range(rank)) - ptrs_expr: str = f"in{i}_ptr + {ptrs_expr}" - load_stmt: str = f"in{i} = tl.load({ptrs_expr}, mask=mask)" + for i in range(op_desc.num_input_tensors()): + if rank > 0: + ptrs_expr: str = " + ".join(f"i{j} * in{i}_stride{j}" for j in range(rank)) + ptrs_expr: str = f"in{i}_ptr + {ptrs_expr}" + load_stmt: str = f"in{i} = tl.load({ptrs_expr}, mask=mask)" + else: + ptrs_expr: str = f"in{i}_ptr" + load_stmt: str = f"in{i} = tl.load({ptrs_expr})" function_ns.create_name(f"in{i}") # add to the namespace code.writeline(load_stmt) code.newline() # compute code.writeline("# compute") + + inputs_to_scalar_fn = [] + input_tensor_index = 0 + non_tensor_index = 0 + for i in range(op_desc.num_inputs()): + if op_desc.is_tensor(i): + inputs_to_scalar_fn.append(f"in{input_tensor_index}") + input_tensor_index += 1 + else: + inputs_to_scalar_fn.append(f"val{non_tensor_index}") + non_tensor_index += 1 + + outputs_to_scalar_fn = [f"out{i}" for i in range(op_desc.num_outputs())] + + compute_body = inline_function( scalar_fn, - [f"in{i}" for i in range(num_inputs)], - [f"out{i}" for i in range(num_outputs)], + inputs_to_scalar_fn, + outputs_to_scalar_fn, function_ns, ) for line in compute_body.strip().splitlines(): code.writeline(line) code.newline() - # loads + # stores code.writeline("# stores") - for i in range(num_outputs): - ptrs_expr: str = " + ".join(f"i{j} * stride_out{i}{j}" for j in range(rank)) - ptrs_expr: str = f"out{i}_ptr + {ptrs_expr}" - load_stmt: str = f"tl.store({ptrs_expr}, out{i}, mask=mask)" - code.writeline(load_stmt) + for i in range(op_desc.num_output_tensors()): + if rank > 0: + ptrs_expr: str = " + ".join(f"i{j} * out{i}_stride{j}" for j in range(rank)) + ptrs_expr: str = f"out{i}_ptr + {ptrs_expr}" + store_stmt: str = f"tl.store({ptrs_expr}, out{i}, mask=mask)" + else: + ptrs_expr: str = f"out{i}_ptr" + store_stmt: str = f"tl.store({ptrs_expr}, out{i})" + code.writeline(store_stmt) code.newline() - return code +def generate_code( + op_desc: OPDesc, + scalar_fn: JITFunction, + inputs: Tuple[Any], + wrapper_name: str, + destination_passing_func_name:str, + kernel_name:str, + code: IndentedBuffer +) -> IndentedBuffer: + assert len(inputs) == op_desc.num_inputs(), "the number of inputs does not match {str(op_desc)}" + input_tensor_ids = [i for i in range(op_desc.num_inputs()) if op_desc.is_tensor(i)] + tensor_shapes = [inputs[i].shape for i in input_tensor_ids] + shape = broadcast_shapes(tensor_shapes) + rank = len(shape) -def generate_imports(code: IndentedBuffer) -> IndentedBuffer: - code.writeline("import math") - code.writeline("import torch") - code.writeline("import triton") - code.writeline("from triton import language as tl") - code.newline() - code.writeline( - "from flag_gems.utils.shape_utils import broadcast_shapes, broadcasted_stride, c_contiguous_stride, volume, Stride" - ) - code.writeline("from flag_gems.utils import libentry") - code.newline() + # the only runtime determined factor is the rank of the task space + code = generate_imports(code) + code = generate_functional_pointwise_wrapper(op_desc, wrapper_name, destination_passing_func_name, code) + code = generate_destination_passing_pointwise_wrapper(op_desc, rank, destination_passing_func_name, kernel_name, code) + code = generate_pointwise_kernel(op_desc, scalar_fn, rank, kernel_name, code) return code - class PointwiseDynamicFunction: """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). The generated code are written out to the cache directory (defaults to ~/.flaggems). """ - def __init__(self, scalar_fn: JITFunction): - self.scalar_fn = scalar_fn - self.scalar_fn_cache_key = scalar_fn.cache_key + def __init__(self, op_desc: OPDesc, scalar_fn: JITFunction): + self._op_desc = op_desc + + assert isinstance(scalar_fn, JITFunction) + self._scalar_fn = scalar_fn + self._scalar_fn_cache_key = scalar_fn.cache_key + + # instantiated & cached overloads self.overloads: Mapping[str, Callable] = {} - def __call__(self, *args, **kwargs): - key = f"{self.arg_key(*args, **kwargs)}" + def __call__(self, *args): + # It does not accept kwargs + key = f"{self.arg_key(*args)}" if key in self.overloads: overload = self.overloads[key] else: # generate file & import it code = IndentedBuffer() - code = generate_imports(code) - code = generate_pointwise_wrapper( - args, 1, "_wrapper", "_jit_function", self.scalar_fn, code - ) - - file_name = f"pointwise_dynamic_{self.scalar_fn_cache_key}_rank_{key}.py" + code = generate_code( + self._op_desc, + self._scalar_fn, + args, + "_wrapper", + "_wrapper_out", + "_jit_function", + code) + + file_name = f"pointwise_dynamic_{self._scalar_fn_cache_key}_rank_{key}.py" with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: f.write(code.getvalue()) - f.close() # load - spec = importlib.util.spec_from_file_location("_add_module", f.name) + spec = importlib.util.spec_from_file_location( + f"_gen_module_{self._scalar_fn_cache_key}", + f.name) m = importlib.util.module_from_spec(spec) # do not expose it to sys.modules # sys.modules["_add_module"] = m spec.loader.exec_module(m) overload = getattr(m, "_wrapper") self.overloads[key] = overload - return overload(*args, **kwargs) + return overload(*args) - def arg_key(self, *args, **kwargs): + def arg_key(self, *args): tensors = [item for item in args if torch.is_tensor(item)] max_rank = max(item.ndim for item in tensors) return max_rank -def pointwise_dynamic(function: JITFunction): - return PointwiseDynamicFunction(function) +def pointwise_dynamic( + f: Optional[JITFunction] = None, + *, + num_inputs: Optional[int] = None, + is_tensor: Optional[List[bool]] = None, + dtypes: Optional[List[Optional[type]]] = None, + num_outputs: Optional[int] = None, + output_dtypes: Optional[List[type]] = None +): + def decorator(fn): + nonlocal num_inputs + if (num_inputs is None) and (is_tensor is None) and (dtypes is None): + num_inputs = len(fn.arg_names) + op_desc = OPDesc( + num_inputs=num_inputs, + is_tensor=is_tensor, + dtypes=dtypes, + num_outputs=num_outputs, + output_dtypes=output_dtypes) + return PointwiseDynamicFunction(op_desc, fn) + + if f is not None: + return decorator(f) + return decorator if __name__ == "__main__": + @pointwise_dynamic( + is_tensor=[True, False, True], + dtypes=[None, float, None]) + @triton.jit + def saxpy(x, alpha, y): + return x * alpha + y + + x = torch.randn((3, 4), device="cuda") + y = torch.randn((4,), device="cuda") + out1 = saxpy(x, 2.0, y) + out2 = x * 2.0 + y + print(out1) + print(out2) + print() + + @pointwise_dynamic(is_tensor=[True, False, True]) + @triton.jit + def saxpy(x, alpha, y): + return x * alpha + y + + out1 = saxpy(x, 2.0, y) + out2 = x * 2.0 + y + print(out1) + print(out2) + print() + + + @pointwise_dynamic(output_dtypes=[torch.bool]) + @triton.jit + def ge(x, y): + return x > y + + out1 = ge(x, y) + out2 = x > y + print(out1) + print(out2) + print() + + @pointwise_dynamic() + @triton.jit + def ordinary(x, y): + return tl.sin(x) + tl.cos(y) + + out1 = ordinary(x, y) + out2 = torch.sin(x) + torch.cos(y) + print(out1) + print(out2) + print() @pointwise_dynamic @triton.jit - def f(a, b): - c = a + b - return tl.sigmoid(c) + def ordinary2(x, y): + return tl.sin(x) + tl.cos(y) - a = torch.randn(100, 100, 100, device="cuda")[::2, ::3, ::2] - b = torch.randn_like(a) - # print(a.shape, a.stride()) + out1 = ordinary2(x, y) + out2 = torch.sin(x) + torch.cos(y) + print(out1) + print(out2) + print() - print(f(a, b)) - print(torch.sigmoid(a + b)) + @pointwise_dynamic + @triton.jit + def ordinary2(x, y): + return tl.sin(x) + tl.cos(y) - import triton + x = torch.tensor(1.0, device="cuda") + y = torch.tensor(2.0, device="cuda") + out1 = ordinary2(x, y) + out2 = torch.sin(x) + torch.cos(y) + print(out1) + print(out2) + print() - t1 = triton.testing.do_bench(lambda: f(a, b), return_mode="median") - t2 = triton.testing.do_bench(lambda: torch.sigmoid(a + b), return_mode="median") - print(t1) - print(t2) diff --git a/src/flag_gems/utils/pointwise_static.py b/src/flag_gems/utils/pointwise_static.py deleted file mode 100644 index b842cf41..00000000 --- a/src/flag_gems/utils/pointwise_static.py +++ /dev/null @@ -1,303 +0,0 @@ -from itertools import chain -import importlib -from typing import List, Callable, Mapping -import hashlib - -import torch -import triton -from triton.runtime.jit import JITFunction - -from flag_gems.utils.shape_utils import ( - broadcast_shapes, - broadcasted_stride, - c_contiguous_stride, - volume, - Shape, - Stride, -) -from flag_gems.utils.code_cache import cache_dir -from flag_gems.utils.inliner import inline_function -from flag_gems.utils.code_utils import IndentedBuffer, NameSpace - - -def generate_pointwise_wrapper( - inputs: List[torch.Tensor], - num_outputs: int, - wrapper_name: str, - kernel_name: str, - scalar_fn: JITFunction, - code: IndentedBuffer, -) -> IndentedBuffer: - """Generate code to call kernel for static shape. - Shape & stride computations are done in code-generation time. - """ - # number of inputs - num_inputs = len(inputs) - - # compute task index space from input shapes - tensor_shapes = tuple( - item.shape - for item in chain( - inputs, - ) - ) - shape = broadcast_shapes(tensor_shapes) - num_tasks = volume(shape) - - # # input strides for each input tensor w.r.t. the task index space - # input_strides = tuple(broadcasted_stride(item.shape, item.stride(), shape) for item in inputs) - # # outputs are all c-contiguous, not the best actually - # output_strides = tuple(c_contiguous_stride(shape) for _ in range(num_outputs)) - - # task partitioning, 1d task indexing - tile_size = 512 - num_warps = 4 - grid = triton.cdiv(num_tasks, tile_size), 1, 1 - - # wrapper signature - input_parameters: List[str] = [f"in{i}: torch.Tensor" for i in range(num_inputs)] - arguments: str = ", ".join(input_parameters) - wrapper_signature: str = f"def {wrapper_name}({arguments}):" - code.writeline(wrapper_signature) - - with code.indent(): - # docstring - wrapper_docstring: str = f'"""Generated pointwise kernel with {num_inputs} input tensors and {num_outputs} output tensors."""' - code.writeline(wrapper_docstring) - - # ----- output allocation ----- - # NOTE: the layout of the output depends on - # 1. the first input, if it has no internal overlapping and has the same shape as the output, the output follows its layout - # 2. otherwise, the output is C-contiguous - shape = broadcast_shapes([item.shape for item in inputs]) - code.writeline(f"shape = {shape}") - - code.writeline("if shape == in0.shape:") - with code.indent(): - for i in range(num_outputs): - allocate_output: str = f"out{i}: torch.Tensor = torch.empty_like(in0)" - code.writeline(allocate_output) - code.writeline("else:") - with code.indent(): - for i in range(num_outputs): - allocate_output: str = f"out{i}: torch.Tensor = torch.empty(shape, dtype=in0.dtype, device=in0.device)" - code.writeline(allocate_output) - - # input strides for each input tensor w.r.t. the task index space - input_strides = tuple( - broadcasted_stride(item.shape, item.stride(), shape) for item in (inputs) - ) - # outputs are all c-contiguous, not the best actually - output_strides = tuple(c_contiguous_stride(shape) for _ in range(num_outputs)) - - # grid - grid_stmt: str = f"grid = {grid}" - code.writeline(grid_stmt) - - # launch kernel - kernel_launch: str = f"{kernel_name}[grid](" - code.writeline(kernel_launch) - with code.indent(): - # input tensors - input_args: str = ", ".join(f"in{i}" for i in range(num_inputs)) - code.writeline(f"{input_args}, # input tensors") - - # output tensors - output_args: str = ", ".join(f"out{i}" for i in range(num_outputs)) - code.writeline(f"{output_args}, # output tensors") - - code.writeline(f"tile_size={tile_size},") - code.writeline(f"num_warps={num_warps},") - code.writeline(")") - - # return - code.writeline(f"return {output_args}") - code.newline() - - # generate triton kernel - code = generate_pointwise_kernel( - input_strides, output_strides, shape, "_jit_function", scalar_fn, code - ) - return code - - -def generate_pointwise_kernel( - input_strides: List[Stride], - output_strides: List[Stride], - task_space: Shape, - kernel_name: str, - scalar_fn: JITFunction, - code: IndentedBuffer, -) -> IndentedBuffer: - code.writeline("@libentry()") - code.writeline("@triton.jit") - code.writeline(f"def {kernel_name}(") - - num_inputs = len(input_strides) - num_outputs = len(output_strides) - rank = len(task_space) - - function_ns = NameSpace() - - # signature - with code.indent(): - input_parameters = [f"in{i}_ptr" for i in range(num_inputs)] - output_parameters = [f"out{i}_ptr" for i in range(num_outputs)] - ptr_arguments = ", ".join(chain(input_parameters, output_parameters)) - code.writeline(f"{ptr_arguments},") - for arg_name in ptr_arguments: - function_ns.create_name(arg_name) - - code.writeline("tile_size: tl.constexpr,") - function_ns.create_name("tile_size") - - code.writeline("):") - - num_tasks = volume(task_space) - with code.indent(): - # get pid - code.writeline("# task id & masking") - pid_stmt = "pid = tl.program_id(0)" - code.writeline(pid_stmt) - function_ns.create_name("pid") - # tile size - tid_stmt = "tid = pid * tile_size + tl.arange(0, tile_size)" - code.writeline(tid_stmt) - function_ns.create_name("tid") - # masking - mask_stmt: str = f"mask = tid < {num_tasks}" - code.writeline(mask_stmt) - function_ns.create_name("mask") - code.newline() - - # reconstruct multi index - code.writeline("# multi index recontruction") - for i in reversed(range(rank)): - code.writeline(f"i{i} = tid % {task_space[i]}") - code.writeline(f"i{i}") - if i > 0: - code.writeline(f"tid //= {task_space[i]}") - code.newline() - - # loads - code.writeline("# loads") - for i in range(num_inputs): - ptrs_expr: str = " + ".join( - f"i{j} * {input_strides[i][j]}" for j in range(rank) - ) - ptrs_expr: str = f"in{i}_ptr + {ptrs_expr}" - load_stmt: str = f"in{i} = tl.load({ptrs_expr}, mask=mask)" - function_ns.create_name(f"in{i}") # add to the namespace - code.writeline(load_stmt) - code.newline() - - # compute - code.writeline("# compute") - compute_body = inline_function( - scalar_fn, - [f"in{i}" for i in range(num_inputs)], - [f"out{i}" for i in range(num_outputs)], - function_ns, - ) - for line in compute_body.strip().splitlines(): - code.writeline(line) - code.newline() - - # loads - code.writeline("# stores") - for i in range(num_outputs): - ptrs_expr: str = " + ".join( - f"i{j} * {output_strides[i][j]}" for j in range(rank) - ) - ptrs_expr: str = f"out{i}_ptr + {ptrs_expr}" - load_stmt: str = f"tl.store({ptrs_expr}, out{i}, mask=mask)" - code.writeline(load_stmt) - code.newline() - - return code - - -def generate_imports(code: IndentedBuffer) -> IndentedBuffer: - code.writeline("import math") - code.writeline("import torch") - code.writeline("import triton") - code.writeline("from triton import language as tl") - code.newline() - - code.writeline("from flag_gems.utils import libentry") - code.newline() - return code - - -class PointwiseStaticFunction: - """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction - which are specialized according to each input tensors shape & strides. - The generated code are written out to the cache directory (defaults to ~/.flaggems). - """ - - def __init__(self, scalar_fn: JITFunction): - self.scalar_fn = scalar_fn - self.scalar_fn_cache_key = scalar_fn.cache_key - self.overloads: Mapping[str, Callable] = {} - - def __call__(self, *args, **kwargs): - key = hashlib.sha256( - f"{self.arg_key(*args, **kwargs)}".encode("utf-8") - ).hexdigest() - if key in self.overloads: - overload = self.overloads[key] - else: - # generate file & import it - code = IndentedBuffer() - code = generate_imports(code) - code = generate_pointwise_wrapper( - args, 1, "_wrapper", "_jit_function", self.scalar_fn, code - ) - - file_name = f"pointwise_static_{self.scalar_fn_cache_key}_spec_{key}.py" - with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: - f.write(code.getvalue()) - f.close() - - # load - spec = importlib.util.spec_from_file_location("_add_module", f.name) - m = importlib.util.module_from_spec(spec) - # do not expose it to sys.modules - # sys.modules["_add_module"] = m - spec.loader.exec_module(m) - overload = getattr(m, "_wrapper") - self.overloads[key] = overload - return overload(*args, **kwargs) - - def arg_key(self, *args, **kwargs): - tensors = [item for item in args if torch.is_tensor(item)] - shapes = tuple(item.shape for item in tensors) - strides = tuple(item.stride() for item in tensors) - return (shapes, strides) - - -def pointwise_static(function: JITFunction): - return PointwiseStaticFunction(function) - - -if __name__ == "__main__": - - @pointwise_static - @triton.jit - def f(a, b): - c = a + b - return tl.sigmoid(c) - - a = torch.randn(4, 4, device="cuda") - b = torch.randn_like(a) - # print(a.shape, a.stride()) - - print(f(a, b)) - print(torch.sigmoid(a + b)) - - import triton - - t1 = triton.testing.do_bench(lambda: f(a, b), return_mode="median") - t2 = triton.testing.do_bench(lambda: torch.sigmoid(a + b), return_mode="median") - print(t1) - print(t2) From e7131e3204f9080f9157624be9a836acb441ef8c Mon Sep 17 00:00:00 2001 From: Ping Zhu <58718936+pingzhuu@users.noreply.github.com> Date: Mon, 27 May 2024 11:27:36 +0800 Subject: [PATCH 02/16] Add fused apply_rotary_pos_emb (#25) * Add fused apply_rotary_pos_emb * Add contiguous check for cos/sin and refine shape check * refine by comment --- src/flag_gems/fused/__init__.py | 6 +- src/flag_gems/fused/rotary_embedding.py | 186 ++++++++++++++++++++++++ tests/flag_gems/op_accu_test.py | 111 +++++++++++++- tests/flag_gems/op_perf_test.py | 22 +++ 4 files changed, 318 insertions(+), 7 deletions(-) create mode 100644 src/flag_gems/fused/rotary_embedding.py diff --git a/src/flag_gems/fused/__init__.py b/src/flag_gems/fused/__init__.py index ed4f706c..fccd6d37 100644 --- a/src/flag_gems/fused/__init__.py +++ b/src/flag_gems/fused/__init__.py @@ -1,3 +1,4 @@ +from .rotary_embedding import apply_rotary_pos_emb from .skip_layernorm import skip_layer_norm from .skip_rms_norm import skip_rms_norm from .silu_and_mul import silu_and_mul @@ -5,8 +6,9 @@ __all__ = [ - "skip_layer_norm", - "skip_rms_norm", + "apply_rotary_pos_emb", + "skip_layer_norm", + "skip_rms_norm", "silu_and_mul", "gelu_and_mul", ] diff --git a/src/flag_gems/fused/rotary_embedding.py b/src/flag_gems/fused/rotary_embedding.py new file mode 100644 index 00000000..7f065f6c --- /dev/null +++ b/src/flag_gems/fused/rotary_embedding.py @@ -0,0 +1,186 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import libentry +import math + + +@libentry() +@triton.jit +def apply_rotary_pos_emb_kernel( + oq_ptr, + ok_ptr, + q_ptr, # (n_tokens, q_heads, head_dim) + k_ptr, # (n_tokens, k_heads, head_dim) + cos_ptr, # (max_seq_len, dim // 2) + sin_ptr, # (max_seq_len, dim // 2) + pos_ptr, # (n_tokens, ) + q_stride_s, + q_stride_h, + q_stride_d, + k_stride_s, + k_stride_h, + k_stride_d, + oq_stride_s, + oq_stride_h, + oq_stride_d, + ok_stride_s, + ok_stride_h, + ok_stride_d, + p_stride_s, + cos_stride_s, + sin_stride_s, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + PADDED_HEAD_DIM: tl.constexpr, + ROTARY_INTERLEAVED: tl.constexpr, + MAX_POSITION_EMBEDDINGS: tl.constexpr, +): + s_id = tl.program_id(0) + + pos_ptr += s_id * p_stride_s + pos_id = tl.load(pos_ptr) + cos_ptr += pos_id * cos_stride_s + sin_ptr += pos_id * sin_stride_s + + # note: set TRITON_DEBUG=1 to enable this check + tl.device_assert(pos_id < MAX_POSITION_EMBEDDINGS, "position id out of bound") + + ordered_block = tl.arange(0, PADDED_HEAD_DIM) + mask = ordered_block < HEAD_DIM + if ROTARY_INTERLEAVED: + odd_mask = ordered_block % 2 == 0 + rotated_block = tl.where(odd_mask, ordered_block + 1, ordered_block - 1) + sin_cos_block = ordered_block // 2 + cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.where(odd_mask, -sin, sin) + else: + rotated_block = (ordered_block + HEAD_DIM // 2) % HEAD_DIM + sin_cos_block = ordered_block % (HEAD_DIM // 2) + cos = tl.load(cos_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.load(sin_ptr + sin_cos_block, mask=mask, other=0.0).to(tl.float32) + sin = tl.where(rotated_block < HEAD_DIM // 2, sin, -sin) + + oq_ptr += s_id * oq_stride_s + q_ptr += s_id * q_stride_s + + for off_h in range(0, NUM_Q_HEADS): + ordered_cols = off_h * q_stride_h + (ordered_block * q_stride_d) + rotated_cols = off_h * q_stride_h + (rotated_block * q_stride_d) + output_offs = off_h * oq_stride_h + (ordered_block * oq_stride_d) + + q = tl.load(q_ptr + ordered_cols, mask=mask, other=0.0) + rotated_q = tl.load(q_ptr + rotated_cols, mask=mask, other=0.0) + y = q * cos + rotated_q * sin + tl.store(oq_ptr + output_offs, y, mask=mask) + + ok_ptr += s_id * ok_stride_s + k_ptr += s_id * k_stride_s + + for off_h in range(0, NUM_K_HEADS): + ordered_cols = off_h * k_stride_h + (ordered_block * k_stride_d) + rotated_cols = off_h * k_stride_h + (rotated_block * k_stride_d) + output_offs = off_h * ok_stride_h + (ordered_block * ok_stride_d) + + k = tl.load(k_ptr + ordered_cols, mask=mask, other=0.0) + rotated_k = tl.load(k_ptr + rotated_cols, mask=mask, other=0.0) + y = k * cos + rotated_k * sin + tl.store(ok_ptr + output_offs, y, mask=mask) + + +def apply_rotary_pos_emb( + q, + k, + cos, + sin, + position_ids, + rotary_interleaved: bool = False, +): + """ + Apply rotary position embedding to q and k + + Args: + q: (*, q_heads, head_dim) + k: (*, k_heads, head_dim) + cos: (max_seq_len, head_dim // 2) + sin: (max_seq_len, head_dim // 2) + position_ids: (*, ) + rotary_interleaved: whether the head_dim is rotated in an interleaved way + + Returns: + q_embed: (*, q_heads, head_dim) + k_embed: (*, k_heads, head_dim) + """ + assert ( + k.shape[-1] == q.shape[-1] + ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}" + assert ( + cos.shape[-1] == sin.shape[-1] + ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}" + assert ( + cos.shape[-1] * 2 == q.shape[-1] + ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}" + assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension" + assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension" + + q_shape = q.shape + k_shape = k.shape + + assert ( + q.shape[:-2] == k.shape[:-2] + ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}" + assert ( + position_ids.shape == q.shape[:-2] + ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}" + + position_ids = position_ids.view(-1) + + q = q.view(-1, q.shape[-2], q.shape[-1]) + k = k.view(-1, k.shape[-2], k.shape[-1]) + + q_embed = torch.empty_like(q) + k_embed = torch.empty_like(k) + + n_tokens, q_heads, head_dim = q.shape + + # The block size must be the next power of two, sometimes we need to pad it. + padded_head_dim = max(triton.next_power_of_2(head_dim), 16) + + grid = (n_tokens,) + + apply_rotary_pos_emb_kernel[grid]( + q_embed, + k_embed, + q, + k, + cos, + sin, + position_ids, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + q_embed.stride(0), + q_embed.stride(1), + q_embed.stride(2), + k_embed.stride(0), + k_embed.stride(1), + k_embed.stride(2), + position_ids.stride(0), + cos.stride(0), + sin.stride(0), + q.shape[-2], + k.shape[-2], + head_dim, + padded_head_dim, + rotary_interleaved, + MAX_POSITION_EMBEDDINGS=cos.shape[0], + ) + q_embed = q_embed.view(q_shape) + k_embed = k_embed.view(k_shape) + return q_embed, k_embed diff --git a/tests/flag_gems/op_accu_test.py b/tests/flag_gems/op_accu_test.py index ccfa7e5a..343bb907 100644 --- a/tests/flag_gems/op_accu_test.py +++ b/tests/flag_gems/op_accu_test.py @@ -604,16 +604,15 @@ def test_accuracy_skip_rmsnorm(shape, dtype): ref_residual = residual.to(torch.float64) ref_weight = weight.to(torch.float64) - - def _torch_rms_norm(x, residual, weight, eps): + def _torch_rms_norm(x, residual, weight, eps): x = x + residual variance = x.pow(2).mean(-1, keepdim=True) hidden_states = x * torch.rsqrt(variance + eps) - return weight * hidden_states + return weight * hidden_states ref_out = _torch_rms_norm( ref_inp, - ref_residual, + ref_residual, weight=ref_weight, eps=eps, ) @@ -623,7 +622,7 @@ def _torch_rms_norm(x, residual, weight, eps): ) allclose_with_dtype(res_out, ref_out, dtype) - + @pytest.mark.parametrize( "shape", @@ -1394,6 +1393,108 @@ def test_accuracy_outer(shape, dtype): allclose_with_dtype(res_in2_grad, ref_in2_grad, dtype) +def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device="cuda"): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + t = torch.arange(max_seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.cohere.modeling_cohere.rotate_half +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py +def rotate_interleave(x): + """Rotates interleave the hidden dims of the input.""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def torch_apply_rotary_pos_emb( + q, + k, + cos, + sin, + position_ids, + rotary_interleaved: bool = False, +): + + q = q.float() + k = k.float() + cos = cos[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + sin = sin[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + if rotary_interleaved: + cos = torch.repeat_interleave(cos, 2, dim=-1) # [bs, seq_len, 1, dim] + sin = torch.repeat_interleave(sin, 2, dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_interleave + else: + cos = torch.cat([cos, cos], dim=-1) # [bs, seq_len, 1, dim] + sin = torch.cat([sin, sin], dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_half + + q_embed = (q * cos) + (rotate_fn(q) * sin) + k_embed = (k * cos) + (rotate_fn(k) * sin) + + return q_embed, k_embed + + +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("max_seq_len", [512, 2048]) +@pytest.mark.parametrize("q_heads,k_heads", [(8, 1), (6, 2), (1, 1), (8, 8)]) +@pytest.mark.parametrize("head_dim", [64, 96, 128, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("rotary_interleaved", [True, False]) +def test_apply_rotary_pos_emb( + batch_size, + max_seq_len, + q_heads, + k_heads, + head_dim, + dtype, + rotary_interleaved, +): + seq_len = torch.randint(1, max_seq_len, (1,)).item() + q = torch.randn( + (batch_size, seq_len, q_heads, head_dim), dtype=dtype, device="cuda" + ) + k = torch.randn( + (batch_size, seq_len, k_heads, head_dim), dtype=dtype, device="cuda" + ) + + position_ids = torch.randint(0, max_seq_len, (batch_size, seq_len), device="cuda") + cos, sin = get_rope_cos_sin(max_seq_len, head_dim, dtype, device="cuda") + + q_embed_ref, k_embed_ref = torch_apply_rotary_pos_emb( + q=q, + k=k, + cos=cos, + sin=sin, + position_ids=position_ids, + rotary_interleaved=rotary_interleaved, + ) + q_embed_out, k_embed_out = flag_gems.apply_rotary_pos_emb( + q=q, + k=k, + cos=cos, + sin=sin, + position_ids=position_ids, + rotary_interleaved=rotary_interleaved, + ) + + allclose_with_dtype(q_embed_out, q_embed_ref, dtype) + allclose_with_dtype(k_embed_out, k_embed_ref, dtype) + + @pytest.mark.parametrize( "shape", [(i, j * 64) for i in [2, 4, 4096] for j in range(1, 10)], diff --git a/tests/flag_gems/op_perf_test.py b/tests/flag_gems/op_perf_test.py index 88441419..45d60a1a 100644 --- a/tests/flag_gems/op_perf_test.py +++ b/tests/flag_gems/op_perf_test.py @@ -3,6 +3,7 @@ import time import triton import random +import op_accu_test from flag_gems import * @@ -464,6 +465,26 @@ def bench_triu(op, M, N, diagonal, dtype): return ms +rope_bench = Benchmark("rope") +rope_bench.bench_params(dtype=f16_f32_bf) +rope_bench.provider_ops( + gem=apply_rotary_pos_emb, torch=op_accu_test.torch_apply_rotary_pos_emb +) +rope_bench.arg_names("M") +rope_bench.arg_vals(sizes) +rope_bench.extra_args(num_heads=16, head_dim=128, max_seq_len=2048) + + +@rope_bench.perf +def bench_rope(op, M, num_heads, head_dim, max_seq_len, dtype): + q = torch.randn((M, num_heads, head_dim), dtype=dtype, device="cuda") + k = torch.randn((M, num_heads, head_dim), dtype=dtype, device="cuda") + position_ids = torch.randint(1, max_seq_len, (M,), device="cuda") + cos = torch.randn((max_seq_len, head_dim // 2), dtype=dtype, device="cuda") + sin = torch.randn((max_seq_len, head_dim // 2), dtype=dtype, device="cuda") + ms = run_bench(op, q, k, cos, sin, position_ids) + + silu_and_mul_bench = Benchmark("silu_and_mul") silu_and_mul_bench.bench_params(dtype=f16_f32_bf) silu_and_mul_bench.provider_ops( @@ -527,5 +548,6 @@ def bench_gelu_and_mul(op, M, N, dtype): bench_sub.run(print_data=True) bench_sub_scalar.run(print_data=True) bench_triu.run(print_data=True) +bench_rope.run(print_data=True) bench_silu_and_mul.run(print_data=True) bench_gelu_and_mul.run(print_data=True) From 0ad8e7c55945215f1b67430612e7ed0f9bdd6334 Mon Sep 17 00:00:00 2001 From: Hiujin Gwok <70586936+GwokHiujin@users.noreply.github.com> Date: Mon, 27 May 2024 14:55:42 +0800 Subject: [PATCH 03/16] [fix] Temporarily use upcasting to make prod support bf16 (#33) * [fix] Temporarily use upcasting to make prod support bf16 * add bfloat16 type to op_accu_test * simplify test cases for all, any --- src/flag_gems/ops/prod.py | 6 +++--- tests/flag_gems/op_accu_test.py | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/flag_gems/ops/prod.py b/src/flag_gems/ops/prod.py index d9bfb9ea..572147d3 100644 --- a/src/flag_gems/ops/prod.py +++ b/src/flag_gems/ops/prod.py @@ -23,7 +23,7 @@ def prod_kernel_mid( offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) + inp_val = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) mid_value = tl.reduce(inp_val, axis=0, combine_fn=reduce_mul) mid_ptr = mid + pid tl.store(mid_ptr, mid_value.to(inp_val.dtype)) @@ -35,7 +35,7 @@ def prod_kernel_result(mid, out, mid_size, BLOCK_MID: tl.constexpr): offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < mid_size - mid_val = tl.load(mid_ptrs, mask=mask, other=1.0) + mid_val = tl.load(mid_ptrs, mask=mask, other=1.0).to(tl.float32) prod_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_mul) tl.store(out, prod_val) @@ -97,7 +97,7 @@ def prod_kernel( mask1 = m_offset < M mask = m_offset[:, None] < M and n_offset[None, :] < N inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0) + inp_vals = tl.load(inp_ptrs, mask=mask, other=1.0).to(tl.float32) result_index = tl.reduce(inp_vals, axis=1, combine_fn=reduce_mul) out_ptrs = out + offset_index diff --git a/tests/flag_gems/op_accu_test.py b/tests/flag_gems/op_accu_test.py index 343bb907..fb3d6e3f 100644 --- a/tests/flag_gems/op_accu_test.py +++ b/tests/flag_gems/op_accu_test.py @@ -1305,7 +1305,7 @@ def test_accuracy_argmax(shape, dim, keepdim, dtype): "shape", [(4096, i * 64) for i in range(1, 20)], ) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_accuracy_prod(shape, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda") ref_out = torch.prod(inp.to(torch.float64)) @@ -1320,7 +1320,7 @@ def test_accuracy_prod(shape, dtype): ) @pytest.mark.parametrize("keepdim", [True, False]) @pytest.mark.parametrize("dim", [0, 1]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) def test_accuracy_prod_dim(shape, dim, keepdim, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda") @@ -1497,7 +1497,7 @@ def test_apply_rotary_pos_emb( @pytest.mark.parametrize( "shape", - [(i, j * 64) for i in [2, 4, 4096] for j in range(1, 10)], + [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], ) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) @pytest.mark.parametrize("kind", ["normal", "allTrue"]) @@ -1516,10 +1516,10 @@ def test_accuracy_all(shape, dtype, kind): @pytest.mark.parametrize( "shape", - [(i, j * 64) for i in [2, 4, 4096] for j in range(1, 10)], + [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], ) @pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1, -1, None]) +@pytest.mark.parametrize("dim", [0, 1, None]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) @pytest.mark.parametrize("kind", ["normal", "allTrue"]) def test_accuracy_all_dim(shape, dim, keepdim, dtype, kind): @@ -1536,7 +1536,7 @@ def test_accuracy_all_dim(shape, dim, keepdim, dtype, kind): @pytest.mark.parametrize( "shape", - [(1024, 1024, 16), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15), (2, 3, 5)], + [(1024, 1024, 16), (16, 128, 64, 64), (2, 3, 5)], ) @pytest.mark.parametrize("dim", [[1, 0], [1, 2]]) @pytest.mark.parametrize("keepdim", [True, False]) @@ -1556,7 +1556,7 @@ def test_accuracy_all_dims(shape, dim, keepdim, dtype, kind): @pytest.mark.parametrize( "shape", - [(i, j * 64) for i in [2, 4, 4096] for j in range(1, 10)], + [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], ) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) @pytest.mark.parametrize("kind", ["normal", "allFalse"]) @@ -1575,10 +1575,10 @@ def test_accuracy_any(shape, dtype, kind): @pytest.mark.parametrize( "shape", - [(i, j * 64) for i in [2, 4, 4096] for j in range(1, 10)], + [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], ) @pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1, -1, None]) +@pytest.mark.parametrize("dim", [0, 1, None]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) @pytest.mark.parametrize("kind", ["normal", "allFalse"]) def test_accuracy_any_dim(shape, dim, keepdim, dtype, kind): @@ -1595,7 +1595,7 @@ def test_accuracy_any_dim(shape, dim, keepdim, dtype, kind): @pytest.mark.parametrize( "shape", - [(1024, 1024, 16), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15), (2, 3, 5)], + [(1024, 1024, 16), (16, 128, 64, 64), (2, 3, 5)], ) @pytest.mark.parametrize("keepdim", [True, False]) @pytest.mark.parametrize("dim", [[1, 0], [1, 2]]) From fe391fee467626b93bcf961a692c899f2f98e6fc Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Mon, 27 May 2024 14:57:16 +0800 Subject: [PATCH 04/16] fix a bug in the case when output dtype is provided. (#34) --- src/flag_gems/utils/pointwise_dynamic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 6726c30b..7e8c90a4 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -261,7 +261,7 @@ def generate_functional_pointwise_wrapper( # output allocation num_output_tensor_index = 0 for i in range(op_desc.num_outputs()): - if op_desc.input_type(i) is None: + if op_desc.output_dtype(i) is None: code.writeline(f"out{num_output_tensor_index} = torch.empty(shape, dtype=in0.dtype, device=in0.device)") else: code.writeline(f"out{num_output_tensor_index} = torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)") @@ -623,6 +623,7 @@ def saxpy(x, alpha, y): out2 = x * 2.0 + y print(out1) print(out2) + torch.testing.assert_close(out1, out2) print() @pointwise_dynamic(is_tensor=[True, False, True]) @@ -634,6 +635,7 @@ def saxpy(x, alpha, y): out2 = x * 2.0 + y print(out1) print(out2) + torch.testing.assert_close(out1, out2) print() @@ -646,6 +648,7 @@ def ge(x, y): out2 = x > y print(out1) print(out2) + torch.testing.assert_close(out1, out2) print() @pointwise_dynamic() @@ -657,6 +660,7 @@ def ordinary(x, y): out2 = torch.sin(x) + torch.cos(y) print(out1) print(out2) + torch.testing.assert_close(out1, out2) print() @pointwise_dynamic @@ -668,6 +672,7 @@ def ordinary2(x, y): out2 = torch.sin(x) + torch.cos(y) print(out1) print(out2) + torch.testing.assert_close(out1, out2) print() @pointwise_dynamic @@ -681,5 +686,6 @@ def ordinary2(x, y): out2 = torch.sin(x) + torch.cos(y) print(out1) print(out2) + torch.testing.assert_close(out1, out2) print() From 7b76f1dc242e68c5b2bf0937fc9c129c3f89b5f7 Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Tue, 28 May 2024 13:46:58 +0800 Subject: [PATCH 05/16] add do_not_specialize for non tensor arguments passed to the scalar function to ensure that they are never wrapped into tl.constexpr. (#37) --- src/flag_gems/utils/pointwise_dynamic.py | 25 +++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 7e8c90a4..5dcf69a4 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -361,7 +361,11 @@ def generate_pointwise_kernel( code: IndentedBuffer ) -> IndentedBuffer: code.writeline("@libentry()") - code.writeline("@triton.jit") + if op_desc.num_non_tensor_args() > 0: + non_specialize_arg_names = [f"val{i}" for i in range(op_desc.num_non_tensor_args())] + code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") + else: + code.writeline("@triton.jit") code.writeline(f"def {kernel_name}(") function_ns = NameSpace() @@ -373,7 +377,7 @@ def generate_pointwise_kernel( # inputs ptrs & non tensor inputs for i in range(op_desc.num_inputs()): if op_desc.is_tensor(i): - code.writeline(f"in{input_tensor_index}_ptr: tl.pointer_type,") + code.writeline(f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type") function_ns.create_name(f"in{input_tensor_index}_ptr") input_tensor_index += 1 else: @@ -386,7 +390,7 @@ def generate_pointwise_kernel( # output ptrs for i in range(op_desc.num_outputs()): - code.writeline(f"out{output_tensor_index}_ptr: tl.pointer_type,") + code.writeline(f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type") function_ns.create_name(f"out{output_tensor_index}_ptr") output_tensor_index += 1 @@ -689,3 +693,18 @@ def ordinary2(x, y): torch.testing.assert_close(out1, out2) print() + @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) + @triton.jit + def eq(x, y): + return x.to(tl.float32) == y.to(tl.float32) # ensures that y is not used for specialization + x = torch.arange(10, device="cuda") + y = 1 + # by default value 1 is treated as constexpr even thought it is not marked as constexpr + # do_not_specialize avoids this + out1 = eq(x, y) + out2 = x == y + print(out1) + print(out2) + torch.testing.assert_close(out1, out2) + print() + From 982c4250177a4ce05d5fef15c06b17d5a44ff148 Mon Sep 17 00:00:00 2001 From: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com> Date: Tue, 28 May 2024 13:52:36 +0800 Subject: [PATCH 06/16] [Operators] implement v2 pointwise operators (#36) * [test] skip testing all_dim/all_dims/any_dim/any_dims for pytorch < 2.2.0 * [operator] implement scalar func of bitwise_and/bitwise_or/clamp * [operator] binary pointwise operators by codegen. eq/ne/ge/gt/le/lt/isinf/isnan. * [operator] matrix-vector multiplication --- OperatorList.md | 2 +- src/flag_gems/__init__.py | 18 ++ src/flag_gems/ops/__init__.py | 31 ++- src/flag_gems/ops/bitwise_and.py | 18 ++ src/flag_gems/ops/bitwise_or.py | 18 ++ src/flag_gems/ops/clamp.py | 39 +++- src/flag_gems/ops/eq.py | 29 +++ src/flag_gems/ops/ge.py | 29 +++ src/flag_gems/ops/gt.py | 29 +++ src/flag_gems/ops/isinf.py | 3 +- src/flag_gems/ops/isnan.py | 3 +- src/flag_gems/ops/le.py | 29 +++ src/flag_gems/ops/lt.py | 29 +++ src/flag_gems/ops/mv.py | 70 +++++++ src/flag_gems/ops/ne.py | 29 +++ tests/flag_gems/op_accu_test.py | 347 ++++++++++++++++++++++++++++++- 16 files changed, 711 insertions(+), 12 deletions(-) create mode 100644 src/flag_gems/ops/eq.py create mode 100644 src/flag_gems/ops/ge.py create mode 100644 src/flag_gems/ops/gt.py create mode 100644 src/flag_gems/ops/le.py create mode 100644 src/flag_gems/ops/lt.py create mode 100644 src/flag_gems/ops/mv.py create mode 100644 src/flag_gems/ops/ne.py diff --git a/OperatorList.md b/OperatorList.md index 5af5dde5..1dcf94ca 100644 --- a/OperatorList.md +++ b/OperatorList.md @@ -32,6 +32,7 @@ FlagGems will implement the following operators as planned. Version 1.0 will be ## v2.0 +- mv - all - any - bitwise_and @@ -41,7 +42,6 @@ FlagGems will implement the following operators as planned. Version 1.0 will be - eq - ge - gt -- is_nonzero - isinf - isnan - le diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index b04555bf..42877311 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -11,26 +11,44 @@ def enable(lib=aten_lib): lib.impl("add.Tensor", add, "CUDA") lib.impl("addmm", addmm, "CUDA") lib.impl("bitwise_and.Tensor", bitwise_and_tensor, "CUDA") + lib.impl("bitwise_and.Scalar", bitwise_and_scalar, "CUDA") + lib.impl("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, "CUDA") lib.impl("bitwise_not", bitwise_not, "CUDA") lib.impl("bitwise_or.Tensor", bitwise_or_tensor, "CUDA") + lib.impl("bitwise_or.Scalar", bitwise_or_scalar, "CUDA") + lib.impl("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, "CUDA") lib.impl("bmm", bmm, "CUDA") + lib.impl("clamp", clamp, "CUDA") lib.impl("clamp.Tensor", clamp_tensor, "CUDA") lib.impl("cos", cos, "CUDA") lib.impl("cumsum", cumsum, "CUDA") lib.impl("div.Tensor", div, "CUDA") lib.impl("native_dropout", native_dropout, "AutogradCUDA") + lib.impl("eq.Tensor", eq, "CUDA") + lib.impl("eq.Scalar", eq_scalar, "CUDA") lib.impl("exp", exp, "CUDA") + lib.impl("ge.Tensor", ge, "CUDA") + lib.impl("ge.Scalar", ge_scalar, "CUDA") lib.impl("gelu", gelu, "CUDA") lib.impl("native_group_norm", group_norm, "AutogradCUDA") + lib.impl("gt.Tensor", gt, "CUDA") + lib.impl("gt.Scalar", gt_scalar, "CUDA") lib.impl("isinf", isinf, "CUDA") lib.impl("isnan", isnan, "CUDA") lib.impl("native_layer_norm", layer_norm, "AutogradCUDA") + lib.impl("le.Tensor", le, "CUDA") + lib.impl("le.Scalar", le_scalar, "CUDA") + lib.impl("lt.Tensor", lt, "CUDA") + lib.impl("lt.Scalar", lt_scalar, "CUDA") lib.impl("rms_norm", rms_norm, "CUDA") lib.impl("mean", mean, "CUDA") lib.impl("mean.dim", mean_dim, "CUDA") lib.impl("mm", mm, "CUDA") lib.impl("mul.Tensor", mul, "CUDA") + lib.impl("mv", mv, "CUDA") + lib.impl("ne.Tensor", ne, "CUDA") + lib.impl("ne.Scalar", ne_scalar, "CUDA") lib.impl("neg", neg, "CUDA") lib.impl("pow.Scalar", pow_scalar, "CUDA") lib.impl("pow.Tensor_Scalar", pow_tensor_scalar, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 665a00c8..df07c43e 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -3,25 +3,32 @@ from .abs import abs from .add import add from .addmm import addmm -from .bitwise_and import bitwise_and_tensor +from .bitwise_and import bitwise_and_tensor, bitwise_and_scalar, bitwise_and_scalar_tensor from .bitwise_not import bitwise_not -from .bitwise_or import bitwise_or_tensor +from .bitwise_or import bitwise_or_tensor, bitwise_or_scalar, bitwise_or_scalar_tensor from .bmm import bmm -from .clamp import clamp_tensor +from .clamp import clamp, clamp_tensor from .cos import cos from .cumsum import cumsum from .dropout import native_dropout from .div import div +from .eq import eq, eq_scalar from .exp import exp +from .ge import ge, ge_scalar from .gelu import gelu from .groupnorm import group_norm +from .gt import gt, gt_scalar from .isinf import isinf from .isnan import isnan from .layernorm import layer_norm +from .le import le, le_scalar +from .lt import lt, lt_scalar from .rms_norm import rms_norm from .mean import mean, mean_dim from .mm import mm from .mul import mul +from .mv import mv +from .ne import ne, ne_scalar from .neg import neg from .pow_scalar import pow_scalar from .pow_tensor_scalar import pow_tensor_scalar @@ -60,25 +67,43 @@ "abs", "addmm", "bitwise_and_tensor", + "bitwise_and_scalar", + "bitwise_and_scalar_tensor", "bitwise_not", "bitwise_or_tensor", + "bitwise_or_scalar", + "bitwise_or_scalar_tensor", "bmm", + "clamp", "clamp_tensor", "cos", "cumsum", "div", "native_dropout", + "eq", + "eq_scalar", "exp", + "ge", + "ge_scalar", "gelu", "group_norm", + "gt", + "gt_scalar", "isinf", "isnan", "layer_norm", + "le", + "le_scalar", + "lt", + "lt_scalar", "rms_norm", "mean", "mean_dim", "mm", "mul", + "mv", + "ne", + "ne_scalar", "neg", "pow_scalar", "pow_tensor_scalar", diff --git a/src/flag_gems/ops/bitwise_and.py b/src/flag_gems/ops/bitwise_and.py index 5ba041f3..b8a4848d 100644 --- a/src/flag_gems/ops/bitwise_and.py +++ b/src/flag_gems/ops/bitwise_and.py @@ -13,3 +13,21 @@ def bitwise_and_tensor(A, B): logging.debug("GEMS BITWISE AND") O = bitwise_and_func(A, B) return O + + +@pointwise_dynamic(is_tensor=[True, False]) +@triton.jit +def bitwise_and_func_scalar(x, y): + return x & y + + +def bitwise_and_scalar(A, B): + logging.debug("GEMS BITWISE AND SCALAR") + O = bitwise_and_func_scalar(A, B) + return O + + +def bitwise_and_scalar_tensor(A, B): + logging.debug("GEMS BITWISE AND SCALAR TENSOR") + O = bitwise_and_func_scalar(B, A) + return O \ No newline at end of file diff --git a/src/flag_gems/ops/bitwise_or.py b/src/flag_gems/ops/bitwise_or.py index 67234fb8..0237ff35 100644 --- a/src/flag_gems/ops/bitwise_or.py +++ b/src/flag_gems/ops/bitwise_or.py @@ -13,3 +13,21 @@ def bitwise_or_tensor(A, B): logging.debug("GEMS BITWISE OR") O = bitwise_or_func(A, B) return O + + +@pointwise_dynamic(is_tensor=[True, False]) +@triton.jit +def bitwise_or_func_scalar(x, y): + return x | y + + +def bitwise_or_scalar(A, B): + logging.debug("GEMS BITWISE OR SCALAR") + O = bitwise_or_func_scalar(A, B) + return O + + +def bitwise_or_scalar_tensor(A, B): + logging.debug("GEMS BITWISE OR SCALAR TENSOR") + O = bitwise_or_func_scalar(B, A) + return O diff --git a/src/flag_gems/ops/clamp.py b/src/flag_gems/ops/clamp.py index c061cf28..c8a9e706 100644 --- a/src/flag_gems/ops/clamp.py +++ b/src/flag_gems/ops/clamp.py @@ -6,23 +6,54 @@ @pointwise_dynamic @triton.jit -def clamp_func(x, mini, maxi): +def clamp_func_tensor(x, mini, maxi): return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) @pointwise_dynamic @triton.jit -def clamp_func_min(x, mini): +def clamp_func_min_tensor(x, mini): return tl.maximum(mini, x.to(tl.float32)) @pointwise_dynamic @triton.jit -def clamp_func_max(x, maxi): +def clamp_func_max_tensor(x, maxi): return tl.minimum(maxi, x.to(tl.float32)) def clamp_tensor(A, mini=None, maxi=None): + logging.debug("GEMS CLAMP TENSOR") + if mini is None and maxi is None: + raise ValueError("At least one of mini or maxi must not be None") + elif mini is None: + O = clamp_func_max_tensor(A, maxi) + elif maxi is None: + O = clamp_func_min_tensor(A, mini) + else: + O = clamp_func_tensor(A, mini, maxi) + return O + + +@pointwise_dynamic(is_tensor=[True, False, False]) +@triton.jit +def clamp_func(x, mini, maxi): + return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) + + +@pointwise_dynamic(is_tensor=[True, False]) +@triton.jit +def clamp_func_min(x, mini): + return tl.maximum(mini, x.to(tl.float32)) + + +@pointwise_dynamic(is_tensor=[True, False]) +@triton.jit +def clamp_func_max(x, maxi): + return tl.minimum(maxi, x.to(tl.float32)) + + +def clamp(A, mini=None, maxi=None): logging.debug("GEMS CLAMP") if mini is None and maxi is None: raise ValueError("At least one of mini or maxi must not be None") @@ -32,4 +63,4 @@ def clamp_tensor(A, mini=None, maxi=None): O = clamp_func_min(A, mini) else: O = clamp_func(A, mini, maxi) - return O + return O \ No newline at end of file diff --git a/src/flag_gems/ops/eq.py b/src/flag_gems/ops/eq.py new file mode 100644 index 00000000..868f5737 --- /dev/null +++ b/src/flag_gems/ops/eq.py @@ -0,0 +1,29 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(output_dtypes=[torch.bool]) +@triton.jit +def eq_func(x, y): + return x.to(tl.float32) == y.to(tl.float32) + + +def eq(A, B): + logging.debug("GEMS EQ") + O = eq_func(A, B) + return O + + +@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@triton.jit +def eq_func_scalar(x, y): + return x.to(tl.float32) == y.to(tl.float32) + + +def eq_scalar(A, B): + logging.debug("GEMS EQ SCALAR") + O = eq_func_scalar(A, B) + return O diff --git a/src/flag_gems/ops/ge.py b/src/flag_gems/ops/ge.py new file mode 100644 index 00000000..7614369b --- /dev/null +++ b/src/flag_gems/ops/ge.py @@ -0,0 +1,29 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(output_dtypes=[torch.bool]) +@triton.jit +def ge_func(x, y): + return x.to(tl.float32) >= y + + +def ge(A, B): + logging.debug("GEMS GE") + O = ge_func(A, B) + return O + + +@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@triton.jit +def ge_func_scalar(x, y): + return x.to(tl.float32) >= y + + +def ge_scalar(A, B): + logging.debug("GEMS GE SCALAR") + O = ge_func_scalar(A, B) + return O diff --git a/src/flag_gems/ops/gt.py b/src/flag_gems/ops/gt.py new file mode 100644 index 00000000..6260fe1a --- /dev/null +++ b/src/flag_gems/ops/gt.py @@ -0,0 +1,29 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(output_dtypes=[torch.bool]) +@triton.jit +def gt_func(x, y): + return x.to(tl.float32) > y + + +def gt(A, B): + logging.debug("GEMS GT") + O = gt_func(A, B) + return O + + +@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@triton.jit +def gt_func_scalar(x, y): + return x.to(tl.float32) > y + + +def gt_scalar(A, B): + logging.debug("GEMS GT SCALAR") + O = gt_func_scalar(A, B) + return O diff --git a/src/flag_gems/ops/isinf.py b/src/flag_gems/ops/isinf.py index d5d9e7d4..fba408d9 100644 --- a/src/flag_gems/ops/isinf.py +++ b/src/flag_gems/ops/isinf.py @@ -1,10 +1,11 @@ +import torch import triton import triton.language as tl import logging from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(output_dtypes=[torch.bool]) @triton.jit def isinf_func(x): return tl.math.isinf(x.to(tl.float32)) diff --git a/src/flag_gems/ops/isnan.py b/src/flag_gems/ops/isnan.py index 489ff004..8c77fd00 100644 --- a/src/flag_gems/ops/isnan.py +++ b/src/flag_gems/ops/isnan.py @@ -1,10 +1,11 @@ +import torch import triton import triton.language as tl import logging from ..utils import pointwise_dynamic -@pointwise_dynamic +@pointwise_dynamic(output_dtypes=[torch.bool]) @triton.jit def isnan_func(x): return tl.math.isnan(x.to(tl.float32)) diff --git a/src/flag_gems/ops/le.py b/src/flag_gems/ops/le.py new file mode 100644 index 00000000..e31437a6 --- /dev/null +++ b/src/flag_gems/ops/le.py @@ -0,0 +1,29 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(output_dtypes=[torch.bool]) +@triton.jit +def le_func(x, y): + return x.to(tl.float32) <= y + + +def le(A, B): + logging.debug("GEMS LE") + O = le_func(A, B) + return O + + +@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@triton.jit +def le_func_scalar(x, y): + return x.to(tl.float32) <= y + + +def le_scalar(A, B): + logging.debug("GEMS LE SCALAR") + O = le_func_scalar(A, B) + return O diff --git a/src/flag_gems/ops/lt.py b/src/flag_gems/ops/lt.py new file mode 100644 index 00000000..8d560ea1 --- /dev/null +++ b/src/flag_gems/ops/lt.py @@ -0,0 +1,29 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(output_dtypes=[torch.bool]) +@triton.jit +def lt_func(x, y): + return x.to(tl.float32) < y + + +def lt(A, B): + logging.debug("GEMS LT") + O = lt_func(A, B) + return O + + +@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@triton.jit +def lt_func_scalar(x, y): + return x.to(tl.float32) < y + + +def lt_scalar(A, B): + logging.debug("GEMS LT SCALAR") + O = lt_func_scalar(A, B) + return O diff --git a/src/flag_gems/ops/mv.py b/src/flag_gems/ops/mv.py new file mode 100644 index 00000000..ab449862 --- /dev/null +++ b/src/flag_gems/ops/mv.py @@ -0,0 +1,70 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import libentry + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_stages=s, num_warps=w) + for m in [32, 64, 128] + for n in [1, 2, 4, 8] + for s in [3, 4] + for w in [4, 8] + ], + key=["M", "N"], +) +@triton.jit +def mv_kernel( + A, + B, + C, + N, + M, + stride_an, + stride_am, + stride_bm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid = tl.program_id(0) + offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] + offset_m = tl.arange(0, BLOCK_M)[None, :] + n_mask = offset_n < N + A_ptrs = A + offset_n * stride_an + offset_m * stride_am + B_ptrs = B + offset_m * stride_bm + acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) + for m in range(0, M, BLOCK_M): + m_mask = m + offset_m < M + a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) + b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) + acc += a * b + A_ptrs += BLOCK_M * stride_am + B_ptrs += BLOCK_M * stride_bm + + acc = tl.sum(acc, axis=1) + C_ptrs = C + offset_n * stride_cn + tl.store(C_ptrs, acc[:, None], mask=n_mask) + + +def mv(inp, vec): + logging.debug("GEMS MV") + assert inp.shape[1] == vec.shape[0], "incompatible dimensions" + N, M = inp.shape + out = torch.empty((N,), device=inp.device, dtype=inp.dtype) + grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) + mv_kernel[grid]( + inp, + vec, + out, + N, + M, + inp.stride(0), + inp.stride(1), + vec.stride(0), + out.stride(0), + ) + return out diff --git a/src/flag_gems/ops/ne.py b/src/flag_gems/ops/ne.py new file mode 100644 index 00000000..28696d0f --- /dev/null +++ b/src/flag_gems/ops/ne.py @@ -0,0 +1,29 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(output_dtypes=[torch.bool]) +@triton.jit +def ne_func(x, y): + return x.to(tl.float32) != y.to(tl.float32) + + +def ne(A, B): + logging.debug("GEMS NE") + O = ne_func(A, B) + return O + + +@pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) +@triton.jit +def ne_func_scalar(x, y): + return x.to(tl.float32) != y.to(tl.float32) + + +def ne_scalar(A, B): + logging.debug("GEMS NE SCALAR") + O = ne_func_scalar(A, B) + return O diff --git a/tests/flag_gems/op_accu_test.py b/tests/flag_gems/op_accu_test.py index fb3d6e3f..3dd7560d 100644 --- a/tests/flag_gems/op_accu_test.py +++ b/tests/flag_gems/op_accu_test.py @@ -2,6 +2,11 @@ import pytest import flag_gems + +major, minor = torch.__version__.split(".")[:2] +skip_expr = major < "2" or minor < "2" + + RESOLUTION = { torch.float16: 1e-3, torch.float32: 1.3e-6, @@ -163,6 +168,48 @@ def test_accuracy_bitwiseand(shape, dtype): assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0x000f, 0x7fff, -0x00ff], +) +@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) +def test_accuracy_bitwiseand_scalar(shape, scalar, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + + ref_out = torch.bitwise_and(inp1, scalar) + with flag_gems.use_gems(): + res_out = torch.bitwise_and(inp1, scalar) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0x000f, 0x7fff, -0x00ff], +) +@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) +def test_accuracy_bitwiseand_scalar_tensor(shape, scalar, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + + ref_out = torch.bitwise_and(scalar, inp1) + with flag_gems.use_gems(): + res_out = torch.bitwise_and(scalar, inp1) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + @pytest.mark.parametrize( "shape", [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], @@ -200,6 +247,48 @@ def test_accuracy_bitwiseor(shape, dtype): assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0x000f, 0x7fff, -0x00ff], +) +@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) +def test_accuracy_bitwiseor_scalar(shape, scalar, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + + ref_out = torch.bitwise_or(inp1, scalar) + with flag_gems.use_gems(): + res_out = torch.bitwise_or(inp1, scalar) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0x000f, 0x7fff, -0x00ff], +) +@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) +def test_accuracy_bitwiseor_scalar_tensor(shape, scalar, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + + ref_out = torch.bitwise_or(scalar, inp1) + with flag_gems.use_gems(): + res_out = torch.bitwise_or(scalar, inp1) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + @pytest.mark.parametrize( "batch, M, N, K", [ @@ -229,6 +318,29 @@ def test_accuracy_bmm(batch, M, N, K, dtype): @pytest.mark.parametrize("isnone", [None, "max", "min"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) def test_accuracy_clamp(shape, isnone, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + import random + maxi = random.random() + mini = random.random() + if isnone == "min": + mini = None + elif isnone == "max": + maxi = None + + ref_out = torch.clamp(inp, min=mini, max=maxi) + with flag_gems.use_gems(): + res_out = torch.clamp(inp, min=mini, max=maxi) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("isnone", [None, "max", "min"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_clamp_tensor(shape, isnone, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda") maxi = torch.randn(shape, dtype=dtype, device="cuda") mini = torch.randn(shape, dtype=dtype, device="cuda") @@ -384,6 +496,38 @@ def test_accuracy_dropout(shape, dtype, p): ), f"num_equal: {num_equal}, exp_equal: {exp_equal}, num_total: {inp.numel()}" +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_eq(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + + ref_out = torch.eq(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.eq(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_eq_scalar(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = 0 + + ref_out = torch.eq(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.eq(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + @pytest.mark.parametrize( "shape", [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], @@ -399,6 +543,41 @@ def test_accuracy_exp(shape, dtype): allclose_with_dtype(res_out, ref_out, dtype) +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_ge(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.ge(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.ge(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0.5, 1.0, 100.9, -111.9], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_ge_scalar(shape, scalar, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.ge(inp, scalar) + with flag_gems.use_gems(): + res_out = torch.ge(inp, scalar) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + @pytest.mark.parametrize( "shape", [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], @@ -466,6 +645,41 @@ def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype): allclose_with_dtype(res_bias_grad, ref_bias_grad, dtype, reduce_dim=N * HW) +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_gt(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.gt(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.gt(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0.5, 1.0, 100.9, -111.9], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_gt_scalar(shape, scalar, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.gt(inp, scalar) + with flag_gems.use_gems(): + res_out = torch.gt(inp, scalar) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + @pytest.mark.parametrize( "shape", [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], @@ -475,7 +689,7 @@ def test_accuracy_isinf(shape, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda") inp = torch.masked_fill(inp, inp > 1.0, -float("inf")) - ref_out = torch.isinf(inp.to(torch.float64)) + ref_out = torch.isinf(inp) with flag_gems.use_gems(): res_out = torch.isinf(inp) @@ -491,7 +705,7 @@ def test_accuracy_isnan(shape, dtype): inp = torch.randn(shape, dtype=dtype, device="cuda") inp = torch.masked_fill(inp, inp > 1.0, float("nan")) - ref_out = torch.isnan(inp.to(torch.float64)) + ref_out = torch.isnan(inp) with flag_gems.use_gems(): res_out = torch.isnan(inp) @@ -548,6 +762,76 @@ def test_accuracy_layernorm(shape, dtype): allclose_with_dtype(res_bias_grad, ref_bias_grad, dtype, reduce_dim=M) +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_le(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.le(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.le(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0.5, 1.0, 100.9, -111.9], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_le_scalar(shape, scalar, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.le(inp, scalar) + with flag_gems.use_gems(): + res_out = torch.le(inp, scalar) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_lt(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.lt(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.lt(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize( + "scalar", + [0.5, 1.0, 100.9, -111.9], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_lt_scalar(shape, scalar, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + + ref_out = torch.lt(inp, scalar) + with flag_gems.use_gems(): + res_out = torch.lt(inp, scalar) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + @pytest.mark.parametrize( "shape", [(4096, i * 64) for i in range(1, 20)], @@ -727,6 +1011,29 @@ def test_accuracy_mul(shape, dtype): allclose_with_dtype(res_out, ref_out, dtype) +@pytest.mark.parametrize( + "shape", + [ + (256, 256), + (1024, 1024), + (1024, 128), + (1024, 64), + (640, 256), + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_mv(shape, dtype): + N, M = shape + matrix = torch.randn((N, M), dtype=dtype, device="cuda") + vector = torch.randn((M,), dtype=dtype, device="cuda") + + ref_out = torch.mv(matrix.to(torch.float64), vector.to(torch.float64)) + with flag_gems.use_gems(): + res_out = torch.mv(matrix, vector) + + allclose_with_dtype(res_out, ref_out, dtype) + + @pytest.mark.parametrize( "shape_a", [(16, 1024, 256)], @@ -787,6 +1094,38 @@ def test_accuracy_mul_scalar_tensor(shape, scalar, dtype): allclose_with_dtype(res_out, ref_out, dtype) +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_ne(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + + ref_out = torch.ne(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.ne(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_ne_scalar(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = 0 + + ref_out = torch.ne(inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.ne(inp1, inp2) + + assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" + + @pytest.mark.parametrize( "shape", [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], @@ -1514,6 +1853,7 @@ def test_accuracy_all(shape, dtype, kind): assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" +@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") @pytest.mark.parametrize( "shape", [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], @@ -1534,6 +1874,7 @@ def test_accuracy_all_dim(shape, dim, keepdim, dtype, kind): assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" +@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") @pytest.mark.parametrize( "shape", [(1024, 1024, 16), (16, 128, 64, 64), (2, 3, 5)], @@ -1573,6 +1914,7 @@ def test_accuracy_any(shape, dtype, kind): assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" +@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") @pytest.mark.parametrize( "shape", [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], @@ -1593,6 +1935,7 @@ def test_accuracy_any_dim(shape, dim, keepdim, dtype, kind): assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" +@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") @pytest.mark.parametrize( "shape", [(1024, 1024, 16), (16, 128, 64, 64), (2, 3, 5)], From 24a4f2ca0aa8386f202a73dadb513a29362f7981 Mon Sep 17 00:00:00 2001 From: FatJhon <156064001+FatJhon@users.noreply.github.com> Date: Tue, 28 May 2024 18:16:51 +0800 Subject: [PATCH 07/16] add cross_entropy_loss (#28) * add cross_entropy_loss * clean code * modify default value of reduction --------- Co-authored-by: jiangbin --- src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/cross_entropy_loss.py | 205 ++++++++++++++++++++++++ tests/flag_gems/op_accu_test.py | 26 +++ 4 files changed, 234 insertions(+) create mode 100644 src/flag_gems/ops/cross_entropy_loss.py diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 42877311..79bde027 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -84,6 +84,7 @@ def enable(lib=aten_lib): lib.impl("any.dims", any_dims, "CUDA") lib.impl("log_softmax.int", log_softmax, "AutogradCUDA") lib.impl("outer", outer, "AutogradCUDA") + lib.impl("cross_entropy_loss", cross_entropy_loss, "AutogradCUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index df07c43e..069d8b66 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -52,6 +52,7 @@ from .prod import prod, prod_dim from .log_softmax import log_softmax from .outer import outer +from .cross_entropy_loss import cross_entropy_loss from .var_mean import var_mean from .vector_norm import vector_norm @@ -132,4 +133,5 @@ "vector_norm", "log_softmax", "outer", + "cross_entropy_loss", ] diff --git a/src/flag_gems/ops/cross_entropy_loss.py b/src/flag_gems/ops/cross_entropy_loss.py new file mode 100644 index 00000000..951cbe72 --- /dev/null +++ b/src/flag_gems/ops/cross_entropy_loss.py @@ -0,0 +1,205 @@ +import torch +import triton +import triton.language as tl +import logging +from ..utils import libentry +from .sum import sum + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 1}, num_stages=4), + triton.Config({"BLOCK_M": 1}, num_stages=5), + triton.Config({"BLOCK_M": 2}, num_stages=4), + triton.Config({"BLOCK_M": 2}, num_stages=5), + triton.Config({"BLOCK_M": 4}, num_stages=4), + triton.Config({"BLOCK_M": 4}, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_stages=4), + triton.Config({"BLOCK_M": 8}, num_stages=5), + ], + key=[ + "M", + "N", + ], +) +@triton.heuristics( + values={ + "BLOCK_N": lambda args: triton.next_power_of_2(args["N"]), + "num_warps": lambda args: ( + 4 if args["N"] <= 1024 else (8 if args["N"] <= 2048 else 16) + ), + }, +) +@triton.jit +def log_softmax_and_mul_kernel( + output_ptr, + input_ptr, + target_ptr, + mean_num, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + n_offset = tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k + mask = m_offset[:, None] < M and n_offset[None, :] < N + input_ptrs = input_ptr + offset + inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) + row_minus_max = inp - tl.max(inp, axis=1)[:, None] + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1)[:, None] + softmax_output = tl.log(numerator / denominator) + target = tl.load(target_ptr + offset, mask=mask, other=0.0) + out = softmax_output * target / (-mean_num) + output_ptrs = output_ptr + offset + tl.store(output_ptrs, out, mask=mask) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 1}, num_stages=4), + triton.Config({"BLOCK_M": 1}, num_stages=5), + triton.Config({"BLOCK_M": 2}, num_stages=4), + triton.Config({"BLOCK_M": 2}, num_stages=5), + triton.Config({"BLOCK_M": 4}, num_stages=4), + triton.Config({"BLOCK_M": 4}, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_stages=4), + triton.Config({"BLOCK_M": 8}, num_stages=5), + ], + key=[ + "M", + "N", + ], +) +@triton.heuristics( + values={ + "BLOCK_N": lambda args: triton.next_power_of_2(args["N"]), + "num_warps": lambda args: ( + 4 if args["N"] <= 1024 else (8 if args["N"] <= 2048 else 16) + ), + }, +) +@triton.jit +def softmax_and_sub_kernel( + output_ptr, + input_ptr, + target_ptr, + out_grad, + mean_num, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + n_offset = tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k + mask = m_offset[:, None] < M and n_offset[None, :] < N + input_ptrs = input_ptr + offset + inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) + row_minus_max = inp - tl.max(inp, axis=1)[:, None] + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1)[:, None] + # todo: reduce unnecessary calculations through mask operations to improve performance + softmax_output = numerator / denominator + target_ptrs = target_ptr + offset + target = tl.load(target_ptrs, mask=mask, other=0.0) + out_grad_value = tl.load(out_grad) + out = out_grad_value * (softmax_output - target) / mean_num + output_ptrs = output_ptr + offset + + tl.store(output_ptrs, out, mask=mask) + + +class CrossEntropyLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, input, target, weight, reduction, ignore_index, label_smoothing): + logging.debug("GEMS CrossEntropyLoss") + assert isinstance(input, torch.Tensor), "input is not a tensor" + if input.ndim >= 2: + dim = 1 + else: + dim = 0 + + shape = list(input.shape) + shape[dim] = 1 + mean_num = target.numel() + target = torch.zeros_like(input).scatter(dim, target.view(shape), 1) + + M = 1 + N = input.shape[dim] + for i in range(dim): + M *= input.shape[i] + inp = input.contiguous() + out = torch.empty_like(inp, dtype=inp.dtype) + K = inp.numel() // M // N + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + K, + ) + log_softmax_and_mul_kernel[grid]( + out, + inp, + target, + mean_num, + M, + N, + K, + ) + out_result = sum(out) + + ctx.save_for_backward(input, target) + ctx.dim = dim + ctx.mean_num = mean_num + return out_result + + @staticmethod + def backward(ctx, out_grad): + logging.debug("GEMS CrossEntropyLoss VJP") + input, target = ctx.saved_tensors + dim = ctx.dim + mean_num = ctx.mean_num + + M = 1 + N = input.shape[dim] + for i in range(dim): + M *= input.shape[i] + inp = input.contiguous() + out = torch.empty_like(inp, dtype=input.dtype) + K = inp.numel() // M // N + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + K, + ) + softmax_and_sub_kernel[grid]( + out, + inp, + target, + out_grad, + mean_num, + M, + N, + K, + ) + return out, None, None, None, None, None + + +# todo: reducetion(dtype: int,default mean->1), support other scenarios as follows: (none->0, sum->2) +def cross_entropy_loss( + input, target, weight=None, reduction=1, ignore_index=-100, label_smoothing=0.0 +): + return CrossEntropyLoss.apply( + input, target, weight, reduction, ignore_index, label_smoothing + ) diff --git a/tests/flag_gems/op_accu_test.py b/tests/flag_gems/op_accu_test.py index 3dd7560d..4875cca8 100644 --- a/tests/flag_gems/op_accu_test.py +++ b/tests/flag_gems/op_accu_test.py @@ -1994,3 +1994,29 @@ def test_accuracy_gelu_and_mul(shape, approximate, dtype): allclose_with_dtype(res_out, ref_out, dtype) + +@pytest.mark.parametrize( + "shape", + [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 30)], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +def test_accuracy_cross_entropy_loss(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + dim = 1 + up_limit = shape[dim] - 1 + target_shape = list(shape) + del target_shape[dim] + target = torch.randint(0, up_limit, target_shape, device="cuda") + + ref_inp = inp.to(torch.float64) + + criterion = torch.nn.CrossEntropyLoss() + + ref_out = criterion(ref_inp, target) + with flag_gems.use_gems(): + res_out = criterion(inp, target) + allclose_with_dtype(res_out, ref_out, dtype) + out_grad = torch.randn_like(res_out) + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + allclose_with_dtype(res_in_grad, ref_in_grad, dtype) From f4e26f9a65574a208f11a140f198d8cab77255ec Mon Sep 17 00:00:00 2001 From: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com> Date: Thu, 30 May 2024 10:51:48 +0800 Subject: [PATCH 08/16] [Tests] Reconstruct test files (#39) 1. reorganized test files in examples, benchmark, tests respectively 2. divide operators into pointwise, reduction, bias and so on 3. implement v1 bitwise ops add, sub, mul, div and pow by codegen --- CONTRIBUTING.md | 8 +- README.md | 49 +- README_cn.md | 42 +- benchmark/__init__.py | 0 benchmark/conftest.py | 15 + benchmark/performance_utils.py | 196 ++ benchmark/test_blas_perf.py | 69 + benchmark/test_pointwise_perf.py | 394 ++++ benchmark/test_reduction_perf.py | 225 ++ .../flag_gems => examples}/model_bert_test.py | 8 +- .../model_llama_test.py | 11 +- src/flag_gems/__init__.py | 2 +- src/flag_gems/fused/skip_rms_norm.py | 9 +- src/flag_gems/ops/__init__.py | 10 +- src/flag_gems/ops/add.py | 127 +- src/flag_gems/ops/all.py | 18 +- src/flag_gems/ops/amax.py | 8 +- src/flag_gems/ops/any.py | 18 +- src/flag_gems/ops/bitwise_and.py | 2 +- src/flag_gems/ops/clamp.py | 2 +- src/flag_gems/ops/div.py | 78 +- src/flag_gems/ops/dropout.py | 4 +- src/flag_gems/ops/mul.py | 46 +- src/flag_gems/ops/pow.py | 40 + src/flag_gems/ops/pow_scalar.py | 47 - src/flag_gems/ops/pow_tensor_scalar.py | 47 - src/flag_gems/ops/pow_tensor_tensor.py | 16 - src/flag_gems/ops/softmax.py | 12 +- src/flag_gems/ops/sub.py | 129 +- src/flag_gems/ops/sum.py | 4 +- src/flag_gems/utils/pointwise_dynamic.py | 126 +- src/flag_gems/utils/random_utils.py | 1 + tests/__init__.py | 0 tests/accuracy_utils.py | 51 + tests/conftest.py | 15 + tests/flag_gems/op_accu_test.py | 2022 ----------------- tests/flag_gems/op_perf_test.py | 553 ----- tests/test_binary_pointwise_ops.py | 569 +++++ tests/test_blas_ops.py | 101 + tests/test_reduction_ops.py | 607 +++++ tests/test_special_ops.py | 143 ++ tests/test_unary_pointwise_ops.py | 278 +++ 42 files changed, 2954 insertions(+), 3148 deletions(-) create mode 100644 benchmark/__init__.py create mode 100644 benchmark/conftest.py create mode 100644 benchmark/performance_utils.py create mode 100644 benchmark/test_blas_perf.py create mode 100644 benchmark/test_pointwise_perf.py create mode 100644 benchmark/test_reduction_perf.py rename {tests/flag_gems => examples}/model_bert_test.py (85%) rename {tests/flag_gems => examples}/model_llama_test.py (78%) create mode 100644 src/flag_gems/ops/pow.py delete mode 100644 src/flag_gems/ops/pow_scalar.py delete mode 100644 src/flag_gems/ops/pow_tensor_scalar.py delete mode 100644 src/flag_gems/ops/pow_tensor_tensor.py create mode 100644 tests/__init__.py create mode 100644 tests/accuracy_utils.py create mode 100644 tests/conftest.py delete mode 100644 tests/flag_gems/op_accu_test.py delete mode 100644 tests/flag_gems/op_perf_test.py create mode 100644 tests/test_binary_pointwise_ops.py create mode 100644 tests/test_blas_ops.py create mode 100644 tests/test_reduction_ops.py create mode 100644 tests/test_special_ops.py create mode 100644 tests/test_unary_pointwise_ops.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8b8d5b28..1c5f5b3a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,11 +24,9 @@ FlagGems │ │ ├──ops: single operators │ │ ├──fused: fused operators │ │ ├──__init__.py -├── tests -│ ├──flag_gems -│ │ ├──model_bert_test.py: test for BERT model running with flag_gems -│ │ ├──op_accu_test.py: test for accuracy of operators -│ │ ├──op_perf_test.py: test for performance of operators +├── tests: accuracy test files +├── benchmark: performance test files +├── examples: model test files ├── LICENSE ├── README.md ├── README_cn.md diff --git a/README.md b/README.md index 15a363c7..07efca9c 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,19 @@ By registering with the ATen backend of PyTorch, FlagGems facilitates a seamless - support pointwise operators: abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu - support reduction operators: cumsum, layernorm, mean, softmax +### v2.0 +- support BLAS operator: mv, outer +- support pointwise operators: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid +- support reduction operators: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm +- support fused operators: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding + ## Quick Start ### Requirements 1. Triton >= 2.2.0 2. PyTorch >= 2.1.2 -3. Transformers >= 4.31.0 +3. Transformers >= 4.40.2 ### Installation @@ -61,27 +67,41 @@ pip install . ### Execute -1. Run Tests - - Operator Accuracy +1. Test Operator Accuracy + - Run reference on cuda + ```shell + cd tests + pytest test_xx_ops.py + ``` + - Run reference on cpu ```shell - cd tests/flag_gems - pytest op_accu_test.py + cd tests + pytest test_xx_ops.py --device cpu ``` - - Model Accuracy + +2. Test Model Accuracy + ```shell + cd examples + pytest model_xx_test.py + ``` + +3. Test Operator Performance + - Test CUDA performance ```shell - cd tests/flag_gems - pytest model_bert_test.py + cd benchmark + pytest test_xx_perf.py -s ``` - - Operator Performance + - Test end-to-end performance ```shell - cd tests/flag_gems - python op_perf_test.py + cd benchmark + pytest test_xx_perf.py -s --mode cpu ``` -2. Run tests with logging infomation +4. Run tests with logging infomation ```shell pytest program.py --log-cli-level debug ``` + Not recommended in performance testing. ## Supported Operators @@ -89,9 +109,8 @@ Operators will be implemented according to [OperatorList.md](https://github.com/ ## Supported Models -| Model | float16 | float32 | bfloat16 | -| :---: | :---: | :---: | :---: | -| Bert_base | ✓ | ✓ | ✓ | +- Bert-base-uncased +- Llama-2-7b ## Supported Platforms diff --git a/README_cn.md b/README_cn.md index 0ebf6604..796c08d3 100644 --- a/README_cn.md +++ b/README_cn.md @@ -13,13 +13,19 @@ FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库 - 支持pointwise类算子:abs, add, div, dropout, exp, gelu, mul, pow, reciprocal, relu, rsqrt, silu, sub, triu - 支持reduction类算子:cumsum, layernorm, mean, softmax +### v2.0 +- 支持BLAS类算子: mv, outer +- 支持pointwise类算子: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid +- 支持reduction类算子: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm +- 支持融合算子: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding + ## 快速入门 ### 依赖 1. Triton >= 2.2.0 2. PyTorch >= 2.1.2 -3. Transformers >= 4.31.0 +3. Transformers >= 4.40.2 ### 安装 @@ -60,27 +66,40 @@ pip install . ### 执行 -1. 运行测试 - - 算子正确性测试 +1. 算子正确性测试 + - 在CUDA上运行参考实现 ```shell cd tests/flag_gems pytest op_accu_test.py ``` - - 模型正确性测试 + - 在CPU上运行参考实现 ```shell - cd tests/flag_gems - pytest model_bert_test.py + cd tests + pytest test_xx_ops.py --device cpu ``` - - 算子性能测试 +2. 模型正确性测试 + ```shell + cd examples + pytest model_xx_test.py + ``` + +3. 算子性能测试 + - 测试CUDA性能 ```shell - cd tests/flag_gems - python op_perf_test.py + cd benchmark + pytest test_xx_perf.py -s + ``` + - 测试端到端性能 + ```shell + cd benchmark + pytest test_xx_perf.py -s --mode cpu ``` 2. 运行时打印日志信息 ```shell pytest program.py --log-cli-level debug ``` + 测试性能时不建议打开。 ## 支持算子 @@ -88,9 +107,8 @@ pip install . ## 支持模型 -| Model | float16 | float32 | bfloat16 | -| :---: | :---: | :---: | :---: | -| Bert_base | ✓ | ✓ | ✓ | +- Bert-base-uncased +- Llama-2-7b ## 支持平台 diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmark/conftest.py b/benchmark/conftest.py new file mode 100644 index 00000000..88298796 --- /dev/null +++ b/benchmark/conftest.py @@ -0,0 +1,15 @@ +def pytest_addoption(parser): + parser.addoption( + "--mode", + action="store", + default="cuda", + required=False, + choices=["cuda", "cpu"], + help="record latency in cuda or cpu", + ) + + +def pytest_configure(config): + value = config.getoption("--mode") + global CPU_MODE + CPU_MODE = value == "cpu" diff --git a/benchmark/performance_utils.py b/benchmark/performance_utils.py new file mode 100644 index 00000000..34e90159 --- /dev/null +++ b/benchmark/performance_utils.py @@ -0,0 +1,196 @@ +import torch +import triton +import time +import flag_gems +from .conftest import CPU_MODE + + +WARMUP = 10 +REPETITION = 1000 + + +class Benchmark: + def __init__(self, op_name, torch_op, arg_func, dtype, batch, sizes): + self.op_name = op_name + self.torch_op = torch_op + self.arg_func = arg_func + self.dtype = dtype + self.batch = batch + self.sizes = sizes + + def profile(self, op, *args): + if CPU_MODE: + for i in range(WARMUP): + op(*args) + torch.cuda.synchronize() + start = time.time() + for i in range(REPETITION): + op(*args) + torch.cuda.synchronize() + end = time.time() + latency = (end - start) / REPETITION * 1000 + else: + latency = triton.testing.do_bench( + lambda: op(*args), warmup=WARMUP, rep=REPETITION, return_mode="median" + ) + # average latency in ms + return latency + + def run(self): + print(f"Operator {self.op_name} Performance Test ({self.dtype})") + print(f"Size Torch Latency (ms) Gems Latency (ms)") + print(f"--------------------------------------------------") + for size in self.sizes: + args = self.arg_func(self.dtype, self.batch, size) + torch_perf = self.profile(self.torch_op, *args) + with flag_gems.use_gems(): + gems_perf = self.profile(self.torch_op, *args) + print(f"{size: <10}{torch_perf: >20.6}{gems_perf: >20.6}") + + +FLOAT_DTYPES = [torch.float16, torch.float32, torch.bfloat16] +INT_DTYPES = [torch.int16, torch.int32] + + +DEFAULT_BATCH = 1 +POINTWISE_BATCH = 1024 +REDUCTION_BATCH = 1024 +BLAS_BATCH = 16 +SIZES = [i * 64 for i in range(1, 21)] + + +def unary_arg(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + return (inp,) + + +def unary_int_arg(dtype, batch, size): + inp = torch.randint( + low=0, high=0x7FFF, size=[batch, size], dtype=dtype, device="cuda" + ) + return (inp,) + + +def binary_args(dtype, batch, size): + inp1 = torch.randn([batch, size], dtype=dtype, device="cuda") + inp2 = torch.randn([batch, size], dtype=dtype, device="cuda") + return inp1, inp2 + + +def binary_int_args(dtype, batch, size): + inp1 = torch.randint( + low=0, high=0x7FFF, size=[batch, size], dtype=dtype, device="cuda" + ) + inp2 = torch.randint( + low=0, high=0x7FFF, size=[batch, size], dtype=dtype, device="cuda" + ) + return inp1, inp2 + + +def ternary_args(dtype, batch, size): + inp1 = torch.randn([batch, size], dtype=dtype, device="cuda") + inp2 = torch.randn([batch, size], dtype=dtype, device="cuda") + inp3 = torch.randn([batch, size], dtype=dtype, device="cuda") + return inp1, inp2, inp3 + + +def cross_entropy_loss_args(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + target = torch.randint( + 0, + size, + [ + batch, + ], + device="cuda", + ) + return inp, target + + +def cumsum_args(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + return inp, 1 + + +def group_norm_args(dtype, batch, size): + C = 16 + G = 16 + inp = torch.randn([batch, C, size], dtype=dtype, device="cuda") + weight = torch.randn( + [ + C, + ], + dtype=dtype, + device="cuda", + ) + bias = torch.randn( + [ + C, + ], + dtype=dtype, + device="cuda", + ) + return inp, G, weight, bias + + +def layer_norm_args(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + weight = torch.randn( + [ + size, + ], + dtype=dtype, + device="cuda", + ) + bias = torch.randn( + [ + size, + ], + dtype=dtype, + device="cuda", + ) + return ( + inp, + [ + size, + ], + weight, + bias, + ) + + +def addmm_args(dtype, batch, size): + bias = torch.randn( + [ + size, + ], + dtype=dtype, + device="cuda", + ) + inp1 = torch.randn([size, size], dtype=dtype, device="cuda") + inp2 = torch.randn([size, size], dtype=dtype, device="cuda") + return bias, inp1, inp2 + + +def bmm_args(dtype, batch, size): + inp1 = torch.randn([batch, size, size], dtype=dtype, device="cuda") + inp2 = torch.randn([batch, size, size], dtype=dtype, device="cuda") + return inp1, inp2 + + +def mm_args(dtype, batch, size): + inp1 = torch.randn([size, size], dtype=dtype, device="cuda") + inp2 = torch.randn([size, size], dtype=dtype, device="cuda") + return inp1, inp2 + + +def mv_args(dtype, batch, size): + inp1 = torch.randn([size, size], dtype=dtype, device="cuda") + inp2 = torch.randn([size], dtype=dtype, device="cuda") + return inp1, inp2 + + +def outer_args(dtype, batch, size): + inp1 = torch.randn([size], dtype=dtype, device="cuda") + inp2 = torch.randn([size], dtype=dtype, device="cuda") + return inp1, inp2 diff --git a/benchmark/test_blas_perf.py b/benchmark/test_blas_perf.py new file mode 100644 index 00000000..083a084a --- /dev/null +++ b/benchmark/test_blas_perf.py @@ -0,0 +1,69 @@ +import torch +import pytest +import flag_gems +from .performance_utils import * + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_addmm(dtype): + bench = Benchmark( + op_name="addmm", + torch_op=torch.addmm, + arg_func=addmm_args, + dtype=dtype, + batch=DEFAULT_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_bmm(dtype): + bench = Benchmark( + op_name="bmm", + torch_op=torch.bmm, + arg_func=bmm_args, + dtype=dtype, + batch=BLAS_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_mm(dtype): + bench = Benchmark( + op_name="mm", + torch_op=torch.mm, + arg_func=mm_args, + dtype=dtype, + batch=DEFAULT_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_mv(dtype): + bench = Benchmark( + op_name="mv", + torch_op=torch.mv, + arg_func=mv_args, + dtype=dtype, + batch=BLAS_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_outer(dtype): + bench = Benchmark( + op_name="outer", + torch_op=torch.outer, + arg_func=outer_args, + dtype=dtype, + batch=DEFAULT_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/benchmark/test_pointwise_perf.py b/benchmark/test_pointwise_perf.py new file mode 100644 index 00000000..cff64b14 --- /dev/null +++ b/benchmark/test_pointwise_perf.py @@ -0,0 +1,394 @@ +import torch +import pytest +import flag_gems +from .performance_utils import * + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_abs(dtype): + bench = Benchmark( + op_name="abs", + torch_op=torch.abs, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_add(dtype): + bench = Benchmark( + op_name="add", + torch_op=torch.add, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_perf_bitwiseand(dtype): + bench = Benchmark( + op_name="bitwiseand", + torch_op=torch.bitwise_and, + arg_func=binary_int_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_perf_bitwisenot(dtype): + bench = Benchmark( + op_name="bitwisenot", + torch_op=torch.bitwise_not, + arg_func=unary_int_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_perf_bitwiseor(dtype): + bench = Benchmark( + op_name="bitwiseor", + torch_op=torch.bitwise_or, + arg_func=binary_int_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_clamp(dtype): + bench = Benchmark( + op_name="clamp", + torch_op=torch.clamp, + arg_func=ternary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_cos(dtype): + bench = Benchmark( + op_name="cos", + torch_op=torch.cos, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_div(dtype): + bench = Benchmark( + op_name="div", + torch_op=torch.div, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_eq(dtype): + bench = Benchmark( + op_name="eq", + torch_op=torch.eq, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_exp(dtype): + bench = Benchmark( + op_name="exp", + torch_op=torch.exp, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_ge(dtype): + bench = Benchmark( + op_name="ge", + torch_op=torch.ge, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_gelu(dtype): + bench = Benchmark( + op_name="gelu", + torch_op=torch.nn.functional.gelu, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_gt(dtype): + bench = Benchmark( + op_name="gt", + torch_op=torch.gt, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_isinf(dtype): + bench = Benchmark( + op_name="isinf", + torch_op=torch.isinf, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_isnan(dtype): + bench = Benchmark( + op_name="isnan", + torch_op=torch.isnan, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_le(dtype): + bench = Benchmark( + op_name="le", + torch_op=torch.le, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_lt(dtype): + bench = Benchmark( + op_name="lt", + torch_op=torch.lt, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_mul(dtype): + bench = Benchmark( + op_name="mul", + torch_op=torch.mul, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_ne(dtype): + bench = Benchmark( + op_name="ne", + torch_op=torch.ne, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_neg(dtype): + bench = Benchmark( + op_name="neg", + torch_op=torch.neg, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_pow(dtype): + bench = Benchmark( + op_name="pow", + torch_op=torch.pow, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_reciprocal(dtype): + bench = Benchmark( + op_name="reciprocal", + torch_op=torch.reciprocal, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_relu(dtype): + bench = Benchmark( + op_name="relu", + torch_op=torch.nn.functional.relu, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_rsqrt(dtype): + bench = Benchmark( + op_name="rsqrt", + torch_op=torch.rsqrt, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_sigmoid(dtype): + bench = Benchmark( + op_name="sigmoid", + torch_op=torch.sigmoid, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_silu(dtype): + bench = Benchmark( + op_name="silu", + torch_op=torch.nn.functional.silu, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_sin(dtype): + bench = Benchmark( + op_name="sin", + torch_op=torch.sin, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_sub(dtype): + bench = Benchmark( + op_name="sub", + torch_op=torch.sub, + arg_func=binary_args, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_tanh(dtype): + bench = Benchmark( + op_name="tanh", + torch_op=torch.tanh, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_triu(dtype): + bench = Benchmark( + op_name="triu", + torch_op=torch.triu, + arg_func=unary_arg, + dtype=dtype, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py new file mode 100644 index 00000000..50ffbcfd --- /dev/null +++ b/benchmark/test_reduction_perf.py @@ -0,0 +1,225 @@ +import torch +import pytest +import flag_gems +from .performance_utils import * + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_all(dtype): + bench = Benchmark( + op_name="all", + torch_op=torch.all, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_amax(dtype): + bench = Benchmark( + op_name="amax", + torch_op=torch.amax, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_any(dtype): + bench = Benchmark( + op_name="any", + torch_op=torch.any, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_argmax(dtype): + bench = Benchmark( + op_name="argmax", + torch_op=torch.argmax, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_cross_entropy_loss(dtype): + bench = Benchmark( + op_name="cross_entropy_loss", + torch_op=torch.nn.CrossEntropyLoss(), + arg_func=cross_entropy_loss_args, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_cumsum(dtype): + bench = Benchmark( + op_name="cumsum", + torch_op=torch.cumsum, + arg_func=cumsum_args, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_groupnorm(dtype): + bench = Benchmark( + op_name="groupnorm", + torch_op=torch.nn.functional.group_norm, + arg_func=group_norm_args, + dtype=dtype, + batch=BLAS_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_layernorm(dtype): + bench = Benchmark( + op_name="layernorm", + torch_op=torch.layer_norm, + arg_func=layer_norm_args, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_log_softmax(dtype): + bench = Benchmark( + op_name="log_softmax", + torch_op=torch.nn.functional.log_softmax, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_max(dtype): + bench = Benchmark( + op_name="max", + torch_op=torch.max, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_mean(dtype): + bench = Benchmark( + op_name="mean", + torch_op=torch.mean, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_min(dtype): + bench = Benchmark( + op_name="min", + torch_op=torch.min, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_prod(dtype): + bench = Benchmark( + op_name="prod", + torch_op=torch.prod, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_softmax(dtype): + bench = Benchmark( + op_name="softmax", + torch_op=torch.nn.functional.softmax, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_sum(dtype): + bench = Benchmark( + op_name="sum", + torch_op=torch.sum, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_var_mean(dtype): + bench = Benchmark( + op_name="var_mean", + torch_op=torch.var_mean, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_perf_vector_norm(dtype): + bench = Benchmark( + op_name="vector_norm", + torch_op=torch.linalg.vector_norm, + arg_func=unary_arg, + dtype=dtype, + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/tests/flag_gems/model_bert_test.py b/examples/model_bert_test.py similarity index 85% rename from tests/flag_gems/model_bert_test.py rename to examples/model_bert_test.py index da10e08a..25c7f41f 100644 --- a/tests/flag_gems/model_bert_test.py +++ b/examples/model_bert_test.py @@ -5,12 +5,16 @@ from transformers import AutoTokenizer, BertConfig, BertModel +@pytest.mark.parametrize( + "prompt", + ["How are you today?", "What is your name?", "Who are you?", "Where are you from?"], +) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_bert(dtype): +def test_accuracy_bert(prompt, dtype): config = BertConfig() model = BertModel(config) tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - inputs = tokenizer("Hello, my dog is cute.", return_tensors="pt").to("cuda") + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") ref_model = copy.deepcopy(model) ref_model.to(torch.float64).to("cuda").eval() diff --git a/tests/flag_gems/model_llama_test.py b/examples/model_llama_test.py similarity index 78% rename from tests/flag_gems/model_llama_test.py rename to examples/model_llama_test.py index e0621916..3c18ae6e 100644 --- a/tests/flag_gems/model_llama_test.py +++ b/examples/model_llama_test.py @@ -5,23 +5,26 @@ from transformers import AutoTokenizer, AutoModelForCausalLM -@pytest.mark.parametrize("prompt", ["How are you today?", "What is your name?", "Who are you?", "Where are you from?"]) +@pytest.mark.parametrize( + "prompt", + ["How are you today?", "What is your name?", "Who are you?", "Where are you from?"], +) def test_accuracy_llama(prompt): tokenizer = AutoTokenizer.from_pretrained("sharpbai/Llama-2-7b-hf") model = AutoModelForCausalLM.from_pretrained("sharpbai/Llama-2-7b-hf") model.to("cuda").eval() - inputs = tokenizer(prompt, return_tensors='pt').to(device="cuda") + inputs = tokenizer(prompt, return_tensors="pt").to(device="cuda") with torch.no_grad(): ref_output = model.generate(**inputs, max_length=100, num_beams=5) with flag_gems.use_gems(): res_output = model.generate(**inputs, max_length=100, num_beams=5) - + maxdiff = torch.max(torch.abs(ref_output - res_output)) assert torch.allclose( ref_output, res_output, atol=1e-3, rtol=1e-3, - ), f"LLAMA FAIL with maxdiff {maxdiff} \nREF: {ref_output}\nRES: {res_output}" \ No newline at end of file + ), f"LLAMA FAIL with maxdiff {maxdiff} \nREF: {ref_output}\nRES: {res_output}" diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 79bde027..491a22a5 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -41,7 +41,7 @@ def enable(lib=aten_lib): lib.impl("lt.Tensor", lt, "CUDA") lib.impl("lt.Scalar", lt_scalar, "CUDA") lib.impl("rms_norm", rms_norm, "CUDA") - + lib.impl("mean", mean, "CUDA") lib.impl("mean.dim", mean_dim, "CUDA") lib.impl("mm", mm, "CUDA") diff --git a/src/flag_gems/fused/skip_rms_norm.py b/src/flag_gems/fused/skip_rms_norm.py index 35bd6531..1db47f0a 100644 --- a/src/flag_gems/fused/skip_rms_norm.py +++ b/src/flag_gems/fused/skip_rms_norm.py @@ -28,13 +28,12 @@ def skip_rms_norm_kernel( X += pid * x_stride_r R += pid * r_stride_r - mask = tl.arange(0, BLOCK_SIZE) < N cols = tl.arange(0, BLOCK_SIZE) x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) - x += r + x += r var = tl.sum(x * x / N, axis=0) rrms = 1 / tl.sqrt(var + eps) @@ -58,9 +57,11 @@ def forward(ctx, x, residual, normalized_shape, weight, eps=1e-5): weight = weight.contiguous() y = torch.empty_like(x) - skip_rms_norm_kernel[M, ](y, x, residual, weight, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE) + skip_rms_norm_kernel[M,]( + y, x, residual, weight, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE + ) return y def skip_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): - return SkipRmsNorm.apply(x, residual, normalized_shape, weight, eps) \ No newline at end of file + return SkipRmsNorm.apply(x, residual, normalized_shape, weight, eps) diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 069d8b66..773f96dc 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -3,7 +3,11 @@ from .abs import abs from .add import add from .addmm import addmm -from .bitwise_and import bitwise_and_tensor, bitwise_and_scalar, bitwise_and_scalar_tensor +from .bitwise_and import ( + bitwise_and_tensor, + bitwise_and_scalar, + bitwise_and_scalar_tensor, +) from .bitwise_not import bitwise_not from .bitwise_or import bitwise_or_tensor, bitwise_or_scalar, bitwise_or_scalar_tensor from .bmm import bmm @@ -30,9 +34,7 @@ from .mv import mv from .ne import ne, ne_scalar from .neg import neg -from .pow_scalar import pow_scalar -from .pow_tensor_scalar import pow_tensor_scalar -from .pow_tensor_tensor import pow_tensor_tensor +from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor from .reciprocal import reciprocal from .relu import relu from .rsqrt import rsqrt diff --git a/src/flag_gems/ops/add.py b/src/flag_gems/ops/add.py index d89894b9..7f849e64 100644 --- a/src/flag_gems/ops/add.py +++ b/src/flag_gems/ops/add.py @@ -2,136 +2,37 @@ import triton import triton.language as tl import logging -from ..utils import libentry +from ..utils import pointwise_dynamic -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[True, True, False]) @triton.jit -def add_kernel( - X, - Y, - alpha, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE) - mask = pid < M - Y_ptrs = Y + pid - X_ptrs = X + pid - O_ptrs = O + pid - X_val = tl.load(X_ptrs, mask) - Y_val = tl.load(Y_ptrs, mask) - O_val = X_val + Y_val * alpha - tl.store(O_ptrs, O_val, mask=mask) +def add_func(x, y, alpha): + return x + y * alpha -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[True, False, False]) @triton.jit -def add_tensor_scalar_kernel( - X, - Y_scalar, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE) - mask = pid < M - X_ptrs = X + pid - O_ptrs = O + pid - X_val = tl.load(X_ptrs, mask) - O_val = X_val + Y_scalar - tl.store(O_ptrs, O_val, mask=mask) +def add_func_tensor_scalar(x, y, alpha): + return x + y * alpha -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[False, True, False]) @triton.jit -def add_scalar_tensor_kernel( - X_scalar, - Y, - alpha, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE) - mask = pid < M - Y_ptrs = Y + pid - O_ptrs = O + pid - Y_val = tl.load(Y_ptrs, mask) - O_val = X_scalar + Y_val * alpha - tl.store(O_ptrs, O_val, mask=mask) +def add_func_scalar_tensor(x, y, alpha): + return x + y * alpha def add(A, B, *, alpha=1): logging.debug("GEMS ADD") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - try: - A, B = torch.broadcast_tensors(A, B) - except RuntimeError as e: - logging.error( - f"Add: Tensor shape {A.shape} and tensor shape {B.shape} cannot broadcast to each other." - ) - A = A.contiguous() - B = B.contiguous() - O = torch.empty_like(A) - M = A.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - add_kernel[grid_fn](A, B, alpha, O, M) + O = add_func(A, B, alpha) return O elif isinstance(A, torch.Tensor): - A = A.contiguous() - O = torch.empty_like(A) - M = A.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - add_tensor_scalar_kernel[grid_fn](A, B * alpha, O, M) + O = add_func_tensor_scalar(A, B, alpha) return O elif isinstance(B, torch.Tensor): - B = B.contiguous() - O = torch.empty_like(B) - M = B.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - add_scalar_tensor_kernel[grid_fn](A, B, alpha, O, M) + O = add_func_scalar_tensor(A, B, alpha) return O else: - # Both scalar - return A - B + return A + B * alpha diff --git a/src/flag_gems/ops/all.py b/src/flag_gems/ops/all.py index 45d97513..19360f39 100644 --- a/src/flag_gems/ops/all.py +++ b/src/flag_gems/ops/all.py @@ -8,7 +8,7 @@ # torch.all: Tests if all elements in input evaluate to True. # If the dtype of input is not BOOL, then test if all elements in input evaluate to non-zero value -# In triton function, test if all elements in input evaluate to non-zero value is ok. +# In triton function, test if all elements in input evaluate to non-zero value is ok. def cfggen(): block_m = [1, 2, 4, 8] configs = [ @@ -58,7 +58,7 @@ def all_kernel_1( inp, mid, n_elements, - mid_size, + mid_size, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -107,7 +107,7 @@ def all_dim(inp, dim=None, keepdim=False): if keepdim: out = torch.reshape(out, [1] * inp.ndim) else: - assert (dim >= -inp.ndim and dim < inp.ndim) , "Invalid dim" + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" dim = dim % inp.ndim order = list(range(0, inp.ndim)) order.remove(dim) @@ -119,9 +119,7 @@ def all_dim(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) all_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) @@ -132,7 +130,7 @@ def all_dims(inp, dim=None, keepdim=False): logging.debug("GEMS ALL DIMS") if dim is None or isinstance(dim, int): return all_dim(inp, dim=dim, keepdim=keepdim) - assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" + assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" shape = list(inp.shape) dim = [d % inp.ndim for d in dim] @@ -146,10 +144,8 @@ def all_dims(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) all_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) - return out \ No newline at end of file + return out diff --git a/src/flag_gems/ops/amax.py b/src/flag_gems/ops/amax.py index 5b5851cf..ebb3eba3 100644 --- a/src/flag_gems/ops/amax.py +++ b/src/flag_gems/ops/amax.py @@ -57,7 +57,7 @@ def amax_kernel( # Map the program id to the row of inp it should compute. pid = tl.program_id(0) rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] - inp = (inp + rows * N) + inp = inp + rows * N out = out + rows row_mask = rows < M @@ -95,7 +95,7 @@ def amax(inp, dim=None, keepdim=False): else: if isinstance(dim, int): dim = [dim] - assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" + assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" dtype = inp.dtype shape = list(inp.shape) @@ -110,9 +110,7 @@ def amax(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=dtype, device=inp.device) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) amax_kernel[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) diff --git a/src/flag_gems/ops/any.py b/src/flag_gems/ops/any.py index 86bd87cf..24741408 100644 --- a/src/flag_gems/ops/any.py +++ b/src/flag_gems/ops/any.py @@ -8,7 +8,7 @@ # torch.any: Tests if any elements in input evaluate to True. # If the dtype of input is not BOOL, then test if any elements in input evaluate to non-zero value -# In triton function, test if any elements in input evaluate to non-zero value is ok. +# In triton function, test if any elements in input evaluate to non-zero value is ok. def cfggen(): block_m = [1, 2, 4, 8] configs = [ @@ -58,7 +58,7 @@ def any_kernel_1( inp, mid, n_elements, - mid_size, + mid_size, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -107,7 +107,7 @@ def any_dim(inp, dim=None, keepdim=False): if keepdim: out = torch.reshape(out, [1] * inp.ndim) else: - assert (dim >= -inp.ndim and dim < inp.ndim) , "Invalid dim" + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" dim = dim % inp.ndim order = list(range(0, inp.ndim)) order.remove(dim) @@ -119,9 +119,7 @@ def any_dim(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) any_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) @@ -132,7 +130,7 @@ def any_dims(inp, dim=None, keepdim=False): logging.debug("GEMS ANY DIMS") if dim is None or isinstance(dim, int): return any_dim(inp, dim=dim, keepdim=keepdim) - assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" + assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" shape = list(inp.shape) dim = [d % inp.ndim for d in dim] @@ -146,10 +144,8 @@ def any_dims(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) any_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) - return out \ No newline at end of file + return out diff --git a/src/flag_gems/ops/bitwise_and.py b/src/flag_gems/ops/bitwise_and.py index b8a4848d..b83dcead 100644 --- a/src/flag_gems/ops/bitwise_and.py +++ b/src/flag_gems/ops/bitwise_and.py @@ -30,4 +30,4 @@ def bitwise_and_scalar(A, B): def bitwise_and_scalar_tensor(A, B): logging.debug("GEMS BITWISE AND SCALAR TENSOR") O = bitwise_and_func_scalar(B, A) - return O \ No newline at end of file + return O diff --git a/src/flag_gems/ops/clamp.py b/src/flag_gems/ops/clamp.py index c8a9e706..a5d51edc 100644 --- a/src/flag_gems/ops/clamp.py +++ b/src/flag_gems/ops/clamp.py @@ -63,4 +63,4 @@ def clamp(A, mini=None, maxi=None): O = clamp_func_min(A, mini) else: O = clamp_func(A, mini, maxi) - return O \ No newline at end of file + return O diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index f9fb5b35..ad711a0b 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -2,7 +2,7 @@ import triton import triton.language as tl import logging -from ..utils import libentry, pointwise_dynamic +from ..utils import pointwise_dynamic @pointwise_dynamic @@ -11,68 +11,16 @@ def div_func(x, y): return x / y -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[True, False]) @triton.jit -def div_tensor_scalar_kernel( - X, - Y_scalar, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - X_ptrs = X + offset - O_ptrs = O + offset - X_val = tl.load(X_ptrs, mask=mask, other=0.0) - O_val = X_val / Y_scalar - tl.store(O_ptrs, O_val.to(X_val.dtype), mask=mask) +def div_func_tensor_scalar(x, y): + return x / y -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[False, True]) @triton.jit -def div_scalar_tensor_kernel( - X_scalar, - Y, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - Y_ptrs = Y + offset - O_ptrs = O + offset - Y_val = tl.load(Y_ptrs, mask=mask, other=0.0) - O_val = X_scalar / Y_val - tl.store(O_ptrs, O_val.to(Y_val.dtype), mask=mask) +def div_func_scalar_tensor(x, y): + return x / y def div(A, B): @@ -81,18 +29,10 @@ def div(A, B): O = div_func(A, B) return O elif isinstance(A, torch.Tensor): - A = A.contiguous() - O = torch.empty_like(A) - M = A.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - div_tensor_scalar_kernel[grid_fn](A, B, O, M) + O = div_func_tensor_scalar(A, B) return O elif isinstance(B, torch.Tensor): - B = B.contiguous() - O = torch.empty_like(B) - M = B.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - div_scalar_tensor_kernel[grid_fn](A, B, O, M) + O = div_func_scalar_tensor(A, B) return O else: # Both scalar diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 910d3f49..7301679d 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -113,7 +113,9 @@ def backward(ctx, grad_outputs, kwargs): grad_inputs = torch.empty_like(grad_outputs) N = grad_outputs.numel() grid_fn = lambda meta: (triton.cdiv(N, meta["N_BLOCK_SIZE"]),) - dropout_backward_kernel[grid_fn](grad_outputs, grad_inputs, N, ctx.p, ctx.philox_seed, ctx.philox_offset) + dropout_backward_kernel[grid_fn]( + grad_outputs, grad_inputs, N, ctx.p, ctx.philox_seed, ctx.philox_offset + ) return grad_inputs, None, None diff --git a/src/flag_gems/ops/mul.py b/src/flag_gems/ops/mul.py index 981126cf..bf65ecba 100644 --- a/src/flag_gems/ops/mul.py +++ b/src/flag_gems/ops/mul.py @@ -2,7 +2,7 @@ import triton import triton.language as tl import logging -from ..utils import libentry, pointwise_dynamic +from ..utils import pointwise_dynamic @pointwise_dynamic @@ -11,36 +11,10 @@ def mul_func(x, y): return x * y -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[True, False]) @triton.jit -def mul_scalar_kernel( - X, - Y_scalar, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - X_ptrs = X + offset - O_ptrs = O + offset - X_val = tl.load(X_ptrs, mask=mask, other=0.0) - O_val = X_val * Y_scalar - tl.store(O_ptrs, O_val, mask=mask) +def mul_func_scalar(x, y): + return x * y def mul(A, B): @@ -49,18 +23,10 @@ def mul(A, B): O = mul_func(A, B) return O elif isinstance(A, torch.Tensor): - A = A.contiguous() - O = torch.empty_like(A) - M = A.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - mul_scalar_kernel[grid_fn](A, B, O, M) + O = mul_func_scalar(A, B) return O elif isinstance(B, torch.Tensor): - B = B.contiguous() - O = torch.empty_like(B) - M = B.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - mul_scalar_kernel[grid_fn](B, A, O, M) + O = mul_func_scalar(B, A) return O else: # Both scalar diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py new file mode 100644 index 00000000..9430d4d5 --- /dev/null +++ b/src/flag_gems/ops/pow.py @@ -0,0 +1,40 @@ +import triton +import triton.language as tl +import logging +from ..utils import pointwise_dynamic + + +@pointwise_dynamic +@triton.jit +def pow_func(x, exponent): + return tl.math.pow(x.to(tl.float32), exponent) + + +def pow_tensor_tensor(A, exponent): + logging.debug("GEMS POW_TENSOR_TENSOR") + O = pow_func(A, exponent) + return O + + +@pointwise_dynamic(is_tensor=[True, False]) +@triton.jit +def pow_func_tensor_scalar(x, exponent): + return tl.math.pow(x.to(tl.float32), exponent) + + +def pow_tensor_scalar(A, exponent): + logging.debug("GEMS POW_TENSOR_SCALAR") + O = pow_func_tensor_scalar(A, exponent) + return O + + +@pointwise_dynamic(is_tensor=[False, True]) +@triton.jit +def pow_func_scalar_tensor(x, exponent): + return tl.math.pow(x.to(tl.float32), exponent) + + +def pow_scalar(A, exponent): + logging.debug("GEMS POW_SCALAR") + O = pow_func_scalar_tensor(A, exponent) + return O diff --git a/src/flag_gems/ops/pow_scalar.py b/src/flag_gems/ops/pow_scalar.py deleted file mode 100644 index a1f81368..00000000 --- a/src/flag_gems/ops/pow_scalar.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import triton -import triton.language as tl -import logging -from ..utils import libentry - - -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) -@triton.jit -def pow_scalar_kernel( - X_val, - exponent, - Y, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - exp_ptrs = exponent + offset - Y_ptrs = Y + offset - exp_val = tl.load(exp_ptrs, mask=mask, other=0.0) - Y_val = tl.math.pow(X_val, exp_val) - tl.store(Y_ptrs, Y_val, mask=mask) - - -def pow_scalar(A, exponent): - logging.debug("GEMS POW_SCALAR") - exponent = exponent.contiguous() - O = torch.empty_like(exponent) - M = exponent.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - pow_scalar_kernel[grid_fn](A, exponent, O, M) - return O diff --git a/src/flag_gems/ops/pow_tensor_scalar.py b/src/flag_gems/ops/pow_tensor_scalar.py deleted file mode 100644 index fcd7bed7..00000000 --- a/src/flag_gems/ops/pow_tensor_scalar.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import triton -import triton.language as tl -import logging -from ..utils import libentry - - -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) -@triton.jit -def pow_tensor_scalar_kernel( - X, - exponent, - Y, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - X_ptrs = X + offset - Y_ptrs = Y + offset - X_val = tl.load(X_ptrs, mask=mask, other=0.0) - Y_val = tl.math.pow(X_val.to(tl.float32), exponent) - tl.store(Y_ptrs, Y_val, mask=mask) - - -def pow_tensor_scalar(A, exponent): - logging.debug("GEMS POW_TENSOR_SCALAR") - A = A.contiguous() - O = torch.empty_like(A) - M = A.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - pow_tensor_scalar_kernel[grid_fn](A, exponent, O, M) - return O diff --git a/src/flag_gems/ops/pow_tensor_tensor.py b/src/flag_gems/ops/pow_tensor_tensor.py deleted file mode 100644 index 9b97fe13..00000000 --- a/src/flag_gems/ops/pow_tensor_tensor.py +++ /dev/null @@ -1,16 +0,0 @@ -import triton -import triton.language as tl -import logging -from ..utils import libentry, pointwise_dynamic - - -@pointwise_dynamic -@triton.jit -def pow_func(x, exponent): - return tl.math.pow(x.to(tl.float32), exponent) - - -def pow_tensor_tensor(A, exponent): - logging.debug("GEMS POW_TENSOR_TENSOR") - O = pow_func(A, exponent) - return O diff --git a/src/flag_gems/ops/softmax.py b/src/flag_gems/ops/softmax.py index 87cf5d07..491267f7 100644 --- a/src/flag_gems/ops/softmax.py +++ b/src/flag_gems/ops/softmax.py @@ -25,9 +25,9 @@ @triton.heuristics( values={ "BLOCK_N": lambda args: triton.next_power_of_2(args["N"]), - "num_warps": lambda args: 4 - if args["N"] <= 1024 - else (8 if args["N"] <= 2048 else 16), + "num_warps": lambda args: ( + 4 if args["N"] <= 1024 else (8 if args["N"] <= 2048 else 16) + ), }, ) @triton.jit @@ -76,9 +76,9 @@ def softmax_kernel( @triton.heuristics( values={ "BLOCK_N": lambda args: triton.next_power_of_2(args["N"]), - "num_warps": lambda args: 4 - if args["N"] <= 1024 - else (8 if args["N"] <= 2048 else 16), + "num_warps": lambda args: ( + 4 if args["N"] <= 1024 else (8 if args["N"] <= 2048 else 16) + ), }, ) @triton.jit diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index dbb2913a..c501c034 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -2,139 +2,38 @@ import triton import triton.language as tl import logging -from ..utils import libentry +from ..utils import pointwise_dynamic -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[True, True, False]) @triton.jit -def sub_kernel( - X, - Y, - alpha, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - X_ptr = X + offset - Y_ptr = Y + offset - O_ptr = O + offset - X_val = tl.load(X_ptr, mask=mask, other=0.0) - Y_val = tl.load(Y_ptr, mask=mask, other=0.0) - O_val = X_val - Y_val * alpha - tl.store(O_ptr, O_val, mask=mask) +def sub_func(x, y, alpha): + return x - y * alpha -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[True, False, False]) @triton.jit -def sub_tensor_scalar_kernel( - X, - Y_scalar, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - X_ptr = X + offset - O_ptr = O + offset - X_val = tl.load(X_ptr, mask=mask, other=0.0) - O_val = X_val - Y_scalar - tl.store(O_ptr, O_val, mask=mask) +def sub_func_tensor_scalar(x, y, alpha): + return x - y * alpha -@libentry() -@triton.autotune( - configs=[ - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 256}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 512}, num_warps=2, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 1024}, num_warps=4, num_stages=5), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=4), - triton.Config({"M_BLOCK_SIZE": 2048}, num_warps=4, num_stages=5), - ], - key=["M"], -) +@pointwise_dynamic(is_tensor=[False, True, False]) @triton.jit -def sub_scalar_tensor_kernel( - X_scalar, - Y, - alpha, - O, - M, - M_BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) * M_BLOCK_SIZE - offset = pid + tl.arange(0, M_BLOCK_SIZE) - mask = offset < M - Y_ptr = Y + offset - O_ptr = O + offset - Y_val = tl.load(Y_ptr, mask=mask, other=0.0) - O_val = X_scalar - Y_val * alpha - tl.store(O_ptr, O_val, mask=mask) +def sub_func_scalar_tensor(x, y, alpha): + return x - y * alpha def sub(A, B, *, alpha=1): logging.debug("GEMS SUB") if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): - try: - A, B = torch.broadcast_tensors(A, B) - except RuntimeError as e: - logging.error( - f"Sub: Tensor shape {A.shape} and tensor shape {B.shape} cannot broadcast to each other." - ) - A = A.contiguous() - B = B.contiguous() - O = torch.empty_like(A) - M = A.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - sub_kernel[grid_fn](A, B, alpha, O, M) + O = sub_func(A, B, alpha) return O elif isinstance(A, torch.Tensor): - A = A.contiguous() - O = torch.empty_like(A) - M = A.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - sub_tensor_scalar_kernel[grid_fn](A, B * alpha, O, M) + O = sub_func_tensor_scalar(A, B, alpha) return O elif isinstance(B, torch.Tensor): - B = B.contiguous() - O = torch.empty_like(B) - M = B.numel() - grid_fn = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) - sub_scalar_tensor_kernel[grid_fn](A, B, alpha, O, M) + O = sub_func_scalar_tensor(A, B, alhpa) return O else: # Both scalar - return A - B + return A - B * alpha diff --git a/src/flag_gems/ops/sum.py b/src/flag_gems/ops/sum.py index 408888ba..8830c46c 100644 --- a/src/flag_gems/ops/sum.py +++ b/src/flag_gems/ops/sum.py @@ -106,9 +106,7 @@ def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): out = torch.empty(shape, dtype=dtype, device=inp.device) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) sum_kernel[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 5dcf69a4..60b172a8 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -21,13 +21,16 @@ def _type_name(type) -> str: return str(type) return str(type) + def _check_typed_list(container, type): for item in container: assert isinstance(item, type) + def _check_sized_list(container, size): assert len(container) == size + class OPDesc: _num_inputs: int _is_tensor: List[bool] @@ -83,7 +86,9 @@ def __init__( else: self._is_tensor = [item is None for item in dtypes] else: - raise ValueError("Cannot make OPDesc when none of (num_inputs, is_tensor, dtypes) is specified.") + raise ValueError( + "Cannot make OPDesc when none of (num_inputs, is_tensor, dtypes) is specified." + ) if output_dtypes is not None: _check_typed_list(output_dtypes, torch.dtype) @@ -94,7 +99,7 @@ def __init__( _check_sized_list(output_dtypes, num_outputs) self._output_dtypes = output_dtypes else: - self._output_dtypes = [None] * num_inputs # infer from the 1st input + self._output_dtypes = [None] * num_inputs # infer from the 1st input elif output_dtypes is not None: self._num_outputs = len(output_dtypes) self._output_dtypes = output_dtypes @@ -108,7 +113,6 @@ def __init__( self._num_input_tensors = sum(self._is_tensor) self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors - def num_inputs(self): # num of arguments, outputs not included return self._num_inputs @@ -159,6 +163,7 @@ def signature(self, outputs_in_arg: bool = False): def __str__(self) -> str: return self.signature(outputs_in_arg=False) + # --------------------------- pointwise wrapper genration ----------------------------------- def parameter_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str: """Generate parameter declaration with type annotation for wrapper function. @@ -174,7 +179,9 @@ def parameter_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str input_tensor_index += 1 else: if op_desc.input_type(i) is not None: - parameters.append(f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))}") + parameters.append( + f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))}" + ) else: parameters.append(f"val{non_tensor_index}") non_tensor_index += 1 @@ -187,6 +194,7 @@ def parameter_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str return ", ".join(parameters) + def parameter_ref_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> str: """Generate parameter reference for wrapper function. Example: in0, val0, out0 @@ -211,6 +219,7 @@ def parameter_ref_for_wrapper(op_desc: OPDesc, include_outputs: bool = False) -> return ", ".join(parameters) + def output_ref_for_wrapper(op_desc: OPDesc) -> str: """Generate output variable refernece for wrapper function. Example: out0, out1 @@ -218,14 +227,17 @@ def output_ref_for_wrapper(op_desc: OPDesc) -> str: parameters: List[str] = [f"out{i}" for i in range(op_desc.num_outputs())] return ", ".join(parameters) + def docstring_for_functional_wrapper(op_desc: OPDesc): doc = f'"""Generated wrapper function with {str(op_desc)}"""' return doc + def docstring_for_destination_passing_wrapper(op_desc: OPDesc): doc = f'"""Generated wrapper function with {op_desc.signature(outputs_in_arg=True)}"""' return doc + def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("import math") code.writeline("import torch") @@ -239,11 +251,12 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.newline() return code + def generate_functional_pointwise_wrapper( op_desc: OPDesc, wrapper_name: str, destination_passing_func_name: str, - code: IndentedBuffer + code: IndentedBuffer, ) -> IndentedBuffer: # wrapper signature parameters: str = parameter_for_wrapper(op_desc, include_outputs=False) @@ -255,16 +268,22 @@ def generate_functional_pointwise_wrapper( wrapper_docstring = docstring_for_functional_wrapper(op_desc) code.writeline(wrapper_docstring) - shapes_str = ", ".join(f"in{i}.shape" for i in range(op_desc.num_input_tensors())) + shapes_str = ", ".join( + f"in{i}.shape" for i in range(op_desc.num_input_tensors()) + ) code.writeline(f"shape = broadcast_shapes([{shapes_str}])") # output allocation num_output_tensor_index = 0 for i in range(op_desc.num_outputs()): if op_desc.output_dtype(i) is None: - code.writeline(f"out{num_output_tensor_index} = torch.empty(shape, dtype=in0.dtype, device=in0.device)") + code.writeline( + f"out{num_output_tensor_index} = torch.empty(shape, dtype=in0.dtype, device=in0.device)" + ) else: - code.writeline(f"out{num_output_tensor_index} = torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)") + code.writeline( + f"out{num_output_tensor_index} = torch.empty(shape, dtype={_type_name(op_desc.output_dtype(i))}, device=in0.device)" + ) num_output_tensor_index += 1 # call destination_passing_func @@ -277,12 +296,13 @@ def generate_functional_pointwise_wrapper( code.newline() return code + def generate_destination_passing_pointwise_wrapper( op_desc: OPDesc, rank: int, wrapper_name: str, kernel_name: str, - code: IndentedBuffer + code: IndentedBuffer, ) -> IndentedBuffer: # wrapper signature parameters: str = parameter_for_wrapper(op_desc, include_outputs=True) @@ -292,7 +312,7 @@ def generate_destination_passing_pointwise_wrapper( # task partitioning, 1d task indexing tile_size = 512 num_warps = 4 - if rank == 0: # special case with rank-0, only 1 element to compute + if rank == 0: # special case with rank-0, only 1 element to compute tile_size = 32 num_warps = 1 @@ -301,7 +321,9 @@ def generate_destination_passing_pointwise_wrapper( wrapper_docstring = docstring_for_destination_passing_wrapper(op_desc) code.writeline(wrapper_docstring) - shapes_str = ", ".join(f"in{i}.shape" for i in range(op_desc.num_input_tensors())) + shapes_str = ", ".join( + f"in{i}.shape" for i in range(op_desc.num_input_tensors()) + ) code.writeline(f"shape = broadcast_shapes([{shapes_str}])") code.writeline(f"num_tasks = volume(shape)") code.newline() @@ -310,7 +332,9 @@ def generate_destination_passing_pointwise_wrapper( if rank > 0: code.writeline("# strides of each tensor argument w.r.t the task space") for i in range(op_desc.num_input_tensors()): - code.writeline(f"in{i}_strides = broadcasted_stride(in{i}.shape, in{i}.stride(), shape)") + code.writeline( + f"in{i}_strides = broadcasted_stride(in{i}.shape, in{i}.stride(), shape)" + ) for i in range(op_desc.num_output_tensors()): code.writeline(f"out{i}_strides = out{i}.stride()") @@ -326,9 +350,10 @@ def generate_destination_passing_pointwise_wrapper( kernel_launch: str = f"{kernel_name}[grid](" code.writeline(kernel_launch) - with code.indent(): - code.writeline(f"{parameter_ref_for_wrapper(op_desc, include_outputs=True)},") + code.writeline( + f"{parameter_ref_for_wrapper(op_desc, include_outputs=True)}," + ) if rank > 0: for i in range(op_desc.num_input_tensors()): @@ -353,16 +378,19 @@ def generate_destination_passing_pointwise_wrapper( code.newline() return code + def generate_pointwise_kernel( op_desc: OPDesc, scalar_fn: JITFunction, rank: int, kernel_name: str, - code: IndentedBuffer + code: IndentedBuffer, ) -> IndentedBuffer: code.writeline("@libentry()") if op_desc.num_non_tensor_args() > 0: - non_specialize_arg_names = [f"val{i}" for i in range(op_desc.num_non_tensor_args())] + non_specialize_arg_names = [ + f"val{i}" for i in range(op_desc.num_non_tensor_args()) + ] code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") else: code.writeline("@triton.jit") @@ -377,12 +405,16 @@ def generate_pointwise_kernel( # inputs ptrs & non tensor inputs for i in range(op_desc.num_inputs()): if op_desc.is_tensor(i): - code.writeline(f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type") + code.writeline( + f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" + ) function_ns.create_name(f"in{input_tensor_index}_ptr") input_tensor_index += 1 else: if op_desc.input_type(i) is not None: - code.writeline(f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))},") + code.writeline( + f"val{non_tensor_index}: {_type_name(op_desc.input_type(i))}," + ) else: code.writeline(f"val{non_tensor_index},") function_ns.create_name(f"val{non_tensor_index}") @@ -390,11 +422,12 @@ def generate_pointwise_kernel( # output ptrs for i in range(op_desc.num_outputs()): - code.writeline(f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type") + code.writeline( + f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" + ) function_ns.create_name(f"out{output_tensor_index}_ptr") output_tensor_index += 1 - if rank > 0: # strides for inputs for i in range(op_desc.num_input_tensors()): @@ -459,7 +492,9 @@ def generate_pointwise_kernel( code.writeline("# loads") for i in range(op_desc.num_input_tensors()): if rank > 0: - ptrs_expr: str = " + ".join(f"i{j} * in{i}_stride{j}" for j in range(rank)) + ptrs_expr: str = " + ".join( + f"i{j} * in{i}_stride{j}" for j in range(rank) + ) ptrs_expr: str = f"in{i}_ptr + {ptrs_expr}" load_stmt: str = f"in{i} = tl.load({ptrs_expr}, mask=mask)" else: @@ -485,7 +520,6 @@ def generate_pointwise_kernel( outputs_to_scalar_fn = [f"out{i}" for i in range(op_desc.num_outputs())] - compute_body = inline_function( scalar_fn, inputs_to_scalar_fn, @@ -500,7 +534,9 @@ def generate_pointwise_kernel( code.writeline("# stores") for i in range(op_desc.num_output_tensors()): if rank > 0: - ptrs_expr: str = " + ".join(f"i{j} * out{i}_stride{j}" for j in range(rank)) + ptrs_expr: str = " + ".join( + f"i{j} * out{i}_stride{j}" for j in range(rank) + ) ptrs_expr: str = f"out{i}_ptr + {ptrs_expr}" store_stmt: str = f"tl.store({ptrs_expr}, out{i}, mask=mask)" else: @@ -510,16 +546,19 @@ def generate_pointwise_kernel( code.newline() return code + def generate_code( op_desc: OPDesc, scalar_fn: JITFunction, inputs: Tuple[Any], wrapper_name: str, - destination_passing_func_name:str, - kernel_name:str, - code: IndentedBuffer + destination_passing_func_name: str, + kernel_name: str, + code: IndentedBuffer, ) -> IndentedBuffer: - assert len(inputs) == op_desc.num_inputs(), "the number of inputs does not match {str(op_desc)}" + assert ( + len(inputs) == op_desc.num_inputs() + ), "the number of inputs does not match {str(op_desc)}" input_tensor_ids = [i for i in range(op_desc.num_inputs()) if op_desc.is_tensor(i)] tensor_shapes = [inputs[i].shape for i in input_tensor_ids] shape = broadcast_shapes(tensor_shapes) @@ -527,11 +566,16 @@ def generate_code( # the only runtime determined factor is the rank of the task space code = generate_imports(code) - code = generate_functional_pointwise_wrapper(op_desc, wrapper_name, destination_passing_func_name, code) - code = generate_destination_passing_pointwise_wrapper(op_desc, rank, destination_passing_func_name, kernel_name, code) + code = generate_functional_pointwise_wrapper( + op_desc, wrapper_name, destination_passing_func_name, code + ) + code = generate_destination_passing_pointwise_wrapper( + op_desc, rank, destination_passing_func_name, kernel_name, code + ) code = generate_pointwise_kernel(op_desc, scalar_fn, rank, kernel_name, code) return code + class PointwiseDynamicFunction: """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). @@ -563,7 +607,8 @@ def __call__(self, *args): "_wrapper", "_wrapper_out", "_jit_function", - code) + code, + ) file_name = f"pointwise_dynamic_{self._scalar_fn_cache_key}_rank_{key}.py" with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: @@ -571,8 +616,8 @@ def __call__(self, *args): # load spec = importlib.util.spec_from_file_location( - f"_gen_module_{self._scalar_fn_cache_key}", - f.name) + f"_gen_module_{self._scalar_fn_cache_key}", f.name + ) m = importlib.util.module_from_spec(spec) # do not expose it to sys.modules # sys.modules["_add_module"] = m @@ -594,7 +639,7 @@ def pointwise_dynamic( is_tensor: Optional[List[bool]] = None, dtypes: Optional[List[Optional[type]]] = None, num_outputs: Optional[int] = None, - output_dtypes: Optional[List[type]] = None + output_dtypes: Optional[List[type]] = None, ): def decorator(fn): nonlocal num_inputs @@ -605,7 +650,8 @@ def decorator(fn): is_tensor=is_tensor, dtypes=dtypes, num_outputs=num_outputs, - output_dtypes=output_dtypes) + output_dtypes=output_dtypes, + ) return PointwiseDynamicFunction(op_desc, fn) if f is not None: @@ -614,9 +660,8 @@ def decorator(fn): if __name__ == "__main__": - @pointwise_dynamic( - is_tensor=[True, False, True], - dtypes=[None, float, None]) + + @pointwise_dynamic(is_tensor=[True, False, True], dtypes=[None, float, None]) @triton.jit def saxpy(x, alpha, y): return x * alpha + y @@ -642,7 +687,6 @@ def saxpy(x, alpha, y): torch.testing.assert_close(out1, out2) print() - @pointwise_dynamic(output_dtypes=[torch.bool]) @triton.jit def ge(x, y): @@ -696,7 +740,10 @@ def ordinary2(x, y): @pointwise_dynamic(is_tensor=[True, False], output_dtypes=[torch.bool]) @triton.jit def eq(x, y): - return x.to(tl.float32) == y.to(tl.float32) # ensures that y is not used for specialization + return x.to(tl.float32) == y.to( + tl.float32 + ) # ensures that y is not used for specialization + x = torch.arange(10, device="cuda") y = 1 # by default value 1 is treated as constexpr even thought it is not marked as constexpr @@ -707,4 +754,3 @@ def eq(x, y): print(out2) torch.testing.assert_close(out1, out2) print() - diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py index af55c3cc..b08c8cfd 100644 --- a/src/flag_gems/utils/random_utils.py +++ b/src/flag_gems/utils/random_utils.py @@ -1,5 +1,6 @@ import torch + # This function is roughly a python wrapper of CUDAGeneratorImpl::philox_cuda_state in Pytorch. # https://github.com/pytorch/pytorch/blob/8a4597980c2692b73f35fb3c7145eaeaf2273e77/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp#L452 # It returns the current state of the default Philox RNG in seed and offset and diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py new file mode 100644 index 00000000..56819fab --- /dev/null +++ b/tests/accuracy_utils.py @@ -0,0 +1,51 @@ +import torch +from .conftest import TO_CPU + + +major, minor = torch.__version__.split(".")[:2] +skip_expr = major < "2" or minor < "2" +skip_reason = "PyTorch < 2.2.0 does not support" + + +RESOLUTION = { + torch.float16: 1e-3, + torch.float32: 1.3e-6, + torch.bfloat16: 0.016, +} + +POINTWISE_SHAPES = [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)] +REDUCTION_SHAPES = [(1024, 64 * i) for i in range(1, 10, 2)] +MNK_SHAPES = [15, 160, 1024] + +FLOAT_DTYPES = [torch.float16, torch.float32, torch.bfloat16] +INT_DTYPES = [torch.int16, torch.int32] + +SCALARS = [0.001, -0.999, 100.001, -111.999] +DIM_LIST = [0, 1] +DIMS_LIST = [0, 1, [0, 1], [1, 0]] + + +def to_reference(inp, upcast=False): + if inp is None: + return None + ref_inp = inp + if TO_CPU: + ref_inp = ref_inp.to("cpu") + if upcast: + ref_inp = ref_inp.to(torch.float64) + return ref_inp + + +def gems_assert_close(a, b, dtype, equal_nan=False, reduce_dim=1): + if TO_CPU: + a = a.to("cpu") + b = b.to(dtype) + atol = 1e-4 * reduce_dim + rtol = RESOLUTION[dtype] + torch.testing.assert_close(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) + + +def gems_assert_equal(a, b): + if TO_CPU: + a = a.to("cpu") + assert torch.equal(a, b) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..9045209f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +def pytest_addoption(parser): + parser.addoption( + "--device", + action="store", + default="cuda", + required=False, + choices=["cuda", "cpu"], + help="device to run reference tests on", + ) + + +def pytest_configure(config): + value = config.getoption("--device") + global TO_CPU + TO_CPU = value == "cpu" diff --git a/tests/flag_gems/op_accu_test.py b/tests/flag_gems/op_accu_test.py deleted file mode 100644 index 4875cca8..00000000 --- a/tests/flag_gems/op_accu_test.py +++ /dev/null @@ -1,2022 +0,0 @@ -import torch -import pytest -import flag_gems - - -major, minor = torch.__version__.split(".")[:2] -skip_expr = major < "2" or minor < "2" - - -RESOLUTION = { - torch.float16: 1e-3, - torch.float32: 1.3e-6, - torch.bfloat16: 0.016, - torch.int64: 1.3e-6, -} - - -def allclose_with_dtype(a, b, dtype, equal_nan=False, reduce_dim=1): - b = b.to(dtype) - atol = 1e-4 * reduce_dim - rtol = RESOLUTION[dtype] - torch.testing.assert_close(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_abs(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.abs(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.abs(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_add(shape, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.add(inp1.to(torch.float64), inp2.to(torch.float64), alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.add(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape_a", - [(16, 1024, 256)], -) -@pytest.mark.parametrize( - "shape_b", - [(1, 256), (1, 1, 256), (16, 1, 256), (1, 1024, 256), (1024, 256)], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_add_broadcast(shape_a, shape_b, alpha, dtype): - inp1 = torch.randn(shape_a, dtype=dtype, device="cuda") - inp2 = torch.randn(shape_b, dtype=dtype, device="cuda") - - ref_out = torch.add(inp1.to(torch.float64), inp2.to(torch.float64), alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.add(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [111.111, -999.999], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_add_tensor_scalar(shape, scalar, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = scalar - - ref_out = torch.add(inp1.to(torch.float64), inp2, alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.add(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [111.111, -999.999], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_add_scalar_tensor(shape, scalar, alpha, dtype): - inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.add(inp1, inp2.to(torch.float64), alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.add(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "M, N, K", - [ - (256, 256, 256), - (1024, 1024, 1024), - (1024, 128, 2048), - (1024, 64, 1280), - (640, 256, 512), - ], -) -@pytest.mark.parametrize("alpha", [1.0, 0.5]) -@pytest.mark.parametrize("beta", [1.0, 0.5]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_addmm(M, N, K, alpha, beta, dtype): - mat1 = torch.randn((M, K), dtype=dtype, device="cuda") - mat2 = torch.randn((K, N), dtype=dtype, device="cuda") - bias = torch.randn((N,), dtype=dtype, device="cuda") - - ref_out = torch.addmm( - bias.to(torch.float64), - mat1.to(torch.float64), - mat2.to(torch.float64), - alpha=alpha, - beta=beta, - ) - with flag_gems.use_gems(): - res_out = torch.addmm(bias, mat1, mat2, alpha=alpha, beta=beta) - - allclose_with_dtype(res_out, ref_out, dtype, reduce_dim=K) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) -def test_accuracy_bitwiseand(shape, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - - ref_out = torch.bitwise_and(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.bitwise_and(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0x000f, 0x7fff, -0x00ff], -) -@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) -def test_accuracy_bitwiseand_scalar(shape, scalar, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - - ref_out = torch.bitwise_and(inp1, scalar) - with flag_gems.use_gems(): - res_out = torch.bitwise_and(inp1, scalar) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0x000f, 0x7fff, -0x00ff], -) -@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) -def test_accuracy_bitwiseand_scalar_tensor(shape, scalar, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - - ref_out = torch.bitwise_and(scalar, inp1) - with flag_gems.use_gems(): - res_out = torch.bitwise_and(scalar, inp1) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) -def test_accuracy_bitwisenot(shape, dtype): - inp = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - - ref_out = torch.bitwise_not(inp) - with flag_gems.use_gems(): - res_out = torch.bitwise_not(inp) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) -def test_accuracy_bitwiseor(shape, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - - ref_out = torch.bitwise_or(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.bitwise_or(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0x000f, 0x7fff, -0x00ff], -) -@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) -def test_accuracy_bitwiseor_scalar(shape, scalar, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - - ref_out = torch.bitwise_or(inp1, scalar) - with flag_gems.use_gems(): - res_out = torch.bitwise_or(inp1, scalar) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0x000f, 0x7fff, -0x00ff], -) -@pytest.mark.parametrize("dtype", [torch.int16, torch.int32]) -def test_accuracy_bitwiseor_scalar_tensor(shape, scalar, dtype): - inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" - ) - - ref_out = torch.bitwise_or(scalar, inp1) - with flag_gems.use_gems(): - res_out = torch.bitwise_or(scalar, inp1) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "batch, M, N, K", - [ - (1, 1024, 1024, 1024), - (3, 1024, 1024, 2048), - (4, 1024, 64, 1280), - (8, 640, 256, 512), - (16, 1024, 128, 2048), - ], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_bmm(batch, M, N, K, dtype): - tensor_A = torch.randn((batch, M, K), dtype=dtype, device="cuda") - tensor_B = torch.randn((batch, K, N), dtype=dtype, device="cuda") - - ref_out = torch.bmm(tensor_A.to(torch.float64), tensor_B.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.bmm(tensor_A, tensor_B) - - allclose_with_dtype(res_out, ref_out, dtype, reduce_dim=K) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("isnone", [None, "max", "min"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_clamp(shape, isnone, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - import random - maxi = random.random() - mini = random.random() - if isnone == "min": - mini = None - elif isnone == "max": - maxi = None - - ref_out = torch.clamp(inp, min=mini, max=maxi) - with flag_gems.use_gems(): - res_out = torch.clamp(inp, min=mini, max=maxi) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("isnone", [None, "max", "min"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_clamp_tensor(shape, isnone, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - maxi = torch.randn(shape, dtype=dtype, device="cuda") - mini = torch.randn(shape, dtype=dtype, device="cuda") - if isnone == "min": - mini = None - elif isnone == "max": - maxi = None - - ref_out = torch.clamp(inp, min=mini, max=maxi) - with flag_gems.use_gems(): - res_out = torch.clamp(inp, min=mini, max=maxi) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_cos(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_inp = inp.to(torch.float64) - ref_out = torch.cos(ref_inp) - with flag_gems.use_gems(): - res_out = torch.cos(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_cumsum(shape, dtype): - dim = 1 - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.cumsum(inp.to(torch.float64), dim=dim) - with flag_gems.use_gems(): - res_out = torch.cumsum(inp, dim=dim) - - allclose_with_dtype(res_out, ref_out, dtype, reduce_dim=shape[dim]) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_div(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.div(inp1.to(torch.float64), inp2.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.div(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype, equal_nan=True) - - -@pytest.mark.parametrize( - "shape_a", - [(16, 1024, 256)], -) -@pytest.mark.parametrize( - "shape_b", - [(1, 256), (1, 1, 256), (16, 1, 256), (1, 1024, 256), (1024, 256)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_div_broadcast(shape_a, shape_b, dtype): - inp1 = torch.randn(shape_a, dtype=dtype, device="cuda") - inp2 = torch.randn(shape_b, dtype=dtype, device="cuda") - - ref_out = torch.div(inp1.to(torch.float64), inp2.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.div(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [111.111, -999.999], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_div_tensor_scalar(shape, scalar, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = scalar - - ref_out = torch.div(inp1.to(torch.float64), inp2) - with flag_gems.use_gems(): - res_out = torch.div(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [200, 100], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_div_scalar_tensor(shape, scalar, dtype): - inp1 = scalar - inp2 = torch.randint(-5, 5, shape, dtype=dtype, device="cuda") - - ref_out = torch.div(inp1, inp2.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.div(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("p", [0.3, 0.6, 0.9]) -def test_accuracy_dropout(shape, dtype, p): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - - ref_out = torch.nn.functional.dropout(inp, p, True) - with flag_gems.use_gems(): - res_out = torch.nn.functional.dropout(inp, p, True) - - # nz_ref = torch.sum(ref_out == 0.0) - # nz_res = torch.sum(res_out == 0.0) - - num_equal = torch.sum(torch.isclose(ref_out, res_out)).item() - exp_equal = (p * p + (1 - p) * (1 - p)) * inp.numel() - assert ( - abs(num_equal - exp_equal) / exp_equal <= 0.05 - ), f"num_equal: {num_equal}, exp_equal: {exp_equal}, num_total: {inp.numel()}" - - out_grad = torch.randn_like(inp) - (ref_in_grad,) = torch.autograd.grad(ref_out, inp, out_grad) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - num_equal = torch.sum(torch.isclose(ref_in_grad, res_in_grad)).item() - exp_equal = (p * p + (1 - p) * (1 - p)) * inp.numel() - assert ( - abs(num_equal - exp_equal) / exp_equal <= 0.05 - ), f"num_equal: {num_equal}, exp_equal: {exp_equal}, num_total: {inp.numel()}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_eq(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - - ref_out = torch.eq(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.eq(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_eq_scalar(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - inp2 = 0 - - ref_out = torch.eq(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.eq(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_exp(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.exp(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.exp(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_ge(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.ge(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.ge(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0.5, 1.0, 100.9, -111.9], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_ge_scalar(shape, scalar, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.ge(inp, scalar) - with flag_gems.use_gems(): - res_out = torch.ge(inp, scalar) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_gelu(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.nn.functional.gelu(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.nn.functional.gelu(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "N, C, H, W, num_groups", - [ - (16, 3, 16, 16, 1), - (32, 32, 32, 32, 8), - (1, 32, 32, 32, 8), - (1, 32, 32, 32, 16), - (1, 64, 32, 32, 16), - (1, 64, 32, 32, 32), - (1, 64, 32, 32, 64), - ], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype): - HW = H * W - inp = torch.randn(size=(N, C, H, W), dtype=dtype, device="cuda", requires_grad=True) - weight = torch.randn(size=(C,), dtype=dtype, device="cuda", requires_grad=True) - bias = torch.randn(size=(C,), dtype=dtype, device="cuda", requires_grad=True) - eps = 1e-5 - - ref_inp = inp.to(torch.float64) - ref_weight = weight.to(torch.float64) - ref_bias = bias.to(torch.float64) - - ref_out = torch.nn.functional.group_norm( - ref_inp, num_groups, weight=ref_weight, bias=ref_bias, eps=eps - ) - ref_mean = torch.mean(ref_inp.reshape([N, num_groups, -1]), dim=2) - ref_var = torch.var(ref_inp.reshape([N, num_groups, -1]), dim=2, correction=0) - ref_rstd = torch.rsqrt(ref_var + eps) - - (res_out, res_mean, res_rstd) = flag_gems.group_norm( - inp, weight, bias, N, C, HW, num_groups, eps - ) - - allclose_with_dtype(res_mean, ref_mean, dtype) - allclose_with_dtype(res_rstd, ref_rstd, dtype) - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(inp) - (ref_in_grad, ref_weight_grad, ref_bias_grad) = torch.autograd.grad( - ref_out, (ref_inp, ref_weight, ref_bias), out_grad.to(torch.float64) - ) - (res_in_grad, res_weight_grad, res_bias_grad) = torch.autograd.grad( - res_out, (inp, weight, bias), out_grad - ) - group_size = C // num_groups - allclose_with_dtype(res_in_grad, ref_in_grad, dtype, reduce_dim=group_size * HW) - allclose_with_dtype(res_weight_grad, ref_weight_grad, dtype, reduce_dim=N * HW) - allclose_with_dtype(res_bias_grad, ref_bias_grad, dtype, reduce_dim=N * HW) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_gt(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.gt(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.gt(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0.5, 1.0, 100.9, -111.9], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_gt_scalar(shape, scalar, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.gt(inp, scalar) - with flag_gems.use_gems(): - res_out = torch.gt(inp, scalar) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_isinf(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - inp = torch.masked_fill(inp, inp > 1.0, -float("inf")) - - ref_out = torch.isinf(inp) - with flag_gems.use_gems(): - res_out = torch.isinf(inp) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_isnan(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - inp = torch.masked_fill(inp, inp > 1.0, float("nan")) - - ref_out = torch.isnan(inp) - with flag_gems.use_gems(): - res_out = torch.isnan(inp) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_layernorm(shape, dtype): - M = shape[0] - N = shape[1] - layer_shape = [ - N, - ] - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - weight = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=True) - bias = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=True) - eps = 1e-5 - - ref_inp = inp.to(torch.float64) - ref_weight = weight.to(torch.float64) - ref_bias = bias.to(torch.float64) - - ref_out = torch.layer_norm( - ref_inp, - list(layer_shape), - weight=ref_weight, - bias=ref_bias, - eps=eps, - ) - (res_out, res_mean, res_rstd) = flag_gems.layer_norm( - inp, list(layer_shape), weight=weight, bias=bias, eps=eps - ) - - ref_mean = torch.mean(ref_inp, dim=1) - ref_var = torch.var(ref_inp, dim=1, correction=0) - ref_rstd = torch.rsqrt(ref_var + eps) - allclose_with_dtype(res_mean, ref_mean, dtype) - allclose_with_dtype(res_rstd, ref_rstd, dtype) - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(inp) - (ref_in_grad, ref_weight_grad, ref_bias_grad) = torch.autograd.grad( - ref_out, (ref_inp, ref_weight, ref_bias), out_grad.to(torch.float64) - ) - (res_in_grad, res_weight_grad, res_bias_grad) = torch.autograd.grad( - res_out, (inp, weight, bias), out_grad - ) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype, reduce_dim=N) - allclose_with_dtype(res_weight_grad, ref_weight_grad, dtype, reduce_dim=M) - allclose_with_dtype(res_bias_grad, ref_bias_grad, dtype, reduce_dim=M) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_le(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.le(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.le(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0.5, 1.0, 100.9, -111.9], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_le_scalar(shape, scalar, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.le(inp, scalar) - with flag_gems.use_gems(): - res_out = torch.le(inp, scalar) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_lt(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.lt(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.lt(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [0.5, 1.0, 100.9, -111.9], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_lt_scalar(shape, scalar, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.lt(inp, scalar) - with flag_gems.use_gems(): - res_out = torch.lt(inp, scalar) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_skip_layernorm(shape, dtype): - M = shape[0] - N = shape[1] - layer_shape = [ - N, - ] - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=False) - residual = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=False) - weight = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=False) - bias = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=False) - eps = 1e-5 - - ref_inp = inp.to(torch.float64) - ref_residual = residual.to(torch.float64) - ref_weight = weight.to(torch.float64) - ref_bias = bias.to(torch.float64) - - ref_out = torch.layer_norm( - ref_inp + ref_residual, - list(layer_shape), - weight=ref_weight, - bias=ref_bias, - eps=eps, - ) - res_out = flag_gems.skip_layer_norm( - inp, residual, list(layer_shape), weight=weight, bias=bias, eps=eps - ) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_skip_rmsnorm(shape, dtype): - M = shape[0] - N = shape[1] - layer_shape = [ - N, - ] - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=False) - residual = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=False) - weight = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=False) - eps = 1e-5 - - ref_inp = inp.to(torch.float64) - ref_residual = residual.to(torch.float64) - ref_weight = weight.to(torch.float64) - - def _torch_rms_norm(x, residual, weight, eps): - x = x + residual - variance = x.pow(2).mean(-1, keepdim=True) - hidden_states = x * torch.rsqrt(variance + eps) - return weight * hidden_states - - ref_out = _torch_rms_norm( - ref_inp, - ref_residual, - weight=ref_weight, - eps=eps, - ) - - res_out = flag_gems.skip_rms_norm( - inp, residual, list(layer_shape), weight=weight, eps=eps - ) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_rmsnorm(shape, dtype): - M = shape[0] - N = shape[1] - layer_shape = [ - N, - ] - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=False) - weight = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=False) - eps = 1e-5 - - ref_inp = inp.to(torch.float64) - ref_weight = weight.to(torch.float64) - - def _torch_rms_norm(x, weight, eps): - variance = x.pow(2).mean(-1, keepdim=True) - hidden_states = x * torch.rsqrt(variance + eps) - return weight * hidden_states - - ref_out = _torch_rms_norm( - ref_inp, - weight=ref_weight, - eps=eps, - ) - - res_out = flag_gems.rms_norm(inp, list(layer_shape), weight=weight, eps=eps) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_mean(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - ref_out = torch.mean(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.mean(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dim", [-1, 0, 1, None, [1, 0]]) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_meandim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - ref_out = torch.mean(inp.to(torch.float64), dim, keepdim) - with flag_gems.use_gems(): - res_out = torch.mean(inp, dim, keepdim) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [ - (256, 256, 256), - (1024, 1024, 1024), - (1024, 128, 2048), - (1024, 64, 1280), - (640, 256, 512), - ], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_mm(shape, dtype): - M, N, K = shape - tensor_a = torch.randn((M, K), dtype=dtype, device="cuda") - tensor_b = torch.randn((K, N), dtype=dtype, device="cuda") - - ref_out = torch.mm(tensor_a.to(torch.float64), tensor_b.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.mm(tensor_a, tensor_b) - - allclose_with_dtype(res_out, ref_out, dtype, reduce_dim=K) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_mul(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.mul(inp1.to(torch.float64), inp2.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.mul(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [ - (256, 256), - (1024, 1024), - (1024, 128), - (1024, 64), - (640, 256), - ], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_mv(shape, dtype): - N, M = shape - matrix = torch.randn((N, M), dtype=dtype, device="cuda") - vector = torch.randn((M,), dtype=dtype, device="cuda") - - ref_out = torch.mv(matrix.to(torch.float64), vector.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.mv(matrix, vector) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape_a", - [(16, 1024, 256)], -) -@pytest.mark.parametrize( - "shape_b", - [(1, 256), (1, 1, 256), (16, 1, 256), (1, 1024, 256), (1024, 256)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_mul_broadcast(shape_a, shape_b, dtype): - inp1 = torch.randn(shape_a, dtype=dtype, device="cuda") - inp2 = torch.randn(shape_b, dtype=dtype, device="cuda") - - ref_out = torch.mul(inp1.to(torch.float64), inp2.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.mul(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [111.111, -999.999], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_mul_tensor_scalar(shape, scalar, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = scalar - - ref_out = torch.mul(inp1.to(torch.float64), inp2) - with flag_gems.use_gems(): - res_out = torch.mul(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [111.111, -999.999], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_mul_scalar_tensor(shape, scalar, dtype): - inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.mul(inp1, inp2.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.mul(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_ne(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - - ref_out = torch.ne(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.ne(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_ne_scalar(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - inp2 = 0 - - ref_out = torch.ne(inp1, inp2) - with flag_gems.use_gems(): - res_out = torch.ne(inp1, inp2) - - assert torch.equal(res_out, ref_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_neg(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.neg(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.neg(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "inp", - [0.9, 1.0, 100.9, -111.9], -) -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_accuracy_pow_scalar_tensor(inp, shape, dtype): - exponent = torch.randint(-5, 5, shape, dtype=dtype, device="cuda") - ref_out = torch.pow(inp, exponent.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.pow(inp, exponent) - - allclose_with_dtype(res_out, ref_out, dtype, equal_nan=True) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "exponent", - [0.5, 1.5, 5.0, -1.0], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_pow_tensor_scalar(shape, exponent, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.pow(inp.to(torch.float64), exponent) - with flag_gems.use_gems(): - res_out = torch.pow(inp, exponent) - - allclose_with_dtype(res_out, ref_out, dtype, equal_nan=True) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_pow_tensor_tensor(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - exponent = torch.randint(-10, 10, shape, dtype=dtype, device="cuda") - - ref_out = torch.pow(inp.to(torch.float64), exponent.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.pow(inp, exponent) - - allclose_with_dtype(res_out, ref_out, dtype, equal_nan=True) - - -@pytest.mark.parametrize( - "shape_a", - [(16, 1024, 256)], -) -@pytest.mark.parametrize( - "shape_b", - [(1, 256), (1, 1, 256), (16, 1, 256), (1, 1024, 256), (1024, 256)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_pow_tensor_tensor_broadcast(shape_a, shape_b, dtype): - inp = torch.randn(shape_a, dtype=dtype, device="cuda") - exponent = torch.randint(-10, 10, shape_b, dtype=dtype, device="cuda") - - ref_out = torch.pow(inp.to(torch.float64), exponent.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.pow(inp, exponent) - - allclose_with_dtype(res_out, ref_out, dtype, equal_nan=True) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_reciprocal(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.reciprocal(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.reciprocal(inp) - - allclose_with_dtype(res_out, ref_out, dtype, equal_nan=True) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_relu(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - - ref_inp = inp.to(torch.float64) - ref_out = torch.nn.functional.relu(ref_inp) - with flag_gems.use_gems(): - res_out = torch.relu(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(inp) - (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_rsqrt(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.rsqrt(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.rsqrt(inp) - - allclose_with_dtype(res_out, ref_out, dtype, equal_nan=True) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_rsub(shape, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.rsub(inp1.to(torch.float64), inp2.to(torch.float64), alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.rsub(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sigmoid(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - - ref_inp = inp.to(torch.float64) - ref_out = torch.sigmoid(ref_inp) - with flag_gems.use_gems(): - res_out = torch.sigmoid(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(inp) - (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_silu(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - - ref_inp = inp.to(torch.float64) - ref_out = torch.nn.functional.silu(ref_inp) - with flag_gems.use_gems(): - res_out = torch.nn.functional.silu(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(inp) - (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sin(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_inp = inp.to(torch.float64) - ref_out = torch.sin(ref_inp) - with flag_gems.use_gems(): - res_out = torch.sin(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sub(shape, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.sub(inp1.to(torch.float64), inp2.to(torch.float64), alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.sub(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape_a", - [(16, 1024, 256)], -) -@pytest.mark.parametrize( - "shape_b", - [(1, 256), (1, 1, 256), (16, 1, 256), (1, 1024, 256), (1024, 256)], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sub_broadcast(shape_a, shape_b, alpha, dtype): - inp1 = torch.randn(shape_a, dtype=dtype, device="cuda") - inp2 = torch.randn(shape_b, dtype=dtype, device="cuda") - - ref_out = torch.sub(inp1.to(torch.float64), inp2.to(torch.float64), alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.sub(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [111.111, -999.999], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sub_tensor_scalar(shape, scalar, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = scalar - - ref_out = torch.sub(inp1.to(torch.float64), inp2, alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.sub(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize( - "scalar", - [111.111, -999.999], -) -@pytest.mark.parametrize("alpha", [0, 1, 4, -9]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sub_scalar_tensor(shape, scalar, alpha, dtype): - inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.sub(inp1, inp2.to(torch.float64), alpha=alpha) - with flag_gems.use_gems(): - res_out = torch.sub(inp1, inp2, alpha=alpha) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_softmax(shape, dtype): - dim = 1 - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - - ref_inp = inp.to(torch.float64) - ref_out = torch.nn.functional.softmax(ref_inp, dim=dim) - with flag_gems.use_gems(): - res_out = torch.nn.functional.softmax(inp, dim=dim) - - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(inp) - (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_tanh(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - - ref_inp = inp.to(torch.float64) - ref_out = torch.tanh(ref_inp) - with flag_gems.use_gems(): - res_out = torch.tanh(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(inp) - (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("diagonal", [-3, -1, 0, 1, 3]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_triu(shape, diagonal, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - ref_out = torch.triu(inp.to(torch.float64), diagonal) - with flag_gems.use_gems(): - res_out = torch.triu(inp, diagonal) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_max(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.max(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.max(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1]) -def test_accuracy_max_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.max(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.max(inp, dim=dim, keepdim=keepdim) - ref_out_value, ref_out_index = ref_out - res_out_value, res_out_index = res_out - assert torch.equal( - ref_out_index, res_out_index - ), f"ref_out_index: {ref_out_index}, res_out_index: {res_out_index}" - allclose_with_dtype(ref_out_value, res_out_value, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_min(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.min(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.min(inp) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("dim", [-1, 0, 1, None, [1, 0]]) -@pytest.mark.parametrize("correction", [0, 1]) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_varmean(shape, dim, correction, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - ref_var, ref_mean = torch.var_mean(inp, dim, correction=correction, keepdim=keepdim) - with flag_gems.use_gems(): - res_var, res_mean = torch.var_mean( - inp, dim, correction=correction, keepdim=keepdim - ) - - allclose_with_dtype(res_mean, ref_mean, dtype) - allclose_with_dtype(res_var, ref_var, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1]) -def test_accuracy_min_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.min(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.min(inp, dim=dim, keepdim=keepdim) - ref_out_value, ref_out_index = ref_out - res_out_value, res_out_index = res_out - assert torch.equal( - ref_out_index, res_out_index - ), f"ref_out_index: {ref_out_index}, res_out_index: {res_out_index}" - - allclose_with_dtype(ref_out_value, res_out_value, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sum(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.sum(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.sum(inp) - - allclose_with_dtype(res_out, ref_out, dtype, reduce_dim=inp.numel()) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [[0, 1], 0, 1]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_sum_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.sum(inp.to(torch.float64), dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.sum(inp, dim=dim, keepdim=keepdim) - if isinstance(dim, int): - dim = [dim] - dim = [d % inp.ndim for d in dim] - _dim = 1 - for d in dim: - _dim *= shape[d] - allclose_with_dtype(res_out, ref_out, dtype, reduce_dim=_dim) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [[0, 1], [1, 0], 0, 1, None]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_amax(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.amax(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.amax(inp, dim=dim, keepdim=keepdim) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1, None]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_argmax(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - ref_out = torch.argmax(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.argmax(inp, dim=dim, keepdim=keepdim) - assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -def test_accuracy_prod(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - ref_out = torch.prod(inp.to(torch.float64)) - with flag_gems.use_gems(): - res_out = torch.prod(inp) - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(4096, i * 64) for i in range(1, 20)], -) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_prod_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - - ref_out = torch.prod(inp.to(torch.float64), dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.prod(inp, dim=dim, keepdim=keepdim) - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)], -) -@pytest.mark.parametrize("ord", [2, float("inf"), -float("inf"), 0, 1]) -@pytest.mark.parametrize("dim", [-1, 0, 1, None, [1, 0]]) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - ref_out = torch.linalg.vector_norm(inp.to(torch.float64), ord, dim, keepdim) - with flag_gems.use_gems(): - res_out = torch.linalg.vector_norm(inp, ord, dim, keepdim) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (20, 320, 15)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_log_softmax(shape, dtype): - dim = 1 - # torch.manual_seed(0) - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - ref_inp = inp.to(torch.float64) - ref_out = torch.nn.functional.log_softmax(ref_inp, dim=dim) - with flag_gems.use_gems(): - res_out = torch.nn.functional.log_softmax(inp, dim=dim) - allclose_with_dtype(res_out, ref_out, dtype) - out_grad = torch.randn_like(res_out) - (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024), (16, 128), (20, 320)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_outer(shape, dtype): - inp1_shape, inp2_shape = list(shape) - inp1 = torch.randn(inp1_shape, dtype=dtype, device="cuda", requires_grad=True) - inp2 = torch.randn(inp2_shape, dtype=dtype, device="cuda", requires_grad=True) - - inp1_f64 = inp1.to(torch.float64) - inp2_f64 = inp2.to(torch.float64) - ref_out = torch.outer(inp1_f64, inp2_f64) - with flag_gems.use_gems(): - res_out = torch.outer(inp1, inp2) - allclose_with_dtype(res_out, ref_out, dtype) - - out_grad = torch.randn_like(res_out) - ref_in1_grad, ref_in2_grad = torch.autograd.grad( - ref_out, (inp1_f64, inp2_f64), out_grad.to(torch.float64) - ) - res_in1_grad, res_in2_grad = torch.autograd.grad(res_out, (inp1, inp2), out_grad) - allclose_with_dtype(res_in1_grad, ref_in1_grad, dtype) - allclose_with_dtype(res_in2_grad, ref_in2_grad, dtype) - - -def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device="cuda"): - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - t = torch.arange(max_seq_len, device=device, dtype=inv_freq.dtype) - freqs = torch.outer(t, inv_freq) - cos = freqs.cos().to(dtype) - sin = freqs.sin().to(dtype) - return cos, sin - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.cohere.modeling_cohere.rotate_half -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py -def rotate_interleave(x): - """Rotates interleave the hidden dims of the input.""" - x1 = x[..., ::2] - x2 = x[..., 1::2] - return torch.stack((-x2, x1), dim=-1).flatten(-2) - - -def torch_apply_rotary_pos_emb( - q, - k, - cos, - sin, - position_ids, - rotary_interleaved: bool = False, -): - - q = q.float() - k = k.float() - cos = cos[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] - sin = sin[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] - if rotary_interleaved: - cos = torch.repeat_interleave(cos, 2, dim=-1) # [bs, seq_len, 1, dim] - sin = torch.repeat_interleave(sin, 2, dim=-1) # [bs, seq_len, 1, dim] - rotate_fn = rotate_interleave - else: - cos = torch.cat([cos, cos], dim=-1) # [bs, seq_len, 1, dim] - sin = torch.cat([sin, sin], dim=-1) # [bs, seq_len, 1, dim] - rotate_fn = rotate_half - - q_embed = (q * cos) + (rotate_fn(q) * sin) - k_embed = (k * cos) + (rotate_fn(k) * sin) - - return q_embed, k_embed - - -@pytest.mark.parametrize("batch_size", [4, 8]) -@pytest.mark.parametrize("max_seq_len", [512, 2048]) -@pytest.mark.parametrize("q_heads,k_heads", [(8, 1), (6, 2), (1, 1), (8, 8)]) -@pytest.mark.parametrize("head_dim", [64, 96, 128, 256]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("rotary_interleaved", [True, False]) -def test_apply_rotary_pos_emb( - batch_size, - max_seq_len, - q_heads, - k_heads, - head_dim, - dtype, - rotary_interleaved, -): - seq_len = torch.randint(1, max_seq_len, (1,)).item() - q = torch.randn( - (batch_size, seq_len, q_heads, head_dim), dtype=dtype, device="cuda" - ) - k = torch.randn( - (batch_size, seq_len, k_heads, head_dim), dtype=dtype, device="cuda" - ) - - position_ids = torch.randint(0, max_seq_len, (batch_size, seq_len), device="cuda") - cos, sin = get_rope_cos_sin(max_seq_len, head_dim, dtype, device="cuda") - - q_embed_ref, k_embed_ref = torch_apply_rotary_pos_emb( - q=q, - k=k, - cos=cos, - sin=sin, - position_ids=position_ids, - rotary_interleaved=rotary_interleaved, - ) - q_embed_out, k_embed_out = flag_gems.apply_rotary_pos_emb( - q=q, - k=k, - cos=cos, - sin=sin, - position_ids=position_ids, - rotary_interleaved=rotary_interleaved, - ) - - allclose_with_dtype(q_embed_out, q_embed_ref, dtype) - allclose_with_dtype(k_embed_out, k_embed_ref, dtype) - - -@pytest.mark.parametrize( - "shape", - [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) -@pytest.mark.parametrize("kind", ["normal", "allTrue"]) -def test_accuracy_all(shape, dtype, kind): - if (kind == "allTrue"): - inp = torch.ones(shape, dtype=dtype, device="cuda") - else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") - - ref_out = torch.all(inp) - with flag_gems.use_gems(): - res_out = torch.all(inp) - - assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") -@pytest.mark.parametrize( - "shape", - [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], -) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1, None]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) -@pytest.mark.parametrize("kind", ["normal", "allTrue"]) -def test_accuracy_all_dim(shape, dim, keepdim, dtype, kind): - if (kind == "allTrue"): - inp = torch.ones(shape, dtype=dtype, device="cuda") - else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") - - ref_out = torch.all(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.all(inp, dim=dim, keepdim=keepdim) - assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") -@pytest.mark.parametrize( - "shape", - [(1024, 1024, 16), (16, 128, 64, 64), (2, 3, 5)], -) -@pytest.mark.parametrize("dim", [[1, 0], [1, 2]]) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) -@pytest.mark.parametrize("kind", ["normal", "allTrue"]) -def test_accuracy_all_dims(shape, dim, keepdim, dtype, kind): - if (kind == "allTrue"): - inp = torch.ones(shape, dtype=dtype, device="cuda") - else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") - - ref_out = torch.all(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.all(inp, dim=dim, keepdim=keepdim) - assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) -@pytest.mark.parametrize("kind", ["normal", "allFalse"]) -def test_accuracy_any(shape, dtype, kind): - if (kind == "allFalse"): - inp = torch.zeros(shape, dtype=dtype, device="cuda") - else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") - - ref_out = torch.any(inp) - with flag_gems.use_gems(): - res_out = torch.any(inp) - - assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") -@pytest.mark.parametrize( - "shape", - [(i, j) for i in [2, 4096] for j in [2, 64, 1024]], -) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [0, 1, None]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) -@pytest.mark.parametrize("kind", ["normal", "allFalse"]) -def test_accuracy_any_dim(shape, dim, keepdim, dtype, kind): - if (kind == "allFalse"): - inp = torch.zeros(shape, dtype=dtype, device="cuda") - else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") - - ref_out = torch.any(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.any(inp, dim=dim, keepdim=keepdim) - assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.skipif(skip_expr, reason="PyTorch < 2.2.0 does not support") -@pytest.mark.parametrize( - "shape", - [(1024, 1024, 16), (16, 128, 64, 64), (2, 3, 5)], -) -@pytest.mark.parametrize("keepdim", [True, False]) -@pytest.mark.parametrize("dim", [[1, 0], [1, 2]]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16, torch.bool]) -@pytest.mark.parametrize("kind", ["normal", "allFalse"]) -def test_accuracy_any_dims(shape, dim, keepdim, dtype, kind): - if (kind == "allFalse"): - inp = torch.zeros(shape, dtype=dtype, device="cuda") - else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") - - ref_out = torch.any(inp, dim=dim, keepdim=keepdim) - with flag_gems.use_gems(): - res_out = torch.any(inp, dim=dim, keepdim=keepdim) - assert torch.equal(ref_out, res_out), f"ref_out: {ref_out}, res_out: {res_out}" - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 30)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_silu_and_mul(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - inp1, inp2 = inp.chunk(2, dim=-1) - - ref_out = torch.mul( - torch.nn.functional.silu(inp1.to(torch.float64)), - inp2.to(torch.float64), - ) - with flag_gems.use_gems(): - res_out = flag_gems.silu_and_mul(inp1, inp2) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 30)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("approximate", ["none", "tanh"]) -def test_accuracy_gelu_and_mul(shape, approximate, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - inp1, inp2 = inp.chunk(2, dim=-1) - - ref_out = torch.mul( - torch.nn.functional.gelu(inp1.to(torch.float64), approximate=approximate), - inp2.to(torch.float64), - ) - with flag_gems.use_gems(): - res_out = flag_gems.gelu_and_mul(inp1, inp2, approximate) - - allclose_with_dtype(res_out, ref_out, dtype) - - -@pytest.mark.parametrize( - "shape", - [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 30)], -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) -def test_accuracy_cross_entropy_loss(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - dim = 1 - up_limit = shape[dim] - 1 - target_shape = list(shape) - del target_shape[dim] - target = torch.randint(0, up_limit, target_shape, device="cuda") - - ref_inp = inp.to(torch.float64) - - criterion = torch.nn.CrossEntropyLoss() - - ref_out = criterion(ref_inp, target) - with flag_gems.use_gems(): - res_out = criterion(inp, target) - allclose_with_dtype(res_out, ref_out, dtype) - out_grad = torch.randn_like(res_out) - (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, out_grad.to(torch.float64)) - (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) - allclose_with_dtype(res_in_grad, ref_in_grad, dtype) diff --git a/tests/flag_gems/op_perf_test.py b/tests/flag_gems/op_perf_test.py deleted file mode 100644 index 45d60a1a..00000000 --- a/tests/flag_gems/op_perf_test.py +++ /dev/null @@ -1,553 +0,0 @@ -import itertools -import torch -import time -import triton -import random -import op_accu_test -from flag_gems import * - - -def run_bench(op, *args, warmups=10, repetitions=1000, **kwargs): - for i in range(warmups): - ref_out = op(*args, **kwargs) - start = time.time() - for i in range(repetitions): - ref_out = op(*args, **kwargs) - torch.cuda.synchronize() - end = time.time() - ms = (end - start) * 1000 - return ms - - -class Benchmark: - def __init__(self, op_name): - self.op_name = op_name - - def provider_ops(self, gem=None, torch=None): - assert gem is not None - assert torch is not None - self.provider_ops = {"gem": gem, "torch": torch} - - def bench_params(self, **params): - self.bench_params = params - - def arg_names(self, *arg_names): - self.x_names = arg_names - - def arg_vals(self, arg_vals): - self.x_vals = arg_vals - - def extra_args(self, **args): - self.extra_args = args - - def perf(self, fn): - line_names, line_vals = zip(*self.provider_ops.items()) - bench_param_names, bench_param_vals = zip(*self.bench_params.items()) - benchmarks = ( - triton.testing.Benchmark( - x_names=self.x_names, - x_vals=self.x_vals, - line_arg="op", - line_names=list(line_names), - line_vals=list(line_vals), - styles=[("red", "-"), ("green", "-")], - ylabel="ms", - plot_name="test_performance_{}_{}".format( - self.op_name, "_".join(str(e) for e in bench_param_set) - ), - args={ - **self.extra_args, - **dict(zip(bench_param_names, bench_param_set)), - }, - ) - for bench_param_set in itertools.product(*bench_param_vals) - ) - return triton.testing.perf_report(benchmarks)(fn) - - -f16_f32_bf = (torch.float16, torch.float32, torch.bfloat16) -sizes = [i * 64 for i in range(1, 20)] -mnk_sizes = list(zip(sizes, sizes, sizes)) - - -abs_bench = Benchmark("abs") -abs_bench.bench_params(dtype=f16_f32_bf) -abs_bench.provider_ops(gem=abs, torch=torch.abs) -abs_bench.arg_names("N") -abs_bench.arg_vals(sizes) -abs_bench.extra_args(M=1024) - - -@abs_bench.perf -def bench_abs(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -add_bench = Benchmark("add") -add_bench.bench_params(dtype=f16_f32_bf) -add_bench.provider_ops(gem=add, torch=torch.add) -add_bench.arg_names("N") -add_bench.arg_vals(sizes) -add_bench.extra_args(M=1024) - - -@add_bench.perf -def bench_add(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = torch.randn((M, N), dtype=dtype, device="cuda") - alpha = random.random() - ms = run_bench(op, inp1, inp2, alpha=alpha) - return ms - - -@add_bench.perf -def bench_add_scalar(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = random.random() - alpha = random.random() - ms = run_bench(op, inp1, inp2, alpha=alpha) - return ms - - -addmm_bench = Benchmark("addmm") -addmm_bench.bench_params(dtype=f16_f32_bf) -addmm_bench.provider_ops(gem=addmm, torch=torch.addmm) -addmm_bench.arg_names("M", "N", "K") -addmm_bench.arg_vals(mnk_sizes) -addmm_bench.extra_args(alpha=1.0, beta=1.0) - - -@addmm_bench.perf -def bench_addmm(op, M, N, K, alpha, beta, dtype): - mat1 = torch.randn((M, K), dtype=dtype, device="cuda") - mat2 = torch.randn((K, N), dtype=dtype, device="cuda") - bias = torch.randn((N,), dtype=dtype, device="cuda") - ms = run_bench(op, bias, mat1, mat2, alpha=alpha, beta=beta) - return ms - - -bmm_bench = Benchmark("bmm") -bmm_bench.bench_params(dtype=f16_f32_bf) -bmm_bench.provider_ops(gem=bmm, torch=torch.bmm) -bmm_bench.arg_names("M", "N", "K") -bmm_bench.arg_vals(mnk_sizes) -bmm_bench.extra_args(batch=4) - - -@bmm_bench.perf -def bench_bmm(op, batch, M, N, K, dtype): - tensor_A = torch.randn((batch, M, K), dtype=dtype, device="cuda") - tensor_B = torch.randn((batch, K, N), dtype=dtype, device="cuda") - ms = run_bench(op, tensor_A, tensor_B) - return ms - - -cumsum_bench = Benchmark("cumsum") -cumsum_bench.bench_params(dtype=f16_f32_bf) -cumsum_bench.provider_ops(gem=cumsum, torch=torch.cumsum) -cumsum_bench.arg_names("N") -cumsum_bench.arg_vals(sizes) -cumsum_bench.extra_args(M=1024, dim=1) - - -@cumsum_bench.perf -def bench_cumsum(op, M, N, dim, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp, dim=dim) - return ms - - -div_bench = Benchmark("div") -div_bench.bench_params(dtype=f16_f32_bf) -div_bench.provider_ops(gem=div, torch=torch.div) -div_bench.arg_names("N") -div_bench.arg_vals(sizes) -div_bench.extra_args(M=1024) - - -@div_bench.perf -def bench_div(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp1, inp2) - return ms - - -@div_bench.perf -def bench_div_scalar(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = random.randint(0, 1000) - ms = run_bench(op, inp1, inp2) - return ms - - -dropout_bench = Benchmark("dropout") -dropout_bench.bench_params(dtype=f16_f32_bf, p=(0.3, 0.6, 0.9)) -dropout_bench.provider_ops(gem=native_dropout, torch=torch.nn.functional.dropout) -dropout_bench.arg_names("N") -dropout_bench.arg_vals(sizes) -dropout_bench.extra_args(M=1024) - - -@dropout_bench.perf -def bench_dropout(op, M, N, p, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp, p, True) - return ms - - -exp_bench = Benchmark("exp") -exp_bench.bench_params(dtype=f16_f32_bf) -exp_bench.provider_ops(gem=exp, torch=torch.exp) -exp_bench.arg_names("N") -exp_bench.arg_vals(sizes) -exp_bench.extra_args(M=1024) - - -@exp_bench.perf -def bench_exp(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -gelu_bench = Benchmark("gelu") -gelu_bench.bench_params(dtype=f16_f32_bf) -gelu_bench.provider_ops(gem=gelu, torch=torch.nn.functional.gelu) -gelu_bench.arg_names("N") -gelu_bench.arg_vals(sizes) -gelu_bench.extra_args(M=1024) - - -@gelu_bench.perf -def bench_gelu(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -layernorm_bench = Benchmark("layernorm") -layernorm_bench.bench_params(dtype=f16_f32_bf) -layernorm_bench.provider_ops(gem=layer_norm, torch=torch.nn.functional.layer_norm) -layernorm_bench.arg_names("N") -layernorm_bench.arg_vals(sizes) -layernorm_bench.extra_args(M=1024) - - -@layernorm_bench.perf -def bench_layernorm(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - weight = torch.randn(N, dtype=dtype, device="cuda") - bias = torch.randn(N, dtype=dtype, device="cuda") - eps = 1e-5 - ms = run_bench( - op, - inp, - normalized_shape=[ - N, - ], - weight=weight, - bias=bias, - eps=eps, - ) - return ms - - -mean_bench = Benchmark("mean") -mean_bench.bench_params(dtype=f16_f32_bf) -mean_bench.provider_ops(gem=mean, torch=torch.mean) -mean_bench.arg_names("N") -mean_bench.arg_vals(sizes) -mean_bench.extra_args(M=1024) - - -@mean_bench.perf -def bench_mean(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -mm_bench = Benchmark("mm") -mm_bench.bench_params(dtype=f16_f32_bf) -mm_bench.provider_ops(gem=mm, torch=torch.mm) -mm_bench.arg_names("M", "N", "K") -mm_bench.arg_vals(mnk_sizes) -mm_bench.extra_args() - - -@mm_bench.perf -def bench_mm(op, M, N, K, dtype): - tensor_a = torch.randn((M, K), dtype=dtype, device="cuda") - tensor_b = torch.randn((K, N), dtype=dtype, device="cuda") - ms = run_bench(op, tensor_a, tensor_b) - return ms - - -mul_bench = Benchmark("mul") -mul_bench.bench_params(dtype=f16_f32_bf) -mul_bench.provider_ops(gem=mul, torch=torch.mul) -mul_bench.arg_names("N") -mul_bench.arg_vals(sizes) -mul_bench.extra_args(M=1024) - - -@mul_bench.perf -def bench_mul(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp1, inp2) - return ms - - -@mul_bench.perf -def bench_mul_scalar(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = random.randint(0, 10000) - ms = run_bench(op, inp1, inp2) - return ms - - -reciprocal_bench = Benchmark("reciprocal") -reciprocal_bench.bench_params(dtype=f16_f32_bf) -reciprocal_bench.provider_ops(gem=reciprocal, torch=torch.reciprocal) -reciprocal_bench.arg_names("N") -reciprocal_bench.arg_vals(sizes) -reciprocal_bench.extra_args(M=1024) - - -pow_bench = Benchmark("pow") -pow_bench.bench_params(dtype=f16_f32_bf) -pow_bench.provider_ops(gem=pow, torch=torch.pow) -pow_bench.arg_names("N") -pow_bench.arg_vals(sizes) -pow_bench.extra_args(M=1024) - - -@pow_bench.perf -def bench_pow_tensor_scalar(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - exponent = random.randint(-10, 10) - ms = run_bench(op, inp, exponent) - return ms - - -@pow_bench.perf -def bench_pow_tensor_tensor(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - exponent = torch.randint(-10, 10, (M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp, exponent) - return ms - - -@reciprocal_bench.perf -def bench_reciprocal(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -relu_bench = Benchmark("relu") -relu_bench.bench_params(dtype=f16_f32_bf) -relu_bench.provider_ops(gem=relu, torch=torch.relu) -relu_bench.arg_names("N") -relu_bench.arg_vals(sizes) -relu_bench.extra_args(M=1024) - - -@relu_bench.perf -def bench_relu(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -rsqrt_bench = Benchmark("rsqrt") -rsqrt_bench.bench_params(dtype=f16_f32_bf) -rsqrt_bench.provider_ops(gem=rsqrt, torch=torch.rsqrt) -rsqrt_bench.arg_names("N") -rsqrt_bench.arg_vals(sizes) -rsqrt_bench.extra_args(M=1024) - - -@rsqrt_bench.perf -def bench_rsqrt(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -silu_bench = Benchmark("silu") -silu_bench.bench_params(dtype=f16_f32_bf) -silu_bench.provider_ops(gem=silu, torch=torch.nn.functional.silu) -silu_bench.arg_names("N") -silu_bench.arg_vals(sizes) -silu_bench.extra_args(M=1024) - - -@silu_bench.perf -def bench_silu(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -sigmoid_bench = Benchmark("sigmoid") -sigmoid_bench.bench_params(dtype=f16_f32_bf) -sigmoid_bench.provider_ops(gem=sigmoid, torch=torch.sigmoid) -sigmoid_bench.arg_names("N") -sigmoid_bench.arg_vals(sizes) -sigmoid_bench.extra_args(M=1024) - - -@silu_bench.perf -def bench_sigmoid(op, M, N, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp) - return ms - - -softmax_bench = Benchmark("softmax") -softmax_bench.bench_params(dtype=f16_f32_bf) -softmax_bench.provider_ops(gem=softmax, torch=torch.nn.functional.softmax) -softmax_bench.arg_names("N") -softmax_bench.arg_vals(sizes) -softmax_bench.extra_args(M=1024, dim=1) - - -@softmax_bench.perf -def bench_softmax(op, M, N, dim, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp, dim=dim) - return ms - - -sub_bench = Benchmark("sub") -sub_bench.bench_params(dtype=f16_f32_bf) -sub_bench.provider_ops(gem=sub, torch=torch.sub) -sub_bench.arg_names("N") -sub_bench.arg_vals(sizes) -sub_bench.extra_args(M=1024) - - -@sub_bench.perf -def bench_sub(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = torch.randn((M, N), dtype=dtype, device="cuda") - alpha = random.randint(0, 10000) - ms = run_bench(op, inp1, inp2, alpha=alpha) - return ms - - -@sub_bench.perf -def bench_sub_scalar(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = random.randint(0, 10000) - alpha = random.randint(0, 10000) - ms = run_bench(op, inp1, inp2, alpha=alpha) - return ms - - -triu_bench = Benchmark("triu") -triu_bench.bench_params(dtype=f16_f32_bf) -triu_bench.provider_ops(gem=triu, torch=torch.triu) -triu_bench.arg_names("N") -triu_bench.arg_vals(sizes) -triu_bench.extra_args(M=1024, diagonal=1) - - -@triu_bench.perf -def bench_triu(op, M, N, diagonal, dtype): - inp = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp, diagonal=diagonal) - return ms - - -rope_bench = Benchmark("rope") -rope_bench.bench_params(dtype=f16_f32_bf) -rope_bench.provider_ops( - gem=apply_rotary_pos_emb, torch=op_accu_test.torch_apply_rotary_pos_emb -) -rope_bench.arg_names("M") -rope_bench.arg_vals(sizes) -rope_bench.extra_args(num_heads=16, head_dim=128, max_seq_len=2048) - - -@rope_bench.perf -def bench_rope(op, M, num_heads, head_dim, max_seq_len, dtype): - q = torch.randn((M, num_heads, head_dim), dtype=dtype, device="cuda") - k = torch.randn((M, num_heads, head_dim), dtype=dtype, device="cuda") - position_ids = torch.randint(1, max_seq_len, (M,), device="cuda") - cos = torch.randn((max_seq_len, head_dim // 2), dtype=dtype, device="cuda") - sin = torch.randn((max_seq_len, head_dim // 2), dtype=dtype, device="cuda") - ms = run_bench(op, q, k, cos, sin, position_ids) - - -silu_and_mul_bench = Benchmark("silu_and_mul") -silu_and_mul_bench.bench_params(dtype=f16_f32_bf) -silu_and_mul_bench.provider_ops( - gem=silu_and_mul, torch=lambda a, b: torch.nn.functional.silu(a) * b -) -silu_and_mul_bench.arg_names("N") -silu_and_mul_bench.arg_vals(sizes) -silu_and_mul_bench.extra_args(M=1024) - - -@silu_and_mul_bench.perf -def bench_silu_and_mul(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp1, inp2) - return ms - - -gelu_and_mul_bench = Benchmark("gelu_and_mul") -gelu_and_mul_bench.bench_params(dtype=f16_f32_bf) -gelu_and_mul_bench.provider_ops( - gem=gelu_and_mul, torch=lambda a, b: torch.nn.functional.gelu(a) * b -) -gelu_and_mul_bench.arg_names("N") -gelu_and_mul_bench.arg_vals(sizes) -gelu_and_mul_bench.extra_args(M=1024) - - -@gelu_and_mul_bench.perf -def bench_gelu_and_mul(op, M, N, dtype): - inp1 = torch.randn((M, N), dtype=dtype, device="cuda") - inp2 = torch.randn((M, N), dtype=dtype, device="cuda") - ms = run_bench(op, inp1, inp2) - return ms - - -bench_abs.run(print_data=True) -bench_add.run(print_data=True) -bench_add_scalar.run(print_data=True) -bench_addmm.run(print_data=True) -bench_bmm.run(print_data=True) -bench_cumsum.run(print_data=True) -bench_exp.run(print_data=True) -bench_dropout.run(print_data=True) -bench_div.run(print_data=True) -bench_div_scalar.run(print_data=True) -bench_gelu.run(print_data=True) -bench_layernorm.run(print_data=True) -bench_mean.run(print_data=True) -bench_mm.run(print_data=True) -bench_mul.run(print_data=True) -bench_mul_scalar.run(print_data=True) -bench_pow_tensor_scalar.run(print_data=True) -bench_pow_tensor_tensor.run(print_data=True) -bench_reciprocal.run(print_data=True) -bench_relu.run(print_data=True) -bench_rsqrt.run(print_data=True) -bench_silu.run(print_data=True) -bench_sigmoid.run(print_data=True) -bench_softmax.run(print_data=True) -bench_sub.run(print_data=True) -bench_sub_scalar.run(print_data=True) -bench_triu.run(print_data=True) -bench_rope.run(print_data=True) -bench_silu_and_mul.run(print_data=True) -bench_gelu_and_mul.run(print_data=True) diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py new file mode 100644 index 00000000..9c3c099c --- /dev/null +++ b/tests/test_binary_pointwise_ops.py @@ -0,0 +1,569 @@ +import torch +import pytest +import flag_gems +from .accuracy_utils import * + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_add(shape, alpha, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.add(ref_inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp1, inp2, alpha=alpha) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_add_tensor_scalar(shape, scalar, alpha, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = scalar + ref_inp1 = to_reference(inp1, True) + + ref_out = torch.add(ref_inp1, inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp1, inp2, alpha=alpha) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_add_scalar_tensor(shape, scalar, alpha, dtype): + inp1 = scalar + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.add(inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.add(inp1, inp2, alpha=alpha) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_accuracy_bitwiseand(shape, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.bitwise_and(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.bitwise_and(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_accuracy_bitwiseand_scalar(shape, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = 0x00FF + ref_inp1 = to_reference(inp1) + + ref_out = torch.bitwise_and(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.bitwise_and(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_accuracy_bitwiseand_scalar_tensor(shape, dtype): + inp1 = 0x00FF + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + ref_inp2 = to_reference(inp2) + + ref_out = torch.bitwise_and(inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.bitwise_and(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_accuracy_bitwiseor(shape, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.bitwise_or(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.bitwise_or(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_accuracy_bitwiseor_scalar(shape, dtype): + inp1 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + inp2 = 0x00FF + ref_inp1 = to_reference(inp1) + + ref_out = torch.bitwise_or(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.bitwise_or(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_accuracy_bitwiseor_scalar_tensor(shape, dtype): + inp1 = 0x00FF + inp2 = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + ref_inp2 = to_reference(inp2) + + ref_out = torch.bitwise_or(inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.bitwise_or(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("maxi", SCALARS) +@pytest.mark.parametrize("mini", SCALARS) +@pytest.mark.parametrize("isnone", [None, "max", "min"]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_clamp(shape, maxi, mini, isnone, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + if isnone == "min": + mini = None + elif isnone == "max": + maxi = None + ref_inp = to_reference(inp) + + ref_out = torch.clamp(ref_inp, min=mini, max=maxi) + with flag_gems.use_gems(): + res_out = torch.clamp(inp, min=mini, max=maxi) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("isnone", [None, "max", "min"]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_clamp_tensor(shape, isnone, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + maxi = torch.randn(shape, dtype=dtype, device="cuda") + mini = torch.randn(shape, dtype=dtype, device="cuda") + if isnone == "min": + mini = None + elif isnone == "max": + maxi = None + ref_inp = to_reference(inp) + ref_maxi = to_reference(maxi) + ref_mini = to_reference(mini) + + ref_out = torch.clamp(ref_inp, min=ref_mini, max=ref_maxi) + with flag_gems.use_gems(): + res_out = torch.clamp(inp, min=mini, max=maxi) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_div(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.div(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.div(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_div_tensor_scalar(shape, scalar, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = scalar + ref_inp1 = to_reference(inp1, True) + + ref_out = torch.div(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.div(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_div_scalar_tensor(shape, scalar, dtype): + inp1 = scalar + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.div(inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.div(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_eq(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.eq(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.eq(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_eq_scalar(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = 0 + ref_inp1 = to_reference(inp1) + + ref_out = torch.eq(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.eq(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_ge(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.ge(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.ge(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_ge_scalar(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = 0 + ref_inp1 = to_reference(inp1) + + ref_out = torch.ge(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.ge(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_gt(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.gt(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.gt(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_gt_scalar(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + inp2 = 0 + + ref_out = torch.gt(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.gt(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_le(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.le(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.le(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_le_scalar(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = 0 + ref_inp1 = to_reference(inp1) + + ref_out = torch.le(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.le(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_lt(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.lt(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.lt(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_lt_scalar(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = 0 + ref_inp1 = to_reference(inp1) + + ref_out = torch.lt(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.lt(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_mul(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.mul(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.mul(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_mul_tensor_scalar(shape, scalar, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = scalar + ref_inp1 = to_reference(inp1, True) + + ref_out = torch.mul(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.mul(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_mul_scalar_tensor(shape, scalar, dtype): + inp1 = scalar + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.mul(inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.mul(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_ne(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1) + ref_inp2 = to_reference(inp2) + + ref_out = torch.ne(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.ne(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_ne_scalar(shape, dtype): + inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp2 = 0 + ref_inp1 = to_reference(inp1) + + ref_out = torch.ne(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.ne(inp1, inp2) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_pow(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.pow(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.pow(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_pow_scalar_tensor(scalar, shape, dtype): + inp1 = scalar + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.pow(inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.pow(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_pow_tensor_scalar(scalar, shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = scalar + ref_inp1 = to_reference(inp1, True) + + ref_out = torch.pow(ref_inp1, inp2) + with flag_gems.use_gems(): + res_out = torch.pow(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_rsub(shape, alpha, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.rsub(ref_inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.rsub(inp1, inp2, alpha=alpha) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_sub(shape, alpha, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.sub(ref_inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.sub(inp1, inp2, alpha=alpha) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_sub_tensor_scalar(shape, scalar, alpha, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = scalar + ref_inp1 = to_reference(inp1, True) + + ref_out = torch.sub(ref_inp1, inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.sub(inp1, inp2, alpha=alpha) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("scalar", SCALARS) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_sub_scalar_tensor(shape, scalar, alpha, dtype): + inp1 = scalar + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.sub(inp1, ref_inp2, alpha=alpha) + with flag_gems.use_gems(): + res_out = torch.sub(inp1, inp2, alpha=alpha) + + gems_assert_close(res_out, ref_out, dtype) diff --git a/tests/test_blas_ops.py b/tests/test_blas_ops.py new file mode 100644 index 00000000..a791adc4 --- /dev/null +++ b/tests/test_blas_ops.py @@ -0,0 +1,101 @@ +import torch +import pytest +import flag_gems +from .accuracy_utils import * + + +@pytest.mark.parametrize("M", MNK_SHAPES) +@pytest.mark.parametrize("N", MNK_SHAPES) +@pytest.mark.parametrize("K", MNK_SHAPES) +@pytest.mark.parametrize("alpha", SCALARS) +@pytest.mark.parametrize("beta", SCALARS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_addmm(M, N, K, alpha, beta, dtype): + mat1 = torch.randn((M, K), dtype=dtype, device="cuda") + mat2 = torch.randn((K, N), dtype=dtype, device="cuda") + bias = torch.randn((N,), dtype=dtype, device="cuda") + ref_mat1 = to_reference(mat1, True) + ref_mat2 = to_reference(mat2, True) + ref_bias = to_reference(bias, True) + + ref_out = torch.addmm(ref_bias, ref_mat1, ref_mat2, alpha=alpha, beta=beta) + with flag_gems.use_gems(): + res_out = torch.addmm(bias, mat1, mat2, alpha=alpha, beta=beta) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=K) + + +@pytest.mark.parametrize("M", MNK_SHAPES) +@pytest.mark.parametrize("N", MNK_SHAPES) +@pytest.mark.parametrize("K", MNK_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_bmm(M, N, K, dtype): + batch = 4 + mat1 = torch.randn((batch, M, K), dtype=dtype, device="cuda") + mat2 = torch.randn((batch, K, N), dtype=dtype, device="cuda") + ref_mat1 = to_reference(mat1, True) + ref_mat2 = to_reference(mat2, True) + + ref_out = torch.bmm(ref_mat1, ref_mat2) + with flag_gems.use_gems(): + res_out = torch.bmm(mat1, mat2) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=K) + + +@pytest.mark.parametrize("M", MNK_SHAPES) +@pytest.mark.parametrize("N", MNK_SHAPES) +@pytest.mark.parametrize("K", MNK_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_mm(M, N, K, dtype): + mat1 = torch.randn((M, K), dtype=dtype, device="cuda") + mat2 = torch.randn((K, N), dtype=dtype, device="cuda") + ref_mat1 = to_reference(mat1, True) + ref_mat2 = to_reference(mat2, True) + + ref_out = torch.mm(ref_mat1, ref_mat2) + with flag_gems.use_gems(): + res_out = torch.mm(mat1, mat2) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=K) + + +@pytest.mark.parametrize("M", MNK_SHAPES) +@pytest.mark.parametrize("N", MNK_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_mv(M, N, dtype): + matrix = torch.randn((N, M), dtype=dtype, device="cuda") + vector = torch.randn((M,), dtype=dtype, device="cuda") + ref_matrix = to_reference(matrix, True) + ref_vector = to_reference(vector, True) + + ref_out = torch.mv(ref_matrix, ref_vector) + with flag_gems.use_gems(): + res_out = torch.mv(matrix, vector) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("M", MNK_SHAPES) +@pytest.mark.parametrize("N", MNK_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_outer(M, N, dtype): + inp1 = torch.randn(M, dtype=dtype, device="cuda", requires_grad=True) + inp2 = torch.randn(N, dtype=dtype, device="cuda", requires_grad=True) + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.outer(ref_inp1, ref_inp2) + with flag_gems.use_gems(): + res_out = torch.outer(inp1, inp2) + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(res_out) + ref_grad = to_reference(out_grad, True) + + ref_in1_grad, ref_in2_grad = torch.autograd.grad( + ref_out, (ref_inp1, ref_inp2), ref_grad + ) + res_in1_grad, res_in2_grad = torch.autograd.grad(res_out, (inp1, inp2), out_grad) + gems_assert_close(res_in1_grad, ref_in1_grad, dtype) + gems_assert_close(res_in2_grad, ref_in2_grad, dtype) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py new file mode 100644 index 00000000..ebdae6ab --- /dev/null +++ b/tests/test_reduction_ops.py @@ -0,0 +1,607 @@ +import torch +import pytest +import flag_gems +from .accuracy_utils import * + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) +@pytest.mark.parametrize("kind", ["normal", "allTrue"]) +def test_accuracy_all(shape, dtype, kind): + if kind == "allTrue": + inp = torch.ones(shape, dtype=dtype, device="cuda") + else: + inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.all(ref_inp) + with flag_gems.use_gems(): + res_out = torch.all(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.skipif(skip_expr, reason=skip_reason) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) +@pytest.mark.parametrize("kind", ["normal", "allTrue"]) +def test_accuracy_all_dim(shape, dim, keepdim, dtype, kind): + if kind == "allTrue": + inp = torch.ones(shape, dtype=dtype, device="cuda") + else: + inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.all(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.all(inp, dim=dim, keepdim=keepdim) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.skipif(skip_expr, reason=skip_reason) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIMS_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) +@pytest.mark.parametrize("kind", ["normal", "allTrue"]) +def test_accuracy_all_dims(shape, dim, keepdim, dtype, kind): + if kind == "allTrue": + inp = torch.ones(shape, dtype=dtype, device="cuda") + else: + inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.all(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.all(inp, dim=dim, keepdim=keepdim) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIMS_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_amax(shape, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.amax(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.amax(inp, dim=dim, keepdim=keepdim) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) +@pytest.mark.parametrize("kind", ["normal", "allFalse"]) +def test_accuracy_any(shape, dtype, kind): + if kind == "allFalse": + inp = torch.zeros(shape, dtype=dtype, device="cuda") + else: + inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.any(ref_inp) + with flag_gems.use_gems(): + res_out = torch.any(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.skipif(skip_expr, reason=skip_reason) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) +@pytest.mark.parametrize("kind", ["normal", "allFalse"]) +def test_accuracy_any_dim(shape, dim, keepdim, dtype, kind): + if kind == "allFalse": + inp = torch.zeros(shape, dtype=dtype, device="cuda") + else: + inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.any(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.any(inp, dim=dim, keepdim=keepdim) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.skipif(skip_expr, reason=skip_reason) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIMS_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) +@pytest.mark.parametrize("kind", ["normal", "allFalse"]) +def test_accuracy_any_dims(shape, dim, keepdim, dtype, kind): + if kind == "allFalse": + inp = torch.zeros(shape, dtype=dtype, device="cuda") + else: + inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.any(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.any(inp, dim=dim, keepdim=keepdim) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_argmax(shape, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.argmax(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.argmax(inp, dim=dim, keepdim=keepdim) + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_cross_entropy_loss(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + dim = 1 + up_limit = shape[dim] - 1 + target_shape = list(shape) + del target_shape[dim] + target = torch.randint(0, up_limit, target_shape, device="cuda") + + ref_inp = to_reference(inp, True) + ref_target = to_reference(target) + + criterion = torch.nn.CrossEntropyLoss() + + ref_out = criterion(ref_inp, ref_target) + with flag_gems.use_gems(): + res_out = criterion(inp, target) + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(res_out) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_cumsum(shape, dtype): + dim = 1 + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.cumsum(ref_inp, dim=dim) + with flag_gems.use_gems(): + res_out = torch.cumsum(inp, dim=dim) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=shape[dim]) + + +@pytest.mark.parametrize( + "N, C, H, W, num_groups", + [ + (16, 3, 16, 16, 1), + (32, 32, 32, 32, 8), + (1, 32, 32, 32, 8), + (1, 32, 32, 32, 16), + (1, 64, 32, 32, 16), + (1, 64, 32, 32, 32), + (1, 64, 32, 32, 64), + ], +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype): + HW = H * W + inp = torch.randn(size=(N, C, H, W), dtype=dtype, device="cuda", requires_grad=True) + weight = torch.randn(size=(C,), dtype=dtype, device="cuda", requires_grad=True) + bias = torch.randn(size=(C,), dtype=dtype, device="cuda", requires_grad=True) + eps = 1e-5 + + ref_inp = to_reference(inp, True) + ref_weight = to_reference(weight, True) + ref_bias = to_reference(bias, True) + + ref_out = torch.nn.functional.group_norm( + ref_inp, num_groups, weight=ref_weight, bias=ref_bias, eps=eps + ) + ref_mean = torch.mean(ref_inp.reshape([N, num_groups, -1]), dim=2) + ref_var = torch.var(ref_inp.reshape([N, num_groups, -1]), dim=2, correction=0) + ref_rstd = torch.rsqrt(ref_var + eps) + + (res_out, res_mean, res_rstd) = flag_gems.group_norm( + inp, weight, bias, N, C, HW, num_groups, eps + ) + + gems_assert_close(res_mean, ref_mean, dtype) + gems_assert_close(res_rstd, ref_rstd, dtype) + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad, ref_weight_grad, ref_bias_grad) = torch.autograd.grad( + ref_out, (ref_inp, ref_weight, ref_bias), ref_grad + ) + (res_in_grad, res_weight_grad, res_bias_grad) = torch.autograd.grad( + res_out, (inp, weight, bias), out_grad + ) + group_size = C // num_groups + gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=group_size * HW) + gems_assert_close(res_weight_grad, ref_weight_grad, dtype, reduce_dim=N * HW) + gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=N * HW) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_layernorm(shape, dtype): + M = shape[0] + N = shape[1] + layer_shape = [ + N, + ] + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + weight = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=True) + bias = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=True) + eps = 1e-5 + + ref_inp = to_reference(inp, True) + ref_weight = to_reference(weight, True) + ref_bias = to_reference(bias, True) + + ref_out = torch.layer_norm( + ref_inp, + list(layer_shape), + weight=ref_weight, + bias=ref_bias, + eps=eps, + ) + (res_out, res_mean, res_rstd) = flag_gems.layer_norm( + inp, list(layer_shape), weight=weight, bias=bias, eps=eps + ) + + ref_mean = torch.mean(ref_inp, dim=1) + ref_var = torch.var(ref_inp, dim=1, correction=0) + ref_rstd = torch.rsqrt(ref_var + eps) + gems_assert_close(res_mean, ref_mean, dtype) + gems_assert_close(res_rstd, ref_rstd, dtype) + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad, ref_weight_grad, ref_bias_grad) = torch.autograd.grad( + ref_out, (ref_inp, ref_weight, ref_bias), ref_grad + ) + (res_in_grad, res_weight_grad, res_bias_grad) = torch.autograd.grad( + res_out, (inp, weight, bias), out_grad + ) + gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=N) + gems_assert_close(res_weight_grad, ref_weight_grad, dtype, reduce_dim=M) + gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=M) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_log_softmax(shape, dtype): + dim = 1 + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.nn.functional.log_softmax(ref_inp, dim=dim) + with flag_gems.use_gems(): + res_out = torch.nn.functional.log_softmax(inp, dim=dim) + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(res_out) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_max(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.max(ref_inp) + with flag_gems.use_gems(): + res_out = torch.max(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_max_dim(shape, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.max(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.max(inp, dim=dim, keepdim=keepdim) + ref_out_value, ref_out_index = ref_out + res_out_value, res_out_index = res_out + gems_assert_equal(res_out_index, ref_out_index) + gems_assert_equal(res_out_value, ref_out_value) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_mean(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.mean(ref_inp) + with flag_gems.use_gems(): + res_out = torch.mean(inp) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIMS_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_mean_dim(shape, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.mean(ref_inp, dim, keepdim) + with flag_gems.use_gems(): + res_out = torch.mean(inp, dim, keepdim) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_min(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.min(ref_inp) + with flag_gems.use_gems(): + res_out = torch.min(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_min_dim(shape, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.min(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.min(inp, dim=dim, keepdim=keepdim) + ref_out_value, ref_out_index = ref_out + res_out_value, res_out_index = res_out + gems_assert_equal(res_out_index, ref_out_index) + gems_assert_equal(res_out_value, ref_out_value) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_prod(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.prod(ref_inp) + with flag_gems.use_gems(): + res_out = torch.prod(inp) + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIM_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_prod_dim(shape, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.prod(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.prod(inp, dim=dim, keepdim=keepdim) + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_rmsnorm(shape, dtype): + N = shape[1] + layer_shape = [ + N, + ] + inp = torch.randn(shape, dtype=dtype, device="cuda") + weight = torch.randn(layer_shape, dtype=dtype, device="cuda") + eps = 1e-5 + + ref_inp = to_reference(inp, True) + ref_weight = to_reference(weight, True) + + def _torch_rms_norm(x, weight, eps): + variance = x.pow(2).mean(-1, keepdim=True) + hidden_states = x * torch.rsqrt(variance + eps) + return weight * hidden_states + + ref_out = _torch_rms_norm(ref_inp, weight=ref_weight, eps=eps) + + res_out = flag_gems.rms_norm(inp, list(layer_shape), weight=weight, eps=eps) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_skip_layernorm(shape, dtype): + N = shape[1] + layer_shape = [ + N, + ] + inp = torch.randn(shape, dtype=dtype, device="cuda") + residual = torch.randn(shape, dtype=dtype, device="cuda") + weight = torch.randn(layer_shape, dtype=dtype, device="cuda") + bias = torch.randn(layer_shape, dtype=dtype, device="cuda") + eps = 1e-5 + + ref_inp = to_reference(inp, True) + ref_residual = to_reference(residual, True) + ref_weight = to_reference(weight, True) + ref_bias = to_reference(bias, True) + + ref_out = torch.layer_norm( + ref_inp + ref_residual, + list(layer_shape), + weight=ref_weight, + bias=ref_bias, + eps=eps, + ) + res_out = flag_gems.skip_layer_norm( + inp, residual, list(layer_shape), weight=weight, bias=bias, eps=eps + ) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_skip_rmsnorm(shape, dtype): + N = shape[1] + layer_shape = [ + N, + ] + inp = torch.randn(shape, dtype=dtype, device="cuda") + residual = torch.randn(shape, dtype=dtype, device="cuda") + weight = torch.randn(layer_shape, dtype=dtype, device="cuda") + eps = 1e-5 + + ref_inp = to_reference(inp, True) + ref_residual = to_reference(residual, True) + ref_weight = to_reference(weight, True) + + def _torch_rms_norm(x, residual, weight, eps): + x = x + residual + variance = x.pow(2).mean(-1, keepdim=True) + hidden_states = x * torch.rsqrt(variance + eps) + return weight * hidden_states + + ref_out = _torch_rms_norm( + ref_inp, + ref_residual, + weight=ref_weight, + eps=eps, + ) + + res_out = flag_gems.skip_rms_norm( + inp, residual, list(layer_shape), weight=weight, eps=eps + ) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_softmax(shape, dtype): + dim = 1 + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.nn.functional.softmax(ref_inp, dim=dim) + with flag_gems.use_gems(): + res_out = torch.nn.functional.softmax(inp, dim=dim) + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_sum(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.sum(ref_inp) + with flag_gems.use_gems(): + res_out = torch.sum(inp) + + gems_assert_close(res_out, ref_out, dtype, reduce_dim=inp.numel()) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIMS_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_sum_dim(shape, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.sum(ref_inp, dim=dim, keepdim=keepdim) + with flag_gems.use_gems(): + res_out = torch.sum(inp, dim=dim, keepdim=keepdim) + + if isinstance(dim, int): + dim = [dim] + dim = [d % inp.ndim for d in dim] + _dim = 1 + for d in dim: + _dim *= shape[d] + gems_assert_close(res_out, ref_out, dtype, reduce_dim=_dim) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dim", DIMS_LIST) +@pytest.mark.parametrize("correction", [0, 1]) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_varmean(shape, dim, correction, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_var, ref_mean = torch.var_mean( + ref_inp, dim, correction=correction, keepdim=keepdim + ) + with flag_gems.use_gems(): + res_var, res_mean = torch.var_mean( + inp, dim, correction=correction, keepdim=keepdim + ) + + gems_assert_close(res_mean, ref_mean, dtype) + gems_assert_close(res_var, ref_var, dtype) + + +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("ord", [2, float("inf"), -float("inf"), 0, 1]) +@pytest.mark.parametrize("dim", DIMS_LIST) +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.linalg.vector_norm(ref_inp, ord, dim, keepdim) + with flag_gems.use_gems(): + res_out = torch.linalg.vector_norm(inp, ord, dim, keepdim) + + gems_assert_close(res_out, ref_out, dtype) diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py new file mode 100644 index 00000000..4ce1c082 --- /dev/null +++ b/tests/test_special_ops.py @@ -0,0 +1,143 @@ +import torch +import pytest +import flag_gems +from .accuracy_utils import * + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("p", [0.3, 0.6, 0.9]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_dropout(shape, p, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + ref_inp = to_reference(inp) + + ref_out = torch.nn.functional.dropout(ref_inp, p, True) + with flag_gems.use_gems(): + res_out = torch.nn.functional.dropout(inp, p, True) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + + res_out = to_reference(res_out) + res_in_grad = to_reference(res_in_grad) + + exp_equal = (p * p + (1 - p) * (1 - p)) * inp.numel() + num_equal = torch.sum(torch.isclose(ref_out, res_out)).item() + assert ( + abs(num_equal - exp_equal) / exp_equal <= 0.05 + ), f"num_equal: {num_equal}, exp_equal: {exp_equal}, num_total: {inp.numel()}" + + num_equal = torch.sum(torch.isclose(ref_in_grad, res_in_grad)).item() + assert ( + abs(num_equal - exp_equal) / exp_equal <= 0.05 + ), f"num_equal: {num_equal}, exp_equal: {exp_equal}, num_total: {inp.numel()}" + + +def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device="cuda"): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + t = torch.arange(max_seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + cos = freqs.cos().to(dtype) + sin = freqs.sin().to(dtype) + return cos, sin + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.cohere.modeling_cohere.rotate_half +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py +def rotate_interleave(x): + """Rotates interleave the hidden dims of the input.""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def torch_apply_rotary_pos_emb( + q, + k, + cos, + sin, + position_ids, + rotary_interleaved: bool = False, +): + q = q.float() + k = k.float() + cos = cos[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + sin = sin[position_ids].unsqueeze(-2) # [bs, seq_len, 1, dim/2] + if rotary_interleaved: + cos = torch.repeat_interleave(cos, 2, dim=-1) # [bs, seq_len, 1, dim] + sin = torch.repeat_interleave(sin, 2, dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_interleave + else: + cos = torch.cat([cos, cos], dim=-1) # [bs, seq_len, 1, dim] + sin = torch.cat([sin, sin], dim=-1) # [bs, seq_len, 1, dim] + rotate_fn = rotate_half + + q_embed = (q * cos) + (rotate_fn(q) * sin) + k_embed = (k * cos) + (rotate_fn(k) * sin) + + return q_embed, k_embed + + +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("max_seq_len", [512, 2048]) +@pytest.mark.parametrize("q_heads,k_heads", [(8, 1), (6, 2), (1, 1), (8, 8)]) +@pytest.mark.parametrize("head_dim", [64, 96, 128, 256]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("rotary_interleaved", [True, False]) +def test_apply_rotary_pos_emb( + batch_size, + max_seq_len, + q_heads, + k_heads, + head_dim, + dtype, + rotary_interleaved, +): + seq_len = torch.randint(1, max_seq_len, (1,)).item() + q = torch.randn( + (batch_size, seq_len, q_heads, head_dim), dtype=dtype, device="cuda" + ) + k = torch.randn( + (batch_size, seq_len, k_heads, head_dim), dtype=dtype, device="cuda" + ) + + position_ids = torch.randint(0, max_seq_len, (batch_size, seq_len), device="cuda") + cos, sin = get_rope_cos_sin(max_seq_len, head_dim, dtype, device="cuda") + + ref_q = to_reference(q, True) + ref_k = to_reference(k, True) + ref_cos = to_reference(cos, True) + ref_sin = to_reference(sin, True) + ref_position_ids = to_reference(position_ids) + + q_embed_ref, k_embed_ref = torch_apply_rotary_pos_emb( + q=ref_q, + k=ref_k, + cos=ref_cos, + sin=ref_sin, + position_ids=ref_position_ids, + rotary_interleaved=rotary_interleaved, + ) + q_embed_out, k_embed_out = flag_gems.apply_rotary_pos_emb( + q=q, + k=k, + cos=cos, + sin=sin, + position_ids=position_ids, + rotary_interleaved=rotary_interleaved, + ) + + gems_assert_close(q_embed_out, q_embed_ref, dtype) + gems_assert_close(k_embed_out, k_embed_ref, dtype) diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py new file mode 100644 index 00000000..e2e57255 --- /dev/null +++ b/tests/test_unary_pointwise_ops.py @@ -0,0 +1,278 @@ +import torch +import pytest +import flag_gems +from .accuracy_utils import * + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_abs(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.abs(ref_inp) + with flag_gems.use_gems(): + res_out = torch.abs(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", INT_DTYPES) +def test_accuracy_bitwisenot(shape, dtype): + inp = torch.randint( + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + ) + ref_inp = to_reference(inp) + + ref_out = torch.bitwise_not(ref_inp) + with flag_gems.use_gems(): + res_out = torch.bitwise_not(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_cos(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.cos(ref_inp) + with flag_gems.use_gems(): + res_out = torch.cos(inp) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_exp(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.exp(ref_inp) + with flag_gems.use_gems(): + res_out = torch.exp(inp) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_gelu(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.nn.functional.gelu(ref_inp) + with flag_gems.use_gems(): + res_out = torch.nn.functional.gelu(inp) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("approximate", ["none", "tanh"]) +def test_accuracy_gelu_and_mul(shape, approximate, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.mul( + torch.nn.functional.gelu(ref_inp1, approximate=approximate), ref_inp2 + ) + with flag_gems.use_gems(): + res_out = flag_gems.gelu_and_mul(inp1, inp2, approximate) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_isinf(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.masked_fill(inp, inp > 1.0, -float("inf")) + ref_inp = to_reference(inp) + + ref_out = torch.isinf(ref_inp) + with flag_gems.use_gems(): + res_out = torch.isinf(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_isnan(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.masked_fill(inp, inp > 1.0, float("nan")) + ref_inp = to_reference(inp) + + ref_out = torch.isnan(ref_inp) + with flag_gems.use_gems(): + res_out = torch.isnan(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_neg(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.neg(ref_inp) + with flag_gems.use_gems(): + res_out = torch.neg(inp) + + gems_assert_equal(res_out, ref_out) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_reciprocal(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.reciprocal(ref_inp) + with flag_gems.use_gems(): + res_out = torch.reciprocal(inp) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_relu(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.nn.functional.relu(ref_inp) + with flag_gems.use_gems(): + res_out = torch.relu(inp) + + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_rsqrt(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.rsqrt(ref_inp) + with flag_gems.use_gems(): + res_out = torch.rsqrt(inp) + + gems_assert_close(res_out, ref_out, dtype, equal_nan=True) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_sigmoid(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.sigmoid(ref_inp) + with flag_gems.use_gems(): + res_out = torch.sigmoid(inp) + + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_silu(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.nn.functional.silu(ref_inp) + with flag_gems.use_gems(): + res_out = torch.nn.functional.silu(inp) + + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_silu_and_mul(shape, dtype): + inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp1 = to_reference(inp1, True) + ref_inp2 = to_reference(inp2, True) + + ref_out = torch.mul(torch.nn.functional.silu(ref_inp1), ref_inp2) + with flag_gems.use_gems(): + res_out = flag_gems.silu_and_mul(inp1, inp2) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_sin(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, True) + + ref_out = torch.sin(ref_inp) + with flag_gems.use_gems(): + res_out = torch.sin(inp) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_tanh(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.tanh(ref_inp) + with flag_gems.use_gems(): + res_out = torch.tanh(inp) + + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("diagonal", [-3, -1, 0, 1, 3]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_triu(shape, diagonal, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp) + + ref_out = torch.triu(ref_inp, diagonal) + with flag_gems.use_gems(): + res_out = torch.triu(inp, diagonal) + + gems_assert_equal(res_out, ref_out) From 2768436019e481ba1ce69ea18361219ac69e9840 Mon Sep 17 00:00:00 2001 From: strongspoon Date: Thu, 30 May 2024 13:32:19 +0800 Subject: [PATCH 09/16] [bugfix] fix misspell --- src/flag_gems/ops/sub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flag_gems/ops/sub.py b/src/flag_gems/ops/sub.py index c501c034..a4e4404e 100644 --- a/src/flag_gems/ops/sub.py +++ b/src/flag_gems/ops/sub.py @@ -32,7 +32,7 @@ def sub(A, B, *, alpha=1): O = sub_func_tensor_scalar(A, B, alpha) return O elif isinstance(B, torch.Tensor): - O = sub_func_scalar_tensor(A, B, alhpa) + O = sub_func_scalar_tensor(A, B, alpha) return O else: # Both scalar From 6f1b1c6b906bba238375dea506a71b35c0c2994a Mon Sep 17 00:00:00 2001 From: strongspoon Date: Thu, 30 May 2024 13:32:47 +0800 Subject: [PATCH 10/16] [doc] update version --- OperatorList.md | 2 -- pyproject.toml | 2 +- src/flag_gems/__init__.py | 2 ++ 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/OperatorList.md b/OperatorList.md index 1dcf94ca..4eedc9c6 100644 --- a/OperatorList.md +++ b/OperatorList.md @@ -1,7 +1,5 @@ ## Operator List -FlagGems will implement the following operators as planned. Version 1.0 will be released within 6 months. - ## v1.0 - addmm - bmm diff --git a/pyproject.toml b/pyproject.toml index 593f823b..f7ef9fa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flag_gems" -version = "1.0.0" +version = "2.0" authors = [ {name = "Zhixin Li", email = "strongspoon@outlook.com"}, {name = "Tongxin Bai", email = "waffle.bai@gmail.com"}, diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 491a22a5..07440b29 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -3,6 +3,8 @@ from .ops import * from .fused import * +__version__ = "2.0" + aten_lib = torch.library.Library("aten", "IMPL") From b43999a605778499213be72bfb0b320ecd5d8609 Mon Sep 17 00:00:00 2001 From: Bowen12992 Date: Fri, 31 May 2024 10:14:34 +0800 Subject: [PATCH 11/16] add yaml file --- .github/workflows/python-test.yaml | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 .github/workflows/python-test.yaml diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml new file mode 100644 index 00000000..050547e6 --- /dev/null +++ b/.github/workflows/python-test.yaml @@ -0,0 +1,35 @@ + +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: flag-gems-test + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + container-test-job: + runs-on: [self-hosted, docker] + container: + image: localhost:5000/flag-gems-ci:v1.0 + ports: + - 80 + options: --gpus all --hostname flag-gems_cicd + steps: + - name: checkout-code + uses: actions/checkout@v2 + + - name: unit_test-flag-gems + run: | + pytest -s tests/test_* + + - name: benchmark-flag-gems + run: | + pytest -s benchmark/test_* + + - name: examples-flag-gems + run: | + pytest -s examples/model_* From 0b964e29dce01784e0921886bfd6c07b3a68369a Mon Sep 17 00:00:00 2001 From: strongspoon Date: Fri, 31 May 2024 11:00:51 +0800 Subject: [PATCH 12/16] [test] scale up accuracy test shapes --- tests/accuracy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 56819fab..9a4ed90e 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -14,7 +14,7 @@ } POINTWISE_SHAPES = [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)] -REDUCTION_SHAPES = [(1024, 64 * i) for i in range(1, 10, 2)] +REDUCTION_SHAPES = [(4096, 256 * i) for i in range(1, 10, 2)] MNK_SHAPES = [15, 160, 1024] FLOAT_DTYPES = [torch.float16, torch.float32, torch.bfloat16] From bbd53865140d83fff3046c303069ae4d7ac3511f Mon Sep 17 00:00:00 2001 From: Bowen <81504862+Bowen12992@users.noreply.github.com> Date: Fri, 31 May 2024 12:03:29 +0800 Subject: [PATCH 13/16] [CI/CD] remove benchmark test in ci/ce --- .github/workflows/python-test.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 050547e6..f555cbba 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -26,10 +26,6 @@ jobs: run: | pytest -s tests/test_* - - name: benchmark-flag-gems - run: | - pytest -s benchmark/test_* - - name: examples-flag-gems run: | pytest -s examples/model_* From 679ce8b478e445e780c48d15feb98b22e37e4760 Mon Sep 17 00:00:00 2001 From: FatJhon <156064001+FatJhon@users.noreply.github.com> Date: Fri, 31 May 2024 16:14:56 +0800 Subject: [PATCH 14/16] add reduction of sum and none for CrossEntropyLoss (#41) * modify name && add reduce function * add reduce none * add test * clean code * add reduce enum * Replacing the enum interface with Intenum & add illegal detection of reduction --------- Co-authored-by: Jiang Bin --- src/flag_gems/ops/argmax.py | 4 +- src/flag_gems/ops/cross_entropy_loss.py | 121 ++++++++++++++++++++---- src/flag_gems/ops/max.py | 8 +- src/flag_gems/ops/min.py | 8 +- tests/test_reduction_ops.py | 9 +- 5 files changed, 122 insertions(+), 28 deletions(-) diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index e31ce6ec..54e1b7ec 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -35,8 +35,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr mid_ptrs = mid_value + offset mask = offset < mid_size mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) - sum_val = tl.argmax(mid_val, axis=0) - mid_index_ptrs = mid_index + sum_val + index_val = tl.argmax(mid_val, axis=0) + mid_index_ptrs = mid_index + index_val out_val = tl.load(mid_index_ptrs) tl.store(out, out_val) diff --git a/src/flag_gems/ops/cross_entropy_loss.py b/src/flag_gems/ops/cross_entropy_loss.py index 951cbe72..d224a4a8 100644 --- a/src/flag_gems/ops/cross_entropy_loss.py +++ b/src/flag_gems/ops/cross_entropy_loss.py @@ -2,8 +2,15 @@ import triton import triton.language as tl import logging +from enum import IntEnum from ..utils import libentry -from .sum import sum +from .sum import sum, sum_dim + + +class Reduction(IntEnum): + NONE = 0 + MEAN = 1 + SUM = 2 @libentry() @@ -56,7 +63,7 @@ def log_softmax_and_mul_kernel( denominator = tl.sum(numerator, axis=1)[:, None] softmax_output = tl.log(numerator / denominator) target = tl.load(target_ptr + offset, mask=mask, other=0.0) - out = softmax_output * target / (-mean_num) + out = softmax_output * target / (mean_num) output_ptrs = output_ptr + offset tl.store(output_ptrs, out, mask=mask) @@ -114,6 +121,68 @@ def softmax_and_sub_kernel( softmax_output = numerator / denominator target_ptrs = target_ptr + offset target = tl.load(target_ptrs, mask=mask, other=0.0) + out_grad_ptr = out_grad + m_offset[:, None] * K + pid_k + out_grad_value = tl.load(out_grad_ptr) + out = out_grad_value * (softmax_output - target) / mean_num + output_ptrs = output_ptr + offset + + tl.store(output_ptrs, out, mask=mask) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 1}, num_stages=4), + triton.Config({"BLOCK_M": 1}, num_stages=5), + triton.Config({"BLOCK_M": 2}, num_stages=4), + triton.Config({"BLOCK_M": 2}, num_stages=5), + triton.Config({"BLOCK_M": 4}, num_stages=4), + triton.Config({"BLOCK_M": 4}, num_stages=5), + triton.Config({"BLOCK_M": 8}, num_stages=4), + triton.Config({"BLOCK_M": 8}, num_stages=5), + ], + key=[ + "M", + "N", + ], +) +@triton.heuristics( + values={ + "BLOCK_N": lambda args: triton.next_power_of_2(args["N"]), + "num_warps": lambda args: ( + 4 if args["N"] <= 1024 else (8 if args["N"] <= 2048 else 16) + ), + }, +) +@triton.jit +def softmax_and_sub_reduce_kernel( + output_ptr, + input_ptr, + target_ptr, + out_grad, + mean_num, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + n_offset = tl.arange(0, BLOCK_N) + offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k + mask = m_offset[:, None] < M and n_offset[None, :] < N + input_ptrs = input_ptr + offset + inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) + row_minus_max = inp - tl.max(inp, axis=1)[:, None] + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=1)[:, None] + # todo: reduce unnecessary calculations through mask operations to improve performance + softmax_output = numerator / denominator + target_ptrs = target_ptr + offset + target = tl.load(target_ptrs, mask=mask, other=0.0) + out_grad_value = tl.load(out_grad) out = out_grad_value * (softmax_output - target) / mean_num output_ptrs = output_ptr + offset @@ -125,15 +194,18 @@ class CrossEntropyLoss(torch.autograd.Function): @staticmethod def forward(ctx, input, target, weight, reduction, ignore_index, label_smoothing): logging.debug("GEMS CrossEntropyLoss") + assert reduction in Reduction._value2member_map_, "Invalid reduction" assert isinstance(input, torch.Tensor), "input is not a tensor" if input.ndim >= 2: dim = 1 else: dim = 0 - + if reduction != Reduction.MEAN.value: + mean_num = -1 + else: + mean_num = -target.numel() shape = list(input.shape) shape[dim] = 1 - mean_num = target.numel() target = torch.zeros_like(input).scatter(dim, target.view(shape), 1) M = 1 @@ -157,11 +229,15 @@ def forward(ctx, input, target, weight, reduction, ignore_index, label_smoothing N, K, ) - out_result = sum(out) + if reduction != Reduction.NONE.value: + out_result = sum(out) + else: + out_result = sum_dim(out, dim=[dim]) ctx.save_for_backward(input, target) ctx.dim = dim - ctx.mean_num = mean_num + ctx.mean_num = -mean_num + ctx.reduction = reduction return out_result @staticmethod @@ -170,6 +246,7 @@ def backward(ctx, out_grad): input, target = ctx.saved_tensors dim = ctx.dim mean_num = ctx.mean_num + reduction = ctx.reduction M = 1 N = input.shape[dim] @@ -183,16 +260,28 @@ def backward(ctx, out_grad): triton.cdiv(M, meta["BLOCK_M"]), K, ) - softmax_and_sub_kernel[grid]( - out, - inp, - target, - out_grad, - mean_num, - M, - N, - K, - ) + if reduction != Reduction.NONE.value: + softmax_and_sub_reduce_kernel[grid]( + out, + inp, + target, + out_grad, + mean_num, + M, + N, + K, + ) + else: + softmax_and_sub_kernel[grid]( + out, + inp, + target, + out_grad, + mean_num, + M, + N, + K, + ) return out, None, None, None, None, None diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 6ff4fcbd..8e27f7a2 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -20,9 +20,9 @@ def max_kernel_1( inp_ptrs = inp + offset mask = offset < M inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) - sum_val = tl.max(inp_val) + max_val = tl.max(inp_val) mid_ptr = mid + pid - tl.store(mid_ptr, sum_val) + tl.store(mid_ptr, max_val) @libentry() @@ -32,8 +32,8 @@ def max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): mid_ptrs = mid + offset mask = offset < mid_size mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) - sum_val = tl.max(mid_val) - tl.store(out, sum_val) + max_val = tl.max(mid_val) + tl.store(out, max_val) @libentry() diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index cb09eb00..1a17dde2 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -20,9 +20,9 @@ def min_kernel_1( inp_ptrs = inp + offset mask = offset < M inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf")) - sum_val = tl.min(inp_val) + min_val = tl.min(inp_val) mid_ptr = mid + pid - tl.store(mid_ptr, sum_val) + tl.store(mid_ptr, min_val) @libentry() @@ -32,8 +32,8 @@ def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): mid_ptrs = mid + offset mask = offset < mid_size mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf")) - sum_val = tl.min(mid_val) - tl.store(out, sum_val) + min_val = tl.min(mid_val) + tl.store(out, min_val) @libentry() diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index ebdae6ab..8bdacc75 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -143,9 +143,12 @@ def test_accuracy_argmax(shape, dim, keepdim, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.parametrize("size_average", [None, True, False]) +@pytest.mark.parametrize("reduce", [None, True, False]) +@pytest.mark.parametrize("reduction", ["mean", "none", "sum"]) @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_cross_entropy_loss(shape, dtype): +def test_accuracy_cross_entropy_loss(shape, dtype, size_average, reduce, reduction): inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) dim = 1 up_limit = shape[dim] - 1 @@ -156,7 +159,9 @@ def test_accuracy_cross_entropy_loss(shape, dtype): ref_inp = to_reference(inp, True) ref_target = to_reference(target) - criterion = torch.nn.CrossEntropyLoss() + criterion = torch.nn.CrossEntropyLoss( + size_average=size_average, reduce=reduce, reduction=reduction + ) ref_out = criterion(ref_inp, ref_target) with flag_gems.use_gems(): From b43cf9313869f9bd5fdff238035ff5ab77d2fb6e Mon Sep 17 00:00:00 2001 From: Bowen <81504862+Bowen12992@users.noreply.github.com> Date: Sat, 1 Jun 2024 21:30:26 +0800 Subject: [PATCH 15/16] [CI/CD] Use tow GPU cards for ci --- .github/workflows/python-test.yaml | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 050547e6..441e4e8d 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -11,25 +11,32 @@ on: branches: [ "master" ] jobs: - container-test-job: + container-unit-test: runs-on: [self-hosted, docker] container: image: localhost:5000/flag-gems-ci:v1.0 ports: - - 80 - options: --gpus all --hostname flag-gems_cicd + - 81 + options: --gpus all --hostname flag-gems_cicd_ut steps: - name: checkout-code uses: actions/checkout@v2 - name: unit_test-flag-gems run: | - pytest -s tests/test_* - - - name: benchmark-flag-gems - run: | - pytest -s benchmark/test_* + CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_* + + container-model-test: + runs-on: [self-hosted, docker] + container: + image: localhost:5000/flag-gems-ci:v1.0 + ports: + - 82 + options: --gpus all --hostname flag-gems_cicd_model + steps: + - name: checkout-code + uses: actions/checkout@v2 - name: examples-flag-gems run: | - pytest -s examples/model_* + CUDA_VISIBLE_DEVICES=1 pytest -s examples/model_bert_test.py From 5e70d2e6b60d2163538030b4e092c2bc0039aaf8 Mon Sep 17 00:00:00 2001 From: Bowen <81504862+Bowen12992@users.noreply.github.com> Date: Sat, 1 Jun 2024 22:12:03 +0800 Subject: [PATCH 16/16] add path --- .github/workflows/python-test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 441e4e8d..6f4f7545 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -32,7 +32,7 @@ jobs: image: localhost:5000/flag-gems-ci:v1.0 ports: - 82 - options: --gpus all --hostname flag-gems_cicd_model + options: --gpus all --hostname flag-gems_cicd_model -v /home/flaggems_cicd/huggingface_cache_bert:/__w/_temp/_github_home/.cache/huggingface steps: - name: checkout-code uses: actions/checkout@v2