Skip to content

Commit

Permalink
resolve formatting and typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
tuturu-tech committed Jan 11, 2024
1 parent 5ab74c9 commit 35536be
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 41 deletions.
7 changes: 3 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SHELL := /bin/bash

PY_MODULE := test-generator
PY_MODULE := test_generator
TEST_MODULE := tests

# Optionally overriden by the user, if they're using a virtual environment manager.
Expand Down Expand Up @@ -40,9 +40,8 @@ run: $(VENV)/pyvenv.cfg
lint: $(VENV)/pyvenv.cfg
. $(VENV_BIN)/activate && \
black --check . && \
pylint $(PY_MODULE) $(TEST_MODULE)
# ruff $(ALL_PY_SRCS) && \
# mypy $(PY_MODULE) &&
pylint $(PY_MODULE) $(TEST_MODULE) && \
mypy $(PY_MODULE)

.PHONY: reformat
reformat:
Expand Down
24 changes: 17 additions & 7 deletions test_generator/fuzzers/Echidna.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_target_contract(self) -> Contract:
# TODO throw error if no contract found
exit(-1)

def parse_reproducer(self, calls: list, index: int) -> str:
def parse_reproducer(self, calls: Any, index: int) -> str:
"""
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.
"""
Expand All @@ -57,7 +57,7 @@ def parse_reproducer(self, calls: list, index: int) -> str:
# 3. Using the call list to generate a test string
# 4. Return the test string

def _parse_call_object(self, call_dict) -> (str, str):
def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]:
"""
Takes a single call dictionary, parses it, and returns the series of function calls as a string, along with
the name of the last function, which is used as the name of the test.
Expand Down Expand Up @@ -158,8 +158,12 @@ def _match_elementary_types(self, param: dict, recursive: bool) -> str:
hex_string = parse_echidna_byte_string(param["contents"].strip('"'))
interpreted_string = f'string(hex"{hex_string}")'
return interpreted_string
case _:
return ""

def _match_array_type(self, param: dict, index: int, input_parameter) -> tuple[str, str, int]:
def _match_array_type(
self, param: dict, index: int, input_parameter: Any
) -> tuple[str, str, int]:
match param["tag"]:
case "AbiArray":
# Consider cases where the array items are more complex types (bytes, string, tuples)
Expand All @@ -180,8 +184,10 @@ def _match_array_type(self, param: dict, index: int, input_parameter) -> tuple[s
index += 1

return name, definitions, index
case _:
return "", "", index

def _match_user_defined_type(self, param: dict, input_parameter) -> tuple[str, str]:
def _match_user_defined_type(self, param: dict, input_parameter: Any) -> tuple[str, str]:
match param["tag"]:
case "AbiTuple":
match input_parameter.type:
Expand All @@ -190,17 +196,21 @@ def _match_user_defined_type(self, param: dict, input_parameter) -> tuple[str, s
param["contents"], True, input_parameter.type.elems_ordered
)
return definitions, f"{input_parameter}({','.join(func_params)})"
case _:
return "", ""
case "AbiUInt":
if isinstance(input_parameter.type, Enum):
enum_uint = self._match_elementary_types(param, False)
return "", f"{input_parameter}({enum_uint})"
else:
# TODO is this even reachable?
return "", ""
case _:
return "", ""

def _decode_function_params(
self, function_params: list, recursive: bool, entry_point: Any
) -> (str | None, list):
) -> tuple[str, list]:
params = []
variable_definitions = ""
index = 0
Expand Down Expand Up @@ -244,7 +254,7 @@ def _decode_function_params(
else:
return "", params

def _get_memarr(self, function_params: dict, index: int) -> (str | None, str | None):
def _get_memarr(self, function_params: dict, index: int) -> tuple[str, str]:
length = len(function_params[1])
match function_params[0]["tag"]:
case "AbiBoolType":
Expand Down Expand Up @@ -272,4 +282,4 @@ def _get_memarr(self, function_params: dict, index: int) -> (str | None, str | N
name = f"dynStringArr_{index}"
return name, f"string[] memory {name} = new string[]({length});\n"
case _:
return None, None
return "", ""
22 changes: 14 additions & 8 deletions test_generator/fuzzers/Medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_target_contract(self) -> Contract:
# TODO throw error if no contract found
exit(-1)

def parse_reproducer(self, calls: list, index: int) -> str:
def parse_reproducer(self, calls: Any, index: int) -> str:
"""
Takes a list of call dicts and returns a Foundry unit test string containing the call sequence.
"""
Expand All @@ -59,7 +59,7 @@ def parse_reproducer(self, calls: list, index: int) -> str:
# 3. Using the call list to generate a test string
# 4. Return the test string

def _parse_call_object(self, call_dict) -> (str, str):
def _parse_call_object(self, call_dict: dict) -> tuple[str, str]:
"""
Takes a single call dictionary, parses it, and returns the series of function calls as a string, along with
the name of the last function, which is used as the name of the test.
Expand Down Expand Up @@ -112,7 +112,7 @@ def _parse_call_object(self, call_dict) -> (str, str):

