Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Code Format] Remove python version and ignore of F403 & E722 #66

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,20 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: flake8
language_version: python3.11
args: ["--ignore=F405,E731,F403,W503,E722,E203", --max-line-length=120]
args: ["--ignore=F405,E731,W503,E203", --max-line-length=120]
# F405 : Name may be undefined, or defined from star imports: module
# E731 : Do not assign a lambda expression, use a def
# F403 : 'from module import *' used; unable to detect undefined names
# W503 : Line break before binary operator
# E722 : Do not use bare 'except'
# E203 : Whitespace before ':'

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
language_version: python3.11
args: ["--profile", "black"]

- repo: https://github.com/psf/black.git
rev: 23.7.0
hooks:
- id: black
language_version: python3.11
- id: black-jupyter
2 changes: 1 addition & 1 deletion benchmark/test_blas_perf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from .performance_utils import *
from .performance_utils import BLAS_BATCH, DEFAULT_BATCH, FLOAT_DTYPES, SIZES, Benchmark


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
9 changes: 8 additions & 1 deletion benchmark/test_fused_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import flag_gems

from .performance_utils import *
from .performance_utils import (
FLOAT_DTYPES,
POINTWISE_BATCH,
REDUCTION_BATCH,
SIZES,
Benchmark,
binary_args,
)


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
13 changes: 12 additions & 1 deletion benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import pytest
import torch

from .performance_utils import *
from .performance_utils import (
FLOAT_DTYPES,
INT_DTYPES,
POINTWISE_BATCH,
SIZES,
Benchmark,
binary_args,
binary_int_args,
ternary_args,
unary_arg,
unary_int_arg,
)


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
9 changes: 8 additions & 1 deletion benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import pytest
import torch

from .performance_utils import *
from .performance_utils import (
BLAS_BATCH,
FLOAT_DTYPES,
REDUCTION_BATCH,
SIZES,
Benchmark,
unary_arg,
)


@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from .fused import *
from .ops import *
from .fused import * # noqa: F403
from .ops import * # noqa: F403

__version__ = "2.0"

Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _rand(seed, offset):
_grid = (1,)
_seed, _offset = philox_cuda_seed_offset(0)
_rand[_grid](_seed, _offset)
except:
except Exception:
tl_rand_dtype = tl.int32

del _grid
Expand Down
10 changes: 9 additions & 1 deletion tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
INT_DTYPES,
POINTWISE_SHAPES,
SCALARS,
gems_assert_close,
gems_assert_equal,
to_reference,
)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_blas_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
MNK_SHAPES,
SCALARS,
gems_assert_close,
to_reference,
)


@pytest.mark.parametrize("M", MNK_SHAPES)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
DIM_LIST,
DIMS_LIST,
FLOAT_DTYPES,
REDUCTION_SHAPES,
gems_assert_close,
gems_assert_equal,
skip_expr,
skip_reason,
to_reference,
)


@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
Expand Down
9 changes: 7 additions & 2 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
POINTWISE_SHAPES,
gems_assert_close,
to_reference,
)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand Down Expand Up @@ -52,7 +57,7 @@ def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device="cuda"):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x2 = x[..., x.shape[-1] // 2 :] # noqa: E203
return torch.cat((-x2, x1), dim=-1)


Expand Down
9 changes: 8 additions & 1 deletion tests/test_unary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import flag_gems

from .accuracy_utils import *
from .accuracy_utils import (
FLOAT_DTYPES,
INT_DTYPES,
POINTWISE_SHAPES,
gems_assert_close,
gems_assert_equal,
to_reference,
)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand Down