diff --git a/test_generator/fuzzers/Echidna.py b/test_generator/fuzzers/Echidna.py index fbea8a7..c8f305e 100644 --- a/test_generator/fuzzers/Echidna.py +++ b/test_generator/fuzzers/Echidna.py @@ -1,19 +1,21 @@ """ Generates a test file from Echidna reproducers """ # type: ignore[misc] # Ignores 'Any' input parameter -import sys -from typing import Any +from typing import Any, NoReturn import jinja2 from slither import Slither from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType from slither.core.declarations.structure import Structure from slither.core.declarations.structure_contract import StructureContract from slither.core.declarations.enum import Enum +from test_generator.utils.crytic_print import CryticPrint from test_generator.templates.foundry_templates import templates from test_generator.utils.encoding import parse_echidna_byte_string +from test_generator.utils.error_handler import handle_exit class Echidna: @@ -36,8 +38,7 @@ def get_target_contract(self) -> Contract: if contract.name == self.target_name: return contract - # TODO throw error if no contract found - sys.exit(-1) + handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.") def parse_reproducer(self, calls: Any, index: int) -> str: """ @@ -46,18 +47,16 @@ def parse_reproducer(self, calls: Any, index: int) -> str: call_list = [] end = len(calls) - 1 function_name = "" + # 1. For each object in the list process the call object and add it to the call list for idx, call in enumerate(calls): call_str, fn_name = self._parse_call_object(call) call_list.append(call_str) if idx == end: function_name = fn_name + "_" + str(index) + # 2. Generate the test string and return it template = jinja2.Template(templates["TEST"]) return template.render(function_name=function_name, call_list=call_list) - # 1. Take a reproducer list and create a test file based on the name of the last function of the list e.g. test_auto_$function_name - # 2. For each object in the list process the call object and add it to the call list - # 3. Using the call list to generate a test string - # 4. Return the test string def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]: """ @@ -81,12 +80,17 @@ def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]: caller = call_dict["src"] value = int(call_dict["value"], 16) - slither_entry_point = None + slither_entry_point: FunctionContract for entry_point in self.target.functions_entry_points: if entry_point.name == function_name: slither_entry_point = entry_point + if 'slither_entry_point' not in locals(): + handle_exit( + f"\n* Slither could not find the function `{function_name}` specified in the call object" + ) + # 2. Decode the function parameters variable_definition, call_definition = self._decode_function_params( function_parameters, False, slither_entry_point @@ -111,7 +115,7 @@ def _parse_call_object(self, call_dict: dict[Any, Any]) -> tuple[str, str]: return call_str, function_name # pylint: disable=R0201 - def _match_elementary_types(self, param: dict, recursive: bool) -> str: + def _match_elementary_types(self, param: dict, recursive: bool) -> str | NoReturn: """ Returns a string which represents a elementary type literal value. e.g. "5" or "uint256(5)" @@ -161,11 +165,13 @@ def _match_elementary_types(self, param: dict, recursive: bool) -> str: interpreted_string = f'string(hex"{hex_string}")' return interpreted_string case _: - return "" + handle_exit( + f"\n* The parameter tag `{param['tag']}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues" + ) def _match_array_type( self, param: dict, index: int, input_parameter: Any - ) -> tuple[str, str, int]: + ) -> tuple[str, str, int] | NoReturn: match param["tag"]: case "AbiArray": # Consider cases where the array items are more complex types (bytes, string, tuples) @@ -187,9 +193,13 @@ def _match_array_type( return name, definitions, index case _: - return "", "", index + handle_exit( + f"\n* The parameter tag `{param['tag']}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues" + ) - def _match_user_defined_type(self, param: dict, input_parameter: Any) -> tuple[str, str]: + def _match_user_defined_type( + self, param: dict, input_parameter: Any + ) -> tuple[str, str] | NoReturn: match param["tag"]: case "AbiTuple": match input_parameter.type: @@ -199,16 +209,22 @@ def _match_user_defined_type(self, param: dict, input_parameter: Any) -> tuple[s ) return definitions, f"{input_parameter}({','.join(func_params)})" case _: - return "", "" + handle_exit( + f"\n* The parameter type `{input_parameter.type}` could not be found. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues" + ) case "AbiUInt": if isinstance(input_parameter.type, Enum): enum_uint = self._match_elementary_types(param, False) return "", f"{input_parameter}({enum_uint})" # TODO is this even reachable? - return "", "" + handle_exit( + f"\n* The parameter type `{input_parameter.type}` does not match the intended type `Enum`. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues" + ) case _: - return "", "" + handle_exit( + f"\n* The parameter tag `{param['tag']}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues" + ) def _decode_function_params( self, function_params: list, recursive: bool, entry_point: Any @@ -245,7 +261,9 @@ def _decode_function_params( params.append(func_params) case _: # TODO should handle all cases, but keeping this just in case - print("UNHANDLED INPUT TYPE -> DEFAULT CASE") + CryticPrint().print_information( + f"\n* Attempted to decode an unidentified type {input_parameter}, this call will be skipped. Please open an issue at https://github.com/crytic/test-generator/issues" + ) continue # 3. Return a list of function parameters diff --git a/test_generator/fuzzers/Medusa.py b/test_generator/fuzzers/Medusa.py index b6740d2..b648ffb 100644 --- a/test_generator/fuzzers/Medusa.py +++ b/test_generator/fuzzers/Medusa.py @@ -1,6 +1,5 @@ """ Generates a test file from Medusa reproducers """ -import sys -from typing import Any +from typing import Any, NoReturn import jinja2 from slither import Slither @@ -13,8 +12,10 @@ from slither.core.declarations.structure_contract import StructureContract from slither.core.declarations.enum import Enum from slither.core.declarations.enum_contract import EnumContract +from test_generator.utils.crytic_print import CryticPrint from test_generator.templates.foundry_templates import templates from test_generator.utils.encoding import parse_medusa_byte_string +from test_generator.utils.error_handler import handle_exit class Medusa: @@ -38,8 +39,7 @@ def get_target_contract(self) -> Contract: if contract.name == self.target_name: return contract - # TODO throw error if no contract found - sys.exit(-1) + handle_exit(f"\n* Slither could not find the specified contract `{self.target_name}`.") def parse_reproducer(self, calls: Any, index: int) -> str: """ @@ -86,6 +86,11 @@ def _parse_call_object(self, call_dict: dict) -> tuple[str, str]: if entry_point.name == function_name: slither_entry_point = entry_point + if 'slither_entry_point' not in locals(): + handle_exit( + f"\n* Slither could not find the function `{function_name}` specified in the call object" + ) + # 2. Decode the function parameters variable_definition, call_definition = self._decode_function_params( function_parameters, False, slither_entry_point @@ -138,8 +143,12 @@ def _match_elementary_types(self, param: str, recursive: bool, input_parameter: hex_string = parse_medusa_byte_string(param) interpreted_string = f'string(hex"{hex_string}")' return interpreted_string + if "address" in input_type: + return param - return param + handle_exit( + f"\n* The parameter type `{input_type}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues" + ) def _match_array_type( self, param: dict, index: int, input_parameter: Any @@ -175,11 +184,13 @@ def _match_user_defined_type( case Enum() | EnumContract(): # type: ignore[misc] return "", f"{input_parameter}({param})" # type: ignore[unreachable] case _: - return "", "" + handle_exit( + f"\n* The parameter type `{input_parameter.type}` could not be found in the call object. This could indicate an issue in decoding the call sequence, or a missing feature. Please open an issue at https://github.com/crytic/test-generator/issues" + ) def _decode_function_params( self, function_params: list | dict, recursive: bool, entry_point: Any - ) -> tuple[str, list]: + ) -> tuple[str, list] | NoReturn: params = [] variable_definitions = "" index = 0 @@ -212,7 +223,9 @@ def _decode_function_params( params.append(func_params) case _: # TODO should handle all cases, but keeping this just in case - print("UNHANDLED INPUT TYPE -> DEFAULT CASE") + CryticPrint().print_information( + f"\n* Attempted to decode an unidentified type {input_parameter}, this call will be skipped. Please open an issue at https://github.com/crytic/test-generator/issues" + ) continue else: for param_idx, param in enumerate(function_params): @@ -245,7 +258,9 @@ def _decode_function_params( params.append(func_params) case _: # TODO should handle all cases, but keeping this just in case - print("UNHANDLED INPUT TYPE -> DEFAULT CASE") + CryticPrint().print_information( + f"\n* Attempted to decode an unidentified type {input_parameter}, this call will be skipped. Please open an issue at https://github.com/crytic/test-generator/issues" + ) continue # 3. Return a list of function parameters diff --git a/test_generator/main.py b/test_generator/main.py index 2a8a4d8..a710ec2 100644 --- a/test_generator/main.py +++ b/test_generator/main.py @@ -13,6 +13,7 @@ from test_generator.templates.foundry_templates import templates from test_generator.fuzzers.Medusa import Medusa from test_generator.fuzzers.Echidna import Echidna +from test_generator.utils.error_handler import handle_exit class FoundryTest: @@ -94,7 +95,7 @@ def main() -> None: # type: ignore[func-returns-value] ) parser.add_argument("file_path", help="Path to the Echidna test harness.") parser.add_argument( - "-cd", "--corpus-dir", dest="corpus_dir", help="Path to the corpus directory" + "-cd", "--corpus-dir", dest="corpus_dir", help="Path to the corpus directory", required=True ) parser.add_argument("-c", "--contract", dest="target_contract", help="Define the contract name") parser.add_argument( @@ -123,6 +124,12 @@ def main() -> None: # type: ignore[func-returns-value] ) args = parser.parse_args() + + missing_args = [arg for arg, value in vars(args).items() if value is None] + if missing_args: + parser.print_help() + handle_exit(f"\n* Missing required arguments: {', '.join(missing_args)}") + file_path = args.file_path corpus_dir = args.corpus_dir test_directory = args.test_directory @@ -137,8 +144,9 @@ def main() -> None: # type: ignore[func-returns-value] case "medusa": fuzzer = Medusa(target_contract, corpus_dir, slither) case _: - # TODO create a descriptive error - sys.exit(-1) + handle_exit( + f"\n* The requested fuzzer {args.selected_fuzzer} is not supported. Supported fuzzers: echidna, medusa." + ) CryticPrint().print_information( f"Generating Foundry unit tests based on the {fuzzer.name} reproducers..." diff --git a/test_generator/utils/error_handler.py b/test_generator/utils/error_handler.py new file mode 100644 index 0000000..0fbea1c --- /dev/null +++ b/test_generator/utils/error_handler.py @@ -0,0 +1,10 @@ +""" Utility function for error handling""" +import sys +from typing import NoReturn +from test_generator.utils.crytic_print import CryticPrint + + +def handle_exit(reason: str) -> NoReturn: + """Print an error message to the console and exit""" + CryticPrint().print_error(reason) + sys.exit()