Skip to content

Commit

Permalink
Initial implementation for fluid interface check (#287)
Browse files Browse the repository at this point in the history
* initial implementation for fluid interface check

* support return statements and non-matching names

* performance optimization

* test results

* exclude other references

* fix tests and exclude `_`

* fix linting

* fix mypy

* add suggested test case

* restrict to top-level function definitions

* Remove `was_referenced` property, run `make docs`

---------

Co-authored-by: dosisod <[email protected]>
  • Loading branch information
sbrugman and dosisod authored Dec 22, 2023
1 parent b15ad49 commit 39cafb0
Show file tree
Hide file tree
Showing 4 changed files with 385 additions and 0 deletions.
52 changes: 52 additions & 0 deletions docs/checks.md
Original file line number Diff line number Diff line change
Expand Up @@ -2180,4 +2180,56 @@ Good:
```python
nums = [123, 456]
num = str(num[0])
```

## FURB184: `use-fluid-interface`

Categories: `readability`

When an API has a Fluent Interface (the ability to chain multiple calls together), you should
chain those calls instead of repeatedly assigning and using the value.
Sometimes a return statement can be written more succinctly:

Bad:

```pythonpython
def get_tensors(device: str) -> torch.Tensor:
t1 = torch.ones(2, 1)
t2 = t1.long()
t3 = t2.to(device)
return t3
def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = spark.read.parquet(file_name)
df = df \
.withColumnRenamed('col1', 'col1_renamed') \
.withColumnRenamed('col2', 'col2_renamed')
df = df \
.select(common_columns) \
.withColumn('service_type', F.lit('green'))
return df
```

Good:

```pythonpython
def get_tensors(device: str) -> torch.Tensor:
t3 = (
torch.ones(2, 1)
.long()
.to(device)
)
return t3
def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = (
spark.read.parquet(file_name)
.withColumnRenamed('col1', 'col1_renamed')
.withColumnRenamed('col2', 'col2_renamed')
.select(common_columns)
.withColumn('service_type', F.lit('green'))
)
return df
```
163 changes: 163 additions & 0 deletions refurb/checks/readability/fluid_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from dataclasses import dataclass

from mypy.nodes import (
AssignmentStmt,
CallExpr,
Expression,
FuncDef,
MemberExpr,
NameExpr,
ReturnStmt,
Statement,
)

from refurb.checks.common import ReadCountVisitor, check_block_like
from refurb.error import Error
from refurb.visitor import TraverserVisitor


@dataclass
class ErrorInfo(Error):
r"""
When an API has a Fluent Interface (the ability to chain multiple calls together), you should
chain those calls instead of repeatedly assigning and using the value.
Sometimes a return statement can be written more succinctly:
Bad:
```python
def get_tensors(device: str) -> torch.Tensor:
t1 = torch.ones(2, 1)
t2 = t1.long()
t3 = t2.to(device)
return t3
def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = spark.read.parquet(file_name)
df = df \
.withColumnRenamed('col1', 'col1_renamed') \
.withColumnRenamed('col2', 'col2_renamed')
df = df \
.select(common_columns) \
.withColumn('service_type', F.lit('green'))
return df
```
Good:
```python
def get_tensors(device: str) -> torch.Tensor:
t3 = (
torch.ones(2, 1)
.long()
.to(device)
)
return t3
def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = (
spark.read.parquet(file_name)
.withColumnRenamed('col1', 'col1_renamed')
.withColumnRenamed('col2', 'col2_renamed')
.select(common_columns)
.withColumn('service_type', F.lit('green'))
)
return df
```
"""

name = "use-fluid-interface"
code = 184
categories = ("readability",)


def check(node: FuncDef, errors: list[Error]) -> None:
check_block_like(check_stmts, node.body, errors)


def check_call(node: Expression, name: str | None = None) -> bool:
match node:
# Single chain
case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)):
if name is None or name == x:
# Exclude other references
x_expr = NameExpr(x)
x_expr.fullname = x
visitor = ReadCountVisitor(x_expr)
visitor.accept(node)
return visitor.read_count == 1
return False

# Nested
case CallExpr(callee=MemberExpr(expr=call_node, name=_)):
return check_call(call_node, name=name)

return False


class NameReferenceVisitor(TraverserVisitor):
name: NameExpr
referenced: bool

def __init__(self, name: NameExpr, stmt: Statement | None = None) -> None:
super().__init__()
self.name = name
self.stmt = stmt
self.referenced = False

def visit_name_expr(self, node: NameExpr) -> None:
if not self.referenced and node.fullname == self.name.fullname:
self.referenced = True


def check_stmts(stmts: list[Statement], errors: list[Error]) -> None:
last = ""
visitors: list[NameReferenceVisitor] = []

for stmt in stmts:
for visitor in visitors:
visitor.accept(stmt)
# No need to track referenced variables anymore
visitors = [visitor for visitor in visitors if not visitor.referenced]

match stmt:
case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue):
if last and check_call(rvalue, name=last):
if f"{last}'" == name:
errors.append(
ErrorInfo.from_node(
stmt,
"Assignment statement should be chained",
)
)
else:
# We need to ensure that the variable is not referenced somewhere else
name_expr = NameExpr(name=last)
name_expr.fullname = last
visitors.append(NameReferenceVisitor(name_expr, stmt))

last = name if name != "_" else ""
case ReturnStmt(expr=rvalue):
if last and rvalue is not None and check_call(rvalue, name=last):
errors.append(
ErrorInfo.from_node(
stmt,
"Return statement should be chained",
)
)
case _:
last = ""

# Ensure that variables are not referenced
errors.extend(
[
ErrorInfo.from_node(
visitor.stmt,
"Assignment statement should be chained",
)
for visitor in visitors
if not visitor.referenced and visitor.stmt is not None
]
)
163 changes: 163 additions & 0 deletions test/data/err_184.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
class torch:
@staticmethod
def ones(*args):
return torch

@staticmethod
def long():
return torch

@staticmethod
def to(device: str):
return torch.Tensor()

class Tensor:
pass


def transform(x):
return x


class spark:
class read:
@staticmethod
def parquet(file_name: str):
return spark.DataFrame()

class functions:
@staticmethod
def lit(constant):
return constant

@staticmethod
def col(col_name):
return col_name

class DataFrame:
@staticmethod
def withColumnRenamed(col_in, col_out):
return spark.DataFrame()

@staticmethod
def withColumn(col_in, col_out):
return spark.DataFrame()

@staticmethod
def select(*args):
return spark.DataFrame()

class F:
@staticmethod
def lit(value):
return value


# these will match
def get_tensors(device: str) -> torch.Tensor:
a = torch.ones(2, 1)
a = a.long()
a = a.to(device)
return a


def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = spark.read.parquet(file_name)
df = df \
.withColumnRenamed('col1', 'col1_renamed') \
.withColumnRenamed('col2', 'col2_renamed')
df = df \
.select(common_columns) \
.withColumn('service_type', spark.functions.lit('green'))
return df


def projection(df_in: spark.DataFrame) -> spark.DataFrame:
df = (
df_in.select(["col1", "col2"])
.withColumnRenamed("col1", "col1a")
)
return df.withColumn("col2a", spark.functions.col("col2").cast("date"))


def assign_multiple(df):
df = df.select("column")
result_df = df.select("another_column")
final_df = result_df.withColumn("column2", F.lit("abc"))
return final_df


# not yet supported
def assign_alternating(df, df2):
df = df.select("column")
df2 = df2.select("another_column")
df = df.withColumn("column2", F.lit("abc"))
return df, df2


# these will not
def ignored(x):
_ = x.op1()
_ = _.op2()
return _

def _(x):
y = x.m()
return y.operation(*[v for v in y])


def assign_multiple_referenced(df, df2):
df = df.select("column")
result_df = df.select("another_column")
return df, result_df


def invalid(df_in: spark.DataFrame, alternative_df: spark.DataFrame) -> spark.DataFrame:
df = (
df_in.select(["col1", "col2"])
.withColumnRenamed("col1", "col1a")
)
return alternative_df.withColumn("col2a", spark.functions.col("col2").cast("date"))


def no_match():
y = 10
y = transform(y)
return y

def f(x):
if x:
name = "alice"
stripped = name.strip()
print(stripped)
else:
name = "bob"
print(name)

def g(x):
try:
name = "alice"
stripped = name.strip()
print(stripped)
except ValueError:
name = "bob"
print(name)

def h(x):
for _ in (1, 2, 3):
name = "alice"
stripped = name.strip()
print(stripped)
else:
name = "bob"
print(name)

def assign_multiple_try(df):
try:
df = df.select("column")
result_df = df.select("another_column")
final_df = result_df.withColumn("column2", F.lit("abc"))
return final_df
except ValueError:
return None
7 changes: 7 additions & 0 deletions test/data/err_184.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
test/data/err_184.py:59:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:60:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:67:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:70:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:81:5 [FURB184]: Return statement should be chained
test/data/err_184.py:86:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:87:5 [FURB184]: Assignment statement should be chained

0 comments on commit 39cafb0

Please sign in to comment.