Skip to content

Commit

Permalink
Improve FURB179 error messages:
Browse files Browse the repository at this point in the history
* Detect `reduce()` when used with `add`/`concat` operator

* Remove duplicated code

* Emit better error message for set conprehensions

* Show more accurate names in error messages
  • Loading branch information
dosisod committed Dec 25, 2023
1 parent 0929acf commit 42292a4
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 24 deletions.
5 changes: 5 additions & 0 deletions refurb/checks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,4 +335,9 @@ def _stringify(node: Node) -> str:

return f"lambda{args}: {body}"

case ListExpr(items=items):
inner = ", ".join(stringify(x) for x in items)

return f"[{inner}]"

raise ValueError
46 changes: 29 additions & 17 deletions refurb/checks/itertools/use_chain_from_iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
ListExpr,
NameExpr,
RefExpr,
SetComprehension,
)

from refurb.checks.common import stringify
from refurb.error import Error


Expand Down Expand Up @@ -81,7 +83,7 @@ def is_flatten_generator(node: GeneratorExpr) -> bool:


def check(
node: ListComprehension | GeneratorExpr | CallExpr,
node: ListComprehension | SetComprehension | GeneratorExpr | CallExpr,
errors: list[Error],
) -> None:
if id(node) in ignore:
Expand All @@ -92,41 +94,51 @@ def check(
old = "[... for ... in x for ... in ...]"
new = "list(chain.from_iterable(x))"

msg = f"Replace `{old}` with `{new}`"
ignore.add(id(g))

errors.append(ErrorInfo.from_node(node, msg))
case SetComprehension(generator=g) if is_flatten_generator(g):
old = "{... for ... in x for ... in ...}"
new = "set(chain.from_iterable(x))"

ignore.add(id(g))

case GeneratorExpr() if is_flatten_generator(node):
old = "... for ... in x for ... in ..."
new = "chain.from_iterable(x)"

msg = f"Replace `{old}` with `{new}`"

errors.append(ErrorInfo.from_node(node, msg))

case CallExpr(
callee=RefExpr(fullname="builtins.sum"),
args=[_, ListExpr(items=[])],
args=[arg, ListExpr(items=[])],
):
old = "sum(x, [])"
new = "chain.from_iterable(x)"
old = f"sum({stringify(arg)}, [])"
new = f"chain.from_iterable({stringify(arg)})"

case CallExpr(
callee=RefExpr(fullname="functools.reduce"),
args=[op, arg] | [op, arg, ListExpr(items=[])],
):
match op:
case RefExpr(fullname="_operator.add" | "_operator.concat"):
pass

msg = f"Replace `{old}` with `{new}`"
case _:
return

errors.append(ErrorInfo.from_node(node, msg))
old = stringify(node)
new = f"chain.from_iterable({stringify(arg)})"

case CallExpr(
callee=RefExpr(fullname="itertools.chain") as callee,
args=[_],
args=[arg],
arg_kinds=[ArgKind.ARG_STAR],
):
chain = "chain" if isinstance(callee, NameExpr) else "itertools.chain"

old = f"{chain}(*x)"
new = f"{chain}.from_iterable(x)"
old = f"{chain}(*{stringify(arg)})"
new = f"{chain}.from_iterable({stringify(arg)})"

msg = f"Replace `{old}` with `{new}`"
case _:
return

errors.append(ErrorInfo.from_node(node, msg))
msg = f"Replace `{old}` with `{new}`"
errors.append(ErrorInfo.from_node(node, msg))
25 changes: 25 additions & 0 deletions test/data/err_179.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from functools import reduce
from operator import add, concat, iadd
from itertools import chain
import functools
import itertools
import operator


rows = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
Expand Down Expand Up @@ -30,6 +34,21 @@ def flatten_via_chain_splat(rows):
def flatten_via_chain_splat_2(rows):
return itertools.chain(*rows)

def flatten_via_reduce_add(rows):
return reduce(add, rows)

def flatten_via_reduce_add_with_default(rows):
return reduce(add, rows, [])

def flatten_via_reduce_concat(rows):
return reduce(concat, rows)

def flatten_via_reduce_concat_with_default(rows):
return reduce(concat, rows, [])

def flatten_via_reduce_full_namespace(rows):
return functools.reduce(operator.add, rows)


# these should not

Expand Down Expand Up @@ -68,3 +87,9 @@ def flatten_via_chain_without_splat(rows):

def flatten_via_chain_from_iterable(rows):
return chain.from_iterable(rows)

def flatten_via_reduce_iadd(rows):
return reduce(iadd, rows, [])

def flatten_via_reduce_non_empty_default(rows):
return reduce(add, rows, [1, 2, 3])
19 changes: 12 additions & 7 deletions test/data/err_179.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
test/data/err_179.py:13:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:16:12 [FURB179]: Replace `[... for ... in x for ... in ...]` with `list(chain.from_iterable(x))`
test/data/err_179.py:19:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:22:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:25:12 [FURB179]: Replace `sum(x, [])` with `chain.from_iterable(x)`
test/data/err_179.py:28:12 [FURB179]: Replace `chain(*x)` with `chain.from_iterable(x)`
test/data/err_179.py:31:12 [FURB179]: Replace `itertools.chain(*x)` with `itertools.chain.from_iterable(x)`
test/data/err_179.py:17:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:20:12 [FURB179]: Replace `[... for ... in x for ... in ...]` with `list(chain.from_iterable(x))`
test/data/err_179.py:23:12 [FURB179]: Replace `{... for ... in x for ... in ...}` with `set(chain.from_iterable(x))`
test/data/err_179.py:26:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:29:12 [FURB179]: Replace `sum(rows, [])` with `chain.from_iterable(rows)`
test/data/err_179.py:32:12 [FURB179]: Replace `chain(*rows)` with `chain.from_iterable(rows)`
test/data/err_179.py:35:12 [FURB179]: Replace `itertools.chain(*rows)` with `itertools.chain.from_iterable(rows)`
test/data/err_179.py:38:12 [FURB179]: Replace `reduce(add, rows)` with `chain.from_iterable(rows)`
test/data/err_179.py:41:12 [FURB179]: Replace `reduce(add, rows, [])` with `chain.from_iterable(rows)`
test/data/err_179.py:44:12 [FURB179]: Replace `reduce(concat, rows)` with `chain.from_iterable(rows)`
test/data/err_179.py:47:12 [FURB179]: Replace `reduce(concat, rows, [])` with `chain.from_iterable(rows)`
test/data/err_179.py:50:12 [FURB179]: Replace `functools.reduce(operator.add, rows)` with `chain.from_iterable(rows)`

0 comments on commit 42292a4

Please sign in to comment.