Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Document] README Feature: automatic code generation #53

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,45 @@ FlagGems is a high-performance general operator library implemented in [OpenAI T
By registering with the ATen backend of PyTorch, FlagGems facilitates a seamless transition, allowing users to switch to the Triton function library without the need to modify their model code. Users can still utilize the ATen backend as usual while experiencing significant performance enhancement. The Triton language offers benefits in readability, user-friendliness and performance comparable to CUDA. This convenience allows developers to engage in the development of FlagGems with minimal learning investment.


## Feature

### Automatic Codegen

In FlagGems, we provide automatic code generation that developers can use to conveniently generate pointwise single operators and pointwise fused operators. Automatic code generation can handle various needs such as normal pointwise computations, non-tensor arguments, and specifying output data types.

#### Normal Pointwise Operator

Decorating the pointwise operator function with `pointwise_dynamic` can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code.

```python
@pointwise_dynamic
@triton.jit
def abs_func(x):
return tl.abs(x)
```

#### Non-Tensor Argument

By default, `pointwise_dynamic` treats all parameters as tensors, and by passing a list of boolean values to the parameter `is_tensor`, developers can specify which parameters are tensors and which are not. Additionally, developers can pass in `dtypes` to indicate the data types of non-tensor parameters, but this is not required. For example, in the following code, the `alpha` parameter is defined as a non-tensor floating point number, while the `x` and `y` parameters are defined as tensors.

```python
@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float])
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha
```

#### Output Data Type

By default, all output tensors have the same data type as the first input tensor, but it can also be customized by providing a list of data types to the parameter `output_dtypes`. For example, in the following code, the output tensor type is specified as `torch.bool`.

```python
@pointwise_dynamic(output_dtypes=[torch.bool])
@triton.jit
def ge(x, y):
return x > y
```

## Changelog

### v1.0
Expand Down
40 changes: 40 additions & 0 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,46 @@ FlagGems是一个使用OpenAI推出的[Triton编程语言](https://github.com/op

FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库的无缝替换,使用户能够在不修改模型代码的情况下平稳地切换到triton算子库。FlagGems不会影响aten后端的正常使用,并且会带来良好的性能提升。Triton语言为算子库提供了更好的可读性和易用性,同时保持了不逊于CUDA的算子性能,因此开发者只需付出较低的学习成本,即可参与FlagGems的算子开发与建设。


## 特性

### 自动代码生成

在FlagGems中,我们提供了一套自动代码生成的机制,开发者可以使用它来便捷地生成pointwise类型的单算子与融合算子。自动代码生成可以处理常规的对位计算、非张量参数、指定输出类型等多种需求。

#### 常规对位计算

在对位算子函数前装饰`pointwise_dynamic`,可以节省张量寻址、张量读写、并行分块、张量广播、动态维度、非连续存储等的手动处理。例如以下代码,开发者只需简单描述计算逻辑,即可生成灵活高效的Triton核函数与包装代码。

```python
@pointwise_dynamic
@triton.jit
def abs_func(x):
return tl.abs(x)
```

#### 非张量参数

在默认情况下,`pointwise_dynamic`将所有参数均处理为张量,而通过向参数`is_tensor`传递布尔值列表,开发者可以指定哪些参数是张量,哪些参数非张量。此外,开发者还可以传入`dtypes`说明非张量参数的数据类型,但这不是必要的。例如以下代码,将`alpha`参数定义为非张量的浮点数,而`x`和`y`参数定义为张量。

```python
@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float])
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha
```

#### 输出数据类型

在默认情况下,输出张量使用与首个输入张量相同的数据类型,但也可向参数`output_dtypes`传入数据类型组成的列表来指定。例如以下代码,指定输出张量类型为`torch.bool`。

```python
@pointwise_dynamic(output_dtypes=[torch.bool])
@triton.jit
def ge(x, y):
return x > y
```

## 更新日志

### v1.0
Expand Down
13 changes: 13 additions & 0 deletions benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ def test_perf_div(dtype):
bench.run()


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_dropout(dtype):
bench = Benchmark(
op_name="dropout",
torch_op=torch.nn.Dropout(p=0.5),
arg_func=unary_arg,
dtype=dtype,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_perf_eq(dtype):
bench = Benchmark(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,24 @@ def test_accuracy_ge_scalar(shape, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("approximate", ["none", "tanh"])
def test_accuracy_gelu_and_mul(shape, approximate, dtype):
inp1 = torch.randn(shape, dtype=dtype, device="cuda")
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

ref_out = torch.mul(
torch.nn.functional.gelu(ref_inp1, approximate=approximate), ref_inp2
)
with flag_gems.use_gems():
res_out = flag_gems.gelu_and_mul(inp1, inp2, approximate)

gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_gt(shape, dtype):
Expand Down Expand Up @@ -521,6 +539,21 @@ def test_accuracy_rsub(shape, alpha, dtype):
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_silu_and_mul(shape, dtype):
inp1 = torch.randn(shape, dtype=dtype, device="cuda")
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

ref_out = torch.mul(torch.nn.functional.silu(ref_inp1), ref_inp2)
with flag_gems.use_gems():
res_out = flag_gems.silu_and_mul(inp1, inp2)

gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("alpha", SCALARS)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
33 changes: 0 additions & 33 deletions tests/test_unary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,6 @@ def test_accuracy_gelu(shape, dtype):
gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("approximate", ["none", "tanh"])
def test_accuracy_gelu_and_mul(shape, approximate, dtype):
inp1 = torch.randn(shape, dtype=dtype, device="cuda")
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

ref_out = torch.mul(
torch.nn.functional.gelu(ref_inp1, approximate=approximate), ref_inp2
)
with flag_gems.use_gems():
res_out = flag_gems.gelu_and_mul(inp1, inp2, approximate)

gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_isinf(shape, dtype):
Expand Down Expand Up @@ -216,21 +198,6 @@ def test_accuracy_silu(shape, dtype):
gems_assert_close(res_in_grad, ref_in_grad, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_silu_and_mul(shape, dtype):
inp1 = torch.randn(shape, dtype=dtype, device="cuda")
inp2 = torch.randn(shape, dtype=dtype, device="cuda")
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

ref_out = torch.mul(torch.nn.functional.silu(ref_inp1), ref_inp2)
with flag_gems.use_gems():
res_out = flag_gems.silu_and_mul(inp1, inp2)

gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_sin(shape, dtype):
Expand Down