return call_str, function_name

def _match_elementary_types(self, param: str, recursive: bool, input_parameter) -> str:
def _match_elementary_types(self, param: str, recursive: bool, input_parameter: Any) -> str:
"""
Returns a string which represents a elementary type literal value. e.g. "5" or "uint256(5)"
Expand Down Expand Up @@ -143,7 +143,9 @@ def _match_elementary_types(self, param: str, recursive: bool, input_parameter)
else:
return param

def _match_array_type(self, param: dict, index: int, input_parameter) -> tuple[str, str, int]:
def _match_array_type(
self, param: dict, index: int, input_parameter: Any
) -> tuple[str, str, int]:
# TODO check if fixed arrays are considered dynamic or not
dynamic = input_parameter.is_dynamic_array
if not dynamic:
Expand All @@ -163,7 +165,9 @@ def _match_array_type(self, param: dict, index: int, input_parameter) -> tuple[s

return name, definitions, index

def _match_user_defined_type(self, param: dict | str, input_parameter) -> tuple[str, str]:
def _match_user_defined_type(
self, param: list[Any] | dict[Any, Any], input_parameter: Any
) -> tuple[str, str]:
match input_parameter.type:
case Structure() | StructureContract():
definitions, func_params = self._decode_function_params(
Expand All @@ -172,10 +176,12 @@ def _match_user_defined_type(self, param: dict | str, input_parameter) -> tuple[
return definitions, f"{input_parameter}({','.join(func_params)})"
case Enum() | EnumContract():
return "", f"{input_parameter}({param})"
case _:
return "", ""

def _decode_function_params(
self, function_params: list | dict, recursive: bool, entry_point: Any
) -> (str | None, list):
) -> tuple[str, list]:
params = []
variable_definitions = ""
index = 0
Expand Down Expand Up @@ -253,8 +259,8 @@ def _decode_function_params(
return "", params

def _get_memarr(
self, function_params: list, index: int, input_parameter
) -> (str | None, str | None):
self, function_params: dict | list, index: int, input_parameter: Any
) -> tuple[str, str]:
length = len(function_params)

input_type = input_parameter.type
Expand Down
12 changes: 8 additions & 4 deletions test_generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.fuzzer = fuzzer

def get_target_contract(self) -> Contract:
""" Gets the Slither Contract object for the specified contract file"""
"""Gets the Slither Contract object for the specified contract file"""
contracts = self.slither.get_contract_from_name(self.target_name)
# Loop in case slither fetches multiple contracts for some reason (e.g., similar names?)
for contract in contracts:
Expand All @@ -58,7 +58,7 @@ def create_poc(self) -> str:
full_path = os.path.join(self.fuzzer.reproducer_dir, entry)

if os.path.isfile(full_path):
with open(full_path, "r", encoding="utf8") as file:
with open(full_path, "r", encoding="utf-8") as file:
file_list.append(json.load(file))

# 2. Parse each reproducer file and add each test function to the functions list
Expand All @@ -84,9 +84,11 @@ def create_poc(self) -> str:
f"Generated a test file in {write_path}_{self.fuzzer.name}_Test.t.sol"
)

return test_file_str


def main() -> None:
""" The main entry point """
"""The main entry point"""
parser = argparse.ArgumentParser(
prog="test-generator", description="Generate test harnesses for Echidna failed properties."
)
Expand Down Expand Up @@ -127,7 +129,7 @@ def main() -> None:
inheritance_path = args.inheritance_path
target_contract = args.target_contract
slither = Slither(file_path)
fuzzer = None
fuzzer: Echidna | Medusa

match args.selected_fuzzer.lower():
case "echidna":
Expand All @@ -147,6 +149,8 @@ def main() -> None:
test_generator.create_poc()
CryticPrint().print_success("Done!")

return None


if __name__ == "__main__":
sys.exit(main())
4 changes: 2 additions & 2 deletions test_generator/utils/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parse_echidna_byte_string(s: str) -> str:
s = s.replace(key, value)

# Handle octal escapes (like \\135)
def octal_to_byte(match):
def octal_to_byte(match: re.Match) -> str:
octal_value = match.group(0)[1:] # Remove the backslash

return chr(int(octal_value, 8))
Expand All @@ -61,5 +61,5 @@ def octal_to_byte(match):


def parse_medusa_byte_string(s: str) -> str:
""" Decode bytes* or string type from Medusa format to Solidity hex literal"""
"""Decode bytes* or string type from Medusa format to Solidity hex literal"""
return s.encode("utf-8").hex()
17 changes: 9 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


class TestGenerator:
""" Helper class for testing all fuzzers with the tool"""
"""Helper class for testing all fuzzers with the tool"""

def __init__(self, target, target_path, corpus_dir):
slither = Slither(target_path)
echidna = Echidna(target, f"echidna-corpora/{corpus_dir}", slither)
Expand All @@ -21,17 +22,17 @@ def __init__(self, target, target_path, corpus_dir):
)

def echidna_generate_tests(self):
""" Runs the test-generator tool for an Echidna corpus"""
"""Runs the test-generator tool for an Echidna corpus"""
self.echidna_generator.create_poc()

def medusa_generate_tests(self):
""" Runs the test-generator tool for a Medusa corpus"""
"""Runs the test-generator tool for a Medusa corpus"""
self.medusa_generator.create_poc()


@pytest.fixture(autouse=True)
def change_test_dir(request, monkeypatch):
""" Helper fixture to change the working directory"""
"""Helper fixture to change the working directory"""
# Directory of the test file
test_dir = request.fspath.dirname

Expand All @@ -44,7 +45,7 @@ def change_test_dir(request, monkeypatch):

@pytest.fixture
def basic_types():
""" Fixture for the BasicTypes test contract"""
"""Fixture for the BasicTypes test contract"""
target = "BasicTypes"
target_path = "./src/BasicTypes.sol"
corpus_dir = "corpus-basic"
Expand All @@ -54,7 +55,7 @@ def basic_types():

@pytest.fixture
def fixed_size_arrays():
""" Fixture for the FixedArrays test contract"""
"""Fixture for the FixedArrays test contract"""
target = "FixedArrays"
target_path = "./src/FixedArrays.sol"
corpus_dir = "corpus-fixed-arr"
Expand All @@ -64,7 +65,7 @@ def fixed_size_arrays():

@pytest.fixture
def dynamic_arrays():
""" Fixture for the DynamicArrays test contract"""
"""Fixture for the DynamicArrays test contract"""
target = "DynamicArrays"
target_path = "./src/DynamicArrays.sol"
corpus_dir = "corpus-dyn-arr"
Expand All @@ -74,7 +75,7 @@ def dynamic_arrays():

@pytest.fixture
def structs_and_enums():
""" Fixture for the TupleTypes test contract"""
"""Fixture for the TupleTypes test contract"""
target = "TupleTypes"
target_path = "./src/TupleTypes.sol"
corpus_dir = "corpus-struct"
Expand Down
8 changes: 4 additions & 4 deletions tests/test_types_echidna.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_echidna_basic_types(basic_types):
""" Tests the BasicTypes contract with an Echidna corpus"""
"""Tests the BasicTypes contract with an Echidna corpus"""
basic_types.echidna_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "BasicTypes_Echidna_Test.t.sol")
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_echidna_basic_types(basic_types):


def test_echidna_fixed_array_types(fixed_size_arrays):
""" Tests the FixedArrays contract with an Echidna corpus"""
"""Tests the FixedArrays contract with an Echidna corpus"""
fixed_size_arrays.echidna_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "FixedArrays_Echidna_Test.t.sol")
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_echidna_fixed_array_types(fixed_size_arrays):


def test_echidna_dynamic_array_types(dynamic_arrays):
""" Tests the DynamicArrays contract with an Echidna corpus"""
"""Tests the DynamicArrays contract with an Echidna corpus"""
dynamic_arrays.echidna_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "DynamicArrays_Echidna_Test.t.sol")
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_echidna_dynamic_array_types(dynamic_arrays):


def test_echidna_structs_and_enums(structs_and_enums):
""" Tests the TupleTypes contract with an Echidna corpus"""
"""Tests the TupleTypes contract with an Echidna corpus"""
structs_and_enums.echidna_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "TupleTypes_Echidna_Test.t.sol")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_types_medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_medusa_basic_types(basic_types):
""" Tests the BasicTypes contract with a Medusa corpus"""
"""Tests the BasicTypes contract with a Medusa corpus"""
basic_types.medusa_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "BasicTypes_Medusa_Test.t.sol")
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_medusa_basic_types(basic_types):


def test_medusa_fixed_array_types(fixed_size_arrays):
""" Tests the FixedArrays contract with a Medusa corpus"""
"""Tests the FixedArrays contract with a Medusa corpus"""
fixed_size_arrays.medusa_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "FixedArrays_Medusa_Test.t.sol")
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_medusa_fixed_array_types(fixed_size_arrays):


def test_medusa_dynamic_array_types(dynamic_arrays):
""" Tests the DynamicArrays contract with a Medusa corpus"""
"""Tests the DynamicArrays contract with a Medusa corpus"""
dynamic_arrays.medusa_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "DynamicArrays_Medusa_Test.t.sol")
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_medusa_dynamic_array_types(dynamic_arrays):


def test_medusa_structs_and_enums(structs_and_enums):
""" Tests the TupleTypes contract with a Medusa corpus"""
"""Tests the TupleTypes contract with a Medusa corpus"""
structs_and_enums.medusa_generate_tests()
# Ensure the file was created
path = os.path.join(os.getcwd(), "test", "TupleTypes_Medusa_Test.t.sol")
Expand Down

0 comments on commit 35536be

Please sign in to comment.