-
-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial implementation for fluid interface check (#287)
* initial implementation for fluid interface check * support return statements and non-matching names * performance optimization * test results * exclude other references * fix tests and exclude `_` * fix linting * fix mypy * add suggested test case * restrict to top-level function definitions * Remove `was_referenced` property, run `make docs` --------- Co-authored-by: dosisod <[email protected]>
- Loading branch information
Showing
4 changed files
with
385 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
from dataclasses import dataclass | ||
|
||
from mypy.nodes import ( | ||
AssignmentStmt, | ||
CallExpr, | ||
Expression, | ||
FuncDef, | ||
MemberExpr, | ||
NameExpr, | ||
ReturnStmt, | ||
Statement, | ||
) | ||
|
||
from refurb.checks.common import ReadCountVisitor, check_block_like | ||
from refurb.error import Error | ||
from refurb.visitor import TraverserVisitor | ||
|
||
|
||
@dataclass | ||
class ErrorInfo(Error): | ||
r""" | ||
When an API has a Fluent Interface (the ability to chain multiple calls together), you should | ||
chain those calls instead of repeatedly assigning and using the value. | ||
Sometimes a return statement can be written more succinctly: | ||
Bad: | ||
```python | ||
def get_tensors(device: str) -> torch.Tensor: | ||
t1 = torch.ones(2, 1) | ||
t2 = t1.long() | ||
t3 = t2.to(device) | ||
return t3 | ||
def process(file_name: str): | ||
common_columns = ["col1_renamed", "col2_renamed", "custom_col"] | ||
df = spark.read.parquet(file_name) | ||
df = df \ | ||
.withColumnRenamed('col1', 'col1_renamed') \ | ||
.withColumnRenamed('col2', 'col2_renamed') | ||
df = df \ | ||
.select(common_columns) \ | ||
.withColumn('service_type', F.lit('green')) | ||
return df | ||
``` | ||
Good: | ||
```python | ||
def get_tensors(device: str) -> torch.Tensor: | ||
t3 = ( | ||
torch.ones(2, 1) | ||
.long() | ||
.to(device) | ||
) | ||
return t3 | ||
def process(file_name: str): | ||
common_columns = ["col1_renamed", "col2_renamed", "custom_col"] | ||
df = ( | ||
spark.read.parquet(file_name) | ||
.withColumnRenamed('col1', 'col1_renamed') | ||
.withColumnRenamed('col2', 'col2_renamed') | ||
.select(common_columns) | ||
.withColumn('service_type', F.lit('green')) | ||
) | ||
return df | ||
``` | ||
""" | ||
|
||
name = "use-fluid-interface" | ||
code = 184 | ||
categories = ("readability",) | ||
|
||
|
||
def check(node: FuncDef, errors: list[Error]) -> None: | ||
check_block_like(check_stmts, node.body, errors) | ||
|
||
|
||
def check_call(node: Expression, name: str | None = None) -> bool: | ||
match node: | ||
# Single chain | ||
case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)): | ||
if name is None or name == x: | ||
# Exclude other references | ||
x_expr = NameExpr(x) | ||
x_expr.fullname = x | ||
visitor = ReadCountVisitor(x_expr) | ||
visitor.accept(node) | ||
return visitor.read_count == 1 | ||
return False | ||
|
||
# Nested | ||
case CallExpr(callee=MemberExpr(expr=call_node, name=_)): | ||
return check_call(call_node, name=name) | ||
|
||
return False | ||
|
||
|
||
class NameReferenceVisitor(TraverserVisitor): | ||
name: NameExpr | ||
referenced: bool | ||
|
||
def __init__(self, name: NameExpr, stmt: Statement | None = None) -> None: | ||
super().__init__() | ||
self.name = name | ||
self.stmt = stmt | ||
self.referenced = False | ||
|
||
def visit_name_expr(self, node: NameExpr) -> None: | ||
if not self.referenced and node.fullname == self.name.fullname: | ||
self.referenced = True | ||
|
||
|
||
def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: | ||
last = "" | ||
visitors: list[NameReferenceVisitor] = [] | ||
|
||
for stmt in stmts: | ||
for visitor in visitors: | ||
visitor.accept(stmt) | ||
# No need to track referenced variables anymore | ||
visitors = [visitor for visitor in visitors if not visitor.referenced] | ||
|
||
match stmt: | ||
case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): | ||
if last and check_call(rvalue, name=last): | ||
if f"{last}'" == name: | ||
errors.append( | ||
ErrorInfo.from_node( | ||
stmt, | ||
"Assignment statement should be chained", | ||
) | ||
) | ||
else: | ||
# We need to ensure that the variable is not referenced somewhere else | ||
name_expr = NameExpr(name=last) | ||
name_expr.fullname = last | ||
visitors.append(NameReferenceVisitor(name_expr, stmt)) | ||
|
||
last = name if name != "_" else "" | ||
case ReturnStmt(expr=rvalue): | ||
if last and rvalue is not None and check_call(rvalue, name=last): | ||
errors.append( | ||
ErrorInfo.from_node( | ||
stmt, | ||
"Return statement should be chained", | ||
) | ||
) | ||
case _: | ||
last = "" | ||
|
||
# Ensure that variables are not referenced | ||
errors.extend( | ||
[ | ||
ErrorInfo.from_node( | ||
visitor.stmt, | ||
"Assignment statement should be chained", | ||
) | ||
for visitor in visitors | ||
if not visitor.referenced and visitor.stmt is not None | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
class torch: | ||
@staticmethod | ||
def ones(*args): | ||
return torch | ||
|
||
@staticmethod | ||
def long(): | ||
return torch | ||
|
||
@staticmethod | ||
def to(device: str): | ||
return torch.Tensor() | ||
|
||
class Tensor: | ||
pass | ||
|
||
|
||
def transform(x): | ||
return x | ||
|
||
|
||
class spark: | ||
class read: | ||
@staticmethod | ||
def parquet(file_name: str): | ||
return spark.DataFrame() | ||
|
||
class functions: | ||
@staticmethod | ||
def lit(constant): | ||
return constant | ||
|
||
@staticmethod | ||
def col(col_name): | ||
return col_name | ||
|
||
class DataFrame: | ||
@staticmethod | ||
def withColumnRenamed(col_in, col_out): | ||
return spark.DataFrame() | ||
|
||
@staticmethod | ||
def withColumn(col_in, col_out): | ||
return spark.DataFrame() | ||
|
||
@staticmethod | ||
def select(*args): | ||
return spark.DataFrame() | ||
|
||
class F: | ||
@staticmethod | ||
def lit(value): | ||
return value | ||
|
||
|
||
# these will match | ||
def get_tensors(device: str) -> torch.Tensor: | ||
a = torch.ones(2, 1) | ||
a = a.long() | ||
a = a.to(device) | ||
return a | ||
|
||
|
||
def process(file_name: str): | ||
common_columns = ["col1_renamed", "col2_renamed", "custom_col"] | ||
df = spark.read.parquet(file_name) | ||
df = df \ | ||
.withColumnRenamed('col1', 'col1_renamed') \ | ||
.withColumnRenamed('col2', 'col2_renamed') | ||
df = df \ | ||
.select(common_columns) \ | ||
.withColumn('service_type', spark.functions.lit('green')) | ||
return df | ||
|
||
|
||
def projection(df_in: spark.DataFrame) -> spark.DataFrame: | ||
df = ( | ||
df_in.select(["col1", "col2"]) | ||
.withColumnRenamed("col1", "col1a") | ||
) | ||
return df.withColumn("col2a", spark.functions.col("col2").cast("date")) | ||
|
||
|
||
def assign_multiple(df): | ||
df = df.select("column") | ||
result_df = df.select("another_column") | ||
final_df = result_df.withColumn("column2", F.lit("abc")) | ||
return final_df | ||
|
||
|
||
# not yet supported | ||
def assign_alternating(df, df2): | ||
df = df.select("column") | ||
df2 = df2.select("another_column") | ||
df = df.withColumn("column2", F.lit("abc")) | ||
return df, df2 | ||
|
||
|
||
# these will not | ||
def ignored(x): | ||
_ = x.op1() | ||
_ = _.op2() | ||
return _ | ||
|
||
def _(x): | ||
y = x.m() | ||
return y.operation(*[v for v in y]) | ||
|
||
|
||
def assign_multiple_referenced(df, df2): | ||
df = df.select("column") | ||
result_df = df.select("another_column") | ||
return df, result_df | ||
|
||
|
||
def invalid(df_in: spark.DataFrame, alternative_df: spark.DataFrame) -> spark.DataFrame: | ||
df = ( | ||
df_in.select(["col1", "col2"]) | ||
.withColumnRenamed("col1", "col1a") | ||
) | ||
return alternative_df.withColumn("col2a", spark.functions.col("col2").cast("date")) | ||
|
||
|
||
def no_match(): | ||
y = 10 | ||
y = transform(y) | ||
return y | ||
|
||
def f(x): | ||
if x: | ||
name = "alice" | ||
stripped = name.strip() | ||
print(stripped) | ||
else: | ||
name = "bob" | ||
print(name) | ||
|
||
def g(x): | ||
try: | ||
name = "alice" | ||
stripped = name.strip() | ||
print(stripped) | ||
except ValueError: | ||
name = "bob" | ||
print(name) | ||
|
||
def h(x): | ||
for _ in (1, 2, 3): | ||
name = "alice" | ||
stripped = name.strip() | ||
print(stripped) | ||
else: | ||
name = "bob" | ||
print(name) | ||
|
||
def assign_multiple_try(df): | ||
try: | ||
df = df.select("column") | ||
result_df = df.select("another_column") | ||
final_df = result_df.withColumn("column2", F.lit("abc")) | ||
return final_df | ||
except ValueError: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
test/data/err_184.py:59:5 [FURB184]: Assignment statement should be chained | ||
test/data/err_184.py:60:5 [FURB184]: Assignment statement should be chained | ||
test/data/err_184.py:67:5 [FURB184]: Assignment statement should be chained | ||
test/data/err_184.py:70:5 [FURB184]: Assignment statement should be chained | ||
test/data/err_184.py:81:5 [FURB184]: Return statement should be chained | ||
test/data/err_184.py:86:5 [FURB184]: Assignment statement should be chained | ||
test/data/err_184.py:87:5 [FURB184]: Assignment statement should be chained |