diff --git a/refurb/checks/string/startswith.py b/refurb/checks/string/startswith.py index e48c40c..6b531a2 100644 --- a/refurb/checks/string/startswith.py +++ b/refurb/checks/string/startswith.py @@ -1,8 +1,8 @@ from dataclasses import dataclass -from mypy.nodes import CallExpr, Expression, MemberExpr, NameExpr, OpExpr, UnaryExpr, Var +from mypy.nodes import CallExpr, Expression, MemberExpr, OpExpr, UnaryExpr -from refurb.checks.common import extract_binary_oper, is_same_type +from refurb.checks.common import extract_binary_oper, get_mypy_type, is_equivalent, is_same_type from refurb.error import Error @@ -40,19 +40,15 @@ def are_startswith_or_endswith_calls( ) -> tuple[str, Expression] | None: match lhs, rhs: case ( - CallExpr( - callee=MemberExpr(expr=NameExpr(node=Var(type=ty)) as lhs, name=lhs_func), - args=args, - ), - CallExpr(callee=MemberExpr(expr=NameExpr() as rhs, name=rhs_func)), + CallExpr(callee=MemberExpr(expr=lhs, name=lhs_func), args=[first_arg]), + CallExpr(callee=MemberExpr(expr=rhs, name=rhs_func), args=[_]), ) if ( - lhs.fullname == rhs.fullname - and is_same_type(ty, str, bytes) + is_equivalent(lhs, rhs) + and is_same_type(get_mypy_type(lhs), str, bytes) and lhs_func == rhs_func and lhs_func in {"startswith", "endswith"} - and args ): - return lhs_func, args[0] + return lhs_func, first_arg return None @@ -77,9 +73,4 @@ def check(node: OpExpr, errors: list[Error]) -> None: old = f"not x.{func}(y) and not x.{func}(z)" new = f"not x.{func}((y, z))" - errors.append( - ErrorInfo.from_node( - arg, - msg=f"Replace `{old}` with `{new}`", - ) - ) + errors.append(ErrorInfo.from_node(arg, msg=f"Replace `{old}` with `{new}`")) diff --git a/test/data/err_102.py b/test/data/err_102.py index 68a5fce..07f96d7 100644 --- a/test/data/err_102.py +++ b/test/data/err_102.py @@ -9,6 +9,16 @@ _ = not name.startswith("a") and not name.startswith("b") +class C: + s: str + +c = C() + +_ = c.s.startswith("a") or c.s.startswith("b") + +# TODO: disallow this because C() differs between branches +_ = C().s.startswith("a") or C().s.startswith("b") + # these should not match _ = name.startswith("a") and name.startswith("b") @@ -23,3 +33,6 @@ _ = not name.startswith("a") or not name.startswith("b") _ = not name.startswith("a") and name.startswith("b") _ = name.startswith("a") and not name.startswith("b") + +_ = name.startswith("a", "b") or name.startswith("b") # type: ignore +_ = name.startswith("a") or name.startswith() # type: ignore diff --git a/test/data/err_102.txt b/test/data/err_102.txt index e0663d6..d1458a9 100644 --- a/test/data/err_102.txt +++ b/test/data/err_102.txt @@ -3,3 +3,5 @@ test/data/err_102.py:6:19 [FURB102]: Replace `x.endswith(y) or x.endswith(z)` wi test/data/err_102.py:7:26 [FURB102]: Replace `x.startswith(y) or x.startswith(z)` with `x.startswith((y, z))` test/data/err_102.py:8:21 [FURB102]: Replace `x.startswith(y) or x.startswith(z)` with `x.startswith((y, z))` test/data/err_102.py:10:25 [FURB102]: Replace `not x.startswith(y) and not x.startswith(z)` with `not x.startswith((y, z))` +test/data/err_102.py:17:20 [FURB102]: Replace `x.startswith(y) or x.startswith(z)` with `x.startswith((y, z))` +test/data/err_102.py:20:22 [FURB102]: Replace `x.startswith(y) or x.startswith(z)` with `x.startswith((y, z))`