diff --git a/.github/scripts/set_release_version.py b/.github/scripts/set_release_version.py index 980a852..92c95dd 100644 --- a/.github/scripts/set_release_version.py +++ b/.github/scripts/set_release_version.py @@ -3,5 +3,5 @@ import datetime now_utc = datetime.datetime.now(datetime.timezone.utc) -version = now_utc.strftime('%Y.%m%d.%H%M%S') -print('::set-output name=version::v{}'.format(version)) +version = now_utc.strftime("%Y.%m%d.%H%M%S") +print("::set-output name=version::v{}".format(version)) diff --git a/compiler/back_end/__init__.py b/compiler/back_end/__init__.py index 2c31d84..086a24e 100644 --- a/compiler/back_end/__init__.py +++ b/compiler/back_end/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/compiler/back_end/cpp/__init__.py b/compiler/back_end/cpp/__init__.py index 2c31d84..086a24e 100644 --- a/compiler/back_end/cpp/__init__.py +++ b/compiler/back_end/cpp/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/compiler/back_end/cpp/attributes.py b/compiler/back_end/cpp/attributes.py index 256451b..a722143 100644 --- a/compiler/back_end/cpp/attributes.py +++ b/compiler/back_end/cpp/attributes.py @@ -20,6 +20,7 @@ class Attribute(str, Enum): """Attributes available in the C++ backend.""" + NAMESPACE = "namespace" ENUM_CASE = "enum_case" @@ -37,6 +38,7 @@ class Scope(set, Enum): Each entry is a set of (Attribute, default?) tuples, the first value being the attribute itself, the second value being a boolean value indicating whether the attribute is allowed to be defaulted in that scope.""" + BITS = { # Bits may contain an enum definition. (Attribute.ENUM_CASE, True) @@ -47,10 +49,12 @@ class Scope(set, Enum): ENUM_VALUE = { (Attribute.ENUM_CASE, False), } - MODULE = { - (Attribute.NAMESPACE, False), - (Attribute.ENUM_CASE, True), - }, + MODULE = ( + { + (Attribute.NAMESPACE, False), + (Attribute.ENUM_CASE, True), + }, + ) STRUCT = { # Struct may contain an enum definition. (Attribute.ENUM_CASE, True), diff --git a/compiler/back_end/cpp/emboss_codegen_cpp.py b/compiler/back_end/cpp/emboss_codegen_cpp.py index 6da9f7d..474bcc1 100644 --- a/compiler/back_end/cpp/emboss_codegen_cpp.py +++ b/compiler/back_end/cpp/emboss_codegen_cpp.py @@ -31,72 +31,79 @@ def _parse_command_line(argv): - """Parses the given command-line arguments.""" - parser = argparse.ArgumentParser(description="Emboss compiler C++ back end.", - prog=argv[0]) - parser.add_argument("--input-file", - type=str, - help=".emb.ir file to compile.") - parser.add_argument("--output-file", - type=str, - help="Write header to file. If not specified, write " + - "header to stdout.") - parser.add_argument("--color-output", - default="if_tty", - choices=["always", "never", "if_tty", "auto"], - help="Print error messages using color. 'auto' is a " - "synonym for 'if_tty'.") - parser.add_argument("--cc-enum-traits", - action=argparse.BooleanOptionalAction, - default=True, - help="""Controls generation of EnumTraits by the C++ - backend""") - return parser.parse_args(argv[1:]) + """Parses the given command-line arguments.""" + parser = argparse.ArgumentParser( + description="Emboss compiler C++ back end.", prog=argv[0] + ) + parser.add_argument("--input-file", type=str, help=".emb.ir file to compile.") + parser.add_argument( + "--output-file", + type=str, + help="Write header to file. If not specified, write " + "header to stdout.", + ) + parser.add_argument( + "--color-output", + default="if_tty", + choices=["always", "never", "if_tty", "auto"], + help="Print error messages using color. 'auto' is a " "synonym for 'if_tty'.", + ) + parser.add_argument( + "--cc-enum-traits", + action=argparse.BooleanOptionalAction, + default=True, + help="""Controls generation of EnumTraits by the C++ + backend""", + ) + return parser.parse_args(argv[1:]) def _show_errors(errors, ir, color_output): - """Prints errors with source code snippets.""" - source_codes = {} - for module in ir.module: - source_codes[module.source_file_name] = module.source_text - use_color = (color_output == "always" or - (color_output in ("auto", "if_tty") and - os.isatty(sys.stderr.fileno()))) - print(error.format_errors(errors, source_codes, use_color), file=sys.stderr) + """Prints errors with source code snippets.""" + source_codes = {} + for module in ir.module: + source_codes[module.source_file_name] = module.source_text + use_color = color_output == "always" or ( + color_output in ("auto", "if_tty") and os.isatty(sys.stderr.fileno()) + ) + print(error.format_errors(errors, source_codes, use_color), file=sys.stderr) + def generate_headers_and_log_errors(ir, color_output, config: header_generator.Config): - """Generates a C++ header and logs any errors. + """Generates a C++ header and logs any errors. + + Arguments: + ir: EmbossIr of the module. + color_output: "always", "never", "if_tty", "auto" + config: Header generation configuration. - Arguments: - ir: EmbossIr of the module. - color_output: "always", "never", "if_tty", "auto" - config: Header generation configuration. + Returns: + A tuple of (header, errors) + """ + header, errors = header_generator.generate_header(ir, config) + if errors: + _show_errors(errors, ir, color_output) + return (header, errors) - Returns: - A tuple of (header, errors) - """ - header, errors = header_generator.generate_header(ir, config) - if errors: - _show_errors(errors, ir, color_output) - return (header, errors) def main(flags): - if flags.input_file: - with open(flags.input_file) as f: - ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, f.read()) - else: - ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, sys.stdin.read()) - config = header_generator.Config(include_enum_traits=flags.cc_enum_traits) - header, errors = generate_headers_and_log_errors(ir, flags.color_output, config) - if errors: - return 1 - if flags.output_file: - with open(flags.output_file, "w") as f: - f.write(header) - else: - print(header) - return 0 - - -if __name__ == '__main__': - sys.exit(main(_parse_command_line(sys.argv))) + if flags.input_file: + with open(flags.input_file) as f: + ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, f.read()) + else: + ir = ir_data_utils.IrDataSerializer.from_json( + ir_data.EmbossIr, sys.stdin.read() + ) + config = header_generator.Config(include_enum_traits=flags.cc_enum_traits) + header, errors = generate_headers_and_log_errors(ir, flags.color_output, config) + if errors: + return 1 + if flags.output_file: + with open(flags.output_file, "w") as f: + f.write(header) + else: + print(header) + return 0 + + +if __name__ == "__main__": + sys.exit(main(_parse_command_line(sys.argv))) diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py index 11fcd17..320b15f 100644 --- a/compiler/back_end/cpp/header_generator.py +++ b/compiler/back_end/cpp/header_generator.py @@ -34,35 +34,125 @@ from compiler.util import resources from compiler.util import traverse_ir -_TEMPLATES = code_template.parse_templates(resources.load( - "compiler.back_end.cpp", "generated_code_templates")) - -_CPP_RESERVED_WORDS = set(( - # C keywords. A few of these are not (yet) C++ keywords, but some compilers - # accept the superset of C and C++, so we still want to avoid them. - "asm", "auto", "break", "case", "char", "const", "continue", "default", - "do", "double", "else", "enum", "extern", "float", "for", "fortran", "goto", - "if", "inline", "int", "long", "register", "restrict", "return", "short", - "signed", "sizeof", "static", "struct", "switch", "typedef", "union", - "unsigned", "void", "volatile", "while", "_Alignas", "_Alignof", "_Atomic", - "_Bool", "_Complex", "_Generic", "_Imaginary", "_Noreturn", "_Pragma", - "_Static_assert", "_Thread_local", - # The following are not technically reserved words, but collisions are - # likely due to the standard macros. - "complex", "imaginary", "noreturn", - # C++ keywords that are not also C keywords. - "alignas", "alignof", "and", "and_eq", "asm", "bitand", "bitor", "bool", - "catch", "char16_t", "char32_t", "class", "compl", "concept", "constexpr", - "const_cast", "decltype", "delete", "dynamic_cast", "explicit", "export", - "false", "friend", "mutable", "namespace", "new", "noexcept", "not", - "not_eq", "nullptr", "operator", "or", "or_eq", "private", "protected", - "public", "reinterpret_cast", "requires", "static_assert", "static_cast", - "template", "this", "thread_local", "throw", "true", "try", "typeid", - "typename", "using", "virtual", "wchar_t", "xor", "xor_eq", - # "NULL" is not a keyword, but is still very likely to cause problems if - # used as a namespace name. - "NULL", -)) +_TEMPLATES = code_template.parse_templates( + resources.load("compiler.back_end.cpp", "generated_code_templates") +) + +_CPP_RESERVED_WORDS = set( + ( + # C keywords. A few of these are not (yet) C++ keywords, but some compilers + # accept the superset of C and C++, so we still want to avoid them. + "asm", + "auto", + "break", + "case", + "char", + "const", + "continue", + "default", + "do", + "double", + "else", + "enum", + "extern", + "float", + "for", + "fortran", + "goto", + "if", + "inline", + "int", + "long", + "register", + "restrict", + "return", + "short", + "signed", + "sizeof", + "static", + "struct", + "switch", + "typedef", + "union", + "unsigned", + "void", + "volatile", + "while", + "_Alignas", + "_Alignof", + "_Atomic", + "_Bool", + "_Complex", + "_Generic", + "_Imaginary", + "_Noreturn", + "_Pragma", + "_Static_assert", + "_Thread_local", + # The following are not technically reserved words, but collisions are + # likely due to the standard macros. + "complex", + "imaginary", + "noreturn", + # C++ keywords that are not also C keywords. + "alignas", + "alignof", + "and", + "and_eq", + "asm", + "bitand", + "bitor", + "bool", + "catch", + "char16_t", + "char32_t", + "class", + "compl", + "concept", + "constexpr", + "const_cast", + "decltype", + "delete", + "dynamic_cast", + "explicit", + "export", + "false", + "friend", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "private", + "protected", + "public", + "reinterpret_cast", + "requires", + "static_assert", + "static_cast", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typeid", + "typename", + "using", + "virtual", + "wchar_t", + "xor", + "xor_eq", + # "NULL" is not a keyword, but is still very likely to cause problems if + # used as a namespace name. + "NULL", + ) +) # The support namespace, as a C++ namespace prefix. This namespace contains the # Emboss C++ support classes. @@ -71,7 +161,7 @@ # Regex matching a C++ namespace component. Captures component name. _NS_COMPONENT_RE = r"(?:^\s*|::)\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=\s*$|::)" # Regex matching a full C++ namespace (at least one namespace component). -_NS_RE = fr"^\s*(?:{_NS_COMPONENT_RE})+\s*$" +_NS_RE = rf"^\s*(?:{_NS_COMPONENT_RE})+\s*$" # Regex matching an empty C++ namespace. _NS_EMPTY_RE = r"^\s*$" # Regex matching only the global C++ namespace. @@ -87,1530 +177,1759 @@ # Verify that all supported enum cases have valid, implemented conversions. for _enum_case in _SUPPORTED_ENUM_CASES: - assert name_conversion.is_case_conversion_supported("SHOUTY_CASE", _enum_case) + assert name_conversion.is_case_conversion_supported("SHOUTY_CASE", _enum_case) class Config(NamedTuple): - """Configuration for C++ header generation.""" + """Configuration for C++ header generation.""" - include_enum_traits: bool = True - """Whether or not to include EnumTraits in the generated header.""" + include_enum_traits: bool = True + """Whether or not to include EnumTraits in the generated header.""" def _get_namespace_components(namespace): - """Gets the components of a C++ namespace + """Gets the components of a C++ namespace - Examples: - "::some::name::detail" -> ["some", "name", "detail"] - "product::name" -> ["product", "name"] - "simple" -> ["simple"] + Examples: + "::some::name::detail" -> ["some", "name", "detail"] + "product::name" -> ["product", "name"] + "simple" -> ["simple"] - Arguments: - namespace: A string containing the namespace. May be fully-qualified. + Arguments: + namespace: A string containing the namespace. May be fully-qualified. - Returns: - A list of strings, one per namespace component.""" - return re.findall(_NS_COMPONENT_RE, namespace) + Returns: + A list of strings, one per namespace component.""" + return re.findall(_NS_COMPONENT_RE, namespace) def _get_module_namespace(module): - """Returns the C++ namespace of the module, as a list of components. - - Arguments: - module: The IR of an Emboss module whose namespace should be returned. - - Returns: - A list of strings, one per namespace component. This list can be formatted - as appropriate by the caller. - """ - namespace_attr = ir_util.get_attribute(module.attribute, "namespace") - if namespace_attr and namespace_attr.string_constant.text: - namespace = namespace_attr.string_constant.text - else: - namespace = "emboss_generated_code" - return _get_namespace_components(namespace) + """Returns the C++ namespace of the module, as a list of components. + + Arguments: + module: The IR of an Emboss module whose namespace should be returned. + + Returns: + A list of strings, one per namespace component. This list can be formatted + as appropriate by the caller. + """ + namespace_attr = ir_util.get_attribute(module.attribute, "namespace") + if namespace_attr and namespace_attr.string_constant.text: + namespace = namespace_attr.string_constant.text + else: + namespace = "emboss_generated_code" + return _get_namespace_components(namespace) def _cpp_string_escape(string): - return re.sub("['\"\\\\]", r"\\\0", string) + return re.sub("['\"\\\\]", r"\\\0", string) def _get_includes(module, config: Config): - """Returns the appropriate #includes based on module's imports.""" - includes = [] - for import_ in module.foreign_import: - if import_.file_name.text: - includes.append( - code_template.format_template( - _TEMPLATES.include, - file_name=_cpp_string_escape(import_.file_name.text + ".h"))) - else: - includes.append( - code_template.format_template( - _TEMPLATES.include, - file_name=_cpp_string_escape(_PRELUDE_INCLUDE_FILE))) - if config.include_enum_traits: - includes.extend( - [code_template.format_template( - _TEMPLATES.include, - file_name=_cpp_string_escape(file_name)) - for file_name in (_ENUM_VIEW_INCLUDE_FILE, _TEXT_UTIL_INCLUDE_FILE) - ]) - return "".join(includes) + """Returns the appropriate #includes based on module's imports.""" + includes = [] + for import_ in module.foreign_import: + if import_.file_name.text: + includes.append( + code_template.format_template( + _TEMPLATES.include, + file_name=_cpp_string_escape(import_.file_name.text + ".h"), + ) + ) + else: + includes.append( + code_template.format_template( + _TEMPLATES.include, + file_name=_cpp_string_escape(_PRELUDE_INCLUDE_FILE), + ) + ) + if config.include_enum_traits: + includes.extend( + [ + code_template.format_template( + _TEMPLATES.include, file_name=_cpp_string_escape(file_name) + ) + for file_name in ( + _ENUM_VIEW_INCLUDE_FILE, + _TEXT_UTIL_INCLUDE_FILE, + ) + ] + ) + return "".join(includes) def _render_namespace_prefix(namespace): - """Returns namespace rendered as a prefix, like ::foo::bar::baz.""" - return "".join(["::" + n for n in namespace]) + """Returns namespace rendered as a prefix, like ::foo::bar::baz.""" + return "".join(["::" + n for n in namespace]) def _render_integer(value): - """Returns a C++ string representation of a constant integer.""" - integer_type = _cpp_integer_type_for_range(value, value) - assert integer_type, ("Bug: value should never be outside [-2**63, 2**64), " - "got {}.".format(value)) - # C++ literals are always positive. Negative constants are actually the - # positive literal with the unary `-` operator applied. - # - # This means that C++ compilers for 2s-complement systems get finicky about - # minimum integers: if you feed `-9223372036854775808` into GCC, with -Wall, - # you get: - # - # warning: integer constant is so large that it is unsigned - # - # and Clang gives: - # - # warning: integer literal is too large to be represented in a signed - # integer type, interpreting as unsigned [-Wimplicitly-unsigned-literal] - # - # and MSVC: - # - # warning C4146: unary minus operator applied to unsigned type, result - # still unsigned - # - # So, workaround #1: -(2**63) must be written `(-9223372036854775807 - 1)`. - # - # The next problem is that MSVC (but not Clang or GCC) will pick `unsigned` - # as the type of a literal like `2147483648`. As far as I can tell, this is a - # violation of the C++11 standard, but it's possible that the final standard - # has different rules. (MSVC seems to treat decimal literals the way that the - # standard says octal and hexadecimal literals should be treated.) - # - # Luckily, workaround #2: we can unconditionally append `LL` to all constants - # to force them to be interpreted as `long long` (or `unsigned long long` for - # `ULL`-suffixed constants), and then use a narrowing cast to the appropriate - # type, without any warnings on any major compilers. - # - # TODO(bolms): This suffix computation is kind of a hack. - suffix = "U" if "uint" in integer_type else "" - if value == -(2**63): - return "static_cast({1}LL - 1)".format(integer_type, -(2**63 - 1)) - else: - return "static_cast({1}{2}LL)".format(integer_type, value, suffix) + """Returns a C++ string representation of a constant integer.""" + integer_type = _cpp_integer_type_for_range(value, value) + assert ( + integer_type + ), "Bug: value should never be outside [-2**63, 2**64), " "got {}.".format(value) + # C++ literals are always positive. Negative constants are actually the + # positive literal with the unary `-` operator applied. + # + # This means that C++ compilers for 2s-complement systems get finicky about + # minimum integers: if you feed `-9223372036854775808` into GCC, with -Wall, + # you get: + # + # warning: integer constant is so large that it is unsigned + # + # and Clang gives: + # + # warning: integer literal is too large to be represented in a signed + # integer type, interpreting as unsigned [-Wimplicitly-unsigned-literal] + # + # and MSVC: + # + # warning C4146: unary minus operator applied to unsigned type, result + # still unsigned + # + # So, workaround #1: -(2**63) must be written `(-9223372036854775807 - 1)`. + # + # The next problem is that MSVC (but not Clang or GCC) will pick `unsigned` + # as the type of a literal like `2147483648`. As far as I can tell, this is a + # violation of the C++11 standard, but it's possible that the final standard + # has different rules. (MSVC seems to treat decimal literals the way that the + # standard says octal and hexadecimal literals should be treated.) + # + # Luckily, workaround #2: we can unconditionally append `LL` to all constants + # to force them to be interpreted as `long long` (or `unsigned long long` for + # `ULL`-suffixed constants), and then use a narrowing cast to the appropriate + # type, without any warnings on any major compilers. + # + # TODO(bolms): This suffix computation is kind of a hack. + suffix = "U" if "uint" in integer_type else "" + if value == -(2**63): + return "static_cast({1}LL - 1)".format(integer_type, -(2**63 - 1)) + else: + return "static_cast({1}{2}LL)".format(integer_type, value, suffix) def _maybe_type(wrapped_type): - return "::emboss::support::Maybe".format(wrapped_type) + return "::emboss::support::Maybe".format(wrapped_type) def _render_integer_for_expression(value): - integer_type = _cpp_integer_type_for_range(value, value) - return "{0}({1})".format(_maybe_type(integer_type), _render_integer(value)) + integer_type = _cpp_integer_type_for_range(value, value) + return "{0}({1})".format(_maybe_type(integer_type), _render_integer(value)) def _wrap_in_namespace(body, namespace): - """Returns the given body wrapped in the given namespace.""" - for component in reversed(namespace): - body = code_template.format_template(_TEMPLATES.namespace_wrap, - component=component, - body=body) + "\n" - return body + """Returns the given body wrapped in the given namespace.""" + for component in reversed(namespace): + body = ( + code_template.format_template( + _TEMPLATES.namespace_wrap, component=component, body=body + ) + + "\n" + ) + return body def _get_type_size(type_ir, ir): - size = ir_util.fixed_size_of_type_in_bits(type_ir, ir) - assert size is not None, ( - "_get_type_size should only be called for constant-sized types.") - return size + size = ir_util.fixed_size_of_type_in_bits(type_ir, ir) + assert ( + size is not None + ), "_get_type_size should only be called for constant-sized types." + return size def _offset_storage_adapter(buffer_type, alignment, static_offset): - return "{}::template OffsetStorageType".format( - buffer_type, alignment, static_offset) + return "{}::template OffsetStorageType".format( + buffer_type, alignment, static_offset + ) def _bytes_to_bits_convertor(buffer_type, byte_order, size): - assert byte_order, "byte_order should not be empty." - return "{}::BitBlock, {}>".format( - _SUPPORT_NAMESPACE, - _SUPPORT_NAMESPACE, - byte_order, - buffer_type, - size) + assert byte_order, "byte_order should not be empty." + return "{}::BitBlock, {}>".format( + _SUPPORT_NAMESPACE, _SUPPORT_NAMESPACE, byte_order, buffer_type, size + ) def _get_fully_qualified_namespace(name, ir): - module = ir_util.find_object((name.module_file,), ir) - namespace = _render_namespace_prefix(_get_module_namespace(module)) - return namespace + "".join(["::" + str(s) for s in name.object_path[:-1]]) + module = ir_util.find_object((name.module_file,), ir) + namespace = _render_namespace_prefix(_get_module_namespace(module)) + return namespace + "".join(["::" + str(s) for s in name.object_path[:-1]]) def _get_unqualified_name(name): - return name.object_path[-1] + return name.object_path[-1] def _get_fully_qualified_name(name, ir): - return (_get_fully_qualified_namespace(name, ir) + "::" + - _get_unqualified_name(name)) - - -def _get_adapted_cpp_buffer_type_for_field(type_definition, size_in_bits, - buffer_type, byte_order, - parent_addressable_unit): - """Returns the adapted C++ type information needed to construct a view.""" - if (parent_addressable_unit == ir_data.AddressableUnit.BYTE and - type_definition.addressable_unit == ir_data.AddressableUnit.BIT): - assert byte_order - return _bytes_to_bits_convertor(buffer_type, byte_order, size_in_bits) - else: - assert parent_addressable_unit == type_definition.addressable_unit, ( - "Addressable unit mismatch: {} vs {}".format( - parent_addressable_unit, - type_definition.addressable_unit)) - return buffer_type + return _get_fully_qualified_namespace(name, ir) + "::" + _get_unqualified_name(name) + + +def _get_adapted_cpp_buffer_type_for_field( + type_definition, size_in_bits, buffer_type, byte_order, parent_addressable_unit +): + """Returns the adapted C++ type information needed to construct a view.""" + if ( + parent_addressable_unit == ir_data.AddressableUnit.BYTE + and type_definition.addressable_unit == ir_data.AddressableUnit.BIT + ): + assert byte_order + return _bytes_to_bits_convertor(buffer_type, byte_order, size_in_bits) + else: + assert ( + parent_addressable_unit == type_definition.addressable_unit + ), "Addressable unit mismatch: {} vs {}".format( + parent_addressable_unit, type_definition.addressable_unit + ) + return buffer_type def _get_cpp_view_type_for_type_definition( - type_definition, size, ir, buffer_type, byte_order, parent_addressable_unit, - validator): - """Returns the C++ type information needed to construct a view. - - Returns the C++ type for a view of the given Emboss TypeDefinition, and the - C++ types of its parameters, if any. - - Arguments: - type_definition: The ir_data.TypeDefinition whose view should be - constructed. - size: The size, in type_definition.addressable_units, of the instantiated - type, or None if it is not known at compile time. - ir: The complete IR. - buffer_type: The C++ type to be used as the Storage parameter of the view - (e.g., "ContiguousBuffer<...>"). - byte_order: For BIT types which are direct children of BYTE types, - "LittleEndian", "BigEndian", or "None". Otherwise, None. - parent_addressable_unit: The addressable_unit_size of the structure - containing this structure. - validator: The name of the validator type to be injected into the view. - - Returns: - A tuple of: the C++ view type and a (possibly-empty) list of the C++ types - of Emboss parameters which must be passed to the view's constructor. - """ - adapted_buffer_type = _get_adapted_cpp_buffer_type_for_field( - type_definition, size, buffer_type, byte_order, parent_addressable_unit) - if type_definition.HasField("external"): - # Externals do not (yet) support runtime parameters. - return code_template.format_template( - _TEMPLATES.external_view_type, - namespace=_get_fully_qualified_namespace( - type_definition.name.canonical_name, ir), - name=_get_unqualified_name(type_definition.name.canonical_name), - bits=size, - validator=validator, - buffer_type=adapted_buffer_type), [] - elif type_definition.HasField("structure"): - parameter_types = [] - for parameter in type_definition.runtime_parameter: - parameter_types.append( - _cpp_basic_type_for_expression_type(parameter.type, ir)) - return code_template.format_template( - _TEMPLATES.structure_view_type, - namespace=_get_fully_qualified_namespace( - type_definition.name.canonical_name, ir), - name=_get_unqualified_name(type_definition.name.canonical_name), - buffer_type=adapted_buffer_type), parameter_types - elif type_definition.HasField("enumeration"): - return code_template.format_template( - _TEMPLATES.enum_view_type, - support_namespace=_SUPPORT_NAMESPACE, - enum_type=_get_fully_qualified_name(type_definition.name.canonical_name, - ir), - bits=size, - validator=validator, - buffer_type=adapted_buffer_type), [] - else: - assert False, "Unknown variety of type {}".format(type_definition) + type_definition, + size, + ir, + buffer_type, + byte_order, + parent_addressable_unit, + validator, +): + """Returns the C++ type information needed to construct a view. + + Returns the C++ type for a view of the given Emboss TypeDefinition, and the + C++ types of its parameters, if any. + + Arguments: + type_definition: The ir_data.TypeDefinition whose view should be + constructed. + size: The size, in type_definition.addressable_units, of the instantiated + type, or None if it is not known at compile time. + ir: The complete IR. + buffer_type: The C++ type to be used as the Storage parameter of the view + (e.g., "ContiguousBuffer<...>"). + byte_order: For BIT types which are direct children of BYTE types, + "LittleEndian", "BigEndian", or "None". Otherwise, None. + parent_addressable_unit: The addressable_unit_size of the structure + containing this structure. + validator: The name of the validator type to be injected into the view. + + Returns: + A tuple of: the C++ view type and a (possibly-empty) list of the C++ types + of Emboss parameters which must be passed to the view's constructor. + """ + adapted_buffer_type = _get_adapted_cpp_buffer_type_for_field( + type_definition, size, buffer_type, byte_order, parent_addressable_unit + ) + if type_definition.HasField("external"): + # Externals do not (yet) support runtime parameters. + return ( + code_template.format_template( + _TEMPLATES.external_view_type, + namespace=_get_fully_qualified_namespace( + type_definition.name.canonical_name, ir + ), + name=_get_unqualified_name(type_definition.name.canonical_name), + bits=size, + validator=validator, + buffer_type=adapted_buffer_type, + ), + [], + ) + elif type_definition.HasField("structure"): + parameter_types = [] + for parameter in type_definition.runtime_parameter: + parameter_types.append( + _cpp_basic_type_for_expression_type(parameter.type, ir) + ) + return ( + code_template.format_template( + _TEMPLATES.structure_view_type, + namespace=_get_fully_qualified_namespace( + type_definition.name.canonical_name, ir + ), + name=_get_unqualified_name(type_definition.name.canonical_name), + buffer_type=adapted_buffer_type, + ), + parameter_types, + ) + elif type_definition.HasField("enumeration"): + return ( + code_template.format_template( + _TEMPLATES.enum_view_type, + support_namespace=_SUPPORT_NAMESPACE, + enum_type=_get_fully_qualified_name( + type_definition.name.canonical_name, ir + ), + bits=size, + validator=validator, + buffer_type=adapted_buffer_type, + ), + [], + ) + else: + assert False, "Unknown variety of type {}".format(type_definition) def _get_cpp_view_type_for_physical_type( - type_ir, size, byte_order, ir, buffer_type, parent_addressable_unit, - validator): - """Returns the C++ type information needed to construct a field's view. - - Returns the C++ type of an ir_data.Type, and the C++ types of its parameters, - if any. - - Arguments: - type_ir: The ir_data.Type whose view should be constructed. - size: The size, in type_definition.addressable_units, of the instantiated - type, or None if it is not known at compile time. - byte_order: For BIT types which are direct children of BYTE types, - "LittleEndian", "BigEndian", or "None". Otherwise, None. - ir: The complete IR. - buffer_type: The C++ type to be used as the Storage parameter of the view - (e.g., "ContiguousBuffer<...>"). - parent_addressable_unit: The addressable_unit_size of the structure - containing this type. - validator: The name of the validator type to be injected into the view. - - Returns: - A tuple of: the C++ type for a view of the given Emboss Type and a list of - the C++ types of any parameters of the view type, which should be passed - to the view's constructor. - """ - if ir_util.is_array(type_ir): - # An array view is parameterized by the element's view type. - base_type = type_ir.array_type.base_type - element_size_in_bits = _get_type_size(base_type, ir) - assert element_size_in_bits, ( - "TODO(bolms): Implement arrays of dynamically-sized elements.") - assert element_size_in_bits % parent_addressable_unit == 0, ( - "Array elements must fall on byte boundaries.") - element_size = element_size_in_bits // parent_addressable_unit - element_view_type, element_view_parameter_types, element_view_parameters = ( - _get_cpp_view_type_for_physical_type( - base_type, element_size_in_bits, byte_order, ir, - _offset_storage_adapter(buffer_type, element_size, 0), - parent_addressable_unit, validator)) - return ( - code_template.format_template( - _TEMPLATES.array_view_adapter, - support_namespace=_SUPPORT_NAMESPACE, - # TODO(bolms): The element size should be calculable from the field - # size and array length. - element_view_type=element_view_type, - element_view_parameter_types="".join( - ", " + p for p in element_view_parameter_types), - element_size=element_size, - addressable_unit_size=int(parent_addressable_unit), - buffer_type=buffer_type), - element_view_parameter_types, - element_view_parameters - ) - else: - assert type_ir.HasField("atomic_type") - reference = type_ir.atomic_type.reference - referenced_type = ir_util.find_object(reference, ir) - if parent_addressable_unit > referenced_type.addressable_unit: - assert byte_order, repr(type_ir) - reader, parameter_types = _get_cpp_view_type_for_type_definition( - referenced_type, size, ir, buffer_type, byte_order, - parent_addressable_unit, validator) - return reader, parameter_types, list(type_ir.atomic_type.runtime_parameter) + type_ir, size, byte_order, ir, buffer_type, parent_addressable_unit, validator +): + """Returns the C++ type information needed to construct a field's view. + + Returns the C++ type of an ir_data.Type, and the C++ types of its parameters, + if any. + + Arguments: + type_ir: The ir_data.Type whose view should be constructed. + size: The size, in type_definition.addressable_units, of the instantiated + type, or None if it is not known at compile time. + byte_order: For BIT types which are direct children of BYTE types, + "LittleEndian", "BigEndian", or "None". Otherwise, None. + ir: The complete IR. + buffer_type: The C++ type to be used as the Storage parameter of the view + (e.g., "ContiguousBuffer<...>"). + parent_addressable_unit: The addressable_unit_size of the structure + containing this type. + validator: The name of the validator type to be injected into the view. + + Returns: + A tuple of: the C++ type for a view of the given Emboss Type and a list of + the C++ types of any parameters of the view type, which should be passed + to the view's constructor. + """ + if ir_util.is_array(type_ir): + # An array view is parameterized by the element's view type. + base_type = type_ir.array_type.base_type + element_size_in_bits = _get_type_size(base_type, ir) + assert ( + element_size_in_bits + ), "TODO(bolms): Implement arrays of dynamically-sized elements." + assert ( + element_size_in_bits % parent_addressable_unit == 0 + ), "Array elements must fall on byte boundaries." + element_size = element_size_in_bits // parent_addressable_unit + element_view_type, element_view_parameter_types, element_view_parameters = ( + _get_cpp_view_type_for_physical_type( + base_type, + element_size_in_bits, + byte_order, + ir, + _offset_storage_adapter(buffer_type, element_size, 0), + parent_addressable_unit, + validator, + ) + ) + return ( + code_template.format_template( + _TEMPLATES.array_view_adapter, + support_namespace=_SUPPORT_NAMESPACE, + # TODO(bolms): The element size should be calculable from the field + # size and array length. + element_view_type=element_view_type, + element_view_parameter_types="".join( + ", " + p for p in element_view_parameter_types + ), + element_size=element_size, + addressable_unit_size=int(parent_addressable_unit), + buffer_type=buffer_type, + ), + element_view_parameter_types, + element_view_parameters, + ) + else: + assert type_ir.HasField("atomic_type") + reference = type_ir.atomic_type.reference + referenced_type = ir_util.find_object(reference, ir) + if parent_addressable_unit > referenced_type.addressable_unit: + assert byte_order, repr(type_ir) + reader, parameter_types = _get_cpp_view_type_for_type_definition( + referenced_type, + size, + ir, + buffer_type, + byte_order, + parent_addressable_unit, + validator, + ) + return reader, parameter_types, list(type_ir.atomic_type.runtime_parameter) def _render_variable(variable, prefix=""): - """Renders a variable reference (e.g., `foo` or `foo.bar.baz`) in C++ code.""" - # A "variable" could be an immediate field or a subcomponent of an immediate - # field. For either case, in C++ it is valid to just use the last component - # of the name; it is not necessary to qualify the method with the type. - components = [] - for component in variable: - components.append(_cpp_field_name(component[-1]) + "()") - components[-1] = prefix + components[-1] - return ".".join(components) + """Renders a variable reference (e.g., `foo` or `foo.bar.baz`) in C++ code.""" + # A "variable" could be an immediate field or a subcomponent of an immediate + # field. For either case, in C++ it is valid to just use the last component + # of the name; it is not necessary to qualify the method with the type. + components = [] + for component in variable: + components.append(_cpp_field_name(component[-1]) + "()") + components[-1] = prefix + components[-1] + return ".".join(components) def _render_enum_value(enum_type, ir): - cpp_enum_type = _get_fully_qualified_name(enum_type.name.canonical_name, ir) - return "{}(static_cast({}))".format( - _maybe_type(cpp_enum_type), cpp_enum_type, enum_type.value) + cpp_enum_type = _get_fully_qualified_name(enum_type.name.canonical_name, ir) + return "{}(static_cast({}))".format( + _maybe_type(cpp_enum_type), cpp_enum_type, enum_type.value + ) def _builtin_function_name(function): - """Returns the C++ operator name corresponding to an Emboss operator.""" - functions = { - ir_data.FunctionMapping.ADDITION: "Sum", - ir_data.FunctionMapping.SUBTRACTION: "Difference", - ir_data.FunctionMapping.MULTIPLICATION: "Product", - ir_data.FunctionMapping.EQUALITY: "Equal", - ir_data.FunctionMapping.INEQUALITY: "NotEqual", - ir_data.FunctionMapping.AND: "And", - ir_data.FunctionMapping.OR: "Or", - ir_data.FunctionMapping.LESS: "LessThan", - ir_data.FunctionMapping.LESS_OR_EQUAL: "LessThanOrEqual", - ir_data.FunctionMapping.GREATER: "GreaterThan", - ir_data.FunctionMapping.GREATER_OR_EQUAL: "GreaterThanOrEqual", - ir_data.FunctionMapping.CHOICE: "Choice", - ir_data.FunctionMapping.MAXIMUM: "Maximum", - } - return functions[function] + """Returns the C++ operator name corresponding to an Emboss operator.""" + functions = { + ir_data.FunctionMapping.ADDITION: "Sum", + ir_data.FunctionMapping.SUBTRACTION: "Difference", + ir_data.FunctionMapping.MULTIPLICATION: "Product", + ir_data.FunctionMapping.EQUALITY: "Equal", + ir_data.FunctionMapping.INEQUALITY: "NotEqual", + ir_data.FunctionMapping.AND: "And", + ir_data.FunctionMapping.OR: "Or", + ir_data.FunctionMapping.LESS: "LessThan", + ir_data.FunctionMapping.LESS_OR_EQUAL: "LessThanOrEqual", + ir_data.FunctionMapping.GREATER: "GreaterThan", + ir_data.FunctionMapping.GREATER_OR_EQUAL: "GreaterThanOrEqual", + ir_data.FunctionMapping.CHOICE: "Choice", + ir_data.FunctionMapping.MAXIMUM: "Maximum", + } + return functions[function] def _cpp_basic_type_for_expression_type(expression_type, ir): - """Returns the C++ basic type (int32_t, bool, etc.) for an ExpressionType.""" - if expression_type.WhichOneof("type") == "integer": - return _cpp_integer_type_for_range( - int(expression_type.integer.minimum_value), - int(expression_type.integer.maximum_value)) - elif expression_type.WhichOneof("type") == "boolean": - return "bool" - elif expression_type.WhichOneof("type") == "enumeration": - return _get_fully_qualified_name( - expression_type.enumeration.name.canonical_name, ir) - else: - assert False, "Unknown expression type " + expression_type.WhichOneof( - "type") + """Returns the C++ basic type (int32_t, bool, etc.) for an ExpressionType.""" + if expression_type.WhichOneof("type") == "integer": + return _cpp_integer_type_for_range( + int(expression_type.integer.minimum_value), + int(expression_type.integer.maximum_value), + ) + elif expression_type.WhichOneof("type") == "boolean": + return "bool" + elif expression_type.WhichOneof("type") == "enumeration": + return _get_fully_qualified_name( + expression_type.enumeration.name.canonical_name, ir + ) + else: + assert False, "Unknown expression type " + expression_type.WhichOneof("type") def _cpp_basic_type_for_expression(expression, ir): - """Returns the C++ basic type (int32_t, bool, etc.) for an Expression.""" - return _cpp_basic_type_for_expression_type(expression.type, ir) + """Returns the C++ basic type (int32_t, bool, etc.) for an Expression.""" + return _cpp_basic_type_for_expression_type(expression.type, ir) def _cpp_integer_type_for_range(min_val, max_val): - """Returns the appropriate C++ integer type to hold min_val up to max_val.""" - # The choice of int32_t, uint32_t, int64_t, then uint64_t is somewhat - # arbitrary here, and might not be perfectly ideal. I (bolms@) have chosen - # this set of types to a) minimize the number of casts that occur in - # arithmetic expressions, and b) favor 32-bit arithmetic, which is mostly - # "cheapest" on current (2018) systems. Signed integers are also preferred - # over unsigned so that the C++ compiler can take advantage of undefined - # overflow. - for size in (32, 64): - if min_val >= -(2**(size - 1)) and max_val <= 2**(size - 1) - 1: - return "::std::int{}_t".format(size) - elif min_val >= 0 and max_val <= 2**size - 1: - return "::std::uint{}_t".format(size) - return None + """Returns the appropriate C++ integer type to hold min_val up to max_val.""" + # The choice of int32_t, uint32_t, int64_t, then uint64_t is somewhat + # arbitrary here, and might not be perfectly ideal. I (bolms@) have chosen + # this set of types to a) minimize the number of casts that occur in + # arithmetic expressions, and b) favor 32-bit arithmetic, which is mostly + # "cheapest" on current (2018) systems. Signed integers are also preferred + # over unsigned so that the C++ compiler can take advantage of undefined + # overflow. + for size in (32, 64): + if min_val >= -(2 ** (size - 1)) and max_val <= 2 ** (size - 1) - 1: + return "::std::int{}_t".format(size) + elif min_val >= 0 and max_val <= 2**size - 1: + return "::std::uint{}_t".format(size) + return None def _cpp_integer_type_for_enum(max_bits, is_signed): - """Returns the appropriate C++ integer type to hold an enum.""" - # This is used to determine the `X` in `enum class : X`. - # - # Unlike _cpp_integer_type_for_range, the type chosen here is used for actual - # storage. Further, sizes smaller than 64 are explicitly chosen by a human - # author, so take the smallest size that can hold the given number of bits. - # - # Technically, the C++ standard allows some of these sizes of integer to not - # exist, and other sizes (say, int24_t) might exist, but in practice this set - # is almost always available. If you're compiling for some exotic DSP that - # uses unusual int sizes, email emboss-dev@google.com. - for size in (8, 16, 32, 64): - if max_bits <= size: - return "::std::{}int{}_t".format("" if is_signed else "u", size) - assert False, f"Invalid value {max_bits} for maximum_bits" + """Returns the appropriate C++ integer type to hold an enum.""" + # This is used to determine the `X` in `enum class : X`. + # + # Unlike _cpp_integer_type_for_range, the type chosen here is used for actual + # storage. Further, sizes smaller than 64 are explicitly chosen by a human + # author, so take the smallest size that can hold the given number of bits. + # + # Technically, the C++ standard allows some of these sizes of integer to not + # exist, and other sizes (say, int24_t) might exist, but in practice this set + # is almost always available. If you're compiling for some exotic DSP that + # uses unusual int sizes, email emboss-dev@google.com. + for size in (8, 16, 32, 64): + if max_bits <= size: + return "::std::{}int{}_t".format("" if is_signed else "u", size) + assert False, f"Invalid value {max_bits} for maximum_bits" def _render_builtin_operation(expression, ir, field_reader, subexpressions): - """Renders a built-in operation (+, -, &&, etc.) into C++ code.""" - assert expression.function.function not in ( - ir_data.FunctionMapping.UPPER_BOUND, ir_data.FunctionMapping.LOWER_BOUND), ( - "UPPER_BOUND and LOWER_BOUND should be constant.") - if expression.function.function == ir_data.FunctionMapping.PRESENCE: - return field_reader.render_existence(expression.function.args[0], - subexpressions) - args = expression.function.args - rendered_args = [ - _render_expression(arg, ir, field_reader, subexpressions).rendered - for arg in args] - minimum_integers = [] - maximum_integers = [] - enum_types = set() - have_boolean_types = False - for subexpression in [expression] + list(args): - if subexpression.type.WhichOneof("type") == "integer": - minimum_integers.append(int(subexpression.type.integer.minimum_value)) - maximum_integers.append(int(subexpression.type.integer.maximum_value)) - elif subexpression.type.WhichOneof("type") == "enumeration": - enum_types.add(_cpp_basic_type_for_expression(subexpression, ir)) - elif subexpression.type.WhichOneof("type") == "boolean": - have_boolean_types = True - # At present, all Emboss functions other than `$has` take and return one of - # the following: - # - # integers - # integers and booleans - # a single enum type - # a single enum type and booleans - # booleans - # - # Really, the intermediate type is only necessary for integers, but it - # simplifies the C++ somewhat if the appropriate enum/boolean type is provided - # as "IntermediateT" -- it means that, e.g., the choice ("?:") operator does - # not have to have two versions, one of which casts (some of) its arguments to - # IntermediateT, and one of which does not. - # - # This is not a particularly robust scheme, but it works for all of the Emboss - # functions I (bolms@) have written and am considering (division, modulus, - # exponentiation, logical negation, bit shifts, bitwise and/or/xor, $min, - # $floor, $ceil, $has). - if minimum_integers and not enum_types: - intermediate_type = _cpp_integer_type_for_range(min(minimum_integers), - max(maximum_integers)) - elif len(enum_types) == 1 and not minimum_integers: - intermediate_type = list(enum_types)[0] - else: - assert have_boolean_types - assert not enum_types - assert not minimum_integers - intermediate_type = "bool" - arg_types = [_cpp_basic_type_for_expression(arg, ir) for arg in args] - result_type = _cpp_basic_type_for_expression(expression, ir) - function_variant = "".format( - intermediate_type, result_type, ", ".join(arg_types)) - return "::emboss::support::{}{}({})".format( - _builtin_function_name(expression.function.function), - function_variant, ", ".join(rendered_args)) + """Renders a built-in operation (+, -, &&, etc.) into C++ code.""" + assert expression.function.function not in ( + ir_data.FunctionMapping.UPPER_BOUND, + ir_data.FunctionMapping.LOWER_BOUND, + ), "UPPER_BOUND and LOWER_BOUND should be constant." + if expression.function.function == ir_data.FunctionMapping.PRESENCE: + return field_reader.render_existence( + expression.function.args[0], subexpressions + ) + args = expression.function.args + rendered_args = [ + _render_expression(arg, ir, field_reader, subexpressions).rendered + for arg in args + ] + minimum_integers = [] + maximum_integers = [] + enum_types = set() + have_boolean_types = False + for subexpression in [expression] + list(args): + if subexpression.type.WhichOneof("type") == "integer": + minimum_integers.append(int(subexpression.type.integer.minimum_value)) + maximum_integers.append(int(subexpression.type.integer.maximum_value)) + elif subexpression.type.WhichOneof("type") == "enumeration": + enum_types.add(_cpp_basic_type_for_expression(subexpression, ir)) + elif subexpression.type.WhichOneof("type") == "boolean": + have_boolean_types = True + # At present, all Emboss functions other than `$has` take and return one of + # the following: + # + # integers + # integers and booleans + # a single enum type + # a single enum type and booleans + # booleans + # + # Really, the intermediate type is only necessary for integers, but it + # simplifies the C++ somewhat if the appropriate enum/boolean type is provided + # as "IntermediateT" -- it means that, e.g., the choice ("?:") operator does + # not have to have two versions, one of which casts (some of) its arguments to + # IntermediateT, and one of which does not. + # + # This is not a particularly robust scheme, but it works for all of the Emboss + # functions I (bolms@) have written and am considering (division, modulus, + # exponentiation, logical negation, bit shifts, bitwise and/or/xor, $min, + # $floor, $ceil, $has). + if minimum_integers and not enum_types: + intermediate_type = _cpp_integer_type_for_range( + min(minimum_integers), max(maximum_integers) + ) + elif len(enum_types) == 1 and not minimum_integers: + intermediate_type = list(enum_types)[0] + else: + assert have_boolean_types + assert not enum_types + assert not minimum_integers + intermediate_type = "bool" + arg_types = [_cpp_basic_type_for_expression(arg, ir) for arg in args] + result_type = _cpp_basic_type_for_expression(expression, ir) + function_variant = "".format( + intermediate_type, result_type, ", ".join(arg_types) + ) + return "::emboss::support::{}{}({})".format( + _builtin_function_name(expression.function.function), + function_variant, + ", ".join(rendered_args), + ) class _FieldRenderer(object): - """Base class for rendering field reads.""" - - def render_field_read_with_context(self, expression, ir, prefix, - subexpressions): - field = ( - prefix + - _render_variable(ir_util.hashable_form_of_field_reference( - expression.field_reference))) - if subexpressions is None: - field_expression = field - else: - field_expression = subexpressions.add(field) - expression_cpp_type = _cpp_basic_type_for_expression(expression, ir) - return ("({0}.Ok()" + """Base class for rendering field reads.""" + + def render_field_read_with_context(self, expression, ir, prefix, subexpressions): + field = prefix + _render_variable( + ir_util.hashable_form_of_field_reference(expression.field_reference) + ) + if subexpressions is None: + field_expression = field + else: + field_expression = subexpressions.add(field) + expression_cpp_type = _cpp_basic_type_for_expression(expression, ir) + return ( + "({0}.Ok()" " ? {1}(static_cast({0}.UncheckedRead()))" " : {1}())".format( - field_expression, - _maybe_type(expression_cpp_type), - expression_cpp_type)) + field_expression, _maybe_type(expression_cpp_type), expression_cpp_type + ) + ) - def render_existence_with_context(self, expression, prefix, subexpressions): - return "{1}{0}".format( - _render_variable( - ir_util.hashable_form_of_field_reference( - expression.field_reference), - "has_"), - prefix) + def render_existence_with_context(self, expression, prefix, subexpressions): + return "{1}{0}".format( + _render_variable( + ir_util.hashable_form_of_field_reference(expression.field_reference), + "has_", + ), + prefix, + ) class _DirectFieldRenderer(_FieldRenderer): - """Renderer for fields read from inside a structure's View type.""" + """Renderer for fields read from inside a structure's View type.""" - def render_field(self, expression, ir, subexpressions): - return self.render_field_read_with_context( - expression, ir, "", subexpressions) + def render_field(self, expression, ir, subexpressions): + return self.render_field_read_with_context(expression, ir, "", subexpressions) - def render_existence(self, expression, subexpressions): - return self.render_existence_with_context(expression, "", subexpressions) + def render_existence(self, expression, subexpressions): + return self.render_existence_with_context(expression, "", subexpressions) class _VirtualViewFieldRenderer(_FieldRenderer): - """Renderer for field reads from inside a virtual field's View.""" + """Renderer for field reads from inside a virtual field's View.""" - def render_existence(self, expression, subexpressions): - return self.render_existence_with_context( - expression, "view_.", subexpressions) + def render_existence(self, expression, subexpressions): + return self.render_existence_with_context(expression, "view_.", subexpressions) - def render_field(self, expression, ir, subexpressions): - return self.render_field_read_with_context( - expression, ir, "view_.", subexpressions) + def render_field(self, expression, ir, subexpressions): + return self.render_field_read_with_context( + expression, ir, "view_.", subexpressions + ) class _SubexpressionStore(object): - """Holder for subexpressions to be assigned to local variables.""" + """Holder for subexpressions to be assigned to local variables.""" - def __init__(self, prefix): - self._prefix = prefix - self._subexpr_to_name = {} - self._index_to_subexpr = [] + def __init__(self, prefix): + self._prefix = prefix + self._subexpr_to_name = {} + self._index_to_subexpr = [] - def add(self, subexpr): - if subexpr not in self._subexpr_to_name: - self._index_to_subexpr.append(subexpr) - self._subexpr_to_name[subexpr] = ( - self._prefix + str(len(self._index_to_subexpr))) - return self._subexpr_to_name[subexpr] + def add(self, subexpr): + if subexpr not in self._subexpr_to_name: + self._index_to_subexpr.append(subexpr) + self._subexpr_to_name[subexpr] = self._prefix + str( + len(self._index_to_subexpr) + ) + return self._subexpr_to_name[subexpr] - def subexprs(self): - return [(self._subexpr_to_name[subexpr], subexpr) - for subexpr in self._index_to_subexpr] + def subexprs(self): + return [ + (self._subexpr_to_name[subexpr], subexpr) + for subexpr in self._index_to_subexpr + ] -_ExpressionResult = collections.namedtuple("ExpressionResult", - ["rendered", "is_constant"]) +_ExpressionResult = collections.namedtuple( + "ExpressionResult", ["rendered", "is_constant"] +) def _render_expression(expression, ir, field_reader=None, subexpressions=None): - """Renders an expression into C++ code. - - Arguments: - expression: The expression to render. - ir: The IR in which to look up references. - field_reader: An object with render_existence and render_field methods - appropriate for the C++ context of the expression. - subexpressions: A _SubexpressionStore in which to put subexpressions, or - None if subexpressions should be inline. - - Returns: - A tuple of (rendered_text, is_constant), where rendered_text is C++ code - that can be emitted, and is_constant is True if the expression is a - compile-time constant suitable for use in a C++11 constexpr context, - otherwise False. - """ - if field_reader is None: - field_reader = _DirectFieldRenderer() - - # If the expression is constant, there are no guarantees that subexpressions - # will fit into C++ types, or that operator arguments and return types can fit - # in the same type: expressions like `-0x8000_0000_0000_0000` and - # `0x1_0000_0000_0000_0000 - 1` can appear. - if expression.type.WhichOneof("type") == "integer": - if expression.type.integer.modulus == "infinity": - return _ExpressionResult(_render_integer_for_expression(int( - expression.type.integer.modular_value)), True) - elif expression.type.WhichOneof("type") == "boolean": - if expression.type.boolean.HasField("value"): - if expression.type.boolean.value: - return _ExpressionResult(_maybe_type("bool") + "(true)", True) - else: - return _ExpressionResult(_maybe_type("bool") + "(false)", True) - elif expression.type.WhichOneof("type") == "enumeration": - if expression.type.enumeration.HasField("value"): - return _ExpressionResult( - _render_enum_value(expression.type.enumeration, ir), True) - else: - # There shouldn't be any "opaque" type expressions here. - assert False, "Unhandled expression type {}".format( - expression.type.WhichOneof("type")) - - result = None - # Otherwise, render the operation. - if expression.WhichOneof("expression") == "function": - result = _render_builtin_operation( - expression, ir, field_reader, subexpressions) - elif expression.WhichOneof("expression") == "field_reference": - result = field_reader.render_field(expression, ir, subexpressions) - elif (expression.WhichOneof("expression") == "builtin_reference" and - expression.builtin_reference.canonical_name.object_path[-1] == - "$logical_value"): - return _ExpressionResult( - _maybe_type("decltype(emboss_reserved_local_value)") + - "(emboss_reserved_local_value)", False) - - # Any of the constant expression types should have been handled in the - # previous section. - assert result is not None, "Unable to render expression {}".format( - str(expression)) - - if subexpressions is None: - return _ExpressionResult(result, False) - else: - return _ExpressionResult(subexpressions.add(result), False) + """Renders an expression into C++ code. + + Arguments: + expression: The expression to render. + ir: The IR in which to look up references. + field_reader: An object with render_existence and render_field methods + appropriate for the C++ context of the expression. + subexpressions: A _SubexpressionStore in which to put subexpressions, or + None if subexpressions should be inline. + + Returns: + A tuple of (rendered_text, is_constant), where rendered_text is C++ code + that can be emitted, and is_constant is True if the expression is a + compile-time constant suitable for use in a C++11 constexpr context, + otherwise False. + """ + if field_reader is None: + field_reader = _DirectFieldRenderer() + + # If the expression is constant, there are no guarantees that subexpressions + # will fit into C++ types, or that operator arguments and return types can fit + # in the same type: expressions like `-0x8000_0000_0000_0000` and + # `0x1_0000_0000_0000_0000 - 1` can appear. + if expression.type.WhichOneof("type") == "integer": + if expression.type.integer.modulus == "infinity": + return _ExpressionResult( + _render_integer_for_expression( + int(expression.type.integer.modular_value) + ), + True, + ) + elif expression.type.WhichOneof("type") == "boolean": + if expression.type.boolean.HasField("value"): + if expression.type.boolean.value: + return _ExpressionResult(_maybe_type("bool") + "(true)", True) + else: + return _ExpressionResult(_maybe_type("bool") + "(false)", True) + elif expression.type.WhichOneof("type") == "enumeration": + if expression.type.enumeration.HasField("value"): + return _ExpressionResult( + _render_enum_value(expression.type.enumeration, ir), True + ) + else: + # There shouldn't be any "opaque" type expressions here. + assert False, "Unhandled expression type {}".format( + expression.type.WhichOneof("type") + ) + + result = None + # Otherwise, render the operation. + if expression.WhichOneof("expression") == "function": + result = _render_builtin_operation(expression, ir, field_reader, subexpressions) + elif expression.WhichOneof("expression") == "field_reference": + result = field_reader.render_field(expression, ir, subexpressions) + elif ( + expression.WhichOneof("expression") == "builtin_reference" + and expression.builtin_reference.canonical_name.object_path[-1] + == "$logical_value" + ): + return _ExpressionResult( + _maybe_type("decltype(emboss_reserved_local_value)") + + "(emboss_reserved_local_value)", + False, + ) + + # Any of the constant expression types should have been handled in the + # previous section. + assert result is not None, "Unable to render expression {}".format(str(expression)) + + if subexpressions is None: + return _ExpressionResult(result, False) + else: + return _ExpressionResult(subexpressions.add(result), False) def _render_existence_test(field, ir, subexpressions=None): - return _render_expression(field.existence_condition, ir, subexpressions) + return _render_expression(field.existence_condition, ir, subexpressions) def _alignment_of_location(location): - constraints = location.start.type.integer - if constraints.modulus == "infinity": - # The C++ templates use 0 as a sentinel value meaning infinity for - # alignment. - return 0, constraints.modular_value - else: - return constraints.modulus, constraints.modular_value - - -def _get_cpp_type_reader_of_field(field_ir, ir, buffer_type, validator, - parent_addressable_unit): - """Returns the C++ view type for a field.""" - field_size = None - if field_ir.type.HasField("size_in_bits"): - field_size = ir_util.constant_value(field_ir.type.size_in_bits) - assert field_size is not None - elif ir_util.is_constant(field_ir.location.size): - # TODO(bolms): Normalize the IR so that this clause is unnecessary. - field_size = (ir_util.constant_value(field_ir.location.size) * - parent_addressable_unit) - byte_order_attr = ir_util.get_attribute(field_ir.attribute, "byte_order") - if byte_order_attr: - byte_order = byte_order_attr.string_constant.text - else: - byte_order = "" - field_alignment, field_offset = _alignment_of_location(field_ir.location) - return _get_cpp_view_type_for_physical_type( - field_ir.type, field_size, byte_order, ir, - _offset_storage_adapter(buffer_type, field_alignment, field_offset), - parent_addressable_unit, validator) - - -def _generate_structure_field_methods(enclosing_type_name, field_ir, ir, - parent_addressable_unit): - if ir_util.field_is_virtual(field_ir): - return _generate_structure_virtual_field_methods( - enclosing_type_name, field_ir, ir) - else: - return _generate_structure_physical_field_methods( - enclosing_type_name, field_ir, ir, parent_addressable_unit) + constraints = location.start.type.integer + if constraints.modulus == "infinity": + # The C++ templates use 0 as a sentinel value meaning infinity for + # alignment. + return 0, constraints.modular_value + else: + return constraints.modulus, constraints.modular_value + + +def _get_cpp_type_reader_of_field( + field_ir, ir, buffer_type, validator, parent_addressable_unit +): + """Returns the C++ view type for a field.""" + field_size = None + if field_ir.type.HasField("size_in_bits"): + field_size = ir_util.constant_value(field_ir.type.size_in_bits) + assert field_size is not None + elif ir_util.is_constant(field_ir.location.size): + # TODO(bolms): Normalize the IR so that this clause is unnecessary. + field_size = ( + ir_util.constant_value(field_ir.location.size) * parent_addressable_unit + ) + byte_order_attr = ir_util.get_attribute(field_ir.attribute, "byte_order") + if byte_order_attr: + byte_order = byte_order_attr.string_constant.text + else: + byte_order = "" + field_alignment, field_offset = _alignment_of_location(field_ir.location) + return _get_cpp_view_type_for_physical_type( + field_ir.type, + field_size, + byte_order, + ir, + _offset_storage_adapter(buffer_type, field_alignment, field_offset), + parent_addressable_unit, + validator, + ) -def _generate_custom_validator_expression_for(field_ir, ir): - """Returns a validator expression for the given field, or None.""" - requires_attr = ir_util.get_attribute(field_ir.attribute, "requires") - if requires_attr: - class _ValidatorFieldReader(object): - """A "FieldReader" that translates the current field to `value`.""" - - def render_existence(self, expression, subexpressions): - del expression # Unused. - assert False, "Shouldn't be here." - - def render_field(self, expression, ir, subexpressions): - assert len(expression.field_reference.path) == 1 - assert (expression.field_reference.path[0].canonical_name == - field_ir.name.canonical_name) - expression_cpp_type = _cpp_basic_type_for_expression(expression, ir) - return "{}(emboss_reserved_local_value)".format( - _maybe_type(expression_cpp_type)) +def _generate_structure_field_methods( + enclosing_type_name, field_ir, ir, parent_addressable_unit +): + if ir_util.field_is_virtual(field_ir): + return _generate_structure_virtual_field_methods( + enclosing_type_name, field_ir, ir + ) + else: + return _generate_structure_physical_field_methods( + enclosing_type_name, field_ir, ir, parent_addressable_unit + ) - validation_body = _render_expression(requires_attr.expression, ir, - _ValidatorFieldReader()) - return validation_body.rendered - else: - return None + +def _generate_custom_validator_expression_for(field_ir, ir): + """Returns a validator expression for the given field, or None.""" + requires_attr = ir_util.get_attribute(field_ir.attribute, "requires") + if requires_attr: + + class _ValidatorFieldReader(object): + """A "FieldReader" that translates the current field to `value`.""" + + def render_existence(self, expression, subexpressions): + del expression # Unused. + assert False, "Shouldn't be here." + + def render_field(self, expression, ir, subexpressions): + assert len(expression.field_reference.path) == 1 + assert ( + expression.field_reference.path[0].canonical_name + == field_ir.name.canonical_name + ) + expression_cpp_type = _cpp_basic_type_for_expression(expression, ir) + return "{}(emboss_reserved_local_value)".format( + _maybe_type(expression_cpp_type) + ) + + validation_body = _render_expression( + requires_attr.expression, ir, _ValidatorFieldReader() + ) + return validation_body.rendered + else: + return None def _generate_validator_expression_for(field_ir, ir): - """Returns a validator expression for the given field.""" - result = _generate_custom_validator_expression_for(field_ir, ir) - if result is None: - return "::emboss::support::Maybe(true)" - return result - - -def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, - ir): - """Generates C++ code for methods for a single virtual field. - - Arguments: - enclosing_type_name: The text name of the enclosing type. - field_ir: The IR for the field to generate methods for. - ir: The full IR for the module. - - Returns: - A tuple of ("", declarations, definitions). The declarations can be - inserted into the class definition for the enclosing type's View. Any - definitions should be placed after the class definition. These are - separated to satisfy C++'s declaration-before-use requirements. - """ - if field_ir.write_method.WhichOneof("method") == "alias": - return _generate_field_indirection(field_ir, enclosing_type_name, ir) - - read_subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_") - read_value = _render_expression( - field_ir.read_transform, ir, - field_reader=_VirtualViewFieldRenderer(), - subexpressions=read_subexpressions) - field_exists = _render_existence_test(field_ir, ir) - logical_type = _cpp_basic_type_for_expression(field_ir.read_transform, ir) - - if read_value.is_constant and field_exists.is_constant: - assert not read_subexpressions.subexprs() - declaration_template = ( - _TEMPLATES.structure_single_const_virtual_field_method_declarations) - definition_template = ( - _TEMPLATES.structure_single_const_virtual_field_method_definitions) - else: - declaration_template = ( - _TEMPLATES.structure_single_virtual_field_method_declarations) - definition_template = ( - _TEMPLATES.structure_single_virtual_field_method_definitions) - - if field_ir.write_method.WhichOneof("method") == "transform": - destination = _render_variable( - ir_util.hashable_form_of_field_reference( - field_ir.write_method.transform.destination)) - transform = _render_expression( - field_ir.write_method.transform.function_body, ir, - field_reader=_VirtualViewFieldRenderer()).rendered - write_methods = code_template.format_template( - _TEMPLATES.structure_single_virtual_field_write_methods, + """Returns a validator expression for the given field.""" + result = _generate_custom_validator_expression_for(field_ir, ir) + if result is None: + return "::emboss::support::Maybe(true)" + return result + + +def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir): + """Generates C++ code for methods for a single virtual field. + + Arguments: + enclosing_type_name: The text name of the enclosing type. + field_ir: The IR for the field to generate methods for. + ir: The full IR for the module. + + Returns: + A tuple of ("", declarations, definitions). The declarations can be + inserted into the class definition for the enclosing type's View. Any + definitions should be placed after the class definition. These are + separated to satisfy C++'s declaration-before-use requirements. + """ + if field_ir.write_method.WhichOneof("method") == "alias": + return _generate_field_indirection(field_ir, enclosing_type_name, ir) + + read_subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_") + read_value = _render_expression( + field_ir.read_transform, + ir, + field_reader=_VirtualViewFieldRenderer(), + subexpressions=read_subexpressions, + ) + field_exists = _render_existence_test(field_ir, ir) + logical_type = _cpp_basic_type_for_expression(field_ir.read_transform, ir) + + if read_value.is_constant and field_exists.is_constant: + assert not read_subexpressions.subexprs() + declaration_template = ( + _TEMPLATES.structure_single_const_virtual_field_method_declarations + ) + definition_template = ( + _TEMPLATES.structure_single_const_virtual_field_method_definitions + ) + else: + declaration_template = ( + _TEMPLATES.structure_single_virtual_field_method_declarations + ) + definition_template = ( + _TEMPLATES.structure_single_virtual_field_method_definitions + ) + + if field_ir.write_method.WhichOneof("method") == "transform": + destination = _render_variable( + ir_util.hashable_form_of_field_reference( + field_ir.write_method.transform.destination + ) + ) + transform = _render_expression( + field_ir.write_method.transform.function_body, + ir, + field_reader=_VirtualViewFieldRenderer(), + ).rendered + write_methods = code_template.format_template( + _TEMPLATES.structure_single_virtual_field_write_methods, + logical_type=logical_type, + destination=destination, + transform=transform, + ) + else: + write_methods = "" + + name = field_ir.name.canonical_name.object_path[-1] + if name.startswith("$"): + name = _cpp_field_name(field_ir.name.name.text) + virtual_view_type_name = "EmbossReservedDollarVirtual{}View".format(name) + else: + virtual_view_type_name = "EmbossReservedVirtual{}View".format( + name_conversion.snake_to_camel(name) + ) + assert logical_type, "Could not find appropriate C++ type for {}".format( + field_ir.read_transform + ) + if field_ir.read_transform.type.WhichOneof("type") == "integer": + write_to_text_stream_function = "WriteIntegerViewToTextStream" + elif field_ir.read_transform.type.WhichOneof("type") == "boolean": + write_to_text_stream_function = "WriteBooleanViewToTextStream" + elif field_ir.read_transform.type.WhichOneof("type") == "enumeration": + write_to_text_stream_function = "WriteEnumViewToTextStream" + else: + assert False, "Unexpected read-only virtual field type {}".format( + field_ir.read_transform.type.WhichOneof("type") + ) + + value_is_ok = _generate_validator_expression_for(field_ir, ir) + declaration = code_template.format_template( + declaration_template, + visibility=_visibility_for_field(field_ir), + name=name, + virtual_view_type_name=virtual_view_type_name, + logical_type=logical_type, + read_subexpressions="".join( + [ + " const auto {} = {};\n".format(subexpr_name, subexpr) + for subexpr_name, subexpr in read_subexpressions.subexprs() + ] + ), + read_value=read_value.rendered, + write_to_text_stream_function=write_to_text_stream_function, + parent_type=enclosing_type_name, + write_methods=write_methods, + value_is_ok=value_is_ok, + ) + definition = code_template.format_template( + definition_template, + name=name, + virtual_view_type_name=virtual_view_type_name, logical_type=logical_type, - destination=destination, - transform=transform) - else: - write_methods = "" - - name = field_ir.name.canonical_name.object_path[-1] - if name.startswith("$"): - name = _cpp_field_name(field_ir.name.name.text) - virtual_view_type_name = "EmbossReservedDollarVirtual{}View".format(name) - else: - virtual_view_type_name = "EmbossReservedVirtual{}View".format( - name_conversion.snake_to_camel(name)) - assert logical_type, "Could not find appropriate C++ type for {}".format( - field_ir.read_transform) - if field_ir.read_transform.type.WhichOneof("type") == "integer": - write_to_text_stream_function = "WriteIntegerViewToTextStream" - elif field_ir.read_transform.type.WhichOneof("type") == "boolean": - write_to_text_stream_function = "WriteBooleanViewToTextStream" - elif field_ir.read_transform.type.WhichOneof("type") == "enumeration": - write_to_text_stream_function = "WriteEnumViewToTextStream" - else: - assert False, "Unexpected read-only virtual field type {}".format( - field_ir.read_transform.type.WhichOneof("type")) - - value_is_ok = _generate_validator_expression_for(field_ir, ir) - declaration = code_template.format_template( - declaration_template, - visibility=_visibility_for_field(field_ir), - name=name, - virtual_view_type_name=virtual_view_type_name, - logical_type=logical_type, - read_subexpressions="".join( - [" const auto {} = {};\n".format(subexpr_name, subexpr) - for subexpr_name, subexpr in read_subexpressions.subexprs()] - ), - read_value=read_value.rendered, - write_to_text_stream_function=write_to_text_stream_function, - parent_type=enclosing_type_name, - write_methods=write_methods, - value_is_ok=value_is_ok) - definition = code_template.format_template( - definition_template, - name=name, - virtual_view_type_name=virtual_view_type_name, - logical_type=logical_type, - read_value=read_value.rendered, - parent_type=enclosing_type_name, - field_exists=field_exists.rendered) - return "", declaration, definition + read_value=read_value.rendered, + parent_type=enclosing_type_name, + field_exists=field_exists.rendered, + ) + return "", declaration, definition def _generate_validator_type_for(enclosing_type_name, field_ir, ir): - """Returns a validator type name and definition for the given field.""" - result_expression = _generate_custom_validator_expression_for(field_ir, ir) - if result_expression is None: - return "::emboss::support::AllValuesAreOk", "" - - field_name = field_ir.name.canonical_name.object_path[-1] - validator_type_name = "EmbossReservedValidatorFor{}".format( - name_conversion.snake_to_camel(field_name)) - qualified_validator_type_name = "{}::{}".format(enclosing_type_name, - validator_type_name) - - validator_declaration = code_template.format_template( - _TEMPLATES.structure_field_validator, - name=validator_type_name, - expression=result_expression, - ) - validator_declaration = _wrap_in_namespace(validator_declaration, - [enclosing_type_name]) - return qualified_validator_type_name, validator_declaration - - -def _generate_structure_physical_field_methods(enclosing_type_name, field_ir, - ir, parent_addressable_unit): - """Generates C++ code for methods for a single physical field. - - Arguments: - enclosing_type_name: The text name of the enclosing type. - field_ir: The IR for the field to generate methods for. - ir: The full IR for the module. - parent_addressable_unit: The addressable unit (BIT or BYTE) of the enclosing - structure. - - Returns: - A tuple of (declarations, definitions). The declarations can be inserted - into the class definition for the enclosing type's View. Any definitions - should be placed after the class definition. These are separated to satisfy - C++'s declaration-before-use requirements. - """ - validator_type, validator_declaration = _generate_validator_type_for( - enclosing_type_name, field_ir, ir) - - type_reader, unused_parameter_types, parameter_expressions = ( - _get_cpp_type_reader_of_field(field_ir, ir, "Storage", validator_type, - parent_addressable_unit)) - - field_name = field_ir.name.canonical_name.object_path[-1] - - subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_") - parameter_values = [] - parameters_known = [] - for parameter in parameter_expressions: - parameter_cpp_expr = _render_expression( - parameter, ir, subexpressions=subexpressions) - parameter_values.append( - "{}.ValueOrDefault(), ".format(parameter_cpp_expr.rendered)) - parameters_known.append( - "{}.Known() && ".format(parameter_cpp_expr.rendered)) - parameter_subexpressions = "".join( - [" const auto {} = {};\n".format(name, subexpr) - for name, subexpr in subexpressions.subexprs()] - ) - - first_size_and_offset_subexpr = len(subexpressions.subexprs()) - offset = _render_expression( - field_ir.location.start, ir, subexpressions=subexpressions).rendered - size = _render_expression( - field_ir.location.size, ir, subexpressions=subexpressions).rendered - size_and_offset_subexpressions = "".join( - [" const auto {} = {};\n".format(name, subexpr) - for name, subexpr in subexpressions.subexprs()[ - first_size_and_offset_subexpr:]] - ) - - field_alignment, field_offset = _alignment_of_location(field_ir.location) - declaration = code_template.format_template( - _TEMPLATES.structure_single_field_method_declarations, - type_reader=type_reader, - visibility=_visibility_for_field(field_ir), - name=field_name) - definition = code_template.format_template( - _TEMPLATES.structure_single_field_method_definitions, - parent_type=enclosing_type_name, - name=field_name, - type_reader=type_reader, - offset=offset, - size=size, - size_and_offset_subexpressions=size_and_offset_subexpressions, - field_exists=_render_existence_test(field_ir, ir).rendered, - alignment=field_alignment, - parameters_known="".join(parameters_known), - parameter_values="".join(parameter_values), - parameter_subexpressions=parameter_subexpressions, - static_offset=field_offset) - return validator_declaration, declaration, definition + """Returns a validator type name and definition for the given field.""" + result_expression = _generate_custom_validator_expression_for(field_ir, ir) + if result_expression is None: + return "::emboss::support::AllValuesAreOk", "" + + field_name = field_ir.name.canonical_name.object_path[-1] + validator_type_name = "EmbossReservedValidatorFor{}".format( + name_conversion.snake_to_camel(field_name) + ) + qualified_validator_type_name = "{}::{}".format( + enclosing_type_name, validator_type_name + ) + + validator_declaration = code_template.format_template( + _TEMPLATES.structure_field_validator, + name=validator_type_name, + expression=result_expression, + ) + validator_declaration = _wrap_in_namespace( + validator_declaration, [enclosing_type_name] + ) + return qualified_validator_type_name, validator_declaration + + +def _generate_structure_physical_field_methods( + enclosing_type_name, field_ir, ir, parent_addressable_unit +): + """Generates C++ code for methods for a single physical field. + + Arguments: + enclosing_type_name: The text name of the enclosing type. + field_ir: The IR for the field to generate methods for. + ir: The full IR for the module. + parent_addressable_unit: The addressable unit (BIT or BYTE) of the enclosing + structure. + + Returns: + A tuple of (declarations, definitions). The declarations can be inserted + into the class definition for the enclosing type's View. Any definitions + should be placed after the class definition. These are separated to satisfy + C++'s declaration-before-use requirements. + """ + validator_type, validator_declaration = _generate_validator_type_for( + enclosing_type_name, field_ir, ir + ) + + type_reader, unused_parameter_types, parameter_expressions = ( + _get_cpp_type_reader_of_field( + field_ir, ir, "Storage", validator_type, parent_addressable_unit + ) + ) + + field_name = field_ir.name.canonical_name.object_path[-1] + + subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_") + parameter_values = [] + parameters_known = [] + for parameter in parameter_expressions: + parameter_cpp_expr = _render_expression( + parameter, ir, subexpressions=subexpressions + ) + parameter_values.append( + "{}.ValueOrDefault(), ".format(parameter_cpp_expr.rendered) + ) + parameters_known.append("{}.Known() && ".format(parameter_cpp_expr.rendered)) + parameter_subexpressions = "".join( + [ + " const auto {} = {};\n".format(name, subexpr) + for name, subexpr in subexpressions.subexprs() + ] + ) + + first_size_and_offset_subexpr = len(subexpressions.subexprs()) + offset = _render_expression( + field_ir.location.start, ir, subexpressions=subexpressions + ).rendered + size = _render_expression( + field_ir.location.size, ir, subexpressions=subexpressions + ).rendered + size_and_offset_subexpressions = "".join( + [ + " const auto {} = {};\n".format(name, subexpr) + for name, subexpr in subexpressions.subexprs()[ + first_size_and_offset_subexpr: + ] + ] + ) + + field_alignment, field_offset = _alignment_of_location(field_ir.location) + declaration = code_template.format_template( + _TEMPLATES.structure_single_field_method_declarations, + type_reader=type_reader, + visibility=_visibility_for_field(field_ir), + name=field_name, + ) + definition = code_template.format_template( + _TEMPLATES.structure_single_field_method_definitions, + parent_type=enclosing_type_name, + name=field_name, + type_reader=type_reader, + offset=offset, + size=size, + size_and_offset_subexpressions=size_and_offset_subexpressions, + field_exists=_render_existence_test(field_ir, ir).rendered, + alignment=field_alignment, + parameters_known="".join(parameters_known), + parameter_values="".join(parameter_values), + parameter_subexpressions=parameter_subexpressions, + static_offset=field_offset, + ) + return validator_declaration, declaration, definition def _render_size_method(fields, ir): - """Renders the Size methods of a struct or bits, using the correct templates. - - Arguments: - fields: The list of fields in the struct or bits. This is used to find the - $size_in_bits or $size_in_bytes virtual field. - ir: The IR to which fields belong. - - Returns: - A string representation of the Size methods, suitable for inclusion in an - Emboss View class. - """ - # The SizeInBytes(), SizeInBits(), and SizeIsKnown() methods just forward to - # the generated IntrinsicSizeIn$_units_$() method, which returns a virtual - # field with Read() and Ok() methods. - # - # TODO(bolms): Remove these shims, rename IntrinsicSizeIn$_units_$ to - # SizeIn$_units_$, and update all callers to the new API. - for field in fields: - if field.name.name.text in ("$size_in_bits", "$size_in_bytes"): - # If the read_transform and existence_condition are constant, then the - # size is constexpr. - if (_render_expression(field.read_transform, ir).is_constant and - _render_expression(field.existence_condition, ir).is_constant): - template = _TEMPLATES.constant_structure_size_method - else: - template = _TEMPLATES.runtime_structure_size_method - return code_template.format_template( - template, - units="Bits" if field.name.name.text == "$size_in_bits" else "Bytes") - assert False, "Expected a $size_in_bits or $size_in_bytes field." + """Renders the Size methods of a struct or bits, using the correct templates. + + Arguments: + fields: The list of fields in the struct or bits. This is used to find the + $size_in_bits or $size_in_bytes virtual field. + ir: The IR to which fields belong. + + Returns: + A string representation of the Size methods, suitable for inclusion in an + Emboss View class. + """ + # The SizeInBytes(), SizeInBits(), and SizeIsKnown() methods just forward to + # the generated IntrinsicSizeIn$_units_$() method, which returns a virtual + # field with Read() and Ok() methods. + # + # TODO(bolms): Remove these shims, rename IntrinsicSizeIn$_units_$ to + # SizeIn$_units_$, and update all callers to the new API. + for field in fields: + if field.name.name.text in ("$size_in_bits", "$size_in_bytes"): + # If the read_transform and existence_condition are constant, then the + # size is constexpr. + if ( + _render_expression(field.read_transform, ir).is_constant + and _render_expression(field.existence_condition, ir).is_constant + ): + template = _TEMPLATES.constant_structure_size_method + else: + template = _TEMPLATES.runtime_structure_size_method + return code_template.format_template( + template, + units="Bits" if field.name.name.text == "$size_in_bits" else "Bytes", + ) + assert False, "Expected a $size_in_bits or $size_in_bytes field." def _visibility_for_field(field_ir): - """Returns the C++ visibility for field_ir within its parent view.""" - # Generally, the Google style guide for hand-written C++ forbids having - # multiple public: and private: sections, but trying to conform to that bit of - # the style guide would make this file significantly more complex. - # - # Alias fields are generated as simple methods that forward directly to the - # aliased field's method: - # - # auto alias() const -> decltype(parent().child().aliased_subchild()) { - # return parent().child().aliased_subchild(); - # } - # - # Figuring out the return type of `parent().child().aliased_subchild()` is - # quite complex, since there are several levels of template indirection - # involved. It is much easier to just leave it up to the C++ compiler. - # - # Unfortunately, the C++ compiler will complain if `parent()` is not declared - # before `alias()`. If the `parent` field happens to be anonymous, the Google - # style guide would put `parent()`'s declaration after `alias()`'s - # declaration, which causes the C++ compiler to complain that `parent` is - # unknown. - # - # The easy fix to this is just to declare `parent()` before `alias()`, and - # explicitly mark `parent()` as `private` and `alias()` as `public`. - # - # Perhaps surprisingly, this limitation does not apply when `parent()`'s type - # is not yet complete at the point where `alias()` is declared; I believe this - # is because both `parent()` and `alias()` exist in a templated `class`, and - # by the time `parent().child().aliased_subchild()` is actually resolved, the - # compiler is instantiating the class and has the full definitions of all the - # other classes available. - if field_ir.name.is_anonymous: - return "private" - else: - return "public" + """Returns the C++ visibility for field_ir within its parent view.""" + # Generally, the Google style guide for hand-written C++ forbids having + # multiple public: and private: sections, but trying to conform to that bit of + # the style guide would make this file significantly more complex. + # + # Alias fields are generated as simple methods that forward directly to the + # aliased field's method: + # + # auto alias() const -> decltype(parent().child().aliased_subchild()) { + # return parent().child().aliased_subchild(); + # } + # + # Figuring out the return type of `parent().child().aliased_subchild()` is + # quite complex, since there are several levels of template indirection + # involved. It is much easier to just leave it up to the C++ compiler. + # + # Unfortunately, the C++ compiler will complain if `parent()` is not declared + # before `alias()`. If the `parent` field happens to be anonymous, the Google + # style guide would put `parent()`'s declaration after `alias()`'s + # declaration, which causes the C++ compiler to complain that `parent` is + # unknown. + # + # The easy fix to this is just to declare `parent()` before `alias()`, and + # explicitly mark `parent()` as `private` and `alias()` as `public`. + # + # Perhaps surprisingly, this limitation does not apply when `parent()`'s type + # is not yet complete at the point where `alias()` is declared; I believe this + # is because both `parent()` and `alias()` exist in a templated `class`, and + # by the time `parent().child().aliased_subchild()` is actually resolved, the + # compiler is instantiating the class and has the full definitions of all the + # other classes available. + if field_ir.name.is_anonymous: + return "private" + else: + return "public" def _generate_field_indirection(field_ir, parent_type_name, ir): - """Renders a method which forwards to a field's view.""" - rendered_aliased_field = _render_variable( - ir_util.hashable_form_of_field_reference(field_ir.write_method.alias)) - declaration = code_template.format_template( - _TEMPLATES.structure_single_field_indirect_method_declarations, - aliased_field=rendered_aliased_field, - visibility=_visibility_for_field(field_ir), - parent_type=parent_type_name, - name=field_ir.name.name.text) - definition = code_template.format_template( - _TEMPLATES.struct_single_field_indirect_method_definitions, - parent_type=parent_type_name, - name=field_ir.name.name.text, - aliased_field=rendered_aliased_field, - field_exists=_render_existence_test(field_ir, ir).rendered) - return "", declaration, definition + """Renders a method which forwards to a field's view.""" + rendered_aliased_field = _render_variable( + ir_util.hashable_form_of_field_reference(field_ir.write_method.alias) + ) + declaration = code_template.format_template( + _TEMPLATES.structure_single_field_indirect_method_declarations, + aliased_field=rendered_aliased_field, + visibility=_visibility_for_field(field_ir), + parent_type=parent_type_name, + name=field_ir.name.name.text, + ) + definition = code_template.format_template( + _TEMPLATES.struct_single_field_indirect_method_definitions, + parent_type=parent_type_name, + name=field_ir.name.name.text, + aliased_field=rendered_aliased_field, + field_exists=_render_existence_test(field_ir, ir).rendered, + ) + return "", declaration, definition def _generate_subtype_definitions(type_ir, ir, config: Config): - """Generates C++ code for subtypes of type_ir.""" - subtype_bodies = [] - subtype_forward_declarations = [] - subtype_method_definitions = [] - type_name = type_ir.name.name.text - for subtype in type_ir.subtype: - inner_defs = _generate_type_definition(subtype, ir, config) - subtype_forward_declaration, subtype_body, subtype_methods = inner_defs - subtype_forward_declarations.append(subtype_forward_declaration) - subtype_bodies.append(subtype_body) - subtype_method_definitions.append(subtype_methods) - wrapped_forward_declarations = _wrap_in_namespace( - "\n".join(subtype_forward_declarations), [type_name]) - wrapped_bodies = _wrap_in_namespace("\n".join(subtype_bodies), [type_name]) - wrapped_method_definitions = _wrap_in_namespace( - "\n".join(subtype_method_definitions), [type_name]) - return (wrapped_bodies, wrapped_forward_declarations, - wrapped_method_definitions) + """Generates C++ code for subtypes of type_ir.""" + subtype_bodies = [] + subtype_forward_declarations = [] + subtype_method_definitions = [] + type_name = type_ir.name.name.text + for subtype in type_ir.subtype: + inner_defs = _generate_type_definition(subtype, ir, config) + subtype_forward_declaration, subtype_body, subtype_methods = inner_defs + subtype_forward_declarations.append(subtype_forward_declaration) + subtype_bodies.append(subtype_body) + subtype_method_definitions.append(subtype_methods) + wrapped_forward_declarations = _wrap_in_namespace( + "\n".join(subtype_forward_declarations), [type_name] + ) + wrapped_bodies = _wrap_in_namespace("\n".join(subtype_bodies), [type_name]) + wrapped_method_definitions = _wrap_in_namespace( + "\n".join(subtype_method_definitions), [type_name] + ) + return (wrapped_bodies, wrapped_forward_declarations, wrapped_method_definitions) def _cpp_field_name(name): - """Returns the C++ name for the given field name.""" - if name.startswith("$"): - dollar_field_names = { - "$size_in_bits": "IntrinsicSizeInBits", - "$size_in_bytes": "IntrinsicSizeInBytes", - "$max_size_in_bits": "MaxSizeInBits", - "$min_size_in_bits": "MinSizeInBits", - "$max_size_in_bytes": "MaxSizeInBytes", - "$min_size_in_bytes": "MinSizeInBytes", - } - return dollar_field_names[name] - else: - return name + """Returns the C++ name for the given field name.""" + if name.startswith("$"): + dollar_field_names = { + "$size_in_bits": "IntrinsicSizeInBits", + "$size_in_bytes": "IntrinsicSizeInBytes", + "$max_size_in_bits": "MaxSizeInBits", + "$min_size_in_bits": "MinSizeInBits", + "$max_size_in_bytes": "MaxSizeInBytes", + "$min_size_in_bytes": "MinSizeInBytes", + } + return dollar_field_names[name] + else: + return name def _generate_structure_definition(type_ir, ir, config: Config): - """Generates C++ for an Emboss structure (struct or bits). - - Arguments: - type_ir: The IR for the struct definition. - ir: The full IR; used for type lookups. - - Returns: - A tuple of: (forward declaration for classes, class bodies, method bodies), - suitable for insertion into the appropriate places in the generated header. - """ - subtype_bodies, subtype_forward_declarations, subtype_method_definitions = ( - _generate_subtype_definitions(type_ir, ir, config)) - type_name = type_ir.name.name.text - field_helper_type_definitions = [] - field_method_declarations = [] - field_method_definitions = [] - virtual_field_type_definitions = [] - decode_field_clauses = [] - write_field_clauses = [] - ok_method_clauses = [] - equals_method_clauses = [] - unchecked_equals_method_clauses = [] - enum_using_statements = [] - parameter_fields = [] - constructor_parameters = [] - forwarded_parameters = [] - parameter_initializers = [] - parameter_copy_initializers = [] - units = {1: "Bits", 8: "Bytes"}[type_ir.addressable_unit] - - for subtype in type_ir.subtype: - if subtype.HasField("enumeration"): - enum_using_statements.append( - code_template.format_template( - _TEMPLATES.enum_using_statement, - component=_get_fully_qualified_name(subtype.name.canonical_name, - ir), - name=_get_unqualified_name(subtype.name.canonical_name))) - - # TODO(bolms): Reorder parameter fields to optimize packing in the view type. - for parameter in type_ir.runtime_parameter: - parameter_type = _cpp_basic_type_for_expression_type(parameter.type, ir) - parameter_name = parameter.name.name.text - parameter_fields.append("{} {}_;".format(parameter_type, parameter_name)) - constructor_parameters.append( - "{} {}, ".format(parameter_type, parameter_name)) - forwarded_parameters.append("::std::forward({}),".format( - parameter_type, parameter_name)) - parameter_initializers.append(", {0}_({0})".format(parameter_name)) - parameter_copy_initializers.append( - ", {0}_(emboss_reserved_local_other.{0}_)".format(parameter_name)) - - field_method_declarations.append( - code_template.format_template( - _TEMPLATES.structure_single_parameter_field_method_declarations, - name=parameter_name, - logical_type=parameter_type)) - # TODO(bolms): Should parameters appear in text format? - equals_method_clauses.append( - code_template.format_template(_TEMPLATES.equals_method_test, - field=parameter_name + "()")) - unchecked_equals_method_clauses.append( - code_template.format_template(_TEMPLATES.unchecked_equals_method_test, - field=parameter_name + "()")) - if type_ir.runtime_parameter: - flag_name = "parameters_initialized_" - parameter_copy_initializers.append( - ", {0}(emboss_reserved_local_other.{0})".format(flag_name)) - parameters_initialized_flag = "bool {} = false;".format(flag_name) - initialize_parameters_initialized_true = ", {}(true)".format(flag_name) - parameter_checks = ["if (!{}) return false;".format(flag_name)] - else: - parameters_initialized_flag = "" - initialize_parameters_initialized_true = "" - parameter_checks = [""] - - for field_index in type_ir.structure.fields_in_dependency_order: - field = type_ir.structure.field[field_index] - helper_types, declaration, definition = ( - _generate_structure_field_methods( - type_name, field, ir, type_ir.addressable_unit)) - field_helper_type_definitions.append(helper_types) - field_method_definitions.append(definition) - ok_method_clauses.append( - code_template.format_template( - _TEMPLATES.ok_method_test, - field=_cpp_field_name(field.name.name.text) + "()")) - if not ir_util.field_is_virtual(field): - # Virtual fields do not participate in equality tests -- they are equal by - # definition. - equals_method_clauses.append( - code_template.format_template( - _TEMPLATES.equals_method_test, field=field.name.name.text + "()")) - unchecked_equals_method_clauses.append( - code_template.format_template( - _TEMPLATES.unchecked_equals_method_test, - field=field.name.name.text + "()")) - field_method_declarations.append(declaration) - if not field.name.is_anonymous and not ir_util.field_is_read_only(field): - # As above, read-only fields cannot be decoded from text format. - decode_field_clauses.append( - code_template.format_template( - _TEMPLATES.decode_field, - field_name=field.name.canonical_name.object_path[-1])) - text_output_attr = ir_util.get_attribute(field.attribute, "text_output") - if not text_output_attr or text_output_attr.string_constant == "Emit": - if ir_util.field_is_read_only(field): - write_field_template = _TEMPLATES.write_read_only_field_to_text_stream - else: - write_field_template = _TEMPLATES.write_field_to_text_stream - write_field_clauses.append( - code_template.format_template( - write_field_template, - field_name=field.name.canonical_name.object_path[-1])) - - requires_attr = ir_util.get_attribute(type_ir.attribute, "requires") - if requires_attr is not None: - requires_clause = _render_expression( - requires_attr.expression, ir, _DirectFieldRenderer()).rendered - requires_check = (" if (!({}).ValueOr(false))\n" - " return false;").format(requires_clause) - else: - requires_check = "" - - if config.include_enum_traits: - text_stream_methods = code_template.format_template( - _TEMPLATES.struct_text_stream, - decode_fields="\n".join(decode_field_clauses), - write_fields="\n".join(write_field_clauses)) - else: - text_stream_methods = "" - - - class_forward_declarations = code_template.format_template( - _TEMPLATES.structure_view_declaration, - name=type_name) - class_bodies = code_template.format_template( - _TEMPLATES.structure_view_class, - name=type_ir.name.canonical_name.object_path[-1], - size_method=_render_size_method(type_ir.structure.field, ir), - field_method_declarations="".join(field_method_declarations), - field_ok_checks="\n".join(ok_method_clauses), - parameter_ok_checks="\n".join(parameter_checks), - requires_check=requires_check, - equals_method_body="\n".join(equals_method_clauses), - unchecked_equals_method_body="\n".join(unchecked_equals_method_clauses), - enum_usings="\n".join(enum_using_statements), - text_stream_methods=text_stream_methods, - parameter_fields="\n".join(parameter_fields), - constructor_parameters="".join(constructor_parameters), - forwarded_parameters="".join(forwarded_parameters), - parameter_initializers="\n".join(parameter_initializers), - parameter_copy_initializers="\n".join(parameter_copy_initializers), - parameters_initialized_flag=parameters_initialized_flag, - initialize_parameters_initialized_true=( - initialize_parameters_initialized_true), - units=units) - method_definitions = "\n".join(field_method_definitions) - early_virtual_field_types = "\n".join(virtual_field_type_definitions) - all_field_helper_type_definitions = "\n".join(field_helper_type_definitions) - return (early_virtual_field_types + subtype_forward_declarations + - class_forward_declarations, - all_field_helper_type_definitions + subtype_bodies + class_bodies, - subtype_method_definitions + method_definitions) + """Generates C++ for an Emboss structure (struct or bits). + + Arguments: + type_ir: The IR for the struct definition. + ir: The full IR; used for type lookups. + + Returns: + A tuple of: (forward declaration for classes, class bodies, method bodies), + suitable for insertion into the appropriate places in the generated header. + """ + subtype_bodies, subtype_forward_declarations, subtype_method_definitions = ( + _generate_subtype_definitions(type_ir, ir, config) + ) + type_name = type_ir.name.name.text + field_helper_type_definitions = [] + field_method_declarations = [] + field_method_definitions = [] + virtual_field_type_definitions = [] + decode_field_clauses = [] + write_field_clauses = [] + ok_method_clauses = [] + equals_method_clauses = [] + unchecked_equals_method_clauses = [] + enum_using_statements = [] + parameter_fields = [] + constructor_parameters = [] + forwarded_parameters = [] + parameter_initializers = [] + parameter_copy_initializers = [] + units = {1: "Bits", 8: "Bytes"}[type_ir.addressable_unit] + + for subtype in type_ir.subtype: + if subtype.HasField("enumeration"): + enum_using_statements.append( + code_template.format_template( + _TEMPLATES.enum_using_statement, + component=_get_fully_qualified_name( + subtype.name.canonical_name, ir + ), + name=_get_unqualified_name(subtype.name.canonical_name), + ) + ) + + # TODO(bolms): Reorder parameter fields to optimize packing in the view type. + for parameter in type_ir.runtime_parameter: + parameter_type = _cpp_basic_type_for_expression_type(parameter.type, ir) + parameter_name = parameter.name.name.text + parameter_fields.append("{} {}_;".format(parameter_type, parameter_name)) + constructor_parameters.append("{} {}, ".format(parameter_type, parameter_name)) + forwarded_parameters.append( + "::std::forward({}),".format(parameter_type, parameter_name) + ) + parameter_initializers.append(", {0}_({0})".format(parameter_name)) + parameter_copy_initializers.append( + ", {0}_(emboss_reserved_local_other.{0}_)".format(parameter_name) + ) + + field_method_declarations.append( + code_template.format_template( + _TEMPLATES.structure_single_parameter_field_method_declarations, + name=parameter_name, + logical_type=parameter_type, + ) + ) + # TODO(bolms): Should parameters appear in text format? + equals_method_clauses.append( + code_template.format_template( + _TEMPLATES.equals_method_test, field=parameter_name + "()" + ) + ) + unchecked_equals_method_clauses.append( + code_template.format_template( + _TEMPLATES.unchecked_equals_method_test, field=parameter_name + "()" + ) + ) + if type_ir.runtime_parameter: + flag_name = "parameters_initialized_" + parameter_copy_initializers.append( + ", {0}(emboss_reserved_local_other.{0})".format(flag_name) + ) + parameters_initialized_flag = "bool {} = false;".format(flag_name) + initialize_parameters_initialized_true = ", {}(true)".format(flag_name) + parameter_checks = ["if (!{}) return false;".format(flag_name)] + else: + parameters_initialized_flag = "" + initialize_parameters_initialized_true = "" + parameter_checks = [""] + + for field_index in type_ir.structure.fields_in_dependency_order: + field = type_ir.structure.field[field_index] + helper_types, declaration, definition = _generate_structure_field_methods( + type_name, field, ir, type_ir.addressable_unit + ) + field_helper_type_definitions.append(helper_types) + field_method_definitions.append(definition) + ok_method_clauses.append( + code_template.format_template( + _TEMPLATES.ok_method_test, + field=_cpp_field_name(field.name.name.text) + "()", + ) + ) + if not ir_util.field_is_virtual(field): + # Virtual fields do not participate in equality tests -- they are equal by + # definition. + equals_method_clauses.append( + code_template.format_template( + _TEMPLATES.equals_method_test, field=field.name.name.text + "()" + ) + ) + unchecked_equals_method_clauses.append( + code_template.format_template( + _TEMPLATES.unchecked_equals_method_test, + field=field.name.name.text + "()", + ) + ) + field_method_declarations.append(declaration) + if not field.name.is_anonymous and not ir_util.field_is_read_only(field): + # As above, read-only fields cannot be decoded from text format. + decode_field_clauses.append( + code_template.format_template( + _TEMPLATES.decode_field, + field_name=field.name.canonical_name.object_path[-1], + ) + ) + text_output_attr = ir_util.get_attribute(field.attribute, "text_output") + if not text_output_attr or text_output_attr.string_constant == "Emit": + if ir_util.field_is_read_only(field): + write_field_template = _TEMPLATES.write_read_only_field_to_text_stream + else: + write_field_template = _TEMPLATES.write_field_to_text_stream + write_field_clauses.append( + code_template.format_template( + write_field_template, + field_name=field.name.canonical_name.object_path[-1], + ) + ) + + requires_attr = ir_util.get_attribute(type_ir.attribute, "requires") + if requires_attr is not None: + requires_clause = _render_expression( + requires_attr.expression, ir, _DirectFieldRenderer() + ).rendered + requires_check = ( + " if (!({}).ValueOr(false))\n" " return false;" + ).format(requires_clause) + else: + requires_check = "" + + if config.include_enum_traits: + text_stream_methods = code_template.format_template( + _TEMPLATES.struct_text_stream, + decode_fields="\n".join(decode_field_clauses), + write_fields="\n".join(write_field_clauses), + ) + else: + text_stream_methods = "" + + class_forward_declarations = code_template.format_template( + _TEMPLATES.structure_view_declaration, name=type_name + ) + class_bodies = code_template.format_template( + _TEMPLATES.structure_view_class, + name=type_ir.name.canonical_name.object_path[-1], + size_method=_render_size_method(type_ir.structure.field, ir), + field_method_declarations="".join(field_method_declarations), + field_ok_checks="\n".join(ok_method_clauses), + parameter_ok_checks="\n".join(parameter_checks), + requires_check=requires_check, + equals_method_body="\n".join(equals_method_clauses), + unchecked_equals_method_body="\n".join(unchecked_equals_method_clauses), + enum_usings="\n".join(enum_using_statements), + text_stream_methods=text_stream_methods, + parameter_fields="\n".join(parameter_fields), + constructor_parameters="".join(constructor_parameters), + forwarded_parameters="".join(forwarded_parameters), + parameter_initializers="\n".join(parameter_initializers), + parameter_copy_initializers="\n".join(parameter_copy_initializers), + parameters_initialized_flag=parameters_initialized_flag, + initialize_parameters_initialized_true=(initialize_parameters_initialized_true), + units=units, + ) + method_definitions = "\n".join(field_method_definitions) + early_virtual_field_types = "\n".join(virtual_field_type_definitions) + all_field_helper_type_definitions = "\n".join(field_helper_type_definitions) + return ( + early_virtual_field_types + + subtype_forward_declarations + + class_forward_declarations, + all_field_helper_type_definitions + subtype_bodies + class_bodies, + subtype_method_definitions + method_definitions, + ) def _split_enum_case_values_into_spans(enum_case_value): - """Yields spans containing each enum case in an enum_case attribute value. - - Each span is of the form (start, end), which is the start and end position - relative to the beginning of the enum_case_value string. To keep the grammar - of this attribute simple, this only splits on delimiters and trims whitespace - for each case. - - Example: 'SHOUTY_CASE, kCamelCase' -> [(0, 11), (13, 23)]""" - # Scan the string from left to right, finding commas and trimming whitespace. - # This is essentially equivalent to (x.trim() fror x in str.split(',')) - # except that this yields spans within the string rather than the strings - # themselves, and no span is yielded for a trailing comma. - start, end = 0, len(enum_case_value) - while start <= end: - # Find a ',' delimiter to split on - delimiter = enum_case_value.find(',', start, end) - if delimiter < 0: - delimiter = end - - substr_start = start - substr_end = delimiter - - # Drop leading whitespace - while (substr_start < substr_end and - enum_case_value[substr_start].isspace()): - substr_start += 1 - # Drop trailing whitespace - while (substr_start < substr_end and - enum_case_value[substr_end - 1].isspace()): - substr_end -= 1 - - # Skip a trailing comma - if substr_start == end and start != 0: - break - - yield substr_start, substr_end - start = delimiter + 1 + """Yields spans containing each enum case in an enum_case attribute value. + + Each span is of the form (start, end), which is the start and end position + relative to the beginning of the enum_case_value string. To keep the grammar + of this attribute simple, this only splits on delimiters and trims whitespace + for each case. + + Example: 'SHOUTY_CASE, kCamelCase' -> [(0, 11), (13, 23)]""" + # Scan the string from left to right, finding commas and trimming whitespace. + # This is essentially equivalent to (x.trim() fror x in str.split(',')) + # except that this yields spans within the string rather than the strings + # themselves, and no span is yielded for a trailing comma. + start, end = 0, len(enum_case_value) + while start <= end: + # Find a ',' delimiter to split on + delimiter = enum_case_value.find(",", start, end) + if delimiter < 0: + delimiter = end + + substr_start = start + substr_end = delimiter + + # Drop leading whitespace + while substr_start < substr_end and enum_case_value[substr_start].isspace(): + substr_start += 1 + # Drop trailing whitespace + while substr_start < substr_end and enum_case_value[substr_end - 1].isspace(): + substr_end -= 1 + + # Skip a trailing comma + if substr_start == end and start != 0: + break + + yield substr_start, substr_end + start = delimiter + 1 def _split_enum_case_values(enum_case_value): - """Returns all enum cases in an enum case value. + """Returns all enum cases in an enum case value. - Example: 'SHOUTY_CASE, kCamelCase' -> ['SHOUTY_CASE', 'kCamelCase']""" - return [enum_case_value[start:end] for start, end - in _split_enum_case_values_into_spans(enum_case_value)] + Example: 'SHOUTY_CASE, kCamelCase' -> ['SHOUTY_CASE', 'kCamelCase']""" + return [ + enum_case_value[start:end] + for start, end in _split_enum_case_values_into_spans(enum_case_value) + ] def _get_enum_value_names(enum_value): - """Determines one or more enum names based on attributes""" - cases = ["SHOUTY_CASE"] - name = enum_value.name.name.text - if enum_case := ir_util.get_attribute(enum_value.attribute, - attributes.Attribute.ENUM_CASE): - cases = _split_enum_case_values(enum_case.string_constant.text) - return [name_conversion.convert_case("SHOUTY_CASE", case, name) - for case in cases] + """Determines one or more enum names based on attributes""" + cases = ["SHOUTY_CASE"] + name = enum_value.name.name.text + if enum_case := ir_util.get_attribute( + enum_value.attribute, attributes.Attribute.ENUM_CASE + ): + cases = _split_enum_case_values(enum_case.string_constant.text) + return [name_conversion.convert_case("SHOUTY_CASE", case, name) for case in cases] def _generate_enum_definition(type_ir, include_traits=True): - """Generates C++ for an Emboss enum.""" - enum_values = [] - enum_from_string_statements = [] - string_from_enum_statements = [] - enum_is_known_statements = [] - previously_seen_numeric_values = set() - max_bits = ir_util.get_integer_attribute(type_ir.attribute, "maximum_bits") - is_signed = ir_util.get_boolean_attribute(type_ir.attribute, "is_signed") - enum_type = _cpp_integer_type_for_enum(max_bits, is_signed) - for value in type_ir.enumeration.value: - numeric_value = ir_util.constant_value(value.value) - enum_value_names = _get_enum_value_names(value) - - for enum_value_name in enum_value_names: - enum_values.append( - code_template.format_template(_TEMPLATES.enum_value, - name=enum_value_name, - value=_render_integer(numeric_value))) - if include_traits: - enum_from_string_statements.append( - code_template.format_template(_TEMPLATES.enum_from_name_case, - enum=type_ir.name.name.text, - value=enum_value_name, - name=value.name.name.text)) - - if numeric_value not in previously_seen_numeric_values: - string_from_enum_statements.append( - code_template.format_template(_TEMPLATES.name_from_enum_case, - enum=type_ir.name.name.text, - value=enum_value_name, - name=value.name.name.text)) - - enum_is_known_statements.append( - code_template.format_template(_TEMPLATES.enum_is_known_case, - enum=type_ir.name.name.text, - name=enum_value_name)) - previously_seen_numeric_values.add(numeric_value) - - declaration = code_template.format_template( - _TEMPLATES.enum_declaration, - enum=type_ir.name.name.text, - enum_type=enum_type) - definition = code_template.format_template( - _TEMPLATES.enum_definition, - enum=type_ir.name.name.text, - enum_type=enum_type, - enum_values="".join(enum_values)) - if include_traits: - definition += code_template.format_template( - _TEMPLATES.enum_traits, - enum=type_ir.name.name.text, - enum_from_name_cases="\n".join(enum_from_string_statements), - name_from_enum_cases="\n".join(string_from_enum_statements), - enum_is_known_cases="\n".join(enum_is_known_statements)) - - return (declaration, definition, "") + """Generates C++ for an Emboss enum.""" + enum_values = [] + enum_from_string_statements = [] + string_from_enum_statements = [] + enum_is_known_statements = [] + previously_seen_numeric_values = set() + max_bits = ir_util.get_integer_attribute(type_ir.attribute, "maximum_bits") + is_signed = ir_util.get_boolean_attribute(type_ir.attribute, "is_signed") + enum_type = _cpp_integer_type_for_enum(max_bits, is_signed) + for value in type_ir.enumeration.value: + numeric_value = ir_util.constant_value(value.value) + enum_value_names = _get_enum_value_names(value) + + for enum_value_name in enum_value_names: + enum_values.append( + code_template.format_template( + _TEMPLATES.enum_value, + name=enum_value_name, + value=_render_integer(numeric_value), + ) + ) + if include_traits: + enum_from_string_statements.append( + code_template.format_template( + _TEMPLATES.enum_from_name_case, + enum=type_ir.name.name.text, + value=enum_value_name, + name=value.name.name.text, + ) + ) + + if numeric_value not in previously_seen_numeric_values: + string_from_enum_statements.append( + code_template.format_template( + _TEMPLATES.name_from_enum_case, + enum=type_ir.name.name.text, + value=enum_value_name, + name=value.name.name.text, + ) + ) + + enum_is_known_statements.append( + code_template.format_template( + _TEMPLATES.enum_is_known_case, + enum=type_ir.name.name.text, + name=enum_value_name, + ) + ) + previously_seen_numeric_values.add(numeric_value) + + declaration = code_template.format_template( + _TEMPLATES.enum_declaration, enum=type_ir.name.name.text, enum_type=enum_type + ) + definition = code_template.format_template( + _TEMPLATES.enum_definition, + enum=type_ir.name.name.text, + enum_type=enum_type, + enum_values="".join(enum_values), + ) + if include_traits: + definition += code_template.format_template( + _TEMPLATES.enum_traits, + enum=type_ir.name.name.text, + enum_from_name_cases="\n".join(enum_from_string_statements), + name_from_enum_cases="\n".join(string_from_enum_statements), + enum_is_known_cases="\n".join(enum_is_known_statements), + ) + + return (declaration, definition, "") def _generate_type_definition(type_ir, ir, config: Config): - """Generates C++ for an Emboss type.""" - if type_ir.HasField("structure"): - return _generate_structure_definition(type_ir, ir, config) - elif type_ir.HasField("enumeration"): - return _generate_enum_definition(type_ir, config.include_enum_traits) - elif type_ir.HasField("external"): - # TODO(bolms): This should probably generate an #include. - return "", "", "" - else: - # TODO(bolms): provide error message instead of ICE - assert False, "Unknown type {}".format(type_ir) + """Generates C++ for an Emboss type.""" + if type_ir.HasField("structure"): + return _generate_structure_definition(type_ir, ir, config) + elif type_ir.HasField("enumeration"): + return _generate_enum_definition(type_ir, config.include_enum_traits) + elif type_ir.HasField("external"): + # TODO(bolms): This should probably generate an #include. + return "", "", "" + else: + # TODO(bolms): provide error message instead of ICE + assert False, "Unknown type {}".format(type_ir) def _generate_header_guard(file_path): - # TODO(bolms): Make this configurable. - header_path = file_path + ".h" - uppercased_path = header_path.upper() - no_punctuation_path = re.sub(r"[^A-Za-z0-9_]", "_", uppercased_path) - suffixed_path = no_punctuation_path + "_" - no_double_underscore_path = re.sub(r"__+", "_", suffixed_path) - return no_double_underscore_path + # TODO(bolms): Make this configurable. + header_path = file_path + ".h" + uppercased_path = header_path.upper() + no_punctuation_path = re.sub(r"[^A-Za-z0-9_]", "_", uppercased_path) + suffixed_path = no_punctuation_path + "_" + no_double_underscore_path = re.sub(r"__+", "_", suffixed_path) + return no_double_underscore_path def _add_missing_enum_case_attribute_on_enum_value(enum_value, defaults): - """Adds an `enum_case` attribute if there isn't one but a default is set.""" - if ir_util.get_attribute(enum_value.attribute, - attributes.Attribute.ENUM_CASE) is None: - if attributes.Attribute.ENUM_CASE in defaults: - enum_value.attribute.extend([defaults[attributes.Attribute.ENUM_CASE]]) + """Adds an `enum_case` attribute if there isn't one but a default is set.""" + if ( + ir_util.get_attribute(enum_value.attribute, attributes.Attribute.ENUM_CASE) + is None + ): + if attributes.Attribute.ENUM_CASE in defaults: + enum_value.attribute.extend([defaults[attributes.Attribute.ENUM_CASE]]) def _propagate_defaults(ir, targets, ancestors, add_fn): - """Propagates default values - - Traverses the IR to propagate default values to target nodes. - - Arguments: - targets: A list of target IR types to add attributes to. - ancestors: Ancestor types which may contain the default values. - add_fn: Function to add the attribute. May use any parameter available in - fast_traverse_ir_top_down actions as well as `defaults` containing the - default attributes set by ancestors. - - Returns: - None - """ - traverse_ir.fast_traverse_ir_top_down( - ir, targets, add_fn, - incidental_actions={ - ancestor: attribute_util.gather_default_attributes - for ancestor in ancestors - }, - parameters={"defaults": {}}) + """Propagates default values + + Traverses the IR to propagate default values to target nodes. + + Arguments: + targets: A list of target IR types to add attributes to. + ancestors: Ancestor types which may contain the default values. + add_fn: Function to add the attribute. May use any parameter available in + fast_traverse_ir_top_down actions as well as `defaults` containing the + default attributes set by ancestors. + + Returns: + None + """ + traverse_ir.fast_traverse_ir_top_down( + ir, + targets, + add_fn, + incidental_actions={ + ancestor: attribute_util.gather_default_attributes for ancestor in ancestors + }, + parameters={"defaults": {}}, + ) def _offset_source_location_column(source_location, offset): - """Adds offsets from the start column of the supplied source location + """Adds offsets from the start column of the supplied source location - Returns a new source location with all of the same properties as the provided - source location, but with the columns modified by offsets from the original - start column. + Returns a new source location with all of the same properties as the provided + source location, but with the columns modified by offsets from the original + start column. - Offset should be a tuple of (start, end), which are the offsets relative to - source_location.start.column to set the new start.column and end.column.""" + Offset should be a tuple of (start, end), which are the offsets relative to + source_location.start.column to set the new start.column and end.column.""" - new_location = ir_data_utils.copy(source_location) - new_location.start.column = source_location.start.column + offset[0] - new_location.end.column = source_location.start.column + offset[1] + new_location = ir_data_utils.copy(source_location) + new_location.start.column = source_location.start.column + offset[0] + new_location.end.column = source_location.start.column + offset[1] - return new_location + return new_location def _verify_namespace_attribute(attr, source_file_name, errors): - if attr.name.text != attributes.Attribute.NAMESPACE: - return - namespace_value = ir_data_utils.reader(attr).value.string_constant - if not re.match(_NS_RE, namespace_value.text): - if re.match(_NS_EMPTY_RE, namespace_value.text): - errors.append([error.error( - source_file_name, namespace_value.source_location, - 'Empty namespace value is not allowed.')]) - elif re.match(_NS_GLOBAL_RE, namespace_value.text): - errors.append([error.error( - source_file_name, namespace_value.source_location, - 'Global namespace is not allowed.')]) - else: - errors.append([error.error( - source_file_name, namespace_value.source_location, - 'Invalid namespace, must be a valid C++ namespace, such as "abc", ' - '"abc::def", or "::abc::def::ghi" (ISO/IEC 14882:2017 ' - 'enclosing-namespace-specifier).')]) - return - for word in _get_namespace_components(namespace_value.text): - if word in _CPP_RESERVED_WORDS: - errors.append([error.error( - source_file_name, namespace_value.source_location, - f'Reserved word "{word}" is not allowed as a namespace component.' - )]) + if attr.name.text != attributes.Attribute.NAMESPACE: + return + namespace_value = ir_data_utils.reader(attr).value.string_constant + if not re.match(_NS_RE, namespace_value.text): + if re.match(_NS_EMPTY_RE, namespace_value.text): + errors.append( + [ + error.error( + source_file_name, + namespace_value.source_location, + "Empty namespace value is not allowed.", + ) + ] + ) + elif re.match(_NS_GLOBAL_RE, namespace_value.text): + errors.append( + [ + error.error( + source_file_name, + namespace_value.source_location, + "Global namespace is not allowed.", + ) + ] + ) + else: + errors.append( + [ + error.error( + source_file_name, + namespace_value.source_location, + 'Invalid namespace, must be a valid C++ namespace, such as "abc", ' + '"abc::def", or "::abc::def::ghi" (ISO/IEC 14882:2017 ' + "enclosing-namespace-specifier).", + ) + ] + ) + return + for word in _get_namespace_components(namespace_value.text): + if word in _CPP_RESERVED_WORDS: + errors.append( + [ + error.error( + source_file_name, + namespace_value.source_location, + f'Reserved word "{word}" is not allowed as a namespace component.', + ) + ] + ) def _verify_enum_case_attribute(attr, source_file_name, errors): - """Verify that `enum_case` values are supported.""" - if attr.name.text != attributes.Attribute.ENUM_CASE: - return - - VALID_CASES = ', '.join(case for case in _SUPPORTED_ENUM_CASES) - enum_case_value = attr.value.string_constant - case_spans = _split_enum_case_values_into_spans(enum_case_value.text) - seen_cases = set() - - for start, end in case_spans: - case_source_location = _offset_source_location_column( - enum_case_value.source_location, (start, end)) - case = enum_case_value.text[start:end] - - if start == end: - errors.append([error.error( - source_file_name, case_source_location, - 'Empty enum case (or excess comma).')]) - continue - - if case in seen_cases: - errors.append([error.error( - source_file_name, case_source_location, - f'Duplicate enum case "{case}".')]) - continue - seen_cases.add(case) - - if case not in _SUPPORTED_ENUM_CASES: - errors.append([error.error( - source_file_name, case_source_location, - f'Unsupported enum case "{case}", ' - f'supported cases are: {VALID_CASES}.')]) + """Verify that `enum_case` values are supported.""" + if attr.name.text != attributes.Attribute.ENUM_CASE: + return + + VALID_CASES = ", ".join(case for case in _SUPPORTED_ENUM_CASES) + enum_case_value = attr.value.string_constant + case_spans = _split_enum_case_values_into_spans(enum_case_value.text) + seen_cases = set() + + for start, end in case_spans: + case_source_location = _offset_source_location_column( + enum_case_value.source_location, (start, end) + ) + case = enum_case_value.text[start:end] + + if start == end: + errors.append( + [ + error.error( + source_file_name, + case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ) + continue + + if case in seen_cases: + errors.append( + [ + error.error( + source_file_name, + case_source_location, + f'Duplicate enum case "{case}".', + ) + ] + ) + continue + seen_cases.add(case) + + if case not in _SUPPORTED_ENUM_CASES: + errors.append( + [ + error.error( + source_file_name, + case_source_location, + f'Unsupported enum case "{case}", ' + f"supported cases are: {VALID_CASES}.", + ) + ] + ) def _verify_attribute_values(ir): - """Verify backend attribute values.""" - errors = [] - - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Attribute], _verify_namespace_attribute, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Attribute], _verify_enum_case_attribute, - parameters={"errors": errors}) - - return errors - + """Verify backend attribute values.""" + errors = [] + + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Attribute], + _verify_namespace_attribute, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Attribute], + _verify_enum_case_attribute, + parameters={"errors": errors}, + ) -def _propagate_defaults_and_verify_attributes(ir): - """Verify attributes and ensure defaults are set when not overridden. - - Returns a list of errors if there are errors present, or an empty list if - verification completed successfully.""" - if errors := attribute_util.check_attributes_in_ir( - ir, - back_end="cpp", - types=attributes.TYPES, - module_attributes=attributes.Scope.MODULE, - struct_attributes=attributes.Scope.STRUCT, - bits_attributes=attributes.Scope.BITS, - enum_attributes=attributes.Scope.ENUM, - enum_value_attributes=attributes.Scope.ENUM_VALUE): return errors - if errors := _verify_attribute_values(ir): - return errors - # Ensure defaults are set on EnumValues for `enum_case`. - _propagate_defaults( - ir, - targets=[ir_data.EnumValue], - ancestors=[ir_data.Module, ir_data.TypeDefinition], - add_fn=_add_missing_enum_case_attribute_on_enum_value) +def _propagate_defaults_and_verify_attributes(ir): + """Verify attributes and ensure defaults are set when not overridden. + + Returns a list of errors if there are errors present, or an empty list if + verification completed successfully.""" + if errors := attribute_util.check_attributes_in_ir( + ir, + back_end="cpp", + types=attributes.TYPES, + module_attributes=attributes.Scope.MODULE, + struct_attributes=attributes.Scope.STRUCT, + bits_attributes=attributes.Scope.BITS, + enum_attributes=attributes.Scope.ENUM, + enum_value_attributes=attributes.Scope.ENUM_VALUE, + ): + return errors + + if errors := _verify_attribute_values(ir): + return errors + + # Ensure defaults are set on EnumValues for `enum_case`. + _propagate_defaults( + ir, + targets=[ir_data.EnumValue], + ancestors=[ir_data.Module, ir_data.TypeDefinition], + add_fn=_add_missing_enum_case_attribute_on_enum_value, + ) - return [] + return [] def generate_header(ir, config=Config()): - """Generates a C++ header from an Emboss module. - - Arguments: - ir: An EmbossIr of the module. - - Returns: - A tuple of (header, errors), where `header` is either a string containing - the text of a C++ header which implements Views for the types in the Emboss - module, or None, and `errors` is a possibly-empty list of error messages to - display to the user. - """ - errors = _propagate_defaults_and_verify_attributes(ir) - if errors: - return None, errors - type_declarations = [] - type_definitions = [] - method_definitions = [] - for type_definition in ir.module[0].type: - declaration, definition, methods = _generate_type_definition( - type_definition, ir, config) - type_declarations.append(declaration) - type_definitions.append(definition) - method_definitions.append(methods) - body = code_template.format_template( - _TEMPLATES.body, - type_declarations="".join(type_declarations), - type_definitions="".join(type_definitions), - method_definitions="".join(method_definitions)) - body = _wrap_in_namespace(body, _get_module_namespace(ir.module[0])) - includes = _get_includes(ir.module[0], config) - return code_template.format_template( - _TEMPLATES.outline, - includes=includes, - body=body, - header_guard=_generate_header_guard(ir.module[0].source_file_name)), [] + """Generates a C++ header from an Emboss module. + + Arguments: + ir: An EmbossIr of the module. + + Returns: + A tuple of (header, errors), where `header` is either a string containing + the text of a C++ header which implements Views for the types in the Emboss + module, or None, and `errors` is a possibly-empty list of error messages to + display to the user. + """ + errors = _propagate_defaults_and_verify_attributes(ir) + if errors: + return None, errors + type_declarations = [] + type_definitions = [] + method_definitions = [] + for type_definition in ir.module[0].type: + declaration, definition, methods = _generate_type_definition( + type_definition, ir, config + ) + type_declarations.append(declaration) + type_definitions.append(definition) + method_definitions.append(methods) + body = code_template.format_template( + _TEMPLATES.body, + type_declarations="".join(type_declarations), + type_definitions="".join(type_definitions), + method_definitions="".join(method_definitions), + ) + body = _wrap_in_namespace(body, _get_module_namespace(ir.module[0])) + includes = _get_includes(ir.module[0], config) + return ( + code_template.format_template( + _TEMPLATES.outline, + includes=includes, + body=body, + header_guard=_generate_header_guard(ir.module[0].source_file_name), + ), + [], + ) diff --git a/compiler/back_end/cpp/header_generator_test.py b/compiler/back_end/cpp/header_generator_test.py index d58f798..6d31df8 100644 --- a/compiler/back_end/cpp/header_generator_test.py +++ b/compiler/back_end/cpp/header_generator_test.py @@ -22,359 +22,529 @@ from compiler.util import ir_data_utils from compiler.util import test_util + def _make_ir_from_emb(emb_text, name="m.emb"): - ir, unused_debug_info, errors = glue.parse_emboss_file( - name, - test_util.dict_file_reader({name: emb_text})) - assert not errors - return ir + ir, unused_debug_info, errors = glue.parse_emboss_file( + name, test_util.dict_file_reader({name: emb_text}) + ) + assert not errors + return ir class NormalizeIrTest(unittest.TestCase): - def test_accepts_string_attribute(self): - ir = _make_ir_from_emb('[(cpp) namespace: "foo"]\n') - self.assertEqual([], header_generator.generate_header(ir)[1]) - - def test_rejects_wrong_type_for_string_attribute(self): - ir = _make_ir_from_emb("[(cpp) namespace: 9]\n") - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error("m.emb", attr.value.source_location, - "Attribute '(cpp) namespace' must have a string value.") - ]], header_generator.generate_header(ir)[1]) - - def test_rejects_emboss_internal_attribute_with_back_end_specifier(self): - ir = _make_ir_from_emb('[(cpp) byte_order: "LittleEndian"]\n') - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error("m.emb", attr.name.source_location, - "Unknown attribute '(cpp) byte_order' on module 'm.emb'.") - ]], header_generator.generate_header(ir)[1]) - - def test_accepts_enum_case(self): - mod_ir = _make_ir_from_emb('[(cpp) $default enum_case: "kCamelCase"]') - self.assertEqual([], header_generator.generate_header(mod_ir)[1]) - enum_ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "kCamelCase"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - self.assertEqual([], header_generator.generate_header(enum_ir)[1]) - enum_value_ir = _make_ir_from_emb('enum Foo:\n' - ' BAR = 1 [(cpp) enum_case: "kCamelCase"]\n' - ' BAZ = 2\n' - ' [(cpp) enum_case: "kCamelCase"]\n') - self.assertEqual([], header_generator.generate_header(enum_value_ir)[1]) - enum_in_struct_ir = _make_ir_from_emb('struct Outer:\n' - ' [(cpp) $default enum_case: "kCamelCase"]\n' - ' enum Inner:\n' - ' BAR = 1\n' - ' BAZ = 2\n') - self.assertEqual([], header_generator.generate_header(enum_in_struct_ir)[1]) - enum_in_bits_ir = _make_ir_from_emb('bits Outer:\n' - ' [(cpp) $default enum_case: "kCamelCase"]\n' - ' enum Inner:\n' - ' BAR = 1\n' - ' BAZ = 2\n') - self.assertEqual([], header_generator.generate_header(enum_in_bits_ir)[1]) - enum_ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE,"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - self.assertEqual([], header_generator.generate_header(enum_ir)[1]) - enum_ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE ,kCamelCase"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - self.assertEqual([], header_generator.generate_header(enum_ir)[1]) - - def test_rejects_bad_enum_case_at_start(self): - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHORTY_CASE, kCamelCase"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - attr = ir.module[0].type[0].attribute[0] - - bad_case_source_location = ir_data.Location() - bad_case_source_location = ir_data_utils.builder(bad_case_source_location) - bad_case_source_location.CopyFrom(attr.value.source_location) - # Location of SHORTY_CASE in the attribute line. - bad_case_source_location.start.column = 30 - bad_case_source_location.end.column = 41 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Unsupported enum case "SHORTY_CASE", ' - 'supported cases are: SHOUTY_CASE, kCamelCase.') - ]], header_generator.generate_header(ir)[1]) - - def test_rejects_bad_enum_case_in_middle(self): - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE, bad_CASE, kCamelCase"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - attr = ir.module[0].type[0].attribute[0] - - bad_case_source_location = ir_data.Location() - bad_case_source_location = ir_data_utils.builder(bad_case_source_location) - bad_case_source_location.CopyFrom(attr.value.source_location) - # Location of bad_CASE in the attribute line. - bad_case_source_location.start.column = 43 - bad_case_source_location.end.column = 51 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Unsupported enum case "bad_CASE", ' - 'supported cases are: SHOUTY_CASE, kCamelCase.') - ]], header_generator.generate_header(ir)[1]) - - def test_rejects_bad_enum_case_at_end(self): - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase, BAD_case"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - attr = ir.module[0].type[0].attribute[0] - - bad_case_source_location = ir_data.Location() - bad_case_source_location = ir_data_utils.builder(bad_case_source_location) - bad_case_source_location.CopyFrom(attr.value.source_location) - # Location of BAD_case in the attribute line. - bad_case_source_location.start.column = 55 - bad_case_source_location.end.column = 63 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Unsupported enum case "BAD_case", ' - 'supported cases are: SHOUTY_CASE, kCamelCase.') - ]], header_generator.generate_header(ir)[1]) - - def test_rejects_duplicate_enum_case(self): - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE, SHOUTY_CASE"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - attr = ir.module[0].type[0].attribute[0] - - bad_case_source_location = ir_data.Location() - bad_case_source_location = ir_data_utils.builder(bad_case_source_location) - bad_case_source_location.CopyFrom(attr.value.source_location) - # Location of the second SHOUTY_CASE in the attribute line. - bad_case_source_location.start.column = 43 - bad_case_source_location.end.column = 54 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Duplicate enum case "SHOUTY_CASE".') - ]], header_generator.generate_header(ir)[1]) - - - def test_rejects_empty_enum_case(self): - # Double comma - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE,, kCamelCase"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - attr = ir.module[0].type[0].attribute[0] - - bad_case_source_location = ir_data.Location() - bad_case_source_location = ir_data_utils.builder(bad_case_source_location) - bad_case_source_location.CopyFrom(attr.value.source_location) - # Location of excess comma. - bad_case_source_location.start.column = 42 - bad_case_source_location.end.column = 42 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Empty enum case (or excess comma).') - ]], header_generator.generate_header(ir)[1]) - - # Leading comma - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: ", SHOUTY_CASE, kCamelCase"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - - bad_case_source_location.start.column = 30 - bad_case_source_location.end.column = 30 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Empty enum case (or excess comma).') - ]], header_generator.generate_header(ir)[1]) - - # Excess trailing comma - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase,,"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - - bad_case_source_location.start.column = 54 - bad_case_source_location.end.column = 54 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Empty enum case (or excess comma).') - ]], header_generator.generate_header(ir)[1]) - - # Whitespace enum case - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: "SHOUTY_CASE, , kCamelCase"]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - - bad_case_source_location.start.column = 45 - bad_case_source_location.end.column = 45 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Empty enum case (or excess comma).') - ]], header_generator.generate_header(ir)[1]) - - # Empty enum_case string - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: ""]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - - bad_case_source_location.start.column = 30 - bad_case_source_location.end.column = 30 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Empty enum case (or excess comma).') - ]], header_generator.generate_header(ir)[1]) - - # Whitespace enum_case string - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: " "]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - - bad_case_source_location.start.column = 35 - bad_case_source_location.end.column = 35 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Empty enum case (or excess comma).') - ]], header_generator.generate_header(ir)[1]) - - # One-character whitespace enum_case string - ir = _make_ir_from_emb('enum Foo:\n' - ' [(cpp) $default enum_case: " "]\n' - ' BAR = 1\n' - ' BAZ = 2\n') - - bad_case_source_location.start.column = 31 - bad_case_source_location.end.column = 31 - - self.assertEqual([[ - error.error("m.emb", bad_case_source_location, - 'Empty enum case (or excess comma).') - ]], header_generator.generate_header(ir)[1]) - - def test_accepts_namespace(self): - for test in [ - '[(cpp) namespace: "basic"]\n', - '[(cpp) namespace: "multiple::components"]\n', - '[(cpp) namespace: "::absolute"]\n', - '[(cpp) namespace: "::fully::qualified"]\n', - '[(cpp) namespace: "CAN::Be::cAPITAL"]\n', - '[(cpp) namespace: "trailingNumbers54321"]\n', - '[(cpp) namespace: "containing4321numbers"]\n', - '[(cpp) namespace: "can_have_underscores"]\n', - '[(cpp) namespace: "_initial_underscore"]\n', - '[(cpp) namespace: "_initial::_underscore"]\n', - '[(cpp) namespace: "::_initial::_underscore"]\n', - '[(cpp) namespace: "trailing_underscore_"]\n', - '[(cpp) namespace: "trailing_::underscore_"]\n', - '[(cpp) namespace: "::trailing_::underscore_"]\n', - '[(cpp) namespace: " spaces "]\n', - '[(cpp) namespace: "with :: spaces"]\n', - '[(cpp) namespace: " ::fully:: qualified :: with::spaces"]\n', - ]: - ir = _make_ir_from_emb(test) - self.assertEqual([], header_generator.generate_header(ir)[1]) - - def test_rejects_non_namespace_strings(self): - for test in [ - '[(cpp) namespace: "5th::avenue"]\n', - '[(cpp) namespace: "can\'t::have::apostrophe"]\n', - '[(cpp) namespace: "cannot-have-dash"]\n', - '[(cpp) namespace: "no/slashes"]\n', - '[(cpp) namespace: "no\\\\slashes"]\n', - '[(cpp) namespace: "apostrophes*are*rejected"]\n', - '[(cpp) namespace: "avoid.dot"]\n', - '[(cpp) namespace: "does5+5"]\n', - '[(cpp) namespace: "=10"]\n', - '[(cpp) namespace: "?"]\n', - '[(cpp) namespace: "reject::spaces in::components"]\n', - '[(cpp) namespace: "totally::valid::but::extra +"]\n', - '[(cpp) namespace: "totally::valid::but::extra ::?"]\n', - '[(cpp) namespace: "< totally::valid::but::extra"]\n', - '[(cpp) namespace: "< ::totally::valid::but::extra"]\n', - '[(cpp) namespace: "::totally::valid::but::extra::"]\n', - '[(cpp) namespace: ":::extra::colon"]\n', - '[(cpp) namespace: "::extra:::colon"]\n', - ]: - ir = _make_ir_from_emb(test) - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error("m.emb", attr.value.source_location, - 'Invalid namespace, must be a valid C++ namespace, such ' - 'as "abc", "abc::def", or "::abc::def::ghi" (ISO/IEC ' - '14882:2017 enclosing-namespace-specifier).') - ]], header_generator.generate_header(ir)[1]) - - def test_rejects_empty_namespace(self): - for test in [ - '[(cpp) namespace: ""]\n', - '[(cpp) namespace: " "]\n', - '[(cpp) namespace: " "]\n', - ]: - ir = _make_ir_from_emb(test) - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error("m.emb", attr.value.source_location, - 'Empty namespace value is not allowed.') - ]], header_generator.generate_header(ir)[1]) - - def test_rejects_global_namespace(self): - for test in [ - '[(cpp) namespace: "::"]\n', - '[(cpp) namespace: " ::"]\n', - '[(cpp) namespace: ":: "]\n', - '[(cpp) namespace: " :: "]\n', - ]: - ir = _make_ir_from_emb(test) - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error("m.emb", attr.value.source_location, - 'Global namespace is not allowed.') - ]], header_generator.generate_header(ir)[1]) - - def test_rejects_reserved_namespace(self): - for test, expected in [ - # Only component - ('[(cpp) namespace: "class"]\n', 'class'), - # Only component, fully qualified name - ('[(cpp) namespace: "::const"]\n', 'const'), - # First component - ('[(cpp) namespace: "if::valid"]\n', 'if'), - # First component, fully qualified name - ('[(cpp) namespace: "::auto::pilot"]\n', 'auto'), - # Last component - ('[(cpp) namespace: "make::do"]\n', 'do'), - # Middle component - ('[(cpp) namespace: "our::new::product"]\n', 'new'), - ]: - ir = _make_ir_from_emb(test) - attr = ir.module[0].attribute[0] - - self.assertEqual([[ - error.error("m.emb", attr.value.source_location, - f'Reserved word "{expected}" is not allowed ' - f'as a namespace component.')]], - header_generator.generate_header(ir)[1]) + def test_accepts_string_attribute(self): + ir = _make_ir_from_emb('[(cpp) namespace: "foo"]\n') + self.assertEqual([], header_generator.generate_header(ir)[1]) + + def test_rejects_wrong_type_for_string_attribute(self): + ir = _make_ir_from_emb("[(cpp) namespace: 9]\n") + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Attribute '(cpp) namespace' must have a string value.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_emboss_internal_attribute_with_back_end_specifier(self): + ir = _make_ir_from_emb('[(cpp) byte_order: "LittleEndian"]\n') + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.name.source_location, + "Unknown attribute '(cpp) byte_order' on module 'm.emb'.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_accepts_enum_case(self): + mod_ir = _make_ir_from_emb('[(cpp) $default enum_case: "kCamelCase"]') + self.assertEqual([], header_generator.generate_header(mod_ir)[1]) + enum_ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "kCamelCase"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + self.assertEqual([], header_generator.generate_header(enum_ir)[1]) + enum_value_ir = _make_ir_from_emb( + "enum Foo:\n" + ' BAR = 1 [(cpp) enum_case: "kCamelCase"]\n' + " BAZ = 2\n" + ' [(cpp) enum_case: "kCamelCase"]\n' + ) + self.assertEqual([], header_generator.generate_header(enum_value_ir)[1]) + enum_in_struct_ir = _make_ir_from_emb( + "struct Outer:\n" + ' [(cpp) $default enum_case: "kCamelCase"]\n' + " enum Inner:\n" + " BAR = 1\n" + " BAZ = 2\n" + ) + self.assertEqual([], header_generator.generate_header(enum_in_struct_ir)[1]) + enum_in_bits_ir = _make_ir_from_emb( + "bits Outer:\n" + ' [(cpp) $default enum_case: "kCamelCase"]\n' + " enum Inner:\n" + " BAR = 1\n" + " BAZ = 2\n" + ) + self.assertEqual([], header_generator.generate_header(enum_in_bits_ir)[1]) + enum_ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE,"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + self.assertEqual([], header_generator.generate_header(enum_ir)[1]) + enum_ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE ,kCamelCase"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + self.assertEqual([], header_generator.generate_header(enum_ir)[1]) + + def test_rejects_bad_enum_case_at_start(self): + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHORTY_CASE, kCamelCase"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + attr = ir.module[0].type[0].attribute[0] + + bad_case_source_location = ir_data.Location() + bad_case_source_location = ir_data_utils.builder(bad_case_source_location) + bad_case_source_location.CopyFrom(attr.value.source_location) + # Location of SHORTY_CASE in the attribute line. + bad_case_source_location.start.column = 30 + bad_case_source_location.end.column = 41 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + 'Unsupported enum case "SHORTY_CASE", ' + "supported cases are: SHOUTY_CASE, kCamelCase.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_bad_enum_case_in_middle(self): + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE, bad_CASE, kCamelCase"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + attr = ir.module[0].type[0].attribute[0] + + bad_case_source_location = ir_data.Location() + bad_case_source_location = ir_data_utils.builder(bad_case_source_location) + bad_case_source_location.CopyFrom(attr.value.source_location) + # Location of bad_CASE in the attribute line. + bad_case_source_location.start.column = 43 + bad_case_source_location.end.column = 51 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + 'Unsupported enum case "bad_CASE", ' + "supported cases are: SHOUTY_CASE, kCamelCase.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_bad_enum_case_at_end(self): + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase, BAD_case"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + attr = ir.module[0].type[0].attribute[0] + + bad_case_source_location = ir_data.Location() + bad_case_source_location = ir_data_utils.builder(bad_case_source_location) + bad_case_source_location.CopyFrom(attr.value.source_location) + # Location of BAD_case in the attribute line. + bad_case_source_location.start.column = 55 + bad_case_source_location.end.column = 63 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + 'Unsupported enum case "BAD_case", ' + "supported cases are: SHOUTY_CASE, kCamelCase.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_duplicate_enum_case(self): + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE, SHOUTY_CASE"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + attr = ir.module[0].type[0].attribute[0] + + bad_case_source_location = ir_data.Location() + bad_case_source_location = ir_data_utils.builder(bad_case_source_location) + bad_case_source_location.CopyFrom(attr.value.source_location) + # Location of the second SHOUTY_CASE in the attribute line. + bad_case_source_location.start.column = 43 + bad_case_source_location.end.column = 54 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + 'Duplicate enum case "SHOUTY_CASE".', + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_empty_enum_case(self): + # Double comma + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE,, kCamelCase"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + attr = ir.module[0].type[0].attribute[0] + + bad_case_source_location = ir_data.Location() + bad_case_source_location = ir_data_utils.builder(bad_case_source_location) + bad_case_source_location.CopyFrom(attr.value.source_location) + # Location of excess comma. + bad_case_source_location.start.column = 42 + bad_case_source_location.end.column = 42 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + # Leading comma + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: ", SHOUTY_CASE, kCamelCase"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + + bad_case_source_location.start.column = 30 + bad_case_source_location.end.column = 30 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + # Excess trailing comma + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase,,"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + + bad_case_source_location.start.column = 54 + bad_case_source_location.end.column = 54 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + # Whitespace enum case + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: "SHOUTY_CASE, , kCamelCase"]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + + bad_case_source_location.start.column = 45 + bad_case_source_location.end.column = 45 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + # Empty enum_case string + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: ""]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + + bad_case_source_location.start.column = 30 + bad_case_source_location.end.column = 30 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + # Whitespace enum_case string + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: " "]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + + bad_case_source_location.start.column = 35 + bad_case_source_location.end.column = 35 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + # One-character whitespace enum_case string + ir = _make_ir_from_emb( + "enum Foo:\n" + ' [(cpp) $default enum_case: " "]\n' + " BAR = 1\n" + " BAZ = 2\n" + ) + + bad_case_source_location.start.column = 31 + bad_case_source_location.end.column = 31 + + self.assertEqual( + [ + [ + error.error( + "m.emb", + bad_case_source_location, + "Empty enum case (or excess comma).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_accepts_namespace(self): + for test in [ + '[(cpp) namespace: "basic"]\n', + '[(cpp) namespace: "multiple::components"]\n', + '[(cpp) namespace: "::absolute"]\n', + '[(cpp) namespace: "::fully::qualified"]\n', + '[(cpp) namespace: "CAN::Be::cAPITAL"]\n', + '[(cpp) namespace: "trailingNumbers54321"]\n', + '[(cpp) namespace: "containing4321numbers"]\n', + '[(cpp) namespace: "can_have_underscores"]\n', + '[(cpp) namespace: "_initial_underscore"]\n', + '[(cpp) namespace: "_initial::_underscore"]\n', + '[(cpp) namespace: "::_initial::_underscore"]\n', + '[(cpp) namespace: "trailing_underscore_"]\n', + '[(cpp) namespace: "trailing_::underscore_"]\n', + '[(cpp) namespace: "::trailing_::underscore_"]\n', + '[(cpp) namespace: " spaces "]\n', + '[(cpp) namespace: "with :: spaces"]\n', + '[(cpp) namespace: " ::fully:: qualified :: with::spaces"]\n', + ]: + ir = _make_ir_from_emb(test) + self.assertEqual([], header_generator.generate_header(ir)[1]) + + def test_rejects_non_namespace_strings(self): + for test in [ + '[(cpp) namespace: "5th::avenue"]\n', + '[(cpp) namespace: "can\'t::have::apostrophe"]\n', + '[(cpp) namespace: "cannot-have-dash"]\n', + '[(cpp) namespace: "no/slashes"]\n', + '[(cpp) namespace: "no\\\\slashes"]\n', + '[(cpp) namespace: "apostrophes*are*rejected"]\n', + '[(cpp) namespace: "avoid.dot"]\n', + '[(cpp) namespace: "does5+5"]\n', + '[(cpp) namespace: "=10"]\n', + '[(cpp) namespace: "?"]\n', + '[(cpp) namespace: "reject::spaces in::components"]\n', + '[(cpp) namespace: "totally::valid::but::extra +"]\n', + '[(cpp) namespace: "totally::valid::but::extra ::?"]\n', + '[(cpp) namespace: "< totally::valid::but::extra"]\n', + '[(cpp) namespace: "< ::totally::valid::but::extra"]\n', + '[(cpp) namespace: "::totally::valid::but::extra::"]\n', + '[(cpp) namespace: ":::extra::colon"]\n', + '[(cpp) namespace: "::extra:::colon"]\n', + ]: + ir = _make_ir_from_emb(test) + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Invalid namespace, must be a valid C++ namespace, such " + 'as "abc", "abc::def", or "::abc::def::ghi" (ISO/IEC ' + "14882:2017 enclosing-namespace-specifier).", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_empty_namespace(self): + for test in [ + '[(cpp) namespace: ""]\n', + '[(cpp) namespace: " "]\n', + '[(cpp) namespace: " "]\n', + ]: + ir = _make_ir_from_emb(test) + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Empty namespace value is not allowed.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_global_namespace(self): + for test in [ + '[(cpp) namespace: "::"]\n', + '[(cpp) namespace: " ::"]\n', + '[(cpp) namespace: ":: "]\n', + '[(cpp) namespace: " :: "]\n', + ]: + ir = _make_ir_from_emb(test) + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Global namespace is not allowed.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) + + def test_rejects_reserved_namespace(self): + for test, expected in [ + # Only component + ('[(cpp) namespace: "class"]\n', "class"), + # Only component, fully qualified name + ('[(cpp) namespace: "::const"]\n', "const"), + # First component + ('[(cpp) namespace: "if::valid"]\n', "if"), + # First component, fully qualified name + ('[(cpp) namespace: "::auto::pilot"]\n', "auto"), + # Last component + ('[(cpp) namespace: "make::do"]\n', "do"), + # Middle component + ('[(cpp) namespace: "our::new::product"]\n', "new"), + ]: + ir = _make_ir_from_emb(test) + attr = ir.module[0].attribute[0] + + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + f'Reserved word "{expected}" is not allowed ' + f"as a namespace component.", + ) + ] + ], + header_generator.generate_header(ir)[1], + ) if __name__ == "__main__": diff --git a/compiler/back_end/util/__init__.py b/compiler/back_end/util/__init__.py index 2c31d84..086a24e 100644 --- a/compiler/back_end/util/__init__.py +++ b/compiler/back_end/util/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/compiler/back_end/util/code_template.py b/compiler/back_end/util/code_template.py index 519fb99..e69c8c3 100644 --- a/compiler/back_end/util/code_template.py +++ b/compiler/back_end/util/code_template.py @@ -23,83 +23,83 @@ def format_template(template, **kwargs): - """format_template acts like str.format, but uses ${name} instead of {name}. + """format_template acts like str.format, but uses ${name} instead of {name}. - format_template acts like a str.format, except that instead of using { and } - to delimit substitutions, format_template uses ${name}. This simplifies - templates of source code in most languages, which frequently use "{" and "}", - but very rarely use "$". + format_template acts like a str.format, except that instead of using { and } + to delimit substitutions, format_template uses ${name}. This simplifies + templates of source code in most languages, which frequently use "{" and "}", + but very rarely use "$". - See the documentation for string.Template for details about - template strings and the format of substitutions. + See the documentation for string.Template for details about + template strings and the format of substitutions. - Arguments: - template: A template to format. - **kwargs: Keyword arguments for string.Template.substitute. + Arguments: + template: A template to format. + **kwargs: Keyword arguments for string.Template.substitute. - Returns: - A formatted string. - """ - return template.substitute(**kwargs) + Returns: + A formatted string. + """ + return template.substitute(**kwargs) def parse_templates(text): - """Parses text into a namedtuple of templates. - - parse_templates will split its argument into templates by searching for lines - of the form: - - [punctuation] " ** " [name] " ** " [punctuation] - - e.g.: - - // ** struct_field_accessor ** //////// - - Leading and trailing punctuation is ignored, and [name] is used as the name - of the template. [name] should match [A-Za-z][A-Za-z0-9_]* -- that is, it - should be a valid ASCII Python identifier. - - Additionally any `//` style comment without leading space of the form: - ```C++ - // This is an emboss developer related comment, it's useful internally - // but not relevant to end-users of generated code. - ``` - will be stripped out of the generated code. - - If a template wants to define a comment that will be included in the - generated code a C-style comment is recommended: - ```C++ - /** This will be included in the generated source. */ - - /** - * So will this! - */ - ``` - - Arguments: - text: The text to parse into templates. - - Returns: - A namedtuple object whose attributes are the templates from text. - """ - delimiter_re = re.compile(r"^\W*\*\* ([A-Za-z][A-Za-z0-9_]*) \*\*\W*$") - comment_re = re.compile(r"^\s*//.*$") - templates = {} - name = None - template = [] - def finish_template(template): - return string.Template("\n".join(template)) - - for line in text.splitlines(): - if delimiter_re.match(line): - if name: + """Parses text into a namedtuple of templates. + + parse_templates will split its argument into templates by searching for lines + of the form: + + [punctuation] " ** " [name] " ** " [punctuation] + + e.g.: + + // ** struct_field_accessor ** //////// + + Leading and trailing punctuation is ignored, and [name] is used as the name + of the template. [name] should match [A-Za-z][A-Za-z0-9_]* -- that is, it + should be a valid ASCII Python identifier. + + Additionally any `//` style comment without leading space of the form: + ```C++ + // This is an emboss developer related comment, it's useful internally + // but not relevant to end-users of generated code. + ``` + will be stripped out of the generated code. + + If a template wants to define a comment that will be included in the + generated code a C-style comment is recommended: + ```C++ + /** This will be included in the generated source. */ + + /** + * So will this! + */ + ``` + + Arguments: + text: The text to parse into templates. + + Returns: + A namedtuple object whose attributes are the templates from text. + """ + delimiter_re = re.compile(r"^\W*\*\* ([A-Za-z][A-Za-z0-9_]*) \*\*\W*$") + comment_re = re.compile(r"^\s*//.*$") + templates = {} + name = None + template = [] + + def finish_template(template): + return string.Template("\n".join(template)) + + for line in text.splitlines(): + if delimiter_re.match(line): + if name: + templates[name] = finish_template(template) + name = delimiter_re.match(line).group(1) + template = [] + else: + if not comment_re.match(line): + template.append(line) + if name: templates[name] = finish_template(template) - name = delimiter_re.match(line).group(1) - template = [] - else: - if not comment_re.match(line): - template.append(line) - if name: - templates[name] = finish_template(template) - return collections.namedtuple("Templates", - list(templates.keys()))(**templates) + return collections.namedtuple("Templates", list(templates.keys()))(**templates) diff --git a/compiler/back_end/util/code_template_test.py b/compiler/back_end/util/code_template_test.py index c133096..e4354fe 100644 --- a/compiler/back_end/util/code_template_test.py +++ b/compiler/back_end/util/code_template_test.py @@ -18,82 +18,88 @@ import unittest from compiler.back_end.util import code_template + def _format_template_str(template: str, **kwargs) -> str: - return code_template.format_template(string.Template(template), **kwargs) + return code_template.format_template(string.Template(template), **kwargs) + class FormatTest(unittest.TestCase): - """Tests for code_template.format.""" + """Tests for code_template.format.""" - def test_no_replacement_fields(self): - self.assertEqual("foo", _format_template_str("foo")) - self.assertEqual("{foo}", _format_template_str("{foo}")) - self.assertEqual("${foo}", _format_template_str("$${foo}")) + def test_no_replacement_fields(self): + self.assertEqual("foo", _format_template_str("foo")) + self.assertEqual("{foo}", _format_template_str("{foo}")) + self.assertEqual("${foo}", _format_template_str("$${foo}")) - def test_one_replacement_field(self): - self.assertEqual("foo", _format_template_str("${bar}", bar="foo")) - self.assertEqual("bazfoo", - _format_template_str("baz${bar}", bar="foo")) - self.assertEqual("foobaz", - _format_template_str("${bar}baz", bar="foo")) - self.assertEqual("bazfooqux", - _format_template_str("baz${bar}qux", bar="foo")) + def test_one_replacement_field(self): + self.assertEqual("foo", _format_template_str("${bar}", bar="foo")) + self.assertEqual("bazfoo", _format_template_str("baz${bar}", bar="foo")) + self.assertEqual("foobaz", _format_template_str("${bar}baz", bar="foo")) + self.assertEqual("bazfooqux", _format_template_str("baz${bar}qux", bar="foo")) - def test_one_replacement_field_with_formatting(self): - # Basic string.Templates don't support formatting values. - self.assertRaises(ValueError, - _format_template_str, "${bar:.6f}", bar=1) + def test_one_replacement_field_with_formatting(self): + # Basic string.Templates don't support formatting values. + self.assertRaises(ValueError, _format_template_str, "${bar:.6f}", bar=1) - def test_one_replacement_field_value_missing(self): - self.assertRaises(KeyError, _format_template_str, "${bar}") + def test_one_replacement_field_value_missing(self): + self.assertRaises(KeyError, _format_template_str, "${bar}") - def test_multiple_replacement_fields(self): - self.assertEqual(" aaa bbb ", - _format_template_str(" ${bar} ${baz} ", - bar="aaa", - baz="bbb")) + def test_multiple_replacement_fields(self): + self.assertEqual( + " aaa bbb ", + _format_template_str(" ${bar} ${baz} ", bar="aaa", baz="bbb"), + ) class ParseTemplatesTest(unittest.TestCase): - """Tests for code_template.parse_templates.""" - - def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name - """Compares the results of a parse_templates""" - # Extract the name and template from the result tuple - actual = { - k: v.template for k, v in actual._asdict().items() - } - self.assertEqual(expected, actual) - - def test_handles_no_template_case(self): - self.assertTemplatesEqual({}, code_template.parse_templates("")) - self.assertTemplatesEqual({}, code_template.parse_templates( - "this is not a template")) - - def test_handles_one_template_at_start(self): - self.assertTemplatesEqual({"foo": "bar"}, - code_template.parse_templates("** foo **\nbar")) - - def test_handles_one_template_after_start(self): - self.assertTemplatesEqual( - {"foo": "bar"}, - code_template.parse_templates("text\n** foo **\nbar")) - - def test_handles_delimiter_with_other_text(self): - self.assertTemplatesEqual( - {"foo": "bar"}, - code_template.parse_templates("text\n// ** foo ** ////\nbar")) - self.assertTemplatesEqual( - {"foo": "bar"}, - code_template.parse_templates("text\n# ** foo ** #####\nbar")) - - def test_handles_multiple_delimiters(self): - self.assertTemplatesEqual({"foo": "bar", - "baz": "qux"}, code_template.parse_templates( - "** foo **\nbar\n** baz **\nqux")) - - def test_returns_object_with_attributes(self): - self.assertEqual("bar", code_template.parse_templates( - "** foo **\nbar\n** baz **\nqux").foo.template) + """Tests for code_template.parse_templates.""" + + def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name + """Compares the results of a parse_templates""" + # Extract the name and template from the result tuple + actual = {k: v.template for k, v in actual._asdict().items()} + self.assertEqual(expected, actual) + + def test_handles_no_template_case(self): + self.assertTemplatesEqual({}, code_template.parse_templates("")) + self.assertTemplatesEqual( + {}, code_template.parse_templates("this is not a template") + ) + + def test_handles_one_template_at_start(self): + self.assertTemplatesEqual( + {"foo": "bar"}, code_template.parse_templates("** foo **\nbar") + ) + + def test_handles_one_template_after_start(self): + self.assertTemplatesEqual( + {"foo": "bar"}, code_template.parse_templates("text\n** foo **\nbar") + ) + + def test_handles_delimiter_with_other_text(self): + self.assertTemplatesEqual( + {"foo": "bar"}, + code_template.parse_templates("text\n// ** foo ** ////\nbar"), + ) + self.assertTemplatesEqual( + {"foo": "bar"}, + code_template.parse_templates("text\n# ** foo ** #####\nbar"), + ) + + def test_handles_multiple_delimiters(self): + self.assertTemplatesEqual( + {"foo": "bar", "baz": "qux"}, + code_template.parse_templates("** foo **\nbar\n** baz **\nqux"), + ) + + def test_returns_object_with_attributes(self): + self.assertEqual( + "bar", + code_template.parse_templates( + "** foo **\nbar\n** baz **\nqux" + ).foo.template, + ) + if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/__init__.py b/compiler/front_end/__init__.py index 2c31d84..086a24e 100644 --- a/compiler/front_end/__init__.py +++ b/compiler/front_end/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/compiler/front_end/attribute_checker.py b/compiler/front_end/attribute_checker.py index d8c637c..9e7fec2 100644 --- a/compiler/front_end/attribute_checker.py +++ b/compiler/front_end/attribute_checker.py @@ -38,22 +38,29 @@ # Attribute type checkers _VALID_BYTE_ORDER = attribute_util.string_from_list( - {"BigEndian", "LittleEndian", "Null"}) + {"BigEndian", "LittleEndian", "Null"} +) _VALID_TEXT_OUTPUT = attribute_util.string_from_list({"Emit", "Skip"}) def _valid_back_ends(attr, module_source_file): - if not re.match( - r"^(?:\s*[a-z][a-z0-9_]*\s*(?:,\s*[a-z][a-z0-9_]*\s*)*,?)?\s*$", - attr.value.string_constant.text): - return [[error.error( - module_source_file, - attr.value.source_location, - "Attribute '{name}' must be a comma-delimited list of back end " - "specifiers (like \"cpp, proto\")), not \"{value}\".".format( - name=attr.name.text, - value=attr.value.string_constant.text))]] - return [] + if not re.match( + r"^(?:\s*[a-z][a-z0-9_]*\s*(?:,\s*[a-z][a-z0-9_]*\s*)*,?)?\s*$", + attr.value.string_constant.text, + ): + return [ + [ + error.error( + module_source_file, + attr.value.source_location, + "Attribute '{name}' must be a comma-delimited list of back end " + 'specifiers (like "cpp, proto")), not "{value}".'.format( + name=attr.name.text, value=attr.value.string_constant.text + ), + ) + ] + ] + return [] # Attributes must be the same type no matter where they occur. @@ -105,401 +112,539 @@ def _valid_back_ends(attr, module_source_file): def _construct_integer_attribute(name, value, source_location): - """Constructs an integer Attribute with the given name and value.""" - attr_value = ir_data.AttributeValue( - expression=ir_data.Expression( - constant=ir_data.NumericConstant(value=str(value), - source_location=source_location), - type=ir_data.ExpressionType( - integer=ir_data.IntegerType(modular_value=str(value), - modulus="infinity", - minimum_value=str(value), - maximum_value=str(value))), - source_location=source_location), - source_location=source_location) - return ir_data.Attribute(name=ir_data.Word(text=name, - source_location=source_location), - value=attr_value, - source_location=source_location) + """Constructs an integer Attribute with the given name and value.""" + attr_value = ir_data.AttributeValue( + expression=ir_data.Expression( + constant=ir_data.NumericConstant( + value=str(value), source_location=source_location + ), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value=str(value), + modulus="infinity", + minimum_value=str(value), + maximum_value=str(value), + ) + ), + source_location=source_location, + ), + source_location=source_location, + ) + return ir_data.Attribute( + name=ir_data.Word(text=name, source_location=source_location), + value=attr_value, + source_location=source_location, + ) def _construct_boolean_attribute(name, value, source_location): - """Constructs a boolean Attribute with the given name and value.""" - attr_value = ir_data.AttributeValue( - expression=ir_data.Expression( - boolean_constant=ir_data.BooleanConstant( - value=value, source_location=source_location), - type=ir_data.ExpressionType(boolean=ir_data.BooleanType(value=value)), - source_location=source_location), - source_location=source_location) - return ir_data.Attribute(name=ir_data.Word(text=name, - source_location=source_location), - value=attr_value, - source_location=source_location) + """Constructs a boolean Attribute with the given name and value.""" + attr_value = ir_data.AttributeValue( + expression=ir_data.Expression( + boolean_constant=ir_data.BooleanConstant( + value=value, source_location=source_location + ), + type=ir_data.ExpressionType(boolean=ir_data.BooleanType(value=value)), + source_location=source_location, + ), + source_location=source_location, + ) + return ir_data.Attribute( + name=ir_data.Word(text=name, source_location=source_location), + value=attr_value, + source_location=source_location, + ) def _construct_string_attribute(name, value, source_location): - """Constructs a string Attribute with the given name and value.""" - attr_value = ir_data.AttributeValue( - string_constant=ir_data.String(text=value, - source_location=source_location), - source_location=source_location) - return ir_data.Attribute(name=ir_data.Word(text=name, - source_location=source_location), - value=attr_value, - source_location=source_location) + """Constructs a string Attribute with the given name and value.""" + attr_value = ir_data.AttributeValue( + string_constant=ir_data.String(text=value, source_location=source_location), + source_location=source_location, + ) + return ir_data.Attribute( + name=ir_data.Word(text=name, source_location=source_location), + value=attr_value, + source_location=source_location, + ) def _fixed_size_of_struct_or_bits(struct, unit_size): - """Returns size of struct in bits or None, if struct is not fixed size.""" - size = 0 - for field in struct.field: - if not field.HasField("location"): - # Virtual fields do not contribute to the physical size of the struct. - continue - field_start = ir_util.constant_value(field.location.start) - field_size = ir_util.constant_value(field.location.size) - if field_start is None or field_size is None: - # Technically, start + size could be constant even if start and size are - # not; e.g. if start == x and size == 10 - x, but we don't handle that - # here. - return None - # TODO(bolms): knows_own_size - # TODO(bolms): compute min/max sizes for variable-sized arrays. - field_end = field_start + field_size - if field_end >= size: - size = field_end - return size * unit_size - - -def _verify_size_attributes_on_structure(struct, type_definition, - source_file_name, errors): - """Verifies size attributes on a struct or bits.""" - fixed_size = _fixed_size_of_struct_or_bits(struct, - type_definition.addressable_unit) - fixed_size_attr = ir_util.get_attribute(type_definition.attribute, - attributes.FIXED_SIZE) - if not fixed_size_attr: - return - if fixed_size is None: - errors.append([error.error( - source_file_name, fixed_size_attr.source_location, - "Struct is marked as fixed size, but contains variable-location " - "fields.")]) - elif ir_util.constant_value(fixed_size_attr.expression) != fixed_size: - errors.append([error.error( - source_file_name, fixed_size_attr.source_location, - "Struct is {} bits, but is marked as {} bits.".format( - fixed_size, ir_util.constant_value(fixed_size_attr.expression)))]) + """Returns size of struct in bits or None, if struct is not fixed size.""" + size = 0 + for field in struct.field: + if not field.HasField("location"): + # Virtual fields do not contribute to the physical size of the struct. + continue + field_start = ir_util.constant_value(field.location.start) + field_size = ir_util.constant_value(field.location.size) + if field_start is None or field_size is None: + # Technically, start + size could be constant even if start and size are + # not; e.g. if start == x and size == 10 - x, but we don't handle that + # here. + return None + # TODO(bolms): knows_own_size + # TODO(bolms): compute min/max sizes for variable-sized arrays. + field_end = field_start + field_size + if field_end >= size: + size = field_end + return size * unit_size + + +def _verify_size_attributes_on_structure( + struct, type_definition, source_file_name, errors +): + """Verifies size attributes on a struct or bits.""" + fixed_size = _fixed_size_of_struct_or_bits(struct, type_definition.addressable_unit) + fixed_size_attr = ir_util.get_attribute( + type_definition.attribute, attributes.FIXED_SIZE + ) + if not fixed_size_attr: + return + if fixed_size is None: + errors.append( + [ + error.error( + source_file_name, + fixed_size_attr.source_location, + "Struct is marked as fixed size, but contains variable-location " + "fields.", + ) + ] + ) + elif ir_util.constant_value(fixed_size_attr.expression) != fixed_size: + errors.append( + [ + error.error( + source_file_name, + fixed_size_attr.source_location, + "Struct is {} bits, but is marked as {} bits.".format( + fixed_size, ir_util.constant_value(fixed_size_attr.expression) + ), + ) + ] + ) # TODO(bolms): remove [fixed_size]; it is superseded by $size_in_{bits,bytes} def _add_missing_size_attributes_on_structure(struct, type_definition): - """Adds missing size attributes on a struct.""" - fixed_size = _fixed_size_of_struct_or_bits(struct, - type_definition.addressable_unit) - if fixed_size is None: - return - fixed_size_attr = ir_util.get_attribute(type_definition.attribute, - attributes.FIXED_SIZE) - if not fixed_size_attr: - # TODO(bolms): Use the offset and length of the last field as the - # source_location of the fixed_size attribute? - type_definition.attribute.extend([ - _construct_integer_attribute(attributes.FIXED_SIZE, fixed_size, - type_definition.source_location)]) + """Adds missing size attributes on a struct.""" + fixed_size = _fixed_size_of_struct_or_bits(struct, type_definition.addressable_unit) + if fixed_size is None: + return + fixed_size_attr = ir_util.get_attribute( + type_definition.attribute, attributes.FIXED_SIZE + ) + if not fixed_size_attr: + # TODO(bolms): Use the offset and length of the last field as the + # source_location of the fixed_size attribute? + type_definition.attribute.extend( + [ + _construct_integer_attribute( + attributes.FIXED_SIZE, fixed_size, type_definition.source_location + ) + ] + ) def _field_needs_byte_order(field, type_definition, ir): - """Returns true if the given field needs a byte_order attribute.""" - if ir_util.field_is_virtual(field): - # Virtual fields have no physical type, and thus do not need a byte order. - return False - field_type = ir_util.find_object( - ir_util.get_base_type(field.type).atomic_type.reference.canonical_name, - ir) - assert field_type is not None - assert field_type.addressable_unit != ir_data.AddressableUnit.NONE - return field_type.addressable_unit != type_definition.addressable_unit + """Returns true if the given field needs a byte_order attribute.""" + if ir_util.field_is_virtual(field): + # Virtual fields have no physical type, and thus do not need a byte order. + return False + field_type = ir_util.find_object( + ir_util.get_base_type(field.type).atomic_type.reference.canonical_name, ir + ) + assert field_type is not None + assert field_type.addressable_unit != ir_data.AddressableUnit.NONE + return field_type.addressable_unit != type_definition.addressable_unit def _field_may_have_null_byte_order(field, type_definition, ir): - """Returns true if "Null" is a valid byte order for the given field.""" - # If the field is one unit in length, then byte order does not matter. - if (ir_util.is_constant(field.location.size) and - ir_util.constant_value(field.location.size) == 1): - return True - unit = type_definition.addressable_unit - # Otherwise, if the field's type is either a one-unit-sized type or an array - # of a one-unit-sized type, then byte order does not matter. - if (ir_util.fixed_size_of_type_in_bits(ir_util.get_base_type(field.type), ir) - == unit): - return True - # In all other cases, byte order does matter. - return False - - -def _add_missing_byte_order_attribute_on_field(field, type_definition, ir, - defaults): - """Adds missing byte_order attributes to fields that need them.""" - if _field_needs_byte_order(field, type_definition, ir): - byte_order_attr = ir_util.get_attribute(field.attribute, - attributes.BYTE_ORDER) - if byte_order_attr is None: - if attributes.BYTE_ORDER in defaults: - field.attribute.extend([defaults[attributes.BYTE_ORDER]]) - elif _field_may_have_null_byte_order(field, type_definition, ir): - field.attribute.extend( - [_construct_string_attribute(attributes.BYTE_ORDER, "Null", - field.source_location)]) + """Returns true if "Null" is a valid byte order for the given field.""" + # If the field is one unit in length, then byte order does not matter. + if ( + ir_util.is_constant(field.location.size) + and ir_util.constant_value(field.location.size) == 1 + ): + return True + unit = type_definition.addressable_unit + # Otherwise, if the field's type is either a one-unit-sized type or an array + # of a one-unit-sized type, then byte order does not matter. + if ( + ir_util.fixed_size_of_type_in_bits(ir_util.get_base_type(field.type), ir) + == unit + ): + return True + # In all other cases, byte order does matter. + return False + + +def _add_missing_byte_order_attribute_on_field(field, type_definition, ir, defaults): + """Adds missing byte_order attributes to fields that need them.""" + if _field_needs_byte_order(field, type_definition, ir): + byte_order_attr = ir_util.get_attribute(field.attribute, attributes.BYTE_ORDER) + if byte_order_attr is None: + if attributes.BYTE_ORDER in defaults: + field.attribute.extend([defaults[attributes.BYTE_ORDER]]) + elif _field_may_have_null_byte_order(field, type_definition, ir): + field.attribute.extend( + [ + _construct_string_attribute( + attributes.BYTE_ORDER, "Null", field.source_location + ) + ] + ) def _add_missing_back_ends_to_module(module): - """Sets the expected_back_ends attribute for a module, if not already set.""" - back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS) - if back_ends_attr is None: - module.attribute.extend( - [_construct_string_attribute(attributes.BACK_ENDS, _DEFAULT_BACK_ENDS, - module.source_location)]) + """Sets the expected_back_ends attribute for a module, if not already set.""" + back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS) + if back_ends_attr is None: + module.attribute.extend( + [ + _construct_string_attribute( + attributes.BACK_ENDS, _DEFAULT_BACK_ENDS, module.source_location + ) + ] + ) def _gather_expected_back_ends(module): - """Captures the expected_back_ends attribute for `module`.""" - back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS) - back_ends_str = back_ends_attr.string_constant.text - return { - "expected_back_ends": {x.strip() for x in back_ends_str.split(",")} | {""} - } + """Captures the expected_back_ends attribute for `module`.""" + back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS) + back_ends_str = back_ends_attr.string_constant.text + return {"expected_back_ends": {x.strip() for x in back_ends_str.split(",")} | {""}} def _add_addressable_unit_to_external(external, type_definition): - """Sets the addressable_unit field for an external TypeDefinition.""" - # Strictly speaking, addressable_unit isn't an "attribute," but it's close - # enough that it makes sense to handle it with attributes. - del external # Unused. - size = ir_util.get_integer_attribute(type_definition.attribute, - attributes.ADDRESSABLE_UNIT_SIZE) - if size == 1: - type_definition.addressable_unit = ir_data.AddressableUnit.BIT - elif size == 8: - type_definition.addressable_unit = ir_data.AddressableUnit.BYTE - # If the addressable_unit_size is not in (1, 8), it will be caught by - # _verify_addressable_unit_attribute_on_external, below. + """Sets the addressable_unit field for an external TypeDefinition.""" + # Strictly speaking, addressable_unit isn't an "attribute," but it's close + # enough that it makes sense to handle it with attributes. + del external # Unused. + size = ir_util.get_integer_attribute( + type_definition.attribute, attributes.ADDRESSABLE_UNIT_SIZE + ) + if size == 1: + type_definition.addressable_unit = ir_data.AddressableUnit.BIT + elif size == 8: + type_definition.addressable_unit = ir_data.AddressableUnit.BYTE + # If the addressable_unit_size is not in (1, 8), it will be caught by + # _verify_addressable_unit_attribute_on_external, below. def _add_missing_width_and_sign_attributes_on_enum(enum, type_definition): - """Sets the maximum_bits and is_signed attributes for an enum, if needed.""" - max_bits_attr = ir_util.get_integer_attribute(type_definition.attribute, - attributes.ENUM_MAXIMUM_BITS) - if max_bits_attr is None: - type_definition.attribute.extend([ - _construct_integer_attribute(attributes.ENUM_MAXIMUM_BITS, - _DEFAULT_ENUM_MAXIMUM_BITS, - type_definition.source_location)]) - signed_attr = ir_util.get_boolean_attribute(type_definition.attribute, - attributes.IS_SIGNED) - if signed_attr is None: - for value in enum.value: - numeric_value = ir_util.constant_value(value.value) - if numeric_value < 0: - is_signed = True - break - else: - is_signed = False - type_definition.attribute.extend([ - _construct_boolean_attribute(attributes.IS_SIGNED, is_signed, - type_definition.source_location)]) - - -def _verify_byte_order_attribute_on_field(field, type_definition, - source_file_name, ir, errors): - """Verifies the byte_order attribute on the given field.""" - byte_order_attr = ir_util.get_attribute(field.attribute, - attributes.BYTE_ORDER) - field_needs_byte_order = _field_needs_byte_order(field, type_definition, ir) - if byte_order_attr and not field_needs_byte_order: - errors.append([error.error( - source_file_name, byte_order_attr.source_location, - "Attribute 'byte_order' not allowed on field which is not byte order " - "dependent.")]) - if not byte_order_attr and field_needs_byte_order: - errors.append([error.error( - source_file_name, field.source_location, - "Attribute 'byte_order' required on field which is byte order " - "dependent.")]) - if (byte_order_attr and byte_order_attr.string_constant.text == "Null" and - not _field_may_have_null_byte_order(field, type_definition, ir)): - errors.append([error.error( - source_file_name, byte_order_attr.source_location, - "Attribute 'byte_order' may only be 'Null' for one-byte fields.")]) + """Sets the maximum_bits and is_signed attributes for an enum, if needed.""" + max_bits_attr = ir_util.get_integer_attribute( + type_definition.attribute, attributes.ENUM_MAXIMUM_BITS + ) + if max_bits_attr is None: + type_definition.attribute.extend( + [ + _construct_integer_attribute( + attributes.ENUM_MAXIMUM_BITS, + _DEFAULT_ENUM_MAXIMUM_BITS, + type_definition.source_location, + ) + ] + ) + signed_attr = ir_util.get_boolean_attribute( + type_definition.attribute, attributes.IS_SIGNED + ) + if signed_attr is None: + for value in enum.value: + numeric_value = ir_util.constant_value(value.value) + if numeric_value < 0: + is_signed = True + break + else: + is_signed = False + type_definition.attribute.extend( + [ + _construct_boolean_attribute( + attributes.IS_SIGNED, is_signed, type_definition.source_location + ) + ] + ) + + +def _verify_byte_order_attribute_on_field( + field, type_definition, source_file_name, ir, errors +): + """Verifies the byte_order attribute on the given field.""" + byte_order_attr = ir_util.get_attribute(field.attribute, attributes.BYTE_ORDER) + field_needs_byte_order = _field_needs_byte_order(field, type_definition, ir) + if byte_order_attr and not field_needs_byte_order: + errors.append( + [ + error.error( + source_file_name, + byte_order_attr.source_location, + "Attribute 'byte_order' not allowed on field which is not byte order " + "dependent.", + ) + ] + ) + if not byte_order_attr and field_needs_byte_order: + errors.append( + [ + error.error( + source_file_name, + field.source_location, + "Attribute 'byte_order' required on field which is byte order " + "dependent.", + ) + ] + ) + if ( + byte_order_attr + and byte_order_attr.string_constant.text == "Null" + and not _field_may_have_null_byte_order(field, type_definition, ir) + ): + errors.append( + [ + error.error( + source_file_name, + byte_order_attr.source_location, + "Attribute 'byte_order' may only be 'Null' for one-byte fields.", + ) + ] + ) def _verify_requires_attribute_on_field(field, source_file_name, ir, errors): - """Verifies that [requires] is valid on the given field.""" - requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES) - if not requires_attr: - return - if ir_util.field_is_virtual(field): - field_expression_type = field.read_transform.type - else: - if not field.type.HasField("atomic_type"): - errors.append([ - error.error(source_file_name, requires_attr.source_location, - "Attribute 'requires' is only allowed on integer, " - "enumeration, or boolean fields, not arrays."), - error.note(source_file_name, field.type.source_location, - "Field type."), - ]) - return - field_type = ir_util.find_object(field.type.atomic_type.reference, ir) - assert field_type, "Field type should be non-None after name resolution." - field_expression_type = ( - type_check.unbounded_expression_type_for_physical_type(field_type)) - if field_expression_type.WhichOneof("type") not in ( - "integer", "enumeration", "boolean"): - errors.append([error.error( - source_file_name, requires_attr.source_location, - "Attribute 'requires' is only allowed on integer, enumeration, or " - "boolean fields.")]) - - -def _verify_addressable_unit_attribute_on_external(external, type_definition, - source_file_name, errors): - """Verifies the addressable_unit_size attribute on an external.""" - del external # Unused. - addressable_unit_size_attr = ir_util.get_integer_attribute( - type_definition.attribute, attributes.ADDRESSABLE_UNIT_SIZE) - if addressable_unit_size_attr is None: - errors.append([error.error( - source_file_name, type_definition.source_location, - "Expected '{}' attribute for external type.".format( - attributes.ADDRESSABLE_UNIT_SIZE))]) - elif addressable_unit_size_attr not in (1, 8): - errors.append([ - error.error(source_file_name, type_definition.source_location, + """Verifies that [requires] is valid on the given field.""" + requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES) + if not requires_attr: + return + if ir_util.field_is_virtual(field): + field_expression_type = field.read_transform.type + else: + if not field.type.HasField("atomic_type"): + errors.append( + [ + error.error( + source_file_name, + requires_attr.source_location, + "Attribute 'requires' is only allowed on integer, " + "enumeration, or boolean fields, not arrays.", + ), + error.note( + source_file_name, field.type.source_location, "Field type." + ), + ] + ) + return + field_type = ir_util.find_object(field.type.atomic_type.reference, ir) + assert field_type, "Field type should be non-None after name resolution." + field_expression_type = type_check.unbounded_expression_type_for_physical_type( + field_type + ) + if field_expression_type.WhichOneof("type") not in ( + "integer", + "enumeration", + "boolean", + ): + errors.append( + [ + error.error( + source_file_name, + requires_attr.source_location, + "Attribute 'requires' is only allowed on integer, enumeration, or " + "boolean fields.", + ) + ] + ) + + +def _verify_addressable_unit_attribute_on_external( + external, type_definition, source_file_name, errors +): + """Verifies the addressable_unit_size attribute on an external.""" + del external # Unused. + addressable_unit_size_attr = ir_util.get_integer_attribute( + type_definition.attribute, attributes.ADDRESSABLE_UNIT_SIZE + ) + if addressable_unit_size_attr is None: + errors.append( + [ + error.error( + source_file_name, + type_definition.source_location, + "Expected '{}' attribute for external type.".format( + attributes.ADDRESSABLE_UNIT_SIZE + ), + ) + ] + ) + elif addressable_unit_size_attr not in (1, 8): + errors.append( + [ + error.error( + source_file_name, + type_definition.source_location, "Only values '1' (bit) and '8' (byte) are allowed for the " - "'{}' attribute".format(attributes.ADDRESSABLE_UNIT_SIZE)) - ]) - - -def _verify_width_attribute_on_enum(enum, type_definition, source_file_name, - errors): - """Verifies the maximum_bits attribute for an enum TypeDefinition.""" - max_bits_value = ir_util.get_integer_attribute(type_definition.attribute, - attributes.ENUM_MAXIMUM_BITS) - # The attribute should already have been defaulted, if not originally present. - assert max_bits_value is not None, "maximum_bits not set" - if max_bits_value > 64 or max_bits_value < 1: - max_bits_attr = ir_util.get_attribute(type_definition.attribute, - attributes.ENUM_MAXIMUM_BITS) - errors.append([ - error.error(source_file_name, max_bits_attr.source_location, - "'maximum_bits' on an 'enum' must be between 1 and 64.") - ]) + "'{}' attribute".format(attributes.ADDRESSABLE_UNIT_SIZE), + ) + ] + ) + + +def _verify_width_attribute_on_enum(enum, type_definition, source_file_name, errors): + """Verifies the maximum_bits attribute for an enum TypeDefinition.""" + max_bits_value = ir_util.get_integer_attribute( + type_definition.attribute, attributes.ENUM_MAXIMUM_BITS + ) + # The attribute should already have been defaulted, if not originally present. + assert max_bits_value is not None, "maximum_bits not set" + if max_bits_value > 64 or max_bits_value < 1: + max_bits_attr = ir_util.get_attribute( + type_definition.attribute, attributes.ENUM_MAXIMUM_BITS + ) + errors.append( + [ + error.error( + source_file_name, + max_bits_attr.source_location, + "'maximum_bits' on an 'enum' must be between 1 and 64.", + ) + ] + ) def _add_missing_attributes_on_ir(ir): - """Adds missing attributes in a complete IR.""" - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Module], _add_missing_back_ends_to_module) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.External], _add_addressable_unit_to_external) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Enum], _add_missing_width_and_sign_attributes_on_enum) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure], _add_missing_size_attributes_on_structure, - incidental_actions={ - ir_data.Module: attribute_util.gather_default_attributes, - ir_data.TypeDefinition: attribute_util.gather_default_attributes, - ir_data.Field: attribute_util.gather_default_attributes, - }, - parameters={"defaults": {}}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Field], _add_missing_byte_order_attribute_on_field, - incidental_actions={ - ir_data.Module: attribute_util.gather_default_attributes, - ir_data.TypeDefinition: attribute_util.gather_default_attributes, - ir_data.Field: attribute_util.gather_default_attributes, - }, - parameters={"defaults": {}}) - return [] - - -def _verify_field_attributes(field, type_definition, source_file_name, ir, - errors): - _verify_byte_order_attribute_on_field(field, type_definition, - source_file_name, ir, errors) - _verify_requires_attribute_on_field(field, source_file_name, ir, errors) - - -def _verify_back_end_attributes(attribute, expected_back_ends, source_file_name, - ir, errors): - back_end_text = ir_data_utils.reader(attribute).back_end.text - if back_end_text not in expected_back_ends: - expected_back_ends_for_error = expected_back_ends - {""} - errors.append([error.error( - source_file_name, attribute.back_end.source_location, - "Back end specifier '{back_end}' does not match any expected back end " - "specifier for this file: '{expected_back_ends}'. Add or update the " - "'[expected_back_ends: \"{new_expected_back_ends}\"]' attribute at the " - "file level if this back end specifier is intentional.".format( - back_end=attribute.back_end.text, - expected_back_ends="', '".join( - sorted(expected_back_ends_for_error)), - new_expected_back_ends=", ".join( - sorted(expected_back_ends_for_error | {back_end_text})), - ))]) + """Adds missing attributes in a complete IR.""" + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Module], _add_missing_back_ends_to_module + ) + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.External], _add_addressable_unit_to_external + ) + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Enum], _add_missing_width_and_sign_attributes_on_enum + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Structure], + _add_missing_size_attributes_on_structure, + incidental_actions={ + ir_data.Module: attribute_util.gather_default_attributes, + ir_data.TypeDefinition: attribute_util.gather_default_attributes, + ir_data.Field: attribute_util.gather_default_attributes, + }, + parameters={"defaults": {}}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Field], + _add_missing_byte_order_attribute_on_field, + incidental_actions={ + ir_data.Module: attribute_util.gather_default_attributes, + ir_data.TypeDefinition: attribute_util.gather_default_attributes, + ir_data.Field: attribute_util.gather_default_attributes, + }, + parameters={"defaults": {}}, + ) + return [] + + +def _verify_field_attributes(field, type_definition, source_file_name, ir, errors): + _verify_byte_order_attribute_on_field( + field, type_definition, source_file_name, ir, errors + ) + _verify_requires_attribute_on_field(field, source_file_name, ir, errors) + + +def _verify_back_end_attributes( + attribute, expected_back_ends, source_file_name, ir, errors +): + back_end_text = ir_data_utils.reader(attribute).back_end.text + if back_end_text not in expected_back_ends: + expected_back_ends_for_error = expected_back_ends - {""} + errors.append( + [ + error.error( + source_file_name, + attribute.back_end.source_location, + "Back end specifier '{back_end}' does not match any expected back end " + "specifier for this file: '{expected_back_ends}'. Add or update the " + "'[expected_back_ends: \"{new_expected_back_ends}\"]' attribute at the " + "file level if this back end specifier is intentional.".format( + back_end=attribute.back_end.text, + expected_back_ends="', '".join( + sorted(expected_back_ends_for_error) + ), + new_expected_back_ends=", ".join( + sorted(expected_back_ends_for_error | {back_end_text}) + ), + ), + ) + ] + ) def _verify_attributes_on_ir(ir): - """Verifies attributes in a complete IR.""" - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Attribute], _verify_back_end_attributes, - incidental_actions={ - ir_data.Module: _gather_expected_back_ends, - }, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure], _verify_size_attributes_on_structure, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Enum], _verify_width_attribute_on_enum, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.External], _verify_addressable_unit_attribute_on_external, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Field], _verify_field_attributes, - parameters={"errors": errors}) - return errors + """Verifies attributes in a complete IR.""" + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Attribute], + _verify_back_end_attributes, + incidental_actions={ + ir_data.Module: _gather_expected_back_ends, + }, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Structure], + _verify_size_attributes_on_structure, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Enum], + _verify_width_attribute_on_enum, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.External], + _verify_addressable_unit_attribute_on_external, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Field], _verify_field_attributes, parameters={"errors": errors} + ) + return errors def normalize_and_verify(ir): - """Performs various normalizations and verifications on ir. - - Checks for duplicate attributes. - - Adds fixed_size_in_bits and addressable_unit_size attributes to types when - they are missing, and checks their correctness when they are not missing. - - Arguments: - ir: The IR object to normalize. - - Returns: - A list of validation errors, or an empty list if no errors were encountered. - """ - errors = attribute_util.check_attributes_in_ir( - ir, - types=_ATTRIBUTE_TYPES, - module_attributes=_MODULE_ATTRIBUTES, - struct_attributes=_STRUCT_ATTRIBUTES, - bits_attributes=_BITS_ATTRIBUTES, - enum_attributes=_ENUM_ATTRIBUTES, - external_attributes=_EXTERNAL_ATTRIBUTES, - structure_virtual_field_attributes=_STRUCT_VIRTUAL_FIELD_ATTRIBUTES, - structure_physical_field_attributes=_STRUCT_PHYSICAL_FIELD_ATTRIBUTES) - if errors: - return errors - _add_missing_attributes_on_ir(ir) - return _verify_attributes_on_ir(ir) + """Performs various normalizations and verifications on ir. + + Checks for duplicate attributes. + + Adds fixed_size_in_bits and addressable_unit_size attributes to types when + they are missing, and checks their correctness when they are not missing. + + Arguments: + ir: The IR object to normalize. + + Returns: + A list of validation errors, or an empty list if no errors were encountered. + """ + errors = attribute_util.check_attributes_in_ir( + ir, + types=_ATTRIBUTE_TYPES, + module_attributes=_MODULE_ATTRIBUTES, + struct_attributes=_STRUCT_ATTRIBUTES, + bits_attributes=_BITS_ATTRIBUTES, + enum_attributes=_ENUM_ATTRIBUTES, + external_attributes=_EXTERNAL_ATTRIBUTES, + structure_virtual_field_attributes=_STRUCT_VIRTUAL_FIELD_ATTRIBUTES, + structure_physical_field_attributes=_STRUCT_PHYSICAL_FIELD_ATTRIBUTES, + ) + if errors: + return errors + _add_missing_attributes_on_ir(ir) + return _verify_attributes_on_ir(ir) diff --git a/compiler/front_end/attribute_checker_test.py b/compiler/front_end/attribute_checker_test.py index 4e7d8c7..4325ca4 100644 --- a/compiler/front_end/attribute_checker_test.py +++ b/compiler/front_end/attribute_checker_test.py @@ -31,667 +31,985 @@ def _make_ir_from_emb(emb_text, name="m.emb"): - ir, unused_debug_info, errors = glue.parse_emboss_file( - name, - test_util.dict_file_reader({name: emb_text}), - stop_before_step="normalize_and_verify") - assert not errors - return ir + ir, unused_debug_info, errors = glue.parse_emboss_file( + name, + test_util.dict_file_reader({name: emb_text}), + stop_before_step="normalize_and_verify", + ) + assert not errors + return ir class NormalizeIrTest(unittest.TestCase): - def test_rejects_may_be_used_as_integer(self): - enum_ir = _make_ir_from_emb("enum Foo:\n" - " [may_be_used_as_integer: false]\n" - " VALUE = 1\n") - enum_type_ir = enum_ir.module[0].type[0] - self.assertEqual([[ - error.error( - "m.emb", enum_type_ir.attribute[0].name.source_location, - "Unknown attribute 'may_be_used_as_integer' on enum 'Foo'.") - ]], attribute_checker.normalize_and_verify(enum_ir)) - - def test_adds_fixed_size_attribute_to_struct(self): - # field2 is intentionally after field3, in order to trigger certain code - # paths in attribute_checker.py. - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+2] UInt field1\n" - " 4 [+4] UInt field2\n" - " 2 [+2] UInt field3\n") - self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) - size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute, - _FIXED_SIZE) - self.assertEqual(64, ir_util.constant_value(size_attr.expression)) - self.assertEqual(struct_ir.module[0].type[0].source_location, - size_attr.source_location) - - def test_adds_fixed_size_attribute_to_struct_with_virtual_field(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+2] UInt field1\n" - " let field2 = field1\n" - " 2 [+2] UInt field3\n") - self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) - size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute, - _FIXED_SIZE) - self.assertEqual(32, ir_util.constant_value(size_attr.expression)) - self.assertEqual(struct_ir.module[0].type[0].source_location, - size_attr.source_location) - - def test_adds_fixed_size_attribute_to_anonymous_bits(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+4] bits:\n" - " 0 [+8] UInt field\n") - self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) - size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute, - _FIXED_SIZE) - self.assertEqual(32, ir_util.constant_value(size_attr.expression)) - bits_size_attr = ir_util.get_attribute( - struct_ir.module[0].type[0].subtype[0].attribute, _FIXED_SIZE) - self.assertEqual(8, ir_util.constant_value(bits_size_attr.expression)) - self.assertEqual(struct_ir.module[0].type[0].source_location, - size_attr.source_location) - - def test_does_not_add_fixed_size_attribute_to_variable_size_struct(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+4] UInt n\n" - " 4 [+n] UInt:8[] payload\n") - self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) - self.assertIsNone(ir_util.get_attribute( - struct_ir.module[0].type[0].attribute, _FIXED_SIZE)) - - def test_accepts_correct_fixed_size_and_size_attributes_on_struct(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " [fixed_size_in_bits: 64]\n" - " 0 [+2] UInt field1\n" - " 2 [+2] UInt field2\n" - " 4 [+4] UInt field3\n") - self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) - size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute, - _FIXED_SIZE) - self.assertTrue(size_attr) - self.assertEqual(64, ir_util.constant_value(size_attr.expression)) - - def test_accepts_correct_size_attribute_on_struct(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " [fixed_size_in_bits: 64]\n" - " 0 [+2] UInt field1\n" - " 4 [+4] UInt field3\n") - self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) - size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute, - _FIXED_SIZE) - self.assertTrue(size_attr.expression) - self.assertEqual(64, ir_util.constant_value(size_attr.expression)) - - def test_rejects_incorrect_fixed_size_attribute_on_variable_size_struct(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " [fixed_size_in_bits: 8]\n" - " 0 [+4] UInt n\n" - " 4 [+n] UInt:8[] payload\n") - struct_type_ir = struct_ir.module[0].type[0] - self.assertEqual([[error.error( - "m.emb", struct_type_ir.attribute[0].value.source_location, - "Struct is marked as fixed size, but contains variable-location " - "fields.")]], attribute_checker.normalize_and_verify(struct_ir)) - - def test_rejects_size_attribute_with_wrong_large_value_on_struct(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " [fixed_size_in_bits: 80]\n" - " 0 [+2] UInt field1\n" - " 2 [+2] UInt field2\n" - " 4 [+4] UInt field3\n") - struct_type_ir = struct_ir.module[0].type[0] - self.assertEqual([ - [error.error("m.emb", struct_type_ir.attribute[0].value.source_location, - "Struct is 64 bits, but is marked as 80 bits.")] - ], attribute_checker.normalize_and_verify(struct_ir)) - - def test_rejects_size_attribute_with_wrong_small_value_on_struct(self): - struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " [fixed_size_in_bits: 40]\n" - " 0 [+2] UInt field1\n" - " 2 [+2] UInt field2\n" - " 4 [+4] UInt field3\n") - struct_type_ir = struct_ir.module[0].type[0] - self.assertEqual([ - [error.error("m.emb", struct_type_ir.attribute[0].value.source_location, - "Struct is 64 bits, but is marked as 40 bits.")] - ], attribute_checker.normalize_and_verify(struct_ir)) - - def test_accepts_variable_size_external(self): - external_ir = _make_ir_from_emb("external Foo:\n" - " [addressable_unit_size: 1]\n") - self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) - - def test_accepts_fixed_size_external(self): - external_ir = _make_ir_from_emb("external Foo:\n" - " [fixed_size_in_bits: 32]\n" - " [addressable_unit_size: 1]\n") - self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) - - def test_rejects_external_with_no_addressable_unit_size_attribute(self): - external_ir = _make_ir_from_emb("external Foo:\n" - " [is_integer: false]\n") - external_type_ir = external_ir.module[0].type[0] - self.assertEqual([ - [error.error( - "m.emb", external_type_ir.source_location, - "Expected 'addressable_unit_size' attribute for external type.")] - ], attribute_checker.normalize_and_verify(external_ir)) - - def test_rejects_is_integer_with_non_constant_value(self): - external_ir = _make_ir_from_emb( - "external Foo:\n" - " [is_integer: $static_size_in_bits == 1]\n" - " [addressable_unit_size: 1]\n") - external_type_ir = external_ir.module[0].type[0] - self.assertEqual([ - [error.error( - "m.emb", external_type_ir.attribute[0].value.source_location, - "Attribute 'is_integer' must have a constant boolean value.")] - ], attribute_checker.normalize_and_verify(external_ir)) - - def test_rejects_addressable_unit_size_with_non_constant_value(self): - external_ir = _make_ir_from_emb( - "external Foo:\n" - " [is_integer: true]\n" - " [addressable_unit_size: $static_size_in_bits]\n") - external_type_ir = external_ir.module[0].type[0] - self.assertEqual([ - [error.error( - "m.emb", external_type_ir.attribute[1].value.source_location, - "Attribute 'addressable_unit_size' must have a constant value.")] - ], attribute_checker.normalize_and_verify(external_ir)) - - def test_rejects_external_with_wrong_addressable_unit_size_attribute(self): - external_ir = _make_ir_from_emb("external Foo:\n" - " [addressable_unit_size: 4]\n") - external_type_ir = external_ir.module[0].type[0] - self.assertEqual([ - [error.error( - "m.emb", external_type_ir.source_location, - "Only values '1' (bit) and '8' (byte) are allowed for the " - "'addressable_unit_size' attribute")] - ], attribute_checker.normalize_and_verify(external_ir)) - - def test_rejects_duplicate_attribute(self): - ir = _make_ir_from_emb("external Foo:\n" - " [is_integer: true]\n" - " [is_integer: true]\n") - self.assertEqual([[ - error.error("m.emb", ir.module[0].type[0].attribute[1].source_location, - "Duplicate attribute 'is_integer'."), - error.note("m.emb", ir.module[0].type[0].attribute[0].source_location, - "Original attribute"), - ]], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_duplicate_default_attribute(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - '[$default byte_order: "LittleEndian"]\n') - self.assertEqual( - [[ - error.error("m.emb", ir.module[0].attribute[1].source_location, - "Duplicate attribute 'byte_order'."), - error.note("m.emb", ir.module[0].attribute[0].source_location, - "Original attribute"), - ]], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_unknown_attribute(self): - ir = _make_ir_from_emb("[gibberish: true]\n") - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error("m.emb", attr.name.source_location, - "Unknown attribute 'gibberish' on module 'm.emb'.") - ]], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_non_constant_attribute(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " [fixed_size_in_bits: field1]\n" - " 0 [+2] UInt field1\n") - attr = ir.module[0].type[0].attribute[0] - self.assertEqual( - [[ - error.error( - "m.emb", attr.value.source_location, - "Attribute 'fixed_size_in_bits' must have a constant value.") - ]], - attribute_checker.normalize_and_verify(ir)) - - def test_rejects_attribute_missing_required_back_end_specifier(self): - ir = _make_ir_from_emb('[namespace: "abc"]\n') - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error("m.emb", attr.name.source_location, - "Unknown attribute 'namespace' on module 'm.emb'.") - ]], attribute_checker.normalize_and_verify(ir)) - - def test_accepts_attribute_with_default_known_back_end_specifier(self): - ir = _make_ir_from_emb('[(cpp) namespace: "abc"]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_attribute_with_specified_back_end_specifier(self): - ir = _make_ir_from_emb('[(c) namespace: "abc"]\n' - '[expected_back_ends: "c, cpp"]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_cpp_backend_attribute_when_not_in_expected_back_ends(self): - ir = _make_ir_from_emb('[(cpp) namespace: "abc"]\n' - '[expected_back_ends: "c"]\n') - attr = ir.module[0].attribute[0] - self.maxDiff = 200000 - self.assertEqual([[ - error.error( - "m.emb", attr.back_end.source_location, - "Back end specifier 'cpp' does not match any expected back end " - "specifier for this file: 'c'. Add or update the " - "'[expected_back_ends: \"c, cpp\"]' attribute at the file level if " - "this back end specifier is intentional.") - ]], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_expected_back_ends_with_bad_back_end(self): - ir = _make_ir_from_emb('[expected_back_ends: "c++"]\n') - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error( - "m.emb", attr.value.source_location, - "Attribute 'expected_back_ends' must be a comma-delimited list of " - "back end specifiers (like \"cpp, proto\")), not \"c++\".") - ]], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_expected_back_ends_with_no_comma(self): - ir = _make_ir_from_emb('[expected_back_ends: "cpp z"]\n') - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error( - "m.emb", attr.value.source_location, - "Attribute 'expected_back_ends' must be a comma-delimited list of " - "back end specifiers (like \"cpp, proto\")), not \"cpp z\".") - ]], attribute_checker.normalize_and_verify(ir)) - - def test_rejects_expected_back_ends_with_extra_commas(self): - ir = _make_ir_from_emb('[expected_back_ends: "cpp,,z"]\n') - attr = ir.module[0].attribute[0] - self.assertEqual([[ - error.error( - "m.emb", attr.value.source_location, - "Attribute 'expected_back_ends' must be a comma-delimited list of " - "back end specifiers (like \"cpp, proto\")), not \"cpp,,z\".") - ]], attribute_checker.normalize_and_verify(ir)) - - def test_accepts_empty_expected_back_ends(self): - ir = _make_ir_from_emb('[expected_back_ends: ""]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - - def test_adds_byte_order_attributes_from_default(self): - ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n' - "struct Foo:\n" - " 0 [+2] UInt bar\n" - " 2 [+2] UInt baz\n" - ' [byte_order: "LittleEndian"]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - byte_order_attr = ir_util.get_attribute( - ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER) - self.assertTrue(byte_order_attr.HasField("string_constant")) - self.assertEqual("BigEndian", byte_order_attr.string_constant.text) - byte_order_attr = ir_util.get_attribute( - ir.module[0].type[0].structure.field[1].attribute, _BYTE_ORDER) - self.assertTrue(byte_order_attr.HasField("string_constant")) - self.assertEqual("LittleEndian", byte_order_attr.string_constant.text) - - def test_adds_null_byte_order_attributes(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt bar\n" - " 1 [+1] UInt baz\n" - ' [byte_order: "LittleEndian"]\n' - " 2 [+2] UInt:8[] baseball\n" - " 4 [+2] UInt:8[] bat\n" - ' [byte_order: "LittleEndian"]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - structure = ir.module[0].type[0].structure - byte_order_attr = ir_util.get_attribute( - structure.field[0].attribute, _BYTE_ORDER) - self.assertTrue(byte_order_attr.HasField("string_constant")) - self.assertEqual("Null", byte_order_attr.string_constant.text) - self.assertEqual(structure.field[0].source_location, - byte_order_attr.source_location) - byte_order_attr = ir_util.get_attribute(structure.field[1].attribute, - _BYTE_ORDER) - self.assertTrue(byte_order_attr.HasField("string_constant")) - self.assertEqual("LittleEndian", byte_order_attr.string_constant.text) - byte_order_attr = ir_util.get_attribute(structure.field[2].attribute, - _BYTE_ORDER) - self.assertTrue(byte_order_attr.HasField("string_constant")) - self.assertEqual("Null", byte_order_attr.string_constant.text) - self.assertEqual(structure.field[2].source_location, - byte_order_attr.source_location) - byte_order_attr = ir_util.get_attribute(structure.field[3].attribute, - _BYTE_ORDER) - self.assertTrue(byte_order_attr.HasField("string_constant")) - self.assertEqual("LittleEndian", byte_order_attr.string_constant.text) - - def test_disallows_default_byte_order_on_field(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+2] UInt bar\n" - ' [$default byte_order: "LittleEndian"]\n') - default_byte_order = ir.module[0].type[0].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", default_byte_order.name.source_location, - "Attribute 'byte_order' may not be defaulted on struct field 'bar'." - )]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_default_byte_order_on_bits(self): - ir = _make_ir_from_emb("bits Foo:\n" - ' [$default byte_order: "LittleEndian"]\n' - " 0 [+2] UInt bar\n") - default_byte_order = ir.module[0].type[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", default_byte_order.name.source_location, - "Attribute 'byte_order' may not be defaulted on bits 'Foo'.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_default_byte_order_on_enum(self): - ir = _make_ir_from_emb("enum Foo:\n" - ' [$default byte_order: "LittleEndian"]\n' - " BAR = 1\n") - default_byte_order = ir.module[0].type[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", default_byte_order.name.source_location, - "Attribute 'byte_order' may not be defaulted on enum 'Foo'.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_adds_byte_order_from_scoped_default(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - ' [$default byte_order: "BigEndian"]\n' - " 0 [+2] UInt bar\n") - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - byte_order_attr = ir_util.get_attribute( - ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER) - self.assertTrue(byte_order_attr.HasField("string_constant")) - self.assertEqual("BigEndian", byte_order_attr.string_constant.text) - - def test_disallows_unknown_byte_order(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+2] UInt bar\n" - ' [byte_order: "NoEndian"]\n') - byte_order = ir.module[0].type[0].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", byte_order.value.source_location, - "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or " - "'Null'.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_unknown_default_byte_order(self): - ir = _make_ir_from_emb('[$default byte_order: "NoEndian"]\n') - default_byte_order = ir.module[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", default_byte_order.value.source_location, - "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or " - "'Null'.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_byte_order_on_non_byte_order_dependent_fields(self): - ir = _make_ir_from_emb("struct Foo:\n" - ' [$default byte_order: "LittleEndian"]\n' - " 0 [+2] UInt uint\n" - "struct Bar:\n" - " 0 [+2] Foo foo\n" - ' [byte_order: "LittleEndian"]\n') - byte_order = ir.module[0].type[1].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", byte_order.value.source_location, - "Attribute 'byte_order' not allowed on field which is not byte " - "order dependent.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_byte_order_on_virtual_field(self): - ir = _make_ir_from_emb("struct Foo:\n" - " let x = 10\n" - ' [byte_order: "LittleEndian"]\n') - byte_order = ir.module[0].type[0].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", byte_order.name.source_location, - "Unknown attribute 'byte_order' on virtual struct field 'x'.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_null_byte_order_on_multibyte_fields(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+2] UInt uint\n" - ' [byte_order: "Null"]\n') - byte_order = ir.module[0].type[0].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", byte_order.value.source_location, - "Attribute 'byte_order' may only be 'Null' for one-byte fields.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_null_byte_order_on_multibyte_array_elements(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+4] UInt:16[] uint\n" - ' [byte_order: "Null"]\n') - byte_order = ir.module[0].type[0].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", byte_order.value.source_location, - "Attribute 'byte_order' may only be 'Null' for one-byte fields.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_requires_byte_order_on_byte_order_dependent_fields(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+2] UInt uint\n") - field = ir.module[0].type[0].structure.field[0] - self.assertEqual( - [[error.error( - "m.emb", field.source_location, - "Attribute 'byte_order' required on field which is byte order " - "dependent.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_unknown_text_output_attribute(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+2] UInt bar\n" - ' [text_output: "None"]\n') - byte_order = ir.module[0].type[0].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", byte_order.value.source_location, - "Attribute 'text_output' must be 'Emit' or 'Skip'.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_disallows_non_string_text_output_attribute(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+2] UInt bar\n" - " [text_output: 0]\n") - byte_order = ir.module[0].type[0].structure.field[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", byte_order.value.source_location, - "Attribute 'text_output' must be 'Emit' or 'Skip'.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_allows_skip_text_output_attribute_on_physical_field(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt bar\n" - ' [text_output: "Skip"]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - - def test_allows_skip_text_output_attribute_on_virtual_field(self): - ir = _make_ir_from_emb("struct Foo:\n" - " let x = 10\n" - ' [text_output: "Skip"]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - - def test_allows_emit_text_output_attribute_on_physical_field(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt bar\n" - ' [text_output: "Emit"]\n') - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - - def test_adds_bit_addressable_unit_to_external(self): - external_ir = _make_ir_from_emb("external Foo:\n" - " [addressable_unit_size: 1]\n") - self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) - self.assertEqual(ir_data.AddressableUnit.BIT, - external_ir.module[0].type[0].addressable_unit) - - def test_adds_byte_addressable_unit_to_external(self): - external_ir = _make_ir_from_emb("external Foo:\n" - " [addressable_unit_size: 8]\n") - self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) - self.assertEqual(ir_data.AddressableUnit.BYTE, - external_ir.module[0].type[0].addressable_unit) - - def test_rejects_requires_using_array(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+4] UInt:8[] array\n" - " [requires: this]\n") - field_ir = ir.module[0].type[0].structure.field[0] - self.assertEqual( - [[error.error("m.emb", field_ir.attribute[0].value.source_location, - "Attribute 'requires' must have a boolean value.")]], - attribute_checker.normalize_and_verify(ir)) - - def test_rejects_requires_on_array(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+4] UInt:8[] array\n" - " [requires: false]\n") - field_ir = ir.module[0].type[0].structure.field[0] - self.assertEqual( - [[ - error.error("m.emb", field_ir.attribute[0].value.source_location, + def test_rejects_may_be_used_as_integer(self): + enum_ir = _make_ir_from_emb( + "enum Foo:\n" " [may_be_used_as_integer: false]\n" " VALUE = 1\n" + ) + enum_type_ir = enum_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + enum_type_ir.attribute[0].name.source_location, + "Unknown attribute 'may_be_used_as_integer' on enum 'Foo'.", + ) + ] + ], + attribute_checker.normalize_and_verify(enum_ir), + ) + + def test_adds_fixed_size_attribute_to_struct(self): + # field2 is intentionally after field3, in order to trigger certain code + # paths in attribute_checker.py. + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+2] UInt field1\n" + " 4 [+4] UInt field2\n" + " 2 [+2] UInt field3\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) + size_attr = ir_util.get_attribute( + struct_ir.module[0].type[0].attribute, _FIXED_SIZE + ) + self.assertEqual(64, ir_util.constant_value(size_attr.expression)) + self.assertEqual( + struct_ir.module[0].type[0].source_location, size_attr.source_location + ) + + def test_adds_fixed_size_attribute_to_struct_with_virtual_field(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+2] UInt field1\n" + " let field2 = field1\n" + " 2 [+2] UInt field3\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) + size_attr = ir_util.get_attribute( + struct_ir.module[0].type[0].attribute, _FIXED_SIZE + ) + self.assertEqual(32, ir_util.constant_value(size_attr.expression)) + self.assertEqual( + struct_ir.module[0].type[0].source_location, size_attr.source_location + ) + + def test_adds_fixed_size_attribute_to_anonymous_bits(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+4] bits:\n" + " 0 [+8] UInt field\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) + size_attr = ir_util.get_attribute( + struct_ir.module[0].type[0].attribute, _FIXED_SIZE + ) + self.assertEqual(32, ir_util.constant_value(size_attr.expression)) + bits_size_attr = ir_util.get_attribute( + struct_ir.module[0].type[0].subtype[0].attribute, _FIXED_SIZE + ) + self.assertEqual(8, ir_util.constant_value(bits_size_attr.expression)) + self.assertEqual( + struct_ir.module[0].type[0].source_location, size_attr.source_location + ) + + def test_does_not_add_fixed_size_attribute_to_variable_size_struct(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+4] UInt n\n" + " 4 [+n] UInt:8[] payload\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) + self.assertIsNone( + ir_util.get_attribute(struct_ir.module[0].type[0].attribute, _FIXED_SIZE) + ) + + def test_accepts_correct_fixed_size_and_size_attributes_on_struct(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " [fixed_size_in_bits: 64]\n" + " 0 [+2] UInt field1\n" + " 2 [+2] UInt field2\n" + " 4 [+4] UInt field3\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) + size_attr = ir_util.get_attribute( + struct_ir.module[0].type[0].attribute, _FIXED_SIZE + ) + self.assertTrue(size_attr) + self.assertEqual(64, ir_util.constant_value(size_attr.expression)) + + def test_accepts_correct_size_attribute_on_struct(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " [fixed_size_in_bits: 64]\n" + " 0 [+2] UInt field1\n" + " 4 [+4] UInt field3\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir)) + size_attr = ir_util.get_attribute( + struct_ir.module[0].type[0].attribute, _FIXED_SIZE + ) + self.assertTrue(size_attr.expression) + self.assertEqual(64, ir_util.constant_value(size_attr.expression)) + + def test_rejects_incorrect_fixed_size_attribute_on_variable_size_struct(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " [fixed_size_in_bits: 8]\n" + " 0 [+4] UInt n\n" + " 4 [+n] UInt:8[] payload\n" + ) + struct_type_ir = struct_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct_type_ir.attribute[0].value.source_location, + "Struct is marked as fixed size, but contains variable-location " + "fields.", + ) + ] + ], + attribute_checker.normalize_and_verify(struct_ir), + ) + + def test_rejects_size_attribute_with_wrong_large_value_on_struct(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " [fixed_size_in_bits: 80]\n" + " 0 [+2] UInt field1\n" + " 2 [+2] UInt field2\n" + " 4 [+4] UInt field3\n" + ) + struct_type_ir = struct_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct_type_ir.attribute[0].value.source_location, + "Struct is 64 bits, but is marked as 80 bits.", + ) + ] + ], + attribute_checker.normalize_and_verify(struct_ir), + ) + + def test_rejects_size_attribute_with_wrong_small_value_on_struct(self): + struct_ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " [fixed_size_in_bits: 40]\n" + " 0 [+2] UInt field1\n" + " 2 [+2] UInt field2\n" + " 4 [+4] UInt field3\n" + ) + struct_type_ir = struct_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct_type_ir.attribute[0].value.source_location, + "Struct is 64 bits, but is marked as 40 bits.", + ) + ] + ], + attribute_checker.normalize_and_verify(struct_ir), + ) + + def test_accepts_variable_size_external(self): + external_ir = _make_ir_from_emb( + "external Foo:\n" " [addressable_unit_size: 1]\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) + + def test_accepts_fixed_size_external(self): + external_ir = _make_ir_from_emb( + "external Foo:\n" + " [fixed_size_in_bits: 32]\n" + " [addressable_unit_size: 1]\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) + + def test_rejects_external_with_no_addressable_unit_size_attribute(self): + external_ir = _make_ir_from_emb("external Foo:\n" " [is_integer: false]\n") + external_type_ir = external_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + external_type_ir.source_location, + "Expected 'addressable_unit_size' attribute for external type.", + ) + ] + ], + attribute_checker.normalize_and_verify(external_ir), + ) + + def test_rejects_is_integer_with_non_constant_value(self): + external_ir = _make_ir_from_emb( + "external Foo:\n" + " [is_integer: $static_size_in_bits == 1]\n" + " [addressable_unit_size: 1]\n" + ) + external_type_ir = external_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + external_type_ir.attribute[0].value.source_location, + "Attribute 'is_integer' must have a constant boolean value.", + ) + ] + ], + attribute_checker.normalize_and_verify(external_ir), + ) + + def test_rejects_addressable_unit_size_with_non_constant_value(self): + external_ir = _make_ir_from_emb( + "external Foo:\n" + " [is_integer: true]\n" + " [addressable_unit_size: $static_size_in_bits]\n" + ) + external_type_ir = external_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + external_type_ir.attribute[1].value.source_location, + "Attribute 'addressable_unit_size' must have a constant value.", + ) + ] + ], + attribute_checker.normalize_and_verify(external_ir), + ) + + def test_rejects_external_with_wrong_addressable_unit_size_attribute(self): + external_ir = _make_ir_from_emb( + "external Foo:\n" " [addressable_unit_size: 4]\n" + ) + external_type_ir = external_ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + external_type_ir.source_location, + "Only values '1' (bit) and '8' (byte) are allowed for the " + "'addressable_unit_size' attribute", + ) + ] + ], + attribute_checker.normalize_and_verify(external_ir), + ) + + def test_rejects_duplicate_attribute(self): + ir = _make_ir_from_emb( + "external Foo:\n" " [is_integer: true]\n" " [is_integer: true]\n" + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + ir.module[0].type[0].attribute[1].source_location, + "Duplicate attribute 'is_integer'.", + ), + error.note( + "m.emb", + ir.module[0].type[0].attribute[0].source_location, + "Original attribute", + ), + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_duplicate_default_attribute(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + '[$default byte_order: "LittleEndian"]\n' + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + ir.module[0].attribute[1].source_location, + "Duplicate attribute 'byte_order'.", + ), + error.note( + "m.emb", + ir.module[0].attribute[0].source_location, + "Original attribute", + ), + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_unknown_attribute(self): + ir = _make_ir_from_emb("[gibberish: true]\n") + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.name.source_location, + "Unknown attribute 'gibberish' on module 'm.emb'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_non_constant_attribute(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " [fixed_size_in_bits: field1]\n" + " 0 [+2] UInt field1\n" + ) + attr = ir.module[0].type[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Attribute 'fixed_size_in_bits' must have a constant value.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_attribute_missing_required_back_end_specifier(self): + ir = _make_ir_from_emb('[namespace: "abc"]\n') + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.name.source_location, + "Unknown attribute 'namespace' on module 'm.emb'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_accepts_attribute_with_default_known_back_end_specifier(self): + ir = _make_ir_from_emb('[(cpp) namespace: "abc"]\n') + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + + def test_rejects_attribute_with_specified_back_end_specifier(self): + ir = _make_ir_from_emb( + '[(c) namespace: "abc"]\n' '[expected_back_ends: "c, cpp"]\n' + ) + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + + def test_rejects_cpp_backend_attribute_when_not_in_expected_back_ends(self): + ir = _make_ir_from_emb( + '[(cpp) namespace: "abc"]\n' '[expected_back_ends: "c"]\n' + ) + attr = ir.module[0].attribute[0] + self.maxDiff = 200000 + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.back_end.source_location, + "Back end specifier 'cpp' does not match any expected back end " + "specifier for this file: 'c'. Add or update the " + "'[expected_back_ends: \"c, cpp\"]' attribute at the file level if " + "this back end specifier is intentional.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_expected_back_ends_with_bad_back_end(self): + ir = _make_ir_from_emb('[expected_back_ends: "c++"]\n') + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Attribute 'expected_back_ends' must be a comma-delimited list of " + 'back end specifiers (like "cpp, proto")), not "c++".', + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_expected_back_ends_with_no_comma(self): + ir = _make_ir_from_emb('[expected_back_ends: "cpp z"]\n') + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Attribute 'expected_back_ends' must be a comma-delimited list of " + 'back end specifiers (like "cpp, proto")), not "cpp z".', + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_expected_back_ends_with_extra_commas(self): + ir = _make_ir_from_emb('[expected_back_ends: "cpp,,z"]\n') + attr = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attr.value.source_location, + "Attribute 'expected_back_ends' must be a comma-delimited list of " + 'back end specifiers (like "cpp, proto")), not "cpp,,z".', + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_accepts_empty_expected_back_ends(self): + ir = _make_ir_from_emb('[expected_back_ends: ""]\n') + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + + def test_adds_byte_order_attributes_from_default(self): + ir = _make_ir_from_emb( + '[$default byte_order: "BigEndian"]\n' + "struct Foo:\n" + " 0 [+2] UInt bar\n" + " 2 [+2] UInt baz\n" + ' [byte_order: "LittleEndian"]\n' + ) + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + byte_order_attr = ir_util.get_attribute( + ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER + ) + self.assertTrue(byte_order_attr.HasField("string_constant")) + self.assertEqual("BigEndian", byte_order_attr.string_constant.text) + byte_order_attr = ir_util.get_attribute( + ir.module[0].type[0].structure.field[1].attribute, _BYTE_ORDER + ) + self.assertTrue(byte_order_attr.HasField("string_constant")) + self.assertEqual("LittleEndian", byte_order_attr.string_constant.text) + + def test_adds_null_byte_order_attributes(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+1] UInt bar\n" + " 1 [+1] UInt baz\n" + ' [byte_order: "LittleEndian"]\n' + " 2 [+2] UInt:8[] baseball\n" + " 4 [+2] UInt:8[] bat\n" + ' [byte_order: "LittleEndian"]\n' + ) + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + structure = ir.module[0].type[0].structure + byte_order_attr = ir_util.get_attribute( + structure.field[0].attribute, _BYTE_ORDER + ) + self.assertTrue(byte_order_attr.HasField("string_constant")) + self.assertEqual("Null", byte_order_attr.string_constant.text) + self.assertEqual( + structure.field[0].source_location, byte_order_attr.source_location + ) + byte_order_attr = ir_util.get_attribute( + structure.field[1].attribute, _BYTE_ORDER + ) + self.assertTrue(byte_order_attr.HasField("string_constant")) + self.assertEqual("LittleEndian", byte_order_attr.string_constant.text) + byte_order_attr = ir_util.get_attribute( + structure.field[2].attribute, _BYTE_ORDER + ) + self.assertTrue(byte_order_attr.HasField("string_constant")) + self.assertEqual("Null", byte_order_attr.string_constant.text) + self.assertEqual( + structure.field[2].source_location, byte_order_attr.source_location + ) + byte_order_attr = ir_util.get_attribute( + structure.field[3].attribute, _BYTE_ORDER + ) + self.assertTrue(byte_order_attr.HasField("string_constant")) + self.assertEqual("LittleEndian", byte_order_attr.string_constant.text) + + def test_disallows_default_byte_order_on_field(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+2] UInt bar\n" + ' [$default byte_order: "LittleEndian"]\n' + ) + default_byte_order = ir.module[0].type[0].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + default_byte_order.name.source_location, + "Attribute 'byte_order' may not be defaulted on struct field 'bar'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_default_byte_order_on_bits(self): + ir = _make_ir_from_emb( + "bits Foo:\n" + ' [$default byte_order: "LittleEndian"]\n' + " 0 [+2] UInt bar\n" + ) + default_byte_order = ir.module[0].type[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + default_byte_order.name.source_location, + "Attribute 'byte_order' may not be defaulted on bits 'Foo'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_default_byte_order_on_enum(self): + ir = _make_ir_from_emb( + "enum Foo:\n" ' [$default byte_order: "LittleEndian"]\n' " BAR = 1\n" + ) + default_byte_order = ir.module[0].type[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + default_byte_order.name.source_location, + "Attribute 'byte_order' may not be defaulted on enum 'Foo'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_adds_byte_order_from_scoped_default(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + ' [$default byte_order: "BigEndian"]\n' + " 0 [+2] UInt bar\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + byte_order_attr = ir_util.get_attribute( + ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER + ) + self.assertTrue(byte_order_attr.HasField("string_constant")) + self.assertEqual("BigEndian", byte_order_attr.string_constant.text) + + def test_disallows_unknown_byte_order(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+2] UInt bar\n" ' [byte_order: "NoEndian"]\n' + ) + byte_order = ir.module[0].type[0].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + byte_order.value.source_location, + "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or " + "'Null'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_unknown_default_byte_order(self): + ir = _make_ir_from_emb('[$default byte_order: "NoEndian"]\n') + default_byte_order = ir.module[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + default_byte_order.value.source_location, + "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or " + "'Null'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_byte_order_on_non_byte_order_dependent_fields(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + ' [$default byte_order: "LittleEndian"]\n' + " 0 [+2] UInt uint\n" + "struct Bar:\n" + " 0 [+2] Foo foo\n" + ' [byte_order: "LittleEndian"]\n' + ) + byte_order = ir.module[0].type[1].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + byte_order.value.source_location, + "Attribute 'byte_order' not allowed on field which is not byte " + "order dependent.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_byte_order_on_virtual_field(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " let x = 10\n" ' [byte_order: "LittleEndian"]\n' + ) + byte_order = ir.module[0].type[0].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + byte_order.name.source_location, + "Unknown attribute 'byte_order' on virtual struct field 'x'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_null_byte_order_on_multibyte_fields(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+2] UInt uint\n" ' [byte_order: "Null"]\n' + ) + byte_order = ir.module[0].type[0].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + byte_order.value.source_location, + "Attribute 'byte_order' may only be 'Null' for one-byte fields.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_null_byte_order_on_multibyte_array_elements(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+4] UInt:16[] uint\n" ' [byte_order: "Null"]\n' + ) + byte_order = ir.module[0].type[0].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + byte_order.value.source_location, + "Attribute 'byte_order' may only be 'Null' for one-byte fields.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_requires_byte_order_on_byte_order_dependent_fields(self): + ir = _make_ir_from_emb("struct Foo:\n" " 0 [+2] UInt uint\n") + field = ir.module[0].type[0].structure.field[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + field.source_location, + "Attribute 'byte_order' required on field which is byte order " + "dependent.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_unknown_text_output_attribute(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+2] UInt bar\n" ' [text_output: "None"]\n' + ) + byte_order = ir.module[0].type[0].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + byte_order.value.source_location, + "Attribute 'text_output' must be 'Emit' or 'Skip'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_disallows_non_string_text_output_attribute(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+2] UInt bar\n" " [text_output: 0]\n" + ) + byte_order = ir.module[0].type[0].structure.field[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + byte_order.value.source_location, + "Attribute 'text_output' must be 'Emit' or 'Skip'.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_allows_skip_text_output_attribute_on_physical_field(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+1] UInt bar\n" ' [text_output: "Skip"]\n' + ) + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + + def test_allows_skip_text_output_attribute_on_virtual_field(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " let x = 10\n" ' [text_output: "Skip"]\n' + ) + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + + def test_allows_emit_text_output_attribute_on_physical_field(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+1] UInt bar\n" ' [text_output: "Emit"]\n' + ) + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + + def test_adds_bit_addressable_unit_to_external(self): + external_ir = _make_ir_from_emb( + "external Foo:\n" " [addressable_unit_size: 1]\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) + self.assertEqual( + ir_data.AddressableUnit.BIT, external_ir.module[0].type[0].addressable_unit + ) + + def test_adds_byte_addressable_unit_to_external(self): + external_ir = _make_ir_from_emb( + "external Foo:\n" " [addressable_unit_size: 8]\n" + ) + self.assertEqual([], attribute_checker.normalize_and_verify(external_ir)) + self.assertEqual( + ir_data.AddressableUnit.BYTE, external_ir.module[0].type[0].addressable_unit + ) + + def test_rejects_requires_using_array(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+4] UInt:8[] array\n" " [requires: this]\n" + ) + field_ir = ir.module[0].type[0].structure.field[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + field_ir.attribute[0].value.source_location, + "Attribute 'requires' must have a boolean value.", + ) + ] + ], + attribute_checker.normalize_and_verify(ir), + ) + + def test_rejects_requires_on_array(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+4] UInt:8[] array\n" " [requires: false]\n" + ) + field_ir = ir.module[0].type[0].structure.field[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + field_ir.attribute[0].value.source_location, "Attribute 'requires' is only allowed on integer, " - "enumeration, or boolean fields, not arrays."), - error.note("m.emb", field_ir.type.source_location, - "Field type."), - ]], - error.filter_errors(attribute_checker.normalize_and_verify(ir))) - - def test_rejects_requires_on_struct(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+4] Bar bar\n" - " [requires: false]\n" - "struct Bar:\n" - " 0 [+4] UInt uint\n") - field_ir = ir.module[0].type[0].structure.field[0] - self.assertEqual( - [[error.error("m.emb", field_ir.attribute[0].value.source_location, - "Attribute 'requires' is only allowed on integer, " - "enumeration, or boolean fields.")]], - error.filter_errors(attribute_checker.normalize_and_verify(ir))) - - def test_rejects_requires_on_float(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+4] Float float\n" - " [requires: false]\n") - field_ir = ir.module[0].type[0].structure.field[0] - self.assertEqual( - [[error.error("m.emb", field_ir.attribute[0].value.source_location, - "Attribute 'requires' is only allowed on integer, " - "enumeration, or boolean fields.")]], - error.filter_errors(attribute_checker.normalize_and_verify(ir))) - - def test_adds_false_is_signed_attribute(self): - ir = _make_ir_from_emb("enum Foo:\n" - " ZERO = 0\n") - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - enum = ir.module[0].type[0] - is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED) - self.assertTrue(is_signed_attr.expression.HasField("boolean_constant")) - self.assertFalse(is_signed_attr.expression.boolean_constant.value) - - def test_leaves_is_signed_attribute(self): - ir = _make_ir_from_emb("enum Foo:\n" - " [is_signed: true]\n" - " ZERO = 0\n") - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - enum = ir.module[0].type[0] - is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED) - self.assertTrue(is_signed_attr.expression.HasField("boolean_constant")) - self.assertTrue(is_signed_attr.expression.boolean_constant.value) - - def test_adds_true_is_signed_attribute(self): - ir = _make_ir_from_emb("enum Foo:\n" - " NEGATIVE_ONE = -1\n") - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - enum = ir.module[0].type[0] - is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED) - self.assertTrue(is_signed_attr.expression.HasField("boolean_constant")) - self.assertTrue(is_signed_attr.expression.boolean_constant.value) - - def test_adds_max_bits_attribute(self): - ir = _make_ir_from_emb("enum Foo:\n" - " ZERO = 0\n") - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - enum = ir.module[0].type[0] - max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS) - self.assertTrue(max_bits_attr.expression.HasField("constant")) - self.assertEqual("64", max_bits_attr.expression.constant.value) - - def test_leaves_max_bits_attribute(self): - ir = _make_ir_from_emb("enum Foo:\n" - " [maximum_bits: 32]\n" - " ZERO = 0\n") - self.assertEqual([], attribute_checker.normalize_and_verify(ir)) - enum = ir.module[0].type[0] - max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS) - self.assertTrue(max_bits_attr.expression.HasField("constant")) - self.assertEqual("32", max_bits_attr.expression.constant.value) - - def test_rejects_too_small_max_bits(self): - ir = _make_ir_from_emb("enum Foo:\n" - " [maximum_bits: 0]\n" - " ZERO = 0\n") - attribute_ir = ir.module[0].type[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", attribute_ir.value.source_location, - "'maximum_bits' on an 'enum' must be between 1 and 64.")]], - error.filter_errors(attribute_checker.normalize_and_verify(ir))) - - def test_rejects_too_large_max_bits(self): - ir = _make_ir_from_emb("enum Foo:\n" - " [maximum_bits: 65]\n" - " ZERO = 0\n") - attribute_ir = ir.module[0].type[0].attribute[0] - self.assertEqual( - [[error.error( - "m.emb", attribute_ir.value.source_location, - "'maximum_bits' on an 'enum' must be between 1 and 64.")]], - error.filter_errors(attribute_checker.normalize_and_verify(ir))) - - def test_rejects_unknown_enum_value_attribute(self): - ir = _make_ir_from_emb("enum Foo:\n" - " BAR = 0 \n" - " [bad_attr: true]\n") - attribute_ir = ir.module[0].type[0].enumeration.value[0].attribute[0] - self.assertNotEqual([], attribute_checker.normalize_and_verify(ir)) - self.assertEqual( - [[error.error( - "m.emb", attribute_ir.name.source_location, - "Unknown attribute 'bad_attr' on enum value 'BAR'.")]], - error.filter_errors(attribute_checker.normalize_and_verify(ir))) + "enumeration, or boolean fields, not arrays.", + ), + error.note("m.emb", field_ir.type.source_location, "Field type."), + ] + ], + error.filter_errors(attribute_checker.normalize_and_verify(ir)), + ) + + def test_rejects_requires_on_struct(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+4] Bar bar\n" + " [requires: false]\n" + "struct Bar:\n" + " 0 [+4] UInt uint\n" + ) + field_ir = ir.module[0].type[0].structure.field[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + field_ir.attribute[0].value.source_location, + "Attribute 'requires' is only allowed on integer, " + "enumeration, or boolean fields.", + ) + ] + ], + error.filter_errors(attribute_checker.normalize_and_verify(ir)), + ) + + def test_rejects_requires_on_float(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+4] Float float\n" + " [requires: false]\n" + ) + field_ir = ir.module[0].type[0].structure.field[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + field_ir.attribute[0].value.source_location, + "Attribute 'requires' is only allowed on integer, " + "enumeration, or boolean fields.", + ) + ] + ], + error.filter_errors(attribute_checker.normalize_and_verify(ir)), + ) + + def test_adds_false_is_signed_attribute(self): + ir = _make_ir_from_emb("enum Foo:\n" " ZERO = 0\n") + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + enum = ir.module[0].type[0] + is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED) + self.assertTrue(is_signed_attr.expression.HasField("boolean_constant")) + self.assertFalse(is_signed_attr.expression.boolean_constant.value) + + def test_leaves_is_signed_attribute(self): + ir = _make_ir_from_emb("enum Foo:\n" " [is_signed: true]\n" " ZERO = 0\n") + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + enum = ir.module[0].type[0] + is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED) + self.assertTrue(is_signed_attr.expression.HasField("boolean_constant")) + self.assertTrue(is_signed_attr.expression.boolean_constant.value) + + def test_adds_true_is_signed_attribute(self): + ir = _make_ir_from_emb("enum Foo:\n" " NEGATIVE_ONE = -1\n") + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + enum = ir.module[0].type[0] + is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED) + self.assertTrue(is_signed_attr.expression.HasField("boolean_constant")) + self.assertTrue(is_signed_attr.expression.boolean_constant.value) + + def test_adds_max_bits_attribute(self): + ir = _make_ir_from_emb("enum Foo:\n" " ZERO = 0\n") + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + enum = ir.module[0].type[0] + max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS) + self.assertTrue(max_bits_attr.expression.HasField("constant")) + self.assertEqual("64", max_bits_attr.expression.constant.value) + + def test_leaves_max_bits_attribute(self): + ir = _make_ir_from_emb("enum Foo:\n" " [maximum_bits: 32]\n" " ZERO = 0\n") + self.assertEqual([], attribute_checker.normalize_and_verify(ir)) + enum = ir.module[0].type[0] + max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS) + self.assertTrue(max_bits_attr.expression.HasField("constant")) + self.assertEqual("32", max_bits_attr.expression.constant.value) + + def test_rejects_too_small_max_bits(self): + ir = _make_ir_from_emb("enum Foo:\n" " [maximum_bits: 0]\n" " ZERO = 0\n") + attribute_ir = ir.module[0].type[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attribute_ir.value.source_location, + "'maximum_bits' on an 'enum' must be between 1 and 64.", + ) + ] + ], + error.filter_errors(attribute_checker.normalize_and_verify(ir)), + ) + + def test_rejects_too_large_max_bits(self): + ir = _make_ir_from_emb("enum Foo:\n" " [maximum_bits: 65]\n" " ZERO = 0\n") + attribute_ir = ir.module[0].type[0].attribute[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + attribute_ir.value.source_location, + "'maximum_bits' on an 'enum' must be between 1 and 64.", + ) + ] + ], + error.filter_errors(attribute_checker.normalize_and_verify(ir)), + ) + + def test_rejects_unknown_enum_value_attribute(self): + ir = _make_ir_from_emb("enum Foo:\n" " BAR = 0 \n" " [bad_attr: true]\n") + attribute_ir = ir.module[0].type[0].enumeration.value[0].attribute[0] + self.assertNotEqual([], attribute_checker.normalize_and_verify(ir)) + self.assertEqual( + [ + [ + error.error( + "m.emb", + attribute_ir.name.source_location, + "Unknown attribute 'bad_attr' on enum value 'BAR'.", + ) + ] + ], + error.filter_errors(attribute_checker.normalize_and_verify(ir)), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py index fabcf04..3249e6d 100644 --- a/compiler/front_end/constraints.py +++ b/compiler/front_end/constraints.py @@ -24,588 +24,764 @@ def _render_type(type_ir, ir): - """Returns the human-readable notation of the given type.""" - assert type_ir.HasField("atomic_type"), ( - "TODO(bolms): Implement _render_type for array types.") - if type_ir.HasField("size_in_bits"): - return _render_atomic_type_name( - type_ir, - ir, - suffix=":" + str(ir_util.constant_value(type_ir.size_in_bits))) - else: - return _render_atomic_type_name(type_ir, ir) + """Returns the human-readable notation of the given type.""" + assert type_ir.HasField( + "atomic_type" + ), "TODO(bolms): Implement _render_type for array types." + if type_ir.HasField("size_in_bits"): + return _render_atomic_type_name( + type_ir, ir, suffix=":" + str(ir_util.constant_value(type_ir.size_in_bits)) + ) + else: + return _render_atomic_type_name(type_ir, ir) def _render_atomic_type_name(type_ir, ir, suffix=None): - assert type_ir.HasField("atomic_type"), ( - "_render_atomic_type_name() requires an atomic type") - if not suffix: - suffix = "" - type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir) - if type_definition.name.is_anonymous: - return "anonymous type" - else: - return "type '{}{}'".format(type_definition.name.name.text, suffix) - - -def _check_that_inner_array_dimensions_are_constant( - type_ir, source_file_name, errors): - """Checks that inner array dimensions are constant.""" - if type_ir.WhichOneof("size") == "automatic": - errors.append([error.error( - source_file_name, - ir_data_utils.reader(type_ir).element_count.source_location, - "Array dimensions can only be omitted for the outermost dimension.")]) - elif type_ir.WhichOneof("size") == "element_count": - if not ir_util.is_constant(type_ir.element_count): - errors.append([error.error(source_file_name, - type_ir.element_count.source_location, - "Inner array dimensions must be constant.")]) - else: - assert False, 'Expected "element_count" or "automatic" array size.' - - -def _check_that_array_base_types_are_fixed_size(type_ir, source_file_name, - errors, ir): - """Checks that the sizes of array elements are known at compile time.""" - if type_ir.base_type.HasField("array_type"): - # An array is fixed size if its base_type is fixed size and its array - # dimension is constant. This function will be called again on the inner - # array, and we do not want to cascade errors if the inner array's base_type - # is not fixed size. The array dimensions are separately checked by - # _check_that_inner_array_dimensions_are_constant, which will provide an - # appropriate error message for that case. - return - assert type_ir.base_type.HasField("atomic_type") - if type_ir.base_type.HasField("size_in_bits"): - # If the base_type has a size_in_bits, then it is fixed size. - return - base_type = ir_util.find_object(type_ir.base_type.atomic_type.reference, ir) - base_type_fixed_size = ir_util.get_integer_attribute( - base_type.attribute, attributes.FIXED_SIZE) - if base_type_fixed_size is None: - errors.append([error.error(source_file_name, - type_ir.base_type.atomic_type.source_location, - "Array elements must be fixed size.")]) + assert type_ir.HasField( + "atomic_type" + ), "_render_atomic_type_name() requires an atomic type" + if not suffix: + suffix = "" + type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir) + if type_definition.name.is_anonymous: + return "anonymous type" + else: + return "type '{}{}'".format(type_definition.name.name.text, suffix) + + +def _check_that_inner_array_dimensions_are_constant(type_ir, source_file_name, errors): + """Checks that inner array dimensions are constant.""" + if type_ir.WhichOneof("size") == "automatic": + errors.append( + [ + error.error( + source_file_name, + ir_data_utils.reader(type_ir).element_count.source_location, + "Array dimensions can only be omitted for the outermost dimension.", + ) + ] + ) + elif type_ir.WhichOneof("size") == "element_count": + if not ir_util.is_constant(type_ir.element_count): + errors.append( + [ + error.error( + source_file_name, + type_ir.element_count.source_location, + "Inner array dimensions must be constant.", + ) + ] + ) + else: + assert False, 'Expected "element_count" or "automatic" array size.' + + +def _check_that_array_base_types_are_fixed_size(type_ir, source_file_name, errors, ir): + """Checks that the sizes of array elements are known at compile time.""" + if type_ir.base_type.HasField("array_type"): + # An array is fixed size if its base_type is fixed size and its array + # dimension is constant. This function will be called again on the inner + # array, and we do not want to cascade errors if the inner array's base_type + # is not fixed size. The array dimensions are separately checked by + # _check_that_inner_array_dimensions_are_constant, which will provide an + # appropriate error message for that case. + return + assert type_ir.base_type.HasField("atomic_type") + if type_ir.base_type.HasField("size_in_bits"): + # If the base_type has a size_in_bits, then it is fixed size. + return + base_type = ir_util.find_object(type_ir.base_type.atomic_type.reference, ir) + base_type_fixed_size = ir_util.get_integer_attribute( + base_type.attribute, attributes.FIXED_SIZE + ) + if base_type_fixed_size is None: + errors.append( + [ + error.error( + source_file_name, + type_ir.base_type.atomic_type.source_location, + "Array elements must be fixed size.", + ) + ] + ) def _check_that_array_base_types_in_structs_are_multiples_of_bytes( - type_ir, type_definition, source_file_name, errors, ir): - # TODO(bolms): Remove this limitation. - """Checks that the sizes of array elements are multiples of 8 bits.""" - if type_ir.base_type.HasField("array_type"): - # Only check the innermost array for multidimensional arrays. - return - assert type_ir.base_type.HasField("atomic_type") - if type_ir.base_type.HasField("size_in_bits"): - assert ir_util.is_constant(type_ir.base_type.size_in_bits) - base_type_size = ir_util.constant_value(type_ir.base_type.size_in_bits) - else: - fixed_size = ir_util.fixed_size_of_type_in_bits(type_ir.base_type, ir) - if fixed_size is None: - # Variable-sized elements are checked elsewhere. - return - base_type_size = fixed_size - if base_type_size % type_definition.addressable_unit != 0: - assert type_definition.addressable_unit == ir_data.AddressableUnit.BYTE - errors.append([error.error(source_file_name, - type_ir.base_type.source_location, - "Array elements in structs must have sizes " - "which are a multiple of 8 bits.")]) - - -def _check_constancy_of_constant_references(expression, source_file_name, - errors, ir): - """Checks that constant_references are constant.""" - if expression.WhichOneof("expression") != "constant_reference": - return - # This is a bit of a hack: really, we want to know that the referred-to object - # has no dependencies on any instance variables of its parent structure; i.e., - # that its value does not depend on having a view of the structure. - if not ir_util.is_constant_type(expression.type): - referred_name = expression.constant_reference.canonical_name - referred_object = ir_util.find_object(referred_name, ir) - errors.append([ - error.error( - source_file_name, expression.source_location, - "Static references must refer to constants."), - error.note( - referred_name.module_file, referred_object.source_location, - "{} is not constant.".format(referred_name.object_path[-1])) - ]) - - -def _check_that_enum_values_are_representable(enum_type, type_definition, - source_file_name, errors): - """Checks that enumeration values can fit in their specified int type.""" - values = [] - max_enum_size = ir_util.get_integer_attribute( - type_definition.attribute, attributes.ENUM_MAXIMUM_BITS) - is_signed = ir_util.get_boolean_attribute( - type_definition.attribute, attributes.IS_SIGNED) - if is_signed: - enum_range = (-(2**(max_enum_size-1)), 2**(max_enum_size-1)-1) - else: - enum_range = (0, 2**max_enum_size-1) - for value in enum_type.value: - values.append((ir_util.constant_value(value.value), value)) - out_of_range = [v for v in values - if not enum_range[0] <= v[0] <= enum_range[1]] - # If all values are in range, this loop will have zero iterations. - for value in out_of_range: - errors.append([ - error.error( - source_file_name, value[1].value.source_location, - "Value {} is out of range for {}-bit {} enumeration.".format( - value[0], max_enum_size, "signed" if is_signed else "unsigned")) - ]) + type_ir, type_definition, source_file_name, errors, ir +): + # TODO(bolms): Remove this limitation. + """Checks that the sizes of array elements are multiples of 8 bits.""" + if type_ir.base_type.HasField("array_type"): + # Only check the innermost array for multidimensional arrays. + return + assert type_ir.base_type.HasField("atomic_type") + if type_ir.base_type.HasField("size_in_bits"): + assert ir_util.is_constant(type_ir.base_type.size_in_bits) + base_type_size = ir_util.constant_value(type_ir.base_type.size_in_bits) + else: + fixed_size = ir_util.fixed_size_of_type_in_bits(type_ir.base_type, ir) + if fixed_size is None: + # Variable-sized elements are checked elsewhere. + return + base_type_size = fixed_size + if base_type_size % type_definition.addressable_unit != 0: + assert type_definition.addressable_unit == ir_data.AddressableUnit.BYTE + errors.append( + [ + error.error( + source_file_name, + type_ir.base_type.source_location, + "Array elements in structs must have sizes " + "which are a multiple of 8 bits.", + ) + ] + ) + + +def _check_constancy_of_constant_references(expression, source_file_name, errors, ir): + """Checks that constant_references are constant.""" + if expression.WhichOneof("expression") != "constant_reference": + return + # This is a bit of a hack: really, we want to know that the referred-to object + # has no dependencies on any instance variables of its parent structure; i.e., + # that its value does not depend on having a view of the structure. + if not ir_util.is_constant_type(expression.type): + referred_name = expression.constant_reference.canonical_name + referred_object = ir_util.find_object(referred_name, ir) + errors.append( + [ + error.error( + source_file_name, + expression.source_location, + "Static references must refer to constants.", + ), + error.note( + referred_name.module_file, + referred_object.source_location, + "{} is not constant.".format(referred_name.object_path[-1]), + ), + ] + ) + + +def _check_that_enum_values_are_representable( + enum_type, type_definition, source_file_name, errors +): + """Checks that enumeration values can fit in their specified int type.""" + values = [] + max_enum_size = ir_util.get_integer_attribute( + type_definition.attribute, attributes.ENUM_MAXIMUM_BITS + ) + is_signed = ir_util.get_boolean_attribute( + type_definition.attribute, attributes.IS_SIGNED + ) + if is_signed: + enum_range = (-(2 ** (max_enum_size - 1)), 2 ** (max_enum_size - 1) - 1) + else: + enum_range = (0, 2**max_enum_size - 1) + for value in enum_type.value: + values.append((ir_util.constant_value(value.value), value)) + out_of_range = [v for v in values if not enum_range[0] <= v[0] <= enum_range[1]] + # If all values are in range, this loop will have zero iterations. + for value in out_of_range: + errors.append( + [ + error.error( + source_file_name, + value[1].value.source_location, + "Value {} is out of range for {}-bit {} enumeration.".format( + value[0], max_enum_size, "signed" if is_signed else "unsigned" + ), + ) + ] + ) def _field_size(field, type_definition): - """Calculates the size of the given field in bits, if it is constant.""" - size = ir_util.constant_value(field.location.size) - if size is None: - return None - return size * type_definition.addressable_unit - - -def _check_type_requirements_for_field(type_ir, type_definition, field, ir, - source_file_name, errors): - """Checks that the `requires` attribute of each field's type is fulfilled.""" - if not type_ir.HasField("atomic_type"): - return - - if field.type.HasField("atomic_type"): - field_min_size = (int(field.location.size.type.integer.minimum_value) * - type_definition.addressable_unit) - field_max_size = (int(field.location.size.type.integer.maximum_value) * - type_definition.addressable_unit) - field_is_atomic = True - else: - field_is_atomic = False - - if type_ir.HasField("size_in_bits"): - element_size = ir_util.constant_value(type_ir.size_in_bits) - else: - element_size = None - - referenced_type_definition = ir_util.find_object( - type_ir.atomic_type.reference, ir) - type_is_anonymous = referenced_type_definition.name.is_anonymous - type_size_attr = ir_util.get_attribute( - referenced_type_definition.attribute, attributes.FIXED_SIZE) - if type_size_attr: - type_size = ir_util.constant_value(type_size_attr.expression) - else: - type_size = None - - if (element_size is not None and type_size is not None and - element_size != type_size): - errors.append([ - error.error( - source_file_name, type_ir.size_in_bits.source_location, - "Explicit size of {} bits does not match fixed size ({} bits) of " - "{}.".format(element_size, type_size, - _render_atomic_type_name(type_ir, ir))), - error.note( - type_ir.atomic_type.reference.canonical_name.module_file, - type_size_attr.source_location, - "Size specified here.") - ]) - return - - # If the type had no size specifier (the ':32' in 'UInt:32'), but the type is - # fixed size, then continue as if the type's size were explicitly stated. - if element_size is None: - element_size = type_size - - # TODO(bolms): When the full dynamic size expression for types is generated, - # add a check that dynamically-sized types can, at least potentially, fit in - # their fields. - - if field_is_atomic and element_size is not None: - # If the field has a fixed size, and the (atomic) type contained therein is - # also fixed size, then the sizes should match. - # - # TODO(bolms): Maybe change the case where the field is bigger than - # necessary into a warning? - if (field_max_size == field_min_size and - (element_size > field_max_size or - (element_size < field_min_size and not type_is_anonymous))): - errors.append([ - error.error( - source_file_name, type_ir.source_location, - "Fixed-size {} cannot be placed in field of size {} bits; " - "requires {} bits.".format( - _render_type(type_ir, ir), field_max_size, element_size)) - ]) - return - elif element_size > field_max_size: - errors.append([ - error.error( - source_file_name, type_ir.source_location, - "Field of maximum size {} bits cannot hold fixed-size {}, which " - "requires {} bits.".format( - field_max_size, _render_type(type_ir, ir), element_size)) - ]) - return - - # If we're here, then field/type sizes are consistent. - if (element_size is None and field_is_atomic and - field_min_size == field_max_size): - # From here down, we just use element_size. - element_size = field_min_size - - errors.extend(_check_physical_type_requirements( - type_ir, field.source_location, element_size, ir, source_file_name)) + """Calculates the size of the given field in bits, if it is constant.""" + size = ir_util.constant_value(field.location.size) + if size is None: + return None + return size * type_definition.addressable_unit + + +def _check_type_requirements_for_field( + type_ir, type_definition, field, ir, source_file_name, errors +): + """Checks that the `requires` attribute of each field's type is fulfilled.""" + if not type_ir.HasField("atomic_type"): + return + + if field.type.HasField("atomic_type"): + field_min_size = ( + int(field.location.size.type.integer.minimum_value) + * type_definition.addressable_unit + ) + field_max_size = ( + int(field.location.size.type.integer.maximum_value) + * type_definition.addressable_unit + ) + field_is_atomic = True + else: + field_is_atomic = False + + if type_ir.HasField("size_in_bits"): + element_size = ir_util.constant_value(type_ir.size_in_bits) + else: + element_size = None + + referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir) + type_is_anonymous = referenced_type_definition.name.is_anonymous + type_size_attr = ir_util.get_attribute( + referenced_type_definition.attribute, attributes.FIXED_SIZE + ) + if type_size_attr: + type_size = ir_util.constant_value(type_size_attr.expression) + else: + type_size = None + + if element_size is not None and type_size is not None and element_size != type_size: + errors.append( + [ + error.error( + source_file_name, + type_ir.size_in_bits.source_location, + "Explicit size of {} bits does not match fixed size ({} bits) of " + "{}.".format( + element_size, type_size, _render_atomic_type_name(type_ir, ir) + ), + ), + error.note( + type_ir.atomic_type.reference.canonical_name.module_file, + type_size_attr.source_location, + "Size specified here.", + ), + ] + ) + return + + # If the type had no size specifier (the ':32' in 'UInt:32'), but the type is + # fixed size, then continue as if the type's size were explicitly stated. + if element_size is None: + element_size = type_size + + # TODO(bolms): When the full dynamic size expression for types is generated, + # add a check that dynamically-sized types can, at least potentially, fit in + # their fields. + + if field_is_atomic and element_size is not None: + # If the field has a fixed size, and the (atomic) type contained therein is + # also fixed size, then the sizes should match. + # + # TODO(bolms): Maybe change the case where the field is bigger than + # necessary into a warning? + if field_max_size == field_min_size and ( + element_size > field_max_size + or (element_size < field_min_size and not type_is_anonymous) + ): + errors.append( + [ + error.error( + source_file_name, + type_ir.source_location, + "Fixed-size {} cannot be placed in field of size {} bits; " + "requires {} bits.".format( + _render_type(type_ir, ir), field_max_size, element_size + ), + ) + ] + ) + return + elif element_size > field_max_size: + errors.append( + [ + error.error( + source_file_name, + type_ir.source_location, + "Field of maximum size {} bits cannot hold fixed-size {}, which " + "requires {} bits.".format( + field_max_size, _render_type(type_ir, ir), element_size + ), + ) + ] + ) + return + + # If we're here, then field/type sizes are consistent. + if element_size is None and field_is_atomic and field_min_size == field_max_size: + # From here down, we just use element_size. + element_size = field_min_size + + errors.extend( + _check_physical_type_requirements( + type_ir, field.source_location, element_size, ir, source_file_name + ) + ) def _check_type_requirements_for_parameter_type( - runtime_parameter, ir, source_file_name, errors): - """Checks that the type of a parameter is valid.""" - physical_type = runtime_parameter.physical_type_alias - logical_type = runtime_parameter.type - size = ir_util.constant_value(physical_type.size_in_bits) - if logical_type.WhichOneof("type") == "integer": - integer_errors = _integer_bounds_errors( - logical_type.integer, "parameter", source_file_name, - physical_type.source_location) - if integer_errors: - errors.extend(integer_errors) - return - errors.extend(_check_physical_type_requirements( - physical_type, runtime_parameter.source_location, - size, ir, source_file_name)) - elif logical_type.WhichOneof("type") == "enumeration": - if physical_type.HasField("size_in_bits"): - # This seems a little weird: for `UInt`, `Int`, etc., the explicit size is - # required, but for enums it is banned. This is because enums have a - # "native" 64-bit size in expressions, so the physical size is just - # ignored. - errors.extend([[ - error.error( - source_file_name, physical_type.size_in_bits.source_location, - "Parameters with enum type may not have explicit size.") - - ]]) - else: - assert False, "Non-integer/enum parameters should have been caught earlier." + runtime_parameter, ir, source_file_name, errors +): + """Checks that the type of a parameter is valid.""" + physical_type = runtime_parameter.physical_type_alias + logical_type = runtime_parameter.type + size = ir_util.constant_value(physical_type.size_in_bits) + if logical_type.WhichOneof("type") == "integer": + integer_errors = _integer_bounds_errors( + logical_type.integer, + "parameter", + source_file_name, + physical_type.source_location, + ) + if integer_errors: + errors.extend(integer_errors) + return + errors.extend( + _check_physical_type_requirements( + physical_type, + runtime_parameter.source_location, + size, + ir, + source_file_name, + ) + ) + elif logical_type.WhichOneof("type") == "enumeration": + if physical_type.HasField("size_in_bits"): + # This seems a little weird: for `UInt`, `Int`, etc., the explicit size is + # required, but for enums it is banned. This is because enums have a + # "native" 64-bit size in expressions, so the physical size is just + # ignored. + errors.extend( + [ + [ + error.error( + source_file_name, + physical_type.size_in_bits.source_location, + "Parameters with enum type may not have explicit size.", + ) + ] + ] + ) + else: + assert False, "Non-integer/enum parameters should have been caught earlier." def _check_physical_type_requirements( - type_ir, usage_source_location, size, ir, source_file_name): - """Checks that the given atomic `type_ir` is allowed to be `size` bits.""" - referenced_type_definition = ir_util.find_object( - type_ir.atomic_type.reference, ir) - if referenced_type_definition.HasField("enumeration"): + type_ir, usage_source_location, size, ir, source_file_name +): + """Checks that the given atomic `type_ir` is allowed to be `size` bits.""" + referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir) + if referenced_type_definition.HasField("enumeration"): + if size is None: + return [ + [ + error.error( + source_file_name, + type_ir.source_location, + "Enumeration {} cannot be placed in a dynamically-sized " + "field.".format(_render_type(type_ir, ir)), + ) + ] + ] + else: + max_enum_size = ir_util.get_integer_attribute( + referenced_type_definition.attribute, attributes.ENUM_MAXIMUM_BITS + ) + if size < 1 or size > max_enum_size: + return [ + [ + error.error( + source_file_name, + type_ir.source_location, + "Enumeration {} cannot be {} bits; {} must be between " + "1 and {} bits, inclusive.".format( + _render_atomic_type_name(type_ir, ir), + size, + _render_atomic_type_name(type_ir, ir), + max_enum_size, + ), + ) + ] + ] + if size is None: - return [[ - error.error( - source_file_name, type_ir.source_location, - "Enumeration {} cannot be placed in a dynamically-sized " - "field.".format(_render_type(type_ir, ir))) - ]] + bindings = {"$is_statically_sized": False} else: - max_enum_size = ir_util.get_integer_attribute( - referenced_type_definition.attribute, attributes.ENUM_MAXIMUM_BITS) - if size < 1 or size > max_enum_size: - return [[ - error.error( - source_file_name, type_ir.source_location, - "Enumeration {} cannot be {} bits; {} must be between " - "1 and {} bits, inclusive.".format( - _render_atomic_type_name(type_ir, ir), size, - _render_atomic_type_name(type_ir, ir), max_enum_size)) - ]] - - if size is None: - bindings = {"$is_statically_sized": False} - else: - bindings = { - "$is_statically_sized": True, - "$static_size_in_bits": size - } - requires_attr = ir_util.get_attribute( - referenced_type_definition.attribute, attributes.STATIC_REQUIREMENTS) - if requires_attr and not ir_util.constant_value(requires_attr.expression, - bindings): - # TODO(bolms): Figure out a better way to build this error message. - # The "Requirements specified here." message should print out the actual - # source text of the requires attribute, so that should help, but it's still - # a bit generic and unfriendly. - return [[ - error.error( - source_file_name, usage_source_location, - "Requirements of {} not met.".format( - type_ir.atomic_type.reference.canonical_name.object_path[-1])), - error.note( - type_ir.atomic_type.reference.canonical_name.module_file, - requires_attr.source_location, - "Requirements specified here.") - ]] - return [] - - -def _check_allowed_in_bits(type_ir, type_definition, source_file_name, ir, - errors): - if not type_ir.HasField("atomic_type"): - return - referenced_type_definition = ir_util.find_object( - type_ir.atomic_type.reference, ir) - if (type_definition.addressable_unit % - referenced_type_definition.addressable_unit != 0): - assert type_definition.addressable_unit == ir_data.AddressableUnit.BIT - assert (referenced_type_definition.addressable_unit == - ir_data.AddressableUnit.BYTE) - errors.append([ - error.error(source_file_name, type_ir.source_location, + bindings = {"$is_statically_sized": True, "$static_size_in_bits": size} + requires_attr = ir_util.get_attribute( + referenced_type_definition.attribute, attributes.STATIC_REQUIREMENTS + ) + if requires_attr and not ir_util.constant_value(requires_attr.expression, bindings): + # TODO(bolms): Figure out a better way to build this error message. + # The "Requirements specified here." message should print out the actual + # source text of the requires attribute, so that should help, but it's still + # a bit generic and unfriendly. + return [ + [ + error.error( + source_file_name, + usage_source_location, + "Requirements of {} not met.".format( + type_ir.atomic_type.reference.canonical_name.object_path[-1] + ), + ), + error.note( + type_ir.atomic_type.reference.canonical_name.module_file, + requires_attr.source_location, + "Requirements specified here.", + ), + ] + ] + return [] + + +def _check_allowed_in_bits(type_ir, type_definition, source_file_name, ir, errors): + if not type_ir.HasField("atomic_type"): + return + referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir) + if ( + type_definition.addressable_unit % referenced_type_definition.addressable_unit + != 0 + ): + assert type_definition.addressable_unit == ir_data.AddressableUnit.BIT + assert ( + referenced_type_definition.addressable_unit == ir_data.AddressableUnit.BYTE + ) + errors.append( + [ + error.error( + source_file_name, + type_ir.source_location, "Byte-oriented {} cannot be used in a bits field.".format( - _render_type(type_ir, ir))) - ]) + _render_type(type_ir, ir) + ), + ) + ] + ) def _check_size_of_bits(type_ir, type_definition, source_file_name, errors): - """Checks that `bits` types are fixed size, less than 64 bits.""" - del type_ir # Unused - if type_definition.addressable_unit != ir_data.AddressableUnit.BIT: - return - fixed_size = ir_util.get_integer_attribute( - type_definition.attribute, attributes.FIXED_SIZE) - if fixed_size is None: - errors.append([error.error(source_file_name, - type_definition.source_location, - "`bits` types must be fixed size.")]) - return - if fixed_size > 64: - errors.append([error.error(source_file_name, - type_definition.source_location, - "`bits` types must be 64 bits or smaller.")]) + """Checks that `bits` types are fixed size, less than 64 bits.""" + del type_ir # Unused + if type_definition.addressable_unit != ir_data.AddressableUnit.BIT: + return + fixed_size = ir_util.get_integer_attribute( + type_definition.attribute, attributes.FIXED_SIZE + ) + if fixed_size is None: + errors.append( + [ + error.error( + source_file_name, + type_definition.source_location, + "`bits` types must be fixed size.", + ) + ] + ) + return + if fixed_size > 64: + errors.append( + [ + error.error( + source_file_name, + type_definition.source_location, + "`bits` types must be 64 bits or smaller.", + ) + ] + ) _RESERVED_WORDS = None def get_reserved_word_list(): - if _RESERVED_WORDS is None: - _initialize_reserved_word_list() - return _RESERVED_WORDS + if _RESERVED_WORDS is None: + _initialize_reserved_word_list() + return _RESERVED_WORDS def _initialize_reserved_word_list(): - global _RESERVED_WORDS - _RESERVED_WORDS = {} - language = None - for line in resources.load( - "compiler.front_end", "reserved_words").splitlines(): - stripped_line = line.partition("#")[0].strip() - if not stripped_line: - continue - if stripped_line.startswith("--"): - language = stripped_line.partition("--")[2].strip() - else: - # For brevity's sake, only use the first language for error messages. - if stripped_line not in _RESERVED_WORDS: - _RESERVED_WORDS[stripped_line] = language + global _RESERVED_WORDS + _RESERVED_WORDS = {} + language = None + for line in resources.load("compiler.front_end", "reserved_words").splitlines(): + stripped_line = line.partition("#")[0].strip() + if not stripped_line: + continue + if stripped_line.startswith("--"): + language = stripped_line.partition("--")[2].strip() + else: + # For brevity's sake, only use the first language for error messages. + if stripped_line not in _RESERVED_WORDS: + _RESERVED_WORDS[stripped_line] = language def _check_name_for_reserved_words(obj, source_file_name, errors, context_name): - if obj.name.name.text in get_reserved_word_list(): - errors.append([ - error.error( - source_file_name, obj.name.name.source_location, - "{} reserved word may not be used as {}.".format( - get_reserved_word_list()[obj.name.name.text], - context_name)) - ]) + if obj.name.name.text in get_reserved_word_list(): + errors.append( + [ + error.error( + source_file_name, + obj.name.name.source_location, + "{} reserved word may not be used as {}.".format( + get_reserved_word_list()[obj.name.name.text], context_name + ), + ) + ] + ) def _check_field_name_for_reserved_words(field, source_file_name, errors): - return _check_name_for_reserved_words(field, source_file_name, errors, - "a field name") + return _check_name_for_reserved_words( + field, source_file_name, errors, "a field name" + ) def _check_enum_name_for_reserved_words(enum, source_file_name, errors): - return _check_name_for_reserved_words(enum, source_file_name, errors, - "an enum name") + return _check_name_for_reserved_words( + enum, source_file_name, errors, "an enum name" + ) -def _check_type_name_for_reserved_words(type_definition, source_file_name, - errors): - return _check_name_for_reserved_words( - type_definition, source_file_name, errors, "a type name") +def _check_type_name_for_reserved_words(type_definition, source_file_name, errors): + return _check_name_for_reserved_words( + type_definition, source_file_name, errors, "a type name" + ) def _bounds_can_fit_64_bit_unsigned(minimum, maximum): - return minimum >= 0 and maximum <= 2**64 - 1 + return minimum >= 0 and maximum <= 2**64 - 1 def _bounds_can_fit_64_bit_signed(minimum, maximum): - return minimum >= -(2**63) and maximum <= 2**63 - 1 + return minimum >= -(2**63) and maximum <= 2**63 - 1 def _bounds_can_fit_any_64_bit_integer_type(minimum, maximum): - return (_bounds_can_fit_64_bit_unsigned(minimum, maximum) or - _bounds_can_fit_64_bit_signed(minimum, maximum)) + return _bounds_can_fit_64_bit_unsigned( + minimum, maximum + ) or _bounds_can_fit_64_bit_signed(minimum, maximum) def _integer_bounds_errors_for_expression(expression, source_file_name): - """Checks that `expression` is in range for int64_t or uint64_t.""" - # Only check non-constant subexpressions. - if (expression.WhichOneof("expression") == "function" and - not ir_util.is_constant_type(expression.type)): - errors = [] - for arg in expression.function.args: - errors += _integer_bounds_errors_for_expression(arg, source_file_name) - if errors: - # Don't cascade bounds errors: report them at the lowest level they - # appear. - return errors - if expression.type.WhichOneof("type") == "integer": - errors = _integer_bounds_errors(expression.type.integer, "expression", - source_file_name, - expression.source_location) - if errors: - return errors - if (expression.WhichOneof("expression") == "function" and - not ir_util.is_constant_type(expression.type)): - int64_only_clauses = [] - uint64_only_clauses = [] - for clause in [expression] + list(expression.function.args): - if clause.type.WhichOneof("type") == "integer": - arg_minimum = int(clause.type.integer.minimum_value) - arg_maximum = int(clause.type.integer.maximum_value) - if not _bounds_can_fit_64_bit_signed(arg_minimum, arg_maximum): - uint64_only_clauses.append(clause) - elif not _bounds_can_fit_64_bit_unsigned(arg_minimum, arg_maximum): - int64_only_clauses.append(clause) - if int64_only_clauses and uint64_only_clauses: - error_set = [ - error.error( - source_file_name, expression.source_location, - "Either all arguments to '{}' and its result must fit in a " - "64-bit unsigned integer, or all must fit in a 64-bit signed " - "integer.".format(expression.function.function_name.text)) - ] - for signedness, clause_list in (("unsigned", uint64_only_clauses), - ("signed", int64_only_clauses)): - for clause in clause_list: - error_set.append(error.note( - source_file_name, clause.source_location, - "Requires {} 64-bit integer.".format(signedness))) - return [error_set] - return [] - - -def _integer_bounds_errors(bounds, name, source_file_name, - error_source_location): - """Returns appropriate errors, if any, for the given integer bounds.""" - assert bounds.minimum_value, "{}".format(bounds) - assert bounds.maximum_value, "{}".format(bounds) - if (bounds.minimum_value == "-infinity" or - bounds.maximum_value == "infinity"): - return [[ - error.error( - source_file_name, error_source_location, - "Integer range of {} must not be unbounded; it must fit " - "in a 64-bit signed or unsigned integer.".format(name)) - ]] - if not _bounds_can_fit_any_64_bit_integer_type(int(bounds.minimum_value), - int(bounds.maximum_value)): - if int(bounds.minimum_value) == int(bounds.maximum_value): - return [[ - error.error( - source_file_name, error_source_location, - "Constant value {} of {} cannot fit in a 64-bit signed or " - "unsigned integer.".format(bounds.minimum_value, name)) - ]] - else: - return [[ - error.error( - source_file_name, error_source_location, - "Potential range of {} is {} to {}, which cannot fit " - "in a 64-bit signed or unsigned integer.".format( - name, bounds.minimum_value, bounds.maximum_value)) - ]] - return [] - - -def _check_bounds_on_runtime_integer_expressions(expression, source_file_name, - in_attribute, errors): - if in_attribute and in_attribute.name.text == attributes.STATIC_REQUIREMENTS: - # [static_requirements] is never evaluated at runtime, and $size_in_bits is - # unbounded, so it should not be checked. - return - # The logic for gathering errors and suppressing cascades is simpler if - # errors are just returned, rather than appended to a shared list. - errors += _integer_bounds_errors_for_expression(expression, source_file_name) + """Checks that `expression` is in range for int64_t or uint64_t.""" + # Only check non-constant subexpressions. + if expression.WhichOneof( + "expression" + ) == "function" and not ir_util.is_constant_type(expression.type): + errors = [] + for arg in expression.function.args: + errors += _integer_bounds_errors_for_expression(arg, source_file_name) + if errors: + # Don't cascade bounds errors: report them at the lowest level they + # appear. + return errors + if expression.type.WhichOneof("type") == "integer": + errors = _integer_bounds_errors( + expression.type.integer, + "expression", + source_file_name, + expression.source_location, + ) + if errors: + return errors + if expression.WhichOneof( + "expression" + ) == "function" and not ir_util.is_constant_type(expression.type): + int64_only_clauses = [] + uint64_only_clauses = [] + for clause in [expression] + list(expression.function.args): + if clause.type.WhichOneof("type") == "integer": + arg_minimum = int(clause.type.integer.minimum_value) + arg_maximum = int(clause.type.integer.maximum_value) + if not _bounds_can_fit_64_bit_signed(arg_minimum, arg_maximum): + uint64_only_clauses.append(clause) + elif not _bounds_can_fit_64_bit_unsigned(arg_minimum, arg_maximum): + int64_only_clauses.append(clause) + if int64_only_clauses and uint64_only_clauses: + error_set = [ + error.error( + source_file_name, + expression.source_location, + "Either all arguments to '{}' and its result must fit in a " + "64-bit unsigned integer, or all must fit in a 64-bit signed " + "integer.".format(expression.function.function_name.text), + ) + ] + for signedness, clause_list in ( + ("unsigned", uint64_only_clauses), + ("signed", int64_only_clauses), + ): + for clause in clause_list: + error_set.append( + error.note( + source_file_name, + clause.source_location, + "Requires {} 64-bit integer.".format(signedness), + ) + ) + return [error_set] + return [] + + +def _integer_bounds_errors(bounds, name, source_file_name, error_source_location): + """Returns appropriate errors, if any, for the given integer bounds.""" + assert bounds.minimum_value, "{}".format(bounds) + assert bounds.maximum_value, "{}".format(bounds) + if bounds.minimum_value == "-infinity" or bounds.maximum_value == "infinity": + return [ + [ + error.error( + source_file_name, + error_source_location, + "Integer range of {} must not be unbounded; it must fit " + "in a 64-bit signed or unsigned integer.".format(name), + ) + ] + ] + if not _bounds_can_fit_any_64_bit_integer_type( + int(bounds.minimum_value), int(bounds.maximum_value) + ): + if int(bounds.minimum_value) == int(bounds.maximum_value): + return [ + [ + error.error( + source_file_name, + error_source_location, + "Constant value {} of {} cannot fit in a 64-bit signed or " + "unsigned integer.".format(bounds.minimum_value, name), + ) + ] + ] + else: + return [ + [ + error.error( + source_file_name, + error_source_location, + "Potential range of {} is {} to {}, which cannot fit " + "in a 64-bit signed or unsigned integer.".format( + name, bounds.minimum_value, bounds.maximum_value + ), + ) + ] + ] + return [] + + +def _check_bounds_on_runtime_integer_expressions( + expression, source_file_name, in_attribute, errors +): + if in_attribute and in_attribute.name.text == attributes.STATIC_REQUIREMENTS: + # [static_requirements] is never evaluated at runtime, and $size_in_bits is + # unbounded, so it should not be checked. + return + # The logic for gathering errors and suppressing cascades is simpler if + # errors are just returned, rather than appended to a shared list. + errors += _integer_bounds_errors_for_expression(expression, source_file_name) + def _attribute_in_attribute_action(a): - return {"in_attribute": a} + return {"in_attribute": a} + def check_constraints(ir): - """Checks miscellaneous validity constraints in ir. - - Checks that auto array sizes are only used for the outermost size of - multidimensional arrays. That is, Type[3][] is OK, but Type[][3] is not. - - Checks that fixed-size fields are a correct size to hold statically-sized - types. - - Checks that inner array dimensions are constant. - - Checks that only constant-size types are used in arrays. - - Arguments: - ir: An ir_data.EmbossIr object to check. - - Returns: - A list of ConstraintViolations, or an empty list if there are none. - """ - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure, ir_data.Type], _check_allowed_in_bits, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - # TODO(bolms): look for [ir_data.ArrayType], [ir_data.AtomicType], and - # simplify _check_that_array_base_types_are_fixed_size. - ir, [ir_data.ArrayType], _check_that_array_base_types_are_fixed_size, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure, ir_data.ArrayType], - _check_that_array_base_types_in_structs_are_multiples_of_bytes, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.ArrayType, ir_data.ArrayType], - _check_that_inner_array_dimensions_are_constant, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure], _check_size_of_bits, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure, ir_data.Type], _check_type_requirements_for_field, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Field], _check_field_name_for_reserved_words, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.EnumValue], _check_enum_name_for_reserved_words, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.TypeDefinition], _check_type_name_for_reserved_words, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Expression], _check_constancy_of_constant_references, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Enum], _check_that_enum_values_are_representable, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Expression], _check_bounds_on_runtime_integer_expressions, - incidental_actions={ir_data.Attribute: _attribute_in_attribute_action}, - skip_descendants_of={ir_data.EnumValue, ir_data.Expression}, - parameters={"errors": errors, "in_attribute": None}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.RuntimeParameter], - _check_type_requirements_for_parameter_type, - parameters={"errors": errors}) - return errors + """Checks miscellaneous validity constraints in ir. + + Checks that auto array sizes are only used for the outermost size of + multidimensional arrays. That is, Type[3][] is OK, but Type[][3] is not. + + Checks that fixed-size fields are a correct size to hold statically-sized + types. + + Checks that inner array dimensions are constant. + + Checks that only constant-size types are used in arrays. + + Arguments: + ir: An ir_data.EmbossIr object to check. + + Returns: + A list of ConstraintViolations, or an empty list if there are none. + """ + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Structure, ir_data.Type], + _check_allowed_in_bits, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + # TODO(bolms): look for [ir_data.ArrayType], [ir_data.AtomicType], and + # simplify _check_that_array_base_types_are_fixed_size. + ir, + [ir_data.ArrayType], + _check_that_array_base_types_are_fixed_size, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Structure, ir_data.ArrayType], + _check_that_array_base_types_in_structs_are_multiples_of_bytes, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.ArrayType, ir_data.ArrayType], + _check_that_inner_array_dimensions_are_constant, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Structure], _check_size_of_bits, parameters={"errors": errors} + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Structure, ir_data.Type], + _check_type_requirements_for_field, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Field], + _check_field_name_for_reserved_words, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.EnumValue], + _check_enum_name_for_reserved_words, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.TypeDefinition], + _check_type_name_for_reserved_words, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Expression], + _check_constancy_of_constant_references, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Enum], + _check_that_enum_values_are_representable, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Expression], + _check_bounds_on_runtime_integer_expressions, + incidental_actions={ir_data.Attribute: _attribute_in_attribute_action}, + skip_descendants_of={ir_data.EnumValue, ir_data.Expression}, + parameters={"errors": errors, "in_attribute": None}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.RuntimeParameter], + _check_type_requirements_for_parameter_type, + parameters={"errors": errors}, + ) + return errors diff --git a/compiler/front_end/constraints_test.py b/compiler/front_end/constraints_test.py index eac1232..4f4df7e 100644 --- a/compiler/front_end/constraints_test.py +++ b/compiler/front_end/constraints_test.py @@ -25,819 +25,1291 @@ def _make_ir_from_emb(emb_text, name="m.emb"): - ir, unused_debug_info, errors = glue.parse_emboss_file( - name, - test_util.dict_file_reader({name: emb_text}), - stop_before_step="check_constraints") - assert not errors, repr(errors) - return ir + ir, unused_debug_info, errors = glue.parse_emboss_file( + name, + test_util.dict_file_reader({name: emb_text}), + stop_before_step="check_constraints", + ) + assert not errors, repr(errors) + return ir class ConstraintsTest(unittest.TestCase): - """Tests constraints.check_constraints and helpers.""" - - def test_error_on_missing_inner_array_size(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt:8[][1] one_byte\n") - # There is a latent issue here where the source location reported in this - # error is using a default value of 0:0. An issue is filed at - # https://github.com/google/emboss/issues/153 for further investigation. - # In the meantime we use `ir_data_utils.reader` to mimic this legacy - # behavior. - error_array = ir_data_utils.reader( - ir.module[0].type[0].structure.field[0].type.array_type) - self.assertEqual([[ - error.error( - "m.emb", - error_array.base_type.array_type.element_count.source_location, - "Array dimensions can only be omitted for the outermost dimension.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_no_error_on_ok_array_size(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt:8[1][1] one_byte\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_no_error_on_ok_missing_outer_array_size(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt:8[1][] one_byte\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_no_error_on_dynamically_sized_struct_in_dynamically_sized_field( - self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt size\n" - " 1 [+size] Bar bar\n" - "struct Bar:\n" - " 0 [+1] UInt size\n" - " 1 [+size] UInt:8[] payload\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_no_error_on_dynamically_sized_struct_in_statically_sized_field(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+10] Bar bar\n" - "struct Bar:\n" - " 0 [+1] UInt size\n" - " 1 [+size] UInt:8[] payload\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_no_error_non_fixed_size_outer_array_dimension(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt size\n" - " 1 [+size] UInt:8[1][size-1] one_byte\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_error_non_fixed_size_inner_array_dimension(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt size\n" - " 1 [+size] UInt:8[size-1][1] one_byte\n") - error_array = ir.module[0].type[0].structure.field[1].type.array_type - self.assertEqual([[ - error.error( - "m.emb", - error_array.base_type.array_type.element_count.source_location, - "Inner array dimensions must be constant.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_error_non_constant_inner_array_dimensions(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] Bar[1] one_byte\n" - # There is no dynamically-sized byte-oriented type in - # the Prelude, so this test has to make its own. - "external Bar:\n" - " [is_integer: true]\n" - " [addressable_unit_size: 8]\n") - error_array = ir.module[0].type[0].structure.field[0].type.array_type - self.assertEqual([[ - error.error( - "m.emb", error_array.base_type.atomic_type.source_location, - "Array elements must be fixed size.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_error_dynamically_sized_array_elements(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] Bar[1] bar\n" - "struct Bar:\n" - " 0 [+1] UInt size\n" - " 1 [+size] UInt:8[] payload\n") - error_array = ir.module[0].type[0].structure.field[0].type.array_type - self.assertEqual([[ - error.error( - "m.emb", error_array.base_type.atomic_type.source_location, - "Array elements must be fixed size.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_field_too_small_for_type(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] Bar bar\n" - "struct Bar:\n" - " 0 [+2] UInt value\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error( - "m.emb", error_type.source_location, - "Fixed-size type 'Bar' cannot be placed in field of size 8 bits; " - "requires 16 bits.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_dynamically_sized_field_always_too_small_for_type(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+1] UInt x\n" - " 0 [+x] Bar bar\n" - "struct Bar:\n" - " 0 [+2] UInt value\n") - error_type = ir.module[0].type[0].structure.field[2].type - self.assertEqual([[ - error.error( - "m.emb", error_type.source_location, - "Field of maximum size 8 bits cannot hold fixed-size type 'Bar', " - "which requires 16 bits.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_struct_field_too_big_for_type(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+2] Byte double_byte\n" - "struct Byte:\n" - " 0 [+1] UInt b\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error( - "m.emb", error_type.source_location, - "Fixed-size type 'Byte' cannot be placed in field of size 16 bits; " - "requires 8 bits.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_bits_field_too_big_for_type(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+9] UInt uint72\n" - ' [byte_order: "LittleEndian"]\n') - error_field = ir.module[0].type[0].structure.field[0] - uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir) - uint_requirements = ir_util.get_attribute(uint_type.attribute, - attributes.STATIC_REQUIREMENTS) - self.assertEqual([[ - error.error("m.emb", error_field.source_location, - "Requirements of UInt not met."), - error.note("", uint_requirements.source_location, - "Requirements specified here."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_field_type_not_allowed_in_bits(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "bits Foo:\n" - " 0 [+16] Bar bar\n" - "external Bar:\n" - " [addressable_unit_size: 8]\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error( - "m.emb", error_type.source_location, - "Byte-oriented type 'Bar' cannot be used in a bits field.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_arrays_allowed_in_bits(self): - ir = _make_ir_from_emb("bits Foo:\n" - " 0 [+16] Flag[16] bar\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_oversized_anonymous_bit_field(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+4] bits:\n" - " 0 [+8] UInt field\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_undersized_anonymous_bit_field(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+32] UInt field\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error( - "m.emb", error_type.source_location, - "Fixed-size anonymous type cannot be placed in field of size 8 " - "bits; requires 32 bits.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_reserved_field_name(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+8] UInt restrict\n") - error_name = ir.module[0].type[0].structure.field[0].name.name - self.assertEqual([[ - error.error( - "m.emb", error_name.source_location, - "C reserved word may not be used as a field name.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_reserved_type_name(self): - ir = _make_ir_from_emb("struct False:\n" - " 0 [+1] UInt foo\n") - error_name = ir.module[0].type[0].name.name - self.assertEqual([[ - error.error( - "m.emb", error_name.source_location, - "Python 3 reserved word may not be used as a type name.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_reserved_enum_name(self): - ir = _make_ir_from_emb("enum Foo:\n" - " NULL = 1\n") - error_name = ir.module[0].type[0].enumeration.value[0].name.name - self.assertEqual([[ - error.error( - "m.emb", error_name.source_location, - "C reserved word may not be used as an enum name.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_bits_type_in_struct_array(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+10] UInt:8[10] array\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_bits_type_in_bits_array(self): - ir = _make_ir_from_emb("bits Foo:\n" - " 0 [+10] UInt:8[10] array\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_explicit_size_too_small(self): - ir = _make_ir_from_emb("bits Foo:\n" - " 0 [+0] UInt:0 zero_bit\n") - error_field = ir.module[0].type[0].structure.field[0] - uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir) - uint_requirements = ir_util.get_attribute(uint_type.attribute, - attributes.STATIC_REQUIREMENTS) - self.assertEqual([[ - error.error("m.emb", error_field.source_location, - "Requirements of UInt not met."), - error.note("", uint_requirements.source_location, - "Requirements specified here."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_enumeration_size_too_small(self): - ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n' - "bits Foo:\n" - " 0 [+0] Bar:0 zero_bit\n" - "enum Bar:\n" - " BAZ = 0\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error("m.emb", error_type.source_location, - "Enumeration type 'Bar' cannot be 0 bits; type 'Bar' " - "must be between 1 and 64 bits, inclusive."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_size_too_big_for_field(self): - ir = _make_ir_from_emb("bits Foo:\n" - " 0 [+8] UInt:32 thirty_two_bit\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error( - "m.emb", error_type.source_location, - "Fixed-size type 'UInt:32' cannot be placed in field of size 8 " - "bits; requires 32 bits.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_size_too_small_for_field(self): - ir = _make_ir_from_emb("bits Foo:\n" - " 0 [+64] UInt:32 thirty_two_bit\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error("m.emb", error_type.source_location, - "Fixed-size type 'UInt:32' cannot be placed in field of " - "size 64 bits; requires 32 bits.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_size_too_big(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+16] UInt:128 one_twenty_eight_bit\n" - ' [byte_order: "LittleEndian"]\n') - error_field = ir.module[0].type[0].structure.field[0] - uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir) - uint_requirements = ir_util.get_attribute(uint_type.attribute, - attributes.STATIC_REQUIREMENTS) - self.assertEqual([[ - error.error("m.emb", error_field.source_location, - "Requirements of UInt not met."), - error.note("", uint_requirements.source_location, - "Requirements specified here."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_enumeration_size_too_big(self): - ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n' - "struct Foo:\n" - " 0 [+9] Bar seventy_two_bit\n" - "enum Bar:\n" - " BAZ = 0\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error("m.emb", error_type.source_location, - "Enumeration type 'Bar' cannot be 72 bits; type 'Bar' " + - "must be between 1 and 64 bits, inclusive."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_enumeration_size_too_big_for_small_enum(self): - ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n' - "struct Foo:\n" - " 0 [+8] Bar sixty_four_bit\n" - "enum Bar:\n" - " [maximum_bits: 63]\n" - " BAZ = 0\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error("m.emb", error_type.source_location, - "Enumeration type 'Bar' cannot be 64 bits; type 'Bar' " + - "must be between 1 and 63 bits, inclusive."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_size_on_fixed_size_type(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] Byte:8 one_byte\n" - "struct Byte:\n" - " 0 [+1] UInt b\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_explicit_size_too_small_on_fixed_size_type(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+0] Byte:0 null_byte\n" - "struct Byte:\n" - " 0 [+1] UInt b\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error( - "m.emb", error_type.size_in_bits.source_location, - "Explicit size of 0 bits does not match fixed size (8 bits) of " - "type 'Byte'."), - error.note("m.emb", ir.module[0].type[1].source_location, - "Size specified here."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_size_too_big_on_fixed_size_type(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+2] Byte:16 double_byte\n" - "struct Byte:\n" - " 0 [+1] UInt b\n") - error_type = ir.module[0].type[0].structure.field[0].type - self.assertEqual([[ - error.error( - "m.emb", error_type.size_in_bits.source_location, - "Explicit size of 16 bits does not match fixed size (8 bits) of " - "type 'Byte'."), - error.note( - "m.emb", ir.module[0].type[1].source_location, - "Size specified here."), - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_size_ignored_on_variable_size_type(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] UInt n\n" - " 1 [+n] UInt:8[] d\n" - "struct Bar:\n" - " 0 [+10] Foo:80 foo\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_fixed_size_type_in_dynamically_sized_field(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt bar\n" - " 0 [+bar] Byte one_byte\n" - "struct Byte:\n" - " 0 [+1] UInt b\n") - self.assertEqual([], constraints.check_constraints(ir)) - - def test_enum_in_dynamically_sized_field(self): - ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n' - "struct Foo:\n" - " 0 [+1] UInt bar\n" - " 0 [+bar] Baz baz\n" - "enum Baz:\n" - " QUX = 0\n") - error_type = ir.module[0].type[0].structure.field[1].type - self.assertEqual( - [[ - error.error("m.emb", error_type.source_location, + """Tests constraints.check_constraints and helpers.""" + + def test_error_on_missing_inner_array_size(self): + ir = _make_ir_from_emb("struct Foo:\n" " 0 [+1] UInt:8[][1] one_byte\n") + # There is a latent issue here where the source location reported in this + # error is using a default value of 0:0. An issue is filed at + # https://github.com/google/emboss/issues/153 for further investigation. + # In the meantime we use `ir_data_utils.reader` to mimic this legacy + # behavior. + error_array = ir_data_utils.reader( + ir.module[0].type[0].structure.field[0].type.array_type + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_array.base_type.array_type.element_count.source_location, + "Array dimensions can only be omitted for the outermost dimension.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_no_error_on_ok_array_size(self): + ir = _make_ir_from_emb("struct Foo:\n" " 0 [+1] UInt:8[1][1] one_byte\n") + self.assertEqual([], constraints.check_constraints(ir)) + + def test_no_error_on_ok_missing_outer_array_size(self): + ir = _make_ir_from_emb("struct Foo:\n" " 0 [+1] UInt:8[1][] one_byte\n") + self.assertEqual([], constraints.check_constraints(ir)) + + def test_no_error_on_dynamically_sized_struct_in_dynamically_sized_field(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+1] UInt size\n" + " 1 [+size] Bar bar\n" + "struct Bar:\n" + " 0 [+1] UInt size\n" + " 1 [+size] UInt:8[] payload\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_no_error_on_dynamically_sized_struct_in_statically_sized_field(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+10] Bar bar\n" + "struct Bar:\n" + " 0 [+1] UInt size\n" + " 1 [+size] UInt:8[] payload\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_no_error_non_fixed_size_outer_array_dimension(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+1] UInt size\n" + " 1 [+size] UInt:8[1][size-1] one_byte\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_error_non_fixed_size_inner_array_dimension(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+1] UInt size\n" + " 1 [+size] UInt:8[size-1][1] one_byte\n" + ) + error_array = ir.module[0].type[0].structure.field[1].type.array_type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_array.base_type.array_type.element_count.source_location, + "Inner array dimensions must be constant.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_error_non_constant_inner_array_dimensions(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+1] Bar[1] one_byte\n" + # There is no dynamically-sized byte-oriented type in + # the Prelude, so this test has to make its own. + "external Bar:\n" + " [is_integer: true]\n" + " [addressable_unit_size: 8]\n" + ) + error_array = ir.module[0].type[0].structure.field[0].type.array_type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_array.base_type.atomic_type.source_location, + "Array elements must be fixed size.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_error_dynamically_sized_array_elements(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] Bar[1] bar\n" + "struct Bar:\n" + " 0 [+1] UInt size\n" + " 1 [+size] UInt:8[] payload\n" + ) + error_array = ir.module[0].type[0].structure.field[0].type.array_type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_array.base_type.atomic_type.source_location, + "Array elements must be fixed size.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_field_too_small_for_type(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] Bar bar\n" + "struct Bar:\n" + " 0 [+2] UInt value\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Fixed-size type 'Bar' cannot be placed in field of size 8 bits; " + "requires 16 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_dynamically_sized_field_always_too_small_for_type(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] bits:\n" + " 0 [+1] UInt x\n" + " 0 [+x] Bar bar\n" + "struct Bar:\n" + " 0 [+2] UInt value\n" + ) + error_type = ir.module[0].type[0].structure.field[2].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Field of maximum size 8 bits cannot hold fixed-size type 'Bar', " + "which requires 16 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_struct_field_too_big_for_type(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+2] Byte double_byte\n" + "struct Byte:\n" + " 0 [+1] UInt b\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Fixed-size type 'Byte' cannot be placed in field of size 16 bits; " + "requires 8 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_bits_field_too_big_for_type(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+9] UInt uint72\n" + ' [byte_order: "LittleEndian"]\n' + ) + error_field = ir.module[0].type[0].structure.field[0] + uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir) + uint_requirements = ir_util.get_attribute( + uint_type.attribute, attributes.STATIC_REQUIREMENTS + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_field.source_location, + "Requirements of UInt not met.", + ), + error.note( + "", + uint_requirements.source_location, + "Requirements specified here.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_field_type_not_allowed_in_bits(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "bits Foo:\n" + " 0 [+16] Bar bar\n" + "external Bar:\n" + " [addressable_unit_size: 8]\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Byte-oriented type 'Bar' cannot be used in a bits field.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_arrays_allowed_in_bits(self): + ir = _make_ir_from_emb("bits Foo:\n" " 0 [+16] Flag[16] bar\n") + self.assertEqual([], constraints.check_constraints(ir)) + + def test_oversized_anonymous_bit_field(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+4] bits:\n" + " 0 [+8] UInt field\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_undersized_anonymous_bit_field(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] bits:\n" + " 0 [+32] UInt field\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Fixed-size anonymous type cannot be placed in field of size 8 " + "bits; requires 32 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_reserved_field_name(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+8] UInt restrict\n" + ) + error_name = ir.module[0].type[0].structure.field[0].name.name + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_name.source_location, + "C reserved word may not be used as a field name.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_reserved_type_name(self): + ir = _make_ir_from_emb("struct False:\n" " 0 [+1] UInt foo\n") + error_name = ir.module[0].type[0].name.name + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_name.source_location, + "Python 3 reserved word may not be used as a type name.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_reserved_enum_name(self): + ir = _make_ir_from_emb("enum Foo:\n" " NULL = 1\n") + error_name = ir.module[0].type[0].enumeration.value[0].name.name + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_name.source_location, + "C reserved word may not be used as an enum name.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_bits_type_in_struct_array(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+10] UInt:8[10] array\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_bits_type_in_bits_array(self): + ir = _make_ir_from_emb("bits Foo:\n" " 0 [+10] UInt:8[10] array\n") + self.assertEqual([], constraints.check_constraints(ir)) + + def test_explicit_size_too_small(self): + ir = _make_ir_from_emb("bits Foo:\n" " 0 [+0] UInt:0 zero_bit\n") + error_field = ir.module[0].type[0].structure.field[0] + uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir) + uint_requirements = ir_util.get_attribute( + uint_type.attribute, attributes.STATIC_REQUIREMENTS + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_field.source_location, + "Requirements of UInt not met.", + ), + error.note( + "", + uint_requirements.source_location, + "Requirements specified here.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_enumeration_size_too_small(self): + ir = _make_ir_from_emb( + '[$default byte_order: "BigEndian"]\n' + "bits Foo:\n" + " 0 [+0] Bar:0 zero_bit\n" + "enum Bar:\n" + " BAZ = 0\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Enumeration type 'Bar' cannot be 0 bits; type 'Bar' " + "must be between 1 and 64 bits, inclusive.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_size_too_big_for_field(self): + ir = _make_ir_from_emb("bits Foo:\n" " 0 [+8] UInt:32 thirty_two_bit\n") + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Fixed-size type 'UInt:32' cannot be placed in field of size 8 " + "bits; requires 32 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_size_too_small_for_field(self): + ir = _make_ir_from_emb("bits Foo:\n" " 0 [+64] UInt:32 thirty_two_bit\n") + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Fixed-size type 'UInt:32' cannot be placed in field of " + "size 64 bits; requires 32 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_size_too_big(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+16] UInt:128 one_twenty_eight_bit\n" + ' [byte_order: "LittleEndian"]\n' + ) + error_field = ir.module[0].type[0].structure.field[0] + uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir) + uint_requirements = ir_util.get_attribute( + uint_type.attribute, attributes.STATIC_REQUIREMENTS + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_field.source_location, + "Requirements of UInt not met.", + ), + error.note( + "", + uint_requirements.source_location, + "Requirements specified here.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_enumeration_size_too_big(self): + ir = _make_ir_from_emb( + '[$default byte_order: "BigEndian"]\n' + "struct Foo:\n" + " 0 [+9] Bar seventy_two_bit\n" + "enum Bar:\n" + " BAZ = 0\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Enumeration type 'Bar' cannot be 72 bits; type 'Bar' " + + "must be between 1 and 64 bits, inclusive.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_enumeration_size_too_big_for_small_enum(self): + ir = _make_ir_from_emb( + '[$default byte_order: "BigEndian"]\n' + "struct Foo:\n" + " 0 [+8] Bar sixty_four_bit\n" + "enum Bar:\n" + " [maximum_bits: 63]\n" + " BAZ = 0\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "Enumeration type 'Bar' cannot be 64 bits; type 'Bar' " + + "must be between 1 and 63 bits, inclusive.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_size_on_fixed_size_type(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+1] Byte:8 one_byte\n" + "struct Byte:\n" + " 0 [+1] UInt b\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_explicit_size_too_small_on_fixed_size_type(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+0] Byte:0 null_byte\n" + "struct Byte:\n" + " 0 [+1] UInt b\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.size_in_bits.source_location, + "Explicit size of 0 bits does not match fixed size (8 bits) of " + "type 'Byte'.", + ), + error.note( + "m.emb", + ir.module[0].type[1].source_location, + "Size specified here.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_size_too_big_on_fixed_size_type(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+2] Byte:16 double_byte\n" + "struct Byte:\n" + " 0 [+1] UInt b\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.size_in_bits.source_location, + "Explicit size of 16 bits does not match fixed size (8 bits) of " + "type 'Byte'.", + ), + error.note( + "m.emb", + ir.module[0].type[1].source_location, + "Size specified here.", + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_size_ignored_on_variable_size_type(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] UInt n\n" + " 1 [+n] UInt:8[] d\n" + "struct Bar:\n" + " 0 [+10] Foo:80 foo\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_fixed_size_type_in_dynamically_sized_field(self): + ir = _make_ir_from_emb( + "struct Foo:\n" + " 0 [+1] UInt bar\n" + " 0 [+bar] Byte one_byte\n" + "struct Byte:\n" + " 0 [+1] UInt b\n" + ) + self.assertEqual([], constraints.check_constraints(ir)) + + def test_enum_in_dynamically_sized_field(self): + ir = _make_ir_from_emb( + '[$default byte_order: "BigEndian"]\n' + "struct Foo:\n" + " 0 [+1] UInt bar\n" + " 0 [+bar] Baz baz\n" + "enum Baz:\n" + " QUX = 0\n" + ) + error_type = ir.module[0].type[0].structure.field[1].type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, "Enumeration type 'Baz' cannot be placed in a " - "dynamically-sized field.") - ]], - error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_too_high(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " HIGH = 0x1_0000_0000_0000_0000\n") - error_value = ir.module[0].type[0].enumeration.value[0].value - self.assertEqual([ - [error.error( - "m.emb", error_value.source_location, - # TODO(bolms): Try to print numbers like 2**64 in hex? (I.e., if a - # number is a round number in hex, but not in decimal, print in - # hex?) - "Value 18446744073709551616 is out of range for 64-bit unsigned " + - "enumeration.")] - ], constraints.check_constraints(ir)) - - def test_enum_value_too_low(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " LOW = -0x8000_0000_0000_0001\n") - error_value = ir.module[0].type[0].enumeration.value[0].value - self.assertEqual([ - [error.error( - "m.emb", error_value.source_location, - "Value -9223372036854775809 is out of range for 64-bit signed " + - "enumeration.")] - ], constraints.check_constraints(ir)) - - def test_enum_value_too_wide(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " LOW = -1\n" - " HIGH = 0x8000_0000_0000_0000\n") - error_value = ir.module[0].type[0].enumeration.value[1].value - self.assertEqual([[ - error.error( - "m.emb", error_value.source_location, - "Value 9223372036854775808 is out of range for 64-bit signed " + - "enumeration.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_too_wide_unsigned_error_message(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " LOW = -2\n" - " LOW2 = -1\n" - " HIGH = 0x8000_0000_0000_0000\n") - error_value = ir.module[0].type[0].enumeration.value[2].value - self.assertEqual([[ - error.error( - "m.emb", error_value.source_location, - "Value 9223372036854775808 is out of range for 64-bit signed " + - "enumeration.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_too_wide_small_size_error_message(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " [maximum_bits: 8]\n" - " HIGH = 0x100\n") - error_value = ir.module[0].type[0].enumeration.value[0].value - self.assertEqual([[ - error.error( - "m.emb", error_value.source_location, - "Value 256 is out of range for 8-bit unsigned enumeration.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_too_wide_small_size_signed_error_message(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " [maximum_bits: 8]\n" - " [is_signed: true]\n" - " HIGH = 0x80\n") - error_value = ir.module[0].type[0].enumeration.value[0].value - self.assertEqual([[ - error.error( - "m.emb", error_value.source_location, - "Value 128 is out of range for 8-bit signed enumeration.") - ]], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_too_wide_multiple(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " LOW = -2\n" - " LOW2 = -1\n" - " HIGH = 0x8000_0000_0000_0000\n" - " HIGH2 = 0x8000_0000_0000_0001\n") - error_value = ir.module[0].type[0].enumeration.value[2].value - error_value2 = ir.module[0].type[0].enumeration.value[3].value - self.assertEqual([ - [error.error( - "m.emb", error_value.source_location, - "Value 9223372036854775808 is out of range for 64-bit signed " + - "enumeration.")], - [error.error( - "m.emb", error_value2.source_location, - "Value 9223372036854775809 is out of range for 64-bit signed " + - "enumeration.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_too_wide_multiple_signed_error_message(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " LOW = -3\n" - " LOW2 = -2\n" - " LOW3 = -1\n" - " HIGH = 0x8000_0000_0000_0000\n" - " HIGH2 = 0x8000_0000_0000_0001\n") - error_value = ir.module[0].type[0].enumeration.value[3].value - error_value2 = ir.module[0].type[0].enumeration.value[4].value - self.assertEqual([ - [error.error( - "m.emb", error_value.source_location, - "Value 9223372036854775808 is out of range for 64-bit signed " - "enumeration.")], - [error.error( - "m.emb", error_value2.source_location, - "Value 9223372036854775809 is out of range for 64-bit signed " - "enumeration.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_mixed_error_message(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " LOW = -1\n" - " HIGH = 0x8000_0000_0000_0000\n" - " HIGH2 = 0x1_0000_0000_0000_0000\n") - error_value1 = ir.module[0].type[0].enumeration.value[1].value - error_value2 = ir.module[0].type[0].enumeration.value[2].value - self.assertEqual([ - [error.error( - "m.emb", error_value1.source_location, - "Value 9223372036854775808 is out of range for 64-bit signed " + - "enumeration.")], - [error.error( - "m.emb", error_value2.source_location, - "Value 18446744073709551616 is out of range for 64-bit signed " + - "enumeration.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_explicitly_signed_error_message(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " [is_signed: true]\n" - " HIGH = 0x8000_0000_0000_0000\n" - " HIGH2 = 0x1_0000_0000_0000_0000\n") - error_value0 = ir.module[0].type[0].enumeration.value[0].value - error_value1 = ir.module[0].type[0].enumeration.value[1].value - self.assertEqual([ - [error.error( - "m.emb", error_value0.source_location, - "Value 9223372036854775808 is out of range for 64-bit signed " + - "enumeration.")], - [error.error( - "m.emb", error_value1.source_location, - "Value 18446744073709551616 is out of range for 64-bit signed " + - "enumeration.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_enum_value_explicitly_unsigned_error_message(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "enum Foo:\n" - " [is_signed: false]\n" - " LOW = -1\n" - " HIGH = 0x8000_0000_0000_0000\n" - " HIGH2 = 0x1_0000_0000_0000_0000\n") - error_value0 = ir.module[0].type[0].enumeration.value[0].value - error_value2 = ir.module[0].type[0].enumeration.value[2].value - self.assertEqual([ - [error.error( - "m.emb", error_value0.source_location, - "Value -1 is out of range for 64-bit unsigned enumeration.")], - [error.error( - "m.emb", error_value2.source_location, - "Value 18446744073709551616 is out of range for 64-bit unsigned " + - "enumeration.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_explicit_non_byte_size_array_element(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+2] UInt:4[4] nibbles\n") - error_type = ir.module[0].type[0].structure.field[0].type.array_type - self.assertEqual([ - [error.error( - "m.emb", error_type.base_type.source_location, - "Array elements in structs must have sizes which are a multiple of " - "8 bits.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_implicit_non_byte_size_array_element(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "bits Nibble:\n" - " 0 [+4] UInt nibble\n" - "struct Foo:\n" - " 0 [+2] Nibble[4] nibbles\n") - error_type = ir.module[0].type[1].structure.field[0].type.array_type - self.assertEqual([ - [error.error( - "m.emb", error_type.base_type.source_location, - "Array elements in structs must have sizes which are a multiple of " - "8 bits.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_bits_must_be_fixed_size(self): - ir = _make_ir_from_emb("bits Dynamic:\n" - " 0 [+3] UInt x\n" - " 3 [+3 * x] UInt:3[x] a\n") - error_type = ir.module[0].type[0] - self.assertEqual([ - [error.error("m.emb", error_type.source_location, - "`bits` types must be fixed size.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_bits_must_be_small(self): - ir = _make_ir_from_emb("bits Big:\n" - " 0 [+64] UInt x\n" - " 64 [+1] UInt y\n") - error_type = ir.module[0].type[0] - self.assertEqual([ - [error.error("m.emb", error_type.source_location, - "`bits` types must be 64 bits or smaller.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_constant_expressions_must_be_small(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+8] UInt x\n" - " if x < 0x1_0000_0000_0000_0000:\n" - " 8 [+1] UInt y\n") - condition = ir.module[0].type[0].structure.field[1].existence_condition - error_location = condition.function.args[1].source_location - self.assertEqual([ - [error.error( - "m.emb", error_location, - "Constant value {} of expression cannot fit in a 64-bit signed or " - "unsigned integer.".format(2**64))] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_variable_expression_out_of_range_for_uint64(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+8] UInt x\n" - " if x + 1 < 0xffff_ffff_ffff_ffff:\n" - " 8 [+1] UInt y\n") - condition = ir.module[0].type[0].structure.field[1].existence_condition - error_location = condition.function.args[0].source_location - self.assertEqual([ - [error.error( - "m.emb", error_location, - "Potential range of expression is {} to {}, which cannot fit in a " - "64-bit signed or unsigned integer.".format(1, 2**64))] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_variable_expression_out_of_range_for_int64(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+8] UInt x\n" - " if x - 0x8000_0000_0000_0001 < 0:\n" - " 8 [+1] UInt y\n") - condition = ir.module[0].type[0].structure.field[1].existence_condition - error_location = condition.function.args[0].source_location - self.assertEqual([ - [error.error( - "m.emb", error_location, - "Potential range of expression is {} to {}, which cannot fit in a " - "64-bit signed or unsigned integer.".format(-(2**63) - 1, - 2**63 - 2))] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_requires_expression_out_of_range_for_uint64(self): - ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+8] UInt x\n" - " [requires: this * 2 < 0x1_0000]\n") - attribute_list = ir.module[0].type[0].structure.field[0].attribute - error_arg = attribute_list[0].value.expression.function.args[0] - error_location = error_arg.source_location - self.assertEqual( - [[ - error.error( - "m.emb", error_location, - "Potential range of expression is {} to {}, which cannot fit " - "in a 64-bit signed or unsigned integer.".format(0, 2**65-2)) - ]], - error.filter_errors(constraints.check_constraints(ir))) - - def test_arguments_require_different_signedness_64_bits(self): - ir = _make_ir_from_emb( - '[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] UInt x\n" - # Left side requires uint64, right side requires int64. - " if (x + 0x8000_0000_0000_0000) + (x - 0x7fff_ffff_ffff_ffff) < 10:\n" - " 1 [+1] UInt y\n") - condition = ir.module[0].type[0].structure.field[1].existence_condition - error_expression = condition.function.args[0] - error_location = error_expression.source_location - arg0_location = error_expression.function.args[0].source_location - arg1_location = error_expression.function.args[1].source_location - self.assertEqual([ - [error.error( - "m.emb", error_location, - "Either all arguments to '+' and its result must fit in a 64-bit " - "unsigned integer, or all must fit in a 64-bit signed integer."), - error.note("m.emb", arg0_location, - "Requires unsigned 64-bit integer."), - error.note("m.emb", arg1_location, - "Requires signed 64-bit integer.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_return_value_requires_different_signedness_from_arguments(self): - ir = _make_ir_from_emb( - '[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] UInt x\n" - # Both arguments require uint64; result fits in int64. - " if (x + 0x7fff_ffff_ffff_ffff) - 0x8000_0000_0000_0000 < 10:\n" - " 1 [+1] UInt y\n") - condition = ir.module[0].type[0].structure.field[1].existence_condition - error_expression = condition.function.args[0] - error_location = error_expression.source_location - arg0_location = error_expression.function.args[0].source_location - arg1_location = error_expression.function.args[1].source_location - self.assertEqual([ - [error.error( - "m.emb", error_location, - "Either all arguments to '-' and its result must fit in a 64-bit " - "unsigned integer, or all must fit in a 64-bit signed integer."), - error.note("m.emb", arg0_location, - "Requires unsigned 64-bit integer."), - error.note("m.emb", arg1_location, - "Requires unsigned 64-bit integer."), - error.note("m.emb", error_location, - "Requires signed 64-bit integer.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_return_value_requires_different_signedness_from_one_argument(self): - ir = _make_ir_from_emb( - '[$default byte_order: "LittleEndian"]\n' - "struct Foo:\n" - " 0 [+1] UInt x\n" - # One argument requires uint64; result fits in int64. - " if (x + 0x7fff_ffff_ffff_fff0) - 0x7fff_ffff_ffff_ffff < 10:\n" - " 1 [+1] UInt y\n") - condition = ir.module[0].type[0].structure.field[1].existence_condition - error_expression = condition.function.args[0] - error_location = error_expression.source_location - arg0_location = error_expression.function.args[0].source_location - self.assertEqual([ - [error.error( - "m.emb", error_location, - "Either all arguments to '-' and its result must fit in a 64-bit " - "unsigned integer, or all must fit in a 64-bit signed integer."), - error.note("m.emb", arg0_location, - "Requires unsigned 64-bit integer."), - error.note("m.emb", error_location, - "Requires signed 64-bit integer.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_checks_constancy_of_constant_references(self): - ir = _make_ir_from_emb("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = x\n" - " let z = Foo.y\n") - error_expression = ir.module[0].type[0].structure.field[2].read_transform - error_location = error_expression.source_location - note_field = ir.module[0].type[0].structure.field[1] - note_location = note_field.source_location - self.assertEqual([ - [error.error("m.emb", error_location, - "Static references must refer to constants."), - error.note("m.emb", note_location, "y is not constant.")] - ], error.filter_errors(constraints.check_constraints(ir))) - - def test_checks_for_explicit_size_on_parameters(self): - ir = _make_ir_from_emb("struct Foo(y: UInt):\n" - " 0 [+1] UInt x\n") - error_parameter = ir.module[0].type[0].runtime_parameter[0] - error_location = error_parameter.physical_type_alias.source_location - self.assertEqual( - [[error.error("m.emb", error_location, - "Integer range of parameter must not be unbounded; it " - "must fit in a 64-bit signed or unsigned integer.")]], - error.filter_errors(constraints.check_constraints(ir))) - - def test_checks_for_correct_explicit_size_on_parameters(self): - ir = _make_ir_from_emb("struct Foo(y: UInt:300):\n" - " 0 [+1] UInt x\n") - error_parameter = ir.module[0].type[0].runtime_parameter[0] - error_location = error_parameter.physical_type_alias.source_location - self.assertEqual( - [[error.error("m.emb", error_location, - "Potential range of parameter is 0 to {}, which cannot " - "fit in a 64-bit signed or unsigned integer.".format( - 2**300-1))]], - error.filter_errors(constraints.check_constraints(ir))) - - def test_checks_for_explicit_enum_size_on_parameters(self): - ir = _make_ir_from_emb("struct Foo(y: Bar:8):\n" - " 0 [+1] UInt x\n" - "enum Bar:\n" - " QUX = 1\n") - error_parameter = ir.module[0].type[0].runtime_parameter[0] - error_size = error_parameter.physical_type_alias.size_in_bits - error_location = error_size.source_location - self.assertEqual( - [[error.error( - "m.emb", error_location, - "Parameters with enum type may not have explicit size.")]], - error.filter_errors(constraints.check_constraints(ir))) + "dynamically-sized field.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_too_high(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " HIGH = 0x1_0000_0000_0000_0000\n" + ) + error_value = ir.module[0].type[0].enumeration.value[0].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + # TODO(bolms): Try to print numbers like 2**64 in hex? (I.e., if a + # number is a round number in hex, but not in decimal, print in + # hex?) + "Value 18446744073709551616 is out of range for 64-bit unsigned " + + "enumeration.", + ) + ] + ], + constraints.check_constraints(ir), + ) + + def test_enum_value_too_low(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " LOW = -0x8000_0000_0000_0001\n" + ) + error_value = ir.module[0].type[0].enumeration.value[0].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + "Value -9223372036854775809 is out of range for 64-bit signed " + + "enumeration.", + ) + ] + ], + constraints.check_constraints(ir), + ) + + def test_enum_value_too_wide(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " LOW = -1\n" + " HIGH = 0x8000_0000_0000_0000\n" + ) + error_value = ir.module[0].type[0].enumeration.value[1].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + "Value 9223372036854775808 is out of range for 64-bit signed " + + "enumeration.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_too_wide_unsigned_error_message(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " LOW = -2\n" + " LOW2 = -1\n" + " HIGH = 0x8000_0000_0000_0000\n" + ) + error_value = ir.module[0].type[0].enumeration.value[2].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + "Value 9223372036854775808 is out of range for 64-bit signed " + + "enumeration.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_too_wide_small_size_error_message(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " [maximum_bits: 8]\n" + " HIGH = 0x100\n" + ) + error_value = ir.module[0].type[0].enumeration.value[0].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + "Value 256 is out of range for 8-bit unsigned enumeration.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_too_wide_small_size_signed_error_message(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " [maximum_bits: 8]\n" + " [is_signed: true]\n" + " HIGH = 0x80\n" + ) + error_value = ir.module[0].type[0].enumeration.value[0].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + "Value 128 is out of range for 8-bit signed enumeration.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_too_wide_multiple(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " LOW = -2\n" + " LOW2 = -1\n" + " HIGH = 0x8000_0000_0000_0000\n" + " HIGH2 = 0x8000_0000_0000_0001\n" + ) + error_value = ir.module[0].type[0].enumeration.value[2].value + error_value2 = ir.module[0].type[0].enumeration.value[3].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + "Value 9223372036854775808 is out of range for 64-bit signed " + + "enumeration.", + ) + ], + [ + error.error( + "m.emb", + error_value2.source_location, + "Value 9223372036854775809 is out of range for 64-bit signed " + + "enumeration.", + ) + ], + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_too_wide_multiple_signed_error_message(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " LOW = -3\n" + " LOW2 = -2\n" + " LOW3 = -1\n" + " HIGH = 0x8000_0000_0000_0000\n" + " HIGH2 = 0x8000_0000_0000_0001\n" + ) + error_value = ir.module[0].type[0].enumeration.value[3].value + error_value2 = ir.module[0].type[0].enumeration.value[4].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value.source_location, + "Value 9223372036854775808 is out of range for 64-bit signed " + "enumeration.", + ) + ], + [ + error.error( + "m.emb", + error_value2.source_location, + "Value 9223372036854775809 is out of range for 64-bit signed " + "enumeration.", + ) + ], + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_mixed_error_message(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " LOW = -1\n" + " HIGH = 0x8000_0000_0000_0000\n" + " HIGH2 = 0x1_0000_0000_0000_0000\n" + ) + error_value1 = ir.module[0].type[0].enumeration.value[1].value + error_value2 = ir.module[0].type[0].enumeration.value[2].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value1.source_location, + "Value 9223372036854775808 is out of range for 64-bit signed " + + "enumeration.", + ) + ], + [ + error.error( + "m.emb", + error_value2.source_location, + "Value 18446744073709551616 is out of range for 64-bit signed " + + "enumeration.", + ) + ], + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_explicitly_signed_error_message(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " [is_signed: true]\n" + " HIGH = 0x8000_0000_0000_0000\n" + " HIGH2 = 0x1_0000_0000_0000_0000\n" + ) + error_value0 = ir.module[0].type[0].enumeration.value[0].value + error_value1 = ir.module[0].type[0].enumeration.value[1].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value0.source_location, + "Value 9223372036854775808 is out of range for 64-bit signed " + + "enumeration.", + ) + ], + [ + error.error( + "m.emb", + error_value1.source_location, + "Value 18446744073709551616 is out of range for 64-bit signed " + + "enumeration.", + ) + ], + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_enum_value_explicitly_unsigned_error_message(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "enum Foo:\n" + " [is_signed: false]\n" + " LOW = -1\n" + " HIGH = 0x8000_0000_0000_0000\n" + " HIGH2 = 0x1_0000_0000_0000_0000\n" + ) + error_value0 = ir.module[0].type[0].enumeration.value[0].value + error_value2 = ir.module[0].type[0].enumeration.value[2].value + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_value0.source_location, + "Value -1 is out of range for 64-bit unsigned enumeration.", + ) + ], + [ + error.error( + "m.emb", + error_value2.source_location, + "Value 18446744073709551616 is out of range for 64-bit unsigned " + + "enumeration.", + ) + ], + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_explicit_non_byte_size_array_element(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+2] UInt:4[4] nibbles\n" + ) + error_type = ir.module[0].type[0].structure.field[0].type.array_type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.base_type.source_location, + "Array elements in structs must have sizes which are a multiple of " + "8 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_implicit_non_byte_size_array_element(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "bits Nibble:\n" + " 0 [+4] UInt nibble\n" + "struct Foo:\n" + " 0 [+2] Nibble[4] nibbles\n" + ) + error_type = ir.module[0].type[1].structure.field[0].type.array_type + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.base_type.source_location, + "Array elements in structs must have sizes which are a multiple of " + "8 bits.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_bits_must_be_fixed_size(self): + ir = _make_ir_from_emb( + "bits Dynamic:\n" + " 0 [+3] UInt x\n" + " 3 [+3 * x] UInt:3[x] a\n" + ) + error_type = ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "`bits` types must be fixed size.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_bits_must_be_small(self): + ir = _make_ir_from_emb( + "bits Big:\n" " 0 [+64] UInt x\n" " 64 [+1] UInt y\n" + ) + error_type = ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_type.source_location, + "`bits` types must be 64 bits or smaller.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_constant_expressions_must_be_small(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+8] UInt x\n" + " if x < 0x1_0000_0000_0000_0000:\n" + " 8 [+1] UInt y\n" + ) + condition = ir.module[0].type[0].structure.field[1].existence_condition + error_location = condition.function.args[1].source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Constant value {} of expression cannot fit in a 64-bit signed or " + "unsigned integer.".format(2**64), + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_variable_expression_out_of_range_for_uint64(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+8] UInt x\n" + " if x + 1 < 0xffff_ffff_ffff_ffff:\n" + " 8 [+1] UInt y\n" + ) + condition = ir.module[0].type[0].structure.field[1].existence_condition + error_location = condition.function.args[0].source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Potential range of expression is {} to {}, which cannot fit in a " + "64-bit signed or unsigned integer.".format(1, 2**64), + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_variable_expression_out_of_range_for_int64(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+8] UInt x\n" + " if x - 0x8000_0000_0000_0001 < 0:\n" + " 8 [+1] UInt y\n" + ) + condition = ir.module[0].type[0].structure.field[1].existence_condition + error_location = condition.function.args[0].source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Potential range of expression is {} to {}, which cannot fit in a " + "64-bit signed or unsigned integer.".format( + -(2**63) - 1, 2**63 - 2 + ), + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_requires_expression_out_of_range_for_uint64(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+8] UInt x\n" + " [requires: this * 2 < 0x1_0000]\n" + ) + attribute_list = ir.module[0].type[0].structure.field[0].attribute + error_arg = attribute_list[0].value.expression.function.args[0] + error_location = error_arg.source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Potential range of expression is {} to {}, which cannot fit " + "in a 64-bit signed or unsigned integer.".format(0, 2**65 - 2), + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_arguments_require_different_signedness_64_bits(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] UInt x\n" + # Left side requires uint64, right side requires int64. + " if (x + 0x8000_0000_0000_0000) + (x - 0x7fff_ffff_ffff_ffff) < 10:\n" + " 1 [+1] UInt y\n" + ) + condition = ir.module[0].type[0].structure.field[1].existence_condition + error_expression = condition.function.args[0] + error_location = error_expression.source_location + arg0_location = error_expression.function.args[0].source_location + arg1_location = error_expression.function.args[1].source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Either all arguments to '+' and its result must fit in a 64-bit " + "unsigned integer, or all must fit in a 64-bit signed integer.", + ), + error.note( + "m.emb", arg0_location, "Requires unsigned 64-bit integer." + ), + error.note( + "m.emb", arg1_location, "Requires signed 64-bit integer." + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_return_value_requires_different_signedness_from_arguments(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] UInt x\n" + # Both arguments require uint64; result fits in int64. + " if (x + 0x7fff_ffff_ffff_ffff) - 0x8000_0000_0000_0000 < 10:\n" + " 1 [+1] UInt y\n" + ) + condition = ir.module[0].type[0].structure.field[1].existence_condition + error_expression = condition.function.args[0] + error_location = error_expression.source_location + arg0_location = error_expression.function.args[0].source_location + arg1_location = error_expression.function.args[1].source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Either all arguments to '-' and its result must fit in a 64-bit " + "unsigned integer, or all must fit in a 64-bit signed integer.", + ), + error.note( + "m.emb", arg0_location, "Requires unsigned 64-bit integer." + ), + error.note( + "m.emb", arg1_location, "Requires unsigned 64-bit integer." + ), + error.note( + "m.emb", error_location, "Requires signed 64-bit integer." + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_return_value_requires_different_signedness_from_one_argument(self): + ir = _make_ir_from_emb( + '[$default byte_order: "LittleEndian"]\n' + "struct Foo:\n" + " 0 [+1] UInt x\n" + # One argument requires uint64; result fits in int64. + " if (x + 0x7fff_ffff_ffff_fff0) - 0x7fff_ffff_ffff_ffff < 10:\n" + " 1 [+1] UInt y\n" + ) + condition = ir.module[0].type[0].structure.field[1].existence_condition + error_expression = condition.function.args[0] + error_location = error_expression.source_location + arg0_location = error_expression.function.args[0].source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Either all arguments to '-' and its result must fit in a 64-bit " + "unsigned integer, or all must fit in a 64-bit signed integer.", + ), + error.note( + "m.emb", arg0_location, "Requires unsigned 64-bit integer." + ), + error.note( + "m.emb", error_location, "Requires signed 64-bit integer." + ), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_checks_constancy_of_constant_references(self): + ir = _make_ir_from_emb( + "struct Foo:\n" " 0 [+1] UInt x\n" " let y = x\n" " let z = Foo.y\n" + ) + error_expression = ir.module[0].type[0].structure.field[2].read_transform + error_location = error_expression.source_location + note_field = ir.module[0].type[0].structure.field[1] + note_location = note_field.source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Static references must refer to constants.", + ), + error.note("m.emb", note_location, "y is not constant."), + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_checks_for_explicit_size_on_parameters(self): + ir = _make_ir_from_emb("struct Foo(y: UInt):\n" " 0 [+1] UInt x\n") + error_parameter = ir.module[0].type[0].runtime_parameter[0] + error_location = error_parameter.physical_type_alias.source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Integer range of parameter must not be unbounded; it " + "must fit in a 64-bit signed or unsigned integer.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_checks_for_correct_explicit_size_on_parameters(self): + ir = _make_ir_from_emb("struct Foo(y: UInt:300):\n" " 0 [+1] UInt x\n") + error_parameter = ir.module[0].type[0].runtime_parameter[0] + error_location = error_parameter.physical_type_alias.source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Potential range of parameter is 0 to {}, which cannot " + "fit in a 64-bit signed or unsigned integer.".format( + 2**300 - 1 + ), + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) + + def test_checks_for_explicit_enum_size_on_parameters(self): + ir = _make_ir_from_emb( + "struct Foo(y: Bar:8):\n" " 0 [+1] UInt x\n" "enum Bar:\n" " QUX = 1\n" + ) + error_parameter = ir.module[0].type[0].runtime_parameter[0] + error_size = error_parameter.physical_type_alias.size_in_bits + error_location = error_size.source_location + self.assertEqual( + [ + [ + error.error( + "m.emb", + error_location, + "Parameters with enum type may not have explicit size.", + ) + ] + ], + error.filter_errors(constraints.check_constraints(ir)), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/dependency_checker.py b/compiler/front_end/dependency_checker.py index 963d8bf..8a9e903 100644 --- a/compiler/front_end/dependency_checker.py +++ b/compiler/front_end/dependency_checker.py @@ -20,261 +20,302 @@ from compiler.util import traverse_ir -def _add_reference_to_dependencies(reference, dependencies, name, - source_file_name, errors): - if reference.canonical_name.object_path[0] in {"$is_statically_sized", - "$static_size_in_bits", - "$next"}: - # This error is a bit opaque, but given that the compiler used to crash on - # this case -- for a couple of years -- and no one complained, it seems - # safe to assume that this is a rare error. - errors.append([ - error.error(source_file_name, reference.source_location, - "Keyword `" + reference.canonical_name.object_path[0] + - "` may not be used in this context."), - ]) - return - dependencies[name] |= {ir_util.hashable_form_of_reference(reference)} +def _add_reference_to_dependencies( + reference, dependencies, name, source_file_name, errors +): + if reference.canonical_name.object_path[0] in { + "$is_statically_sized", + "$static_size_in_bits", + "$next", + }: + # This error is a bit opaque, but given that the compiler used to crash on + # this case -- for a couple of years -- and no one complained, it seems + # safe to assume that this is a rare error. + errors.append( + [ + error.error( + source_file_name, + reference.source_location, + "Keyword `" + + reference.canonical_name.object_path[0] + + "` may not be used in this context.", + ), + ] + ) + return + dependencies[name] |= {ir_util.hashable_form_of_reference(reference)} def _add_field_reference_to_dependencies(reference, dependencies, name): - dependencies[name] |= {ir_util.hashable_form_of_reference(reference.path[0])} + dependencies[name] |= {ir_util.hashable_form_of_reference(reference.path[0])} def _add_name_to_dependencies(proto, dependencies): - name = ir_util.hashable_form_of_reference(proto.name) - dependencies.setdefault(name, set()) - return {"name": name} + name = ir_util.hashable_form_of_reference(proto.name) + dependencies.setdefault(name, set()) + return {"name": name} def _find_dependencies(ir): - """Constructs a dependency graph for the entire IR.""" - dependencies = {} - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Reference], _add_reference_to_dependencies, - # TODO(bolms): Add handling for references inside of attributes, once - # there are attributes with non-constant values. - skip_descendants_of={ - ir_data.AtomicType, ir_data.Attribute, ir_data.FieldReference - }, - incidental_actions={ - ir_data.Field: _add_name_to_dependencies, - ir_data.EnumValue: _add_name_to_dependencies, - ir_data.RuntimeParameter: _add_name_to_dependencies, - }, - parameters={ - "dependencies": dependencies, - "errors": errors, - }) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.FieldReference], _add_field_reference_to_dependencies, - skip_descendants_of={ir_data.Attribute}, - incidental_actions={ - ir_data.Field: _add_name_to_dependencies, - ir_data.EnumValue: _add_name_to_dependencies, - ir_data.RuntimeParameter: _add_name_to_dependencies, - }, - parameters={"dependencies": dependencies}) - return dependencies, errors + """Constructs a dependency graph for the entire IR.""" + dependencies = {} + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Reference], + _add_reference_to_dependencies, + # TODO(bolms): Add handling for references inside of attributes, once + # there are attributes with non-constant values. + skip_descendants_of={ + ir_data.AtomicType, + ir_data.Attribute, + ir_data.FieldReference, + }, + incidental_actions={ + ir_data.Field: _add_name_to_dependencies, + ir_data.EnumValue: _add_name_to_dependencies, + ir_data.RuntimeParameter: _add_name_to_dependencies, + }, + parameters={ + "dependencies": dependencies, + "errors": errors, + }, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.FieldReference], + _add_field_reference_to_dependencies, + skip_descendants_of={ir_data.Attribute}, + incidental_actions={ + ir_data.Field: _add_name_to_dependencies, + ir_data.EnumValue: _add_name_to_dependencies, + ir_data.RuntimeParameter: _add_name_to_dependencies, + }, + parameters={"dependencies": dependencies}, + ) + return dependencies, errors def _find_dependency_ordering_for_fields_in_structure( - structure, type_definition, dependencies): - """Populates structure.fields_in_dependency_order.""" - # For fields which appear before their dependencies in the original source - # text, this algorithm moves them to immediately after their dependencies. - # - # This is one of many possible schemes for constructing a dependency ordering; - # it has the advantage that all of the generated fields (e.g., $size_in_bytes) - # stay at the end of the ordering, which makes testing easier. - order = [] - added = set() - for parameter in type_definition.runtime_parameter: - added.add(ir_util.hashable_form_of_reference(parameter.name)) - needed = list(range(len(structure.field))) - while True: - for i in range(len(needed)): - field_number = needed[i] - field = ir_util.hashable_form_of_reference( - structure.field[field_number].name) - assert field in dependencies, "dependencies = {}".format(dependencies) - if all(dependency in added for dependency in dependencies[field]): - order.append(field_number) - added.add(field) - del needed[i] - break - else: - break - # If a non-local-field dependency were in dependencies[field], then not all - # fields would be added to the dependency ordering. This shouldn't happen. - assert len(order) == len(structure.field), ( - "order: {}\nlen(structure.field: {})".format(order, len(structure.field))) - del structure.fields_in_dependency_order[:] - structure.fields_in_dependency_order.extend(order) + structure, type_definition, dependencies +): + """Populates structure.fields_in_dependency_order.""" + # For fields which appear before their dependencies in the original source + # text, this algorithm moves them to immediately after their dependencies. + # + # This is one of many possible schemes for constructing a dependency ordering; + # it has the advantage that all of the generated fields (e.g., $size_in_bytes) + # stay at the end of the ordering, which makes testing easier. + order = [] + added = set() + for parameter in type_definition.runtime_parameter: + added.add(ir_util.hashable_form_of_reference(parameter.name)) + needed = list(range(len(structure.field))) + while True: + for i in range(len(needed)): + field_number = needed[i] + field = ir_util.hashable_form_of_reference( + structure.field[field_number].name + ) + assert field in dependencies, "dependencies = {}".format(dependencies) + if all(dependency in added for dependency in dependencies[field]): + order.append(field_number) + added.add(field) + del needed[i] + break + else: + break + # If a non-local-field dependency were in dependencies[field], then not all + # fields would be added to the dependency ordering. This shouldn't happen. + assert len(order) == len( + structure.field + ), "order: {}\nlen(structure.field: {})".format(order, len(structure.field)) + del structure.fields_in_dependency_order[:] + structure.fields_in_dependency_order.extend(order) def _find_dependency_ordering_for_fields(ir): - """Populates the fields_in_dependency_order fields throughout ir.""" - dependencies = {} - # TODO(bolms): This duplicates work in _find_dependencies that could be - # shared. - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.FieldReference], _add_field_reference_to_dependencies, - skip_descendants_of={ir_data.Attribute}, - incidental_actions={ - ir_data.Field: _add_name_to_dependencies, - ir_data.EnumValue: _add_name_to_dependencies, - ir_data.RuntimeParameter: _add_name_to_dependencies, - }, - parameters={"dependencies": dependencies}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure], - _find_dependency_ordering_for_fields_in_structure, - parameters={"dependencies": dependencies}) + """Populates the fields_in_dependency_order fields throughout ir.""" + dependencies = {} + # TODO(bolms): This duplicates work in _find_dependencies that could be + # shared. + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.FieldReference], + _add_field_reference_to_dependencies, + skip_descendants_of={ir_data.Attribute}, + incidental_actions={ + ir_data.Field: _add_name_to_dependencies, + ir_data.EnumValue: _add_name_to_dependencies, + ir_data.RuntimeParameter: _add_name_to_dependencies, + }, + parameters={"dependencies": dependencies}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Structure], + _find_dependency_ordering_for_fields_in_structure, + parameters={"dependencies": dependencies}, + ) def _find_module_import_dependencies(ir): - """Constructs a dependency graph of module imports.""" - dependencies = {} - for module in ir.module: - foreign_imports = set() - for foreign_import in module.foreign_import: - # The prelude gets an automatic self-import that shouldn't cause any - # problems. No other self-imports are allowed, however. - if foreign_import.file_name.text or module.source_file_name: - foreign_imports |= {(foreign_import.file_name.text,)} - dependencies[module.source_file_name,] = foreign_imports - return dependencies + """Constructs a dependency graph of module imports.""" + dependencies = {} + for module in ir.module: + foreign_imports = set() + for foreign_import in module.foreign_import: + # The prelude gets an automatic self-import that shouldn't cause any + # problems. No other self-imports are allowed, however. + if foreign_import.file_name.text or module.source_file_name: + foreign_imports |= {(foreign_import.file_name.text,)} + dependencies[module.source_file_name,] = foreign_imports + return dependencies def _find_cycles(graph): - """Finds cycles in graph. - - The graph does not need to be fully connected. - - Arguments: - graph: A dictionary whose keys are node labels. Values are sets of node - labels, representing edges from the key node to the value nodes. - - Returns: - A set of sets of nodes which form strongly-connected components (subgraphs - where every node is directly or indirectly reachable from every other node). - No node will be included in more than one strongly-connected component, by - definition. Strongly-connected components of size 1, where the node in the - component does not have a self-edge, are not included in the result. - - Note that a strongly-connected component may have a more complex structure - than a single loop. For example: - - +-- A <-+ +-> B --+ - | | | | - v C v - D ^ ^ E - | | | | - +-> F --+ +-- G <-+ - """ - # This uses Tarjan's strongly-connected components algorithm, as described by - # Wikipedia. This is a depth-first traversal of the graph with a node stack - # that is independent of the call stack; nodes are added to the stack when - # they are first encountered, but not removed until all nodes they can reach - # have been checked. - next_index = [0] - node_indices = {} - node_lowlinks = {} - nodes_on_stack = set() - stack = [] - nontrivial_components = set() - - def strong_connect(node): - """Implements the STRONGCONNECT routine of Tarjan's algorithm.""" - node_indices[node] = next_index[0] - node_lowlinks[node] = next_index[0] - next_index[0] += 1 - stack.append(node) - nodes_on_stack.add(node) - - for destination_node in graph[node]: - if destination_node not in node_indices: - strong_connect(destination_node) - node_lowlinks[node] = min(node_lowlinks[node], - node_lowlinks[destination_node]) - elif destination_node in nodes_on_stack: - node_lowlinks[node] = min(node_lowlinks[node], - node_indices[destination_node]) - - strongly_connected_component = [] - if node_lowlinks[node] == node_indices[node]: - while True: - popped_node = stack.pop() - nodes_on_stack.remove(popped_node) - strongly_connected_component.append(popped_node) - if popped_node == node: - break - if (len(strongly_connected_component) > 1 or - strongly_connected_component[0] in - graph[strongly_connected_component[0]]): - nontrivial_components.add(frozenset(strongly_connected_component)) - - for node in graph: - if node not in node_indices: - strong_connect(node) - return nontrivial_components + """Finds cycles in graph. + + The graph does not need to be fully connected. + + Arguments: + graph: A dictionary whose keys are node labels. Values are sets of node + labels, representing edges from the key node to the value nodes. + + Returns: + A set of sets of nodes which form strongly-connected components (subgraphs + where every node is directly or indirectly reachable from every other node). + No node will be included in more than one strongly-connected component, by + definition. Strongly-connected components of size 1, where the node in the + component does not have a self-edge, are not included in the result. + + Note that a strongly-connected component may have a more complex structure + than a single loop. For example: + + +-- A <-+ +-> B --+ + | | | | + v C v + D ^ ^ E + | | | | + +-> F --+ +-- G <-+ + """ + # This uses Tarjan's strongly-connected components algorithm, as described by + # Wikipedia. This is a depth-first traversal of the graph with a node stack + # that is independent of the call stack; nodes are added to the stack when + # they are first encountered, but not removed until all nodes they can reach + # have been checked. + next_index = [0] + node_indices = {} + node_lowlinks = {} + nodes_on_stack = set() + stack = [] + nontrivial_components = set() + + def strong_connect(node): + """Implements the STRONGCONNECT routine of Tarjan's algorithm.""" + node_indices[node] = next_index[0] + node_lowlinks[node] = next_index[0] + next_index[0] += 1 + stack.append(node) + nodes_on_stack.add(node) + + for destination_node in graph[node]: + if destination_node not in node_indices: + strong_connect(destination_node) + node_lowlinks[node] = min( + node_lowlinks[node], node_lowlinks[destination_node] + ) + elif destination_node in nodes_on_stack: + node_lowlinks[node] = min( + node_lowlinks[node], node_indices[destination_node] + ) + + strongly_connected_component = [] + if node_lowlinks[node] == node_indices[node]: + while True: + popped_node = stack.pop() + nodes_on_stack.remove(popped_node) + strongly_connected_component.append(popped_node) + if popped_node == node: + break + if ( + len(strongly_connected_component) > 1 + or strongly_connected_component[0] + in graph[strongly_connected_component[0]] + ): + nontrivial_components.add(frozenset(strongly_connected_component)) + + for node in graph: + if node not in node_indices: + strong_connect(node) + return nontrivial_components def _find_object_dependency_cycles(ir): - """Finds dependency cycles in types in the ir.""" - dependencies, find_dependency_errors = _find_dependencies(ir) - if find_dependency_errors: - return find_dependency_errors - errors = [] - cycles = _find_cycles(dict(dependencies)) - for cycle in cycles: - # TODO(bolms): This lists the entire strongly-connected component in a - # fairly arbitrary order. This is simple, and handles components that - # aren't simple cycles, but may not be the most user-friendly way to - # present this information. - cycle_list = sorted(list(cycle)) - node_object = ir_util.find_object(cycle_list[0], ir) - error_group = [ - error.error(cycle_list[0][0], node_object.source_location, - "Dependency cycle\n" + node_object.name.name.text) - ] - for node in cycle_list[1:]: - node_object = ir_util.find_object(node, ir) - error_group.append(error.note(node[0], node_object.source_location, - node_object.name.name.text)) - errors.append(error_group) - return errors + """Finds dependency cycles in types in the ir.""" + dependencies, find_dependency_errors = _find_dependencies(ir) + if find_dependency_errors: + return find_dependency_errors + errors = [] + cycles = _find_cycles(dict(dependencies)) + for cycle in cycles: + # TODO(bolms): This lists the entire strongly-connected component in a + # fairly arbitrary order. This is simple, and handles components that + # aren't simple cycles, but may not be the most user-friendly way to + # present this information. + cycle_list = sorted(list(cycle)) + node_object = ir_util.find_object(cycle_list[0], ir) + error_group = [ + error.error( + cycle_list[0][0], + node_object.source_location, + "Dependency cycle\n" + node_object.name.name.text, + ) + ] + for node in cycle_list[1:]: + node_object = ir_util.find_object(node, ir) + error_group.append( + error.note( + node[0], node_object.source_location, node_object.name.name.text + ) + ) + errors.append(error_group) + return errors def _find_module_dependency_cycles(ir): - """Finds dependency cycles in modules in the ir.""" - dependencies = _find_module_import_dependencies(ir) - cycles = _find_cycles(dict(dependencies)) - errors = [] - for cycle in cycles: - cycle_list = sorted(list(cycle)) - module = ir_util.find_object(cycle_list[0], ir) - error_group = [ - error.error(cycle_list[0][0], module.source_location, - "Import dependency cycle\n" + module.source_file_name) - ] - for module_name in cycle_list[1:]: - module = ir_util.find_object(module_name, ir) - error_group.append(error.note(module_name[0], module.source_location, - module.source_file_name)) - errors.append(error_group) - return errors + """Finds dependency cycles in modules in the ir.""" + dependencies = _find_module_import_dependencies(ir) + cycles = _find_cycles(dict(dependencies)) + errors = [] + for cycle in cycles: + cycle_list = sorted(list(cycle)) + module = ir_util.find_object(cycle_list[0], ir) + error_group = [ + error.error( + cycle_list[0][0], + module.source_location, + "Import dependency cycle\n" + module.source_file_name, + ) + ] + for module_name in cycle_list[1:]: + module = ir_util.find_object(module_name, ir) + error_group.append( + error.note( + module_name[0], module.source_location, module.source_file_name + ) + ) + errors.append(error_group) + return errors def find_dependency_cycles(ir): - """Finds any dependency cycles in the ir.""" - errors = _find_module_dependency_cycles(ir) - return errors + _find_object_dependency_cycles(ir) + """Finds any dependency cycles in the ir.""" + errors = _find_module_dependency_cycles(ir) + return errors + _find_object_dependency_cycles(ir) def set_dependency_order(ir): - """Sets the fields_in_dependency_order member of Structures.""" - _find_dependency_ordering_for_fields(ir) - return [] + """Sets the fields_in_dependency_order member of Structures.""" + _find_dependency_ordering_for_fields(ir) + return [] diff --git a/compiler/front_end/dependency_checker_test.py b/compiler/front_end/dependency_checker_test.py index 27af812..a9110da 100644 --- a/compiler/front_end/dependency_checker_test.py +++ b/compiler/front_end/dependency_checker_test.py @@ -22,287 +22,408 @@ def _parse_snippet(emb_file): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": emb_file}), - stop_before_step="find_dependency_cycles") - assert not errors - return ir + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader({"m.emb": emb_file}), + stop_before_step="find_dependency_cycles", + ) + assert not errors + return ir def _find_dependencies_for_snippet(emb_file): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({ - "m.emb": emb_file - }), - stop_before_step="set_dependency_order") - assert not errors, errors - return ir + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader({"m.emb": emb_file}), + stop_before_step="set_dependency_order", + ) + assert not errors, errors + return ir class DependencyCheckerTest(unittest.TestCase): - def test_error_on_simple_field_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " 0 [+field2] UInt field1\n" - " 0 [+field1] UInt field2\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nfield1"), - error.note("m.emb", struct.field[1].source_location, "field2") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_self_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " 0 [+field1] UInt field1\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nfield1") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_triple_field_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " 0 [+field2] UInt field1\n" - " 0 [+field3] UInt field2\n" - " 0 [+field1] UInt field3\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nfield1"), - error.note("m.emb", struct.field[1].source_location, "field2"), - error.note("m.emb", struct.field[2].source_location, "field3"), - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_complex_field_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " 0 [+field2] UInt field1\n" - " 0 [+field3+field4] UInt field2\n" - " 0 [+field1] UInt field3\n" - " 0 [+field2] UInt field4\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nfield1"), - error.note("m.emb", struct.field[1].source_location, "field2"), - error.note("m.emb", struct.field[2].source_location, "field3"), - error.note("m.emb", struct.field[3].source_location, "field4"), - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_simple_enum_value_cycle(self): - ir = _parse_snippet("enum Foo:\n" - " XX = YY\n" - " YY = XX\n") - enum = ir.module[0].type[0].enumeration - self.assertEqual([[ - error.error("m.emb", enum.value[0].source_location, - "Dependency cycle\nXX"), - error.note("m.emb", enum.value[1].source_location, "YY") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_no_error_on_no_cycle(self): - ir = _parse_snippet("enum Foo:\n" - " XX = 0\n" - " YY = XX\n") - self.assertEqual([], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_cycle_nested(self): - ir = _parse_snippet("struct Foo:\n" - " struct Bar:\n" - " 0 [+field2] UInt field1\n" - " 0 [+field1] UInt field2\n" - " 0 [+1] UInt field\n") - struct = ir.module[0].type[0].subtype[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nfield1"), - error.note("m.emb", struct.field[1].source_location, "field2") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_import_cycle(self): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": 'import "n.emb" as n\n', - "n.emb": 'import "m.emb" as m\n'}), - stop_before_step="find_dependency_cycles") - assert not errors - self.assertEqual([[ - error.error("m.emb", ir.module[0].source_location, - "Import dependency cycle\nm.emb"), - error.note("n.emb", ir.module[2].source_location, "n.emb") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_import_cycle_and_field_cycle(self): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": 'import "n.emb" as n\n' - "struct Foo:\n" - " 0 [+field1] UInt field1\n", - "n.emb": 'import "m.emb" as m\n'}), - stop_before_step="find_dependency_cycles") - assert not errors - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", ir.module[0].source_location, - "Import dependency cycle\nm.emb"), - error.note("n.emb", ir.module[2].source_location, "n.emb") - ], [ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nfield1") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_field_existence_self_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " if x == 1:\n" - " 0 [+1] UInt x\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nx") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_field_existence_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " if y == 1:\n" - " 0 [+1] UInt x\n" - " if x == 0:\n" - " 1 [+1] UInt y\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nx"), - error.note("m.emb", struct.field[1].source_location, "y") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_virtual_field_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " let x = y\n" - " let y = x\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nx"), - error.note("m.emb", struct.field[1].source_location, "y") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_virtual_non_virtual_field_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " let x = y\n" - " x [+4] UInt y\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nx"), - error.note("m.emb", struct.field[1].source_location, "y") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_non_virtual_virtual_field_cycle(self): - ir = _parse_snippet("struct Foo:\n" - " y [+4] UInt x\n" - " let y = x\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nx"), - error.note("m.emb", struct.field[1].source_location, "y") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_error_on_cycle_involving_subfield(self): - ir = _parse_snippet("struct Bar:\n" - " foo_b.x [+4] Foo foo_a\n" - " foo_a.x [+4] Foo foo_b\n" - "struct Foo:\n" - " 0 [+4] UInt x\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].source_location, - "Dependency cycle\nfoo_a"), - error.note("m.emb", struct.field[1].source_location, "foo_b") - ]], dependency_checker.find_dependency_cycles(ir)) - - def test_dependency_ordering_with_no_dependencies(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " 0 [+4] UInt a\n" - " 4 [+4] UInt b\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([0, 1], struct.fields_in_dependency_order[:2]) - - def test_dependency_ordering_with_dependency_in_order(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " 0 [+4] UInt a\n" - " a [+4] UInt b\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([0, 1], struct.fields_in_dependency_order[:2]) - - def test_dependency_ordering_with_dependency_in_reverse_order(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " b [+4] UInt a\n" - " 0 [+4] UInt b\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([1, 0], struct.fields_in_dependency_order[:2]) - - def test_dependency_ordering_with_extra_fields(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " d [+4] UInt a\n" - " 4 [+4] UInt b\n" - " 8 [+4] UInt c\n" - " 12 [+4] UInt d\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([1, 2, 3, 0], struct.fields_in_dependency_order[:4]) - - def test_dependency_ordering_scrambled(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " d [+4] UInt a\n" - " c [+4] UInt b\n" - " a [+4] UInt c\n" - " 12 [+4] UInt d\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([3, 0, 2, 1], struct.fields_in_dependency_order[:4]) - - def test_dependency_ordering_multiple_dependents(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " d [+4] UInt a\n" - " d [+4] UInt b\n" - " d [+4] UInt c\n" - " 12 [+4] UInt d\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([3, 0, 1, 2], struct.fields_in_dependency_order[:4]) - - def test_dependency_ordering_multiple_dependencies(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " b+c [+4] UInt a\n" - " 4 [+4] UInt b\n" - " 8 [+4] UInt c\n" - " a [+4] UInt d\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([1, 2, 0, 3], struct.fields_in_dependency_order[:4]) - - def test_dependency_ordering_with_parameter(self): - ir = _find_dependencies_for_snippet("struct Foo:\n" - " 0 [+1] Bar(x) b\n" - " 1 [+1] UInt x\n" - "struct Bar(x: UInt:8):\n" - " x [+1] UInt y\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([1, 0], struct.fields_in_dependency_order[:2]) - - def test_dependency_ordering_with_local_parameter(self): - ir = _find_dependencies_for_snippet("struct Foo(x: Int:13):\n" - " 0 [+x] Int b\n") - self.assertEqual([], dependency_checker.set_dependency_order(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([0], struct.fields_in_dependency_order[:1]) + def test_error_on_simple_field_cycle(self): + ir = _parse_snippet( + "struct Foo:\n" + " 0 [+field2] UInt field1\n" + " 0 [+field1] UInt field2\n" + ) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[0].source_location, + "Dependency cycle\nfield1", + ), + error.note("m.emb", struct.field[1].source_location, "field2"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_self_cycle(self): + ir = _parse_snippet("struct Foo:\n" " 0 [+field1] UInt field1\n") + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[0].source_location, + "Dependency cycle\nfield1", + ) + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_triple_field_cycle(self): + ir = _parse_snippet( + "struct Foo:\n" + " 0 [+field2] UInt field1\n" + " 0 [+field3] UInt field2\n" + " 0 [+field1] UInt field3\n" + ) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[0].source_location, + "Dependency cycle\nfield1", + ), + error.note("m.emb", struct.field[1].source_location, "field2"), + error.note("m.emb", struct.field[2].source_location, "field3"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_complex_field_cycle(self): + ir = _parse_snippet( + "struct Foo:\n" + " 0 [+field2] UInt field1\n" + " 0 [+field3+field4] UInt field2\n" + " 0 [+field1] UInt field3\n" + " 0 [+field2] UInt field4\n" + ) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[0].source_location, + "Dependency cycle\nfield1", + ), + error.note("m.emb", struct.field[1].source_location, "field2"), + error.note("m.emb", struct.field[2].source_location, "field3"), + error.note("m.emb", struct.field[3].source_location, "field4"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_simple_enum_value_cycle(self): + ir = _parse_snippet("enum Foo:\n" " XX = YY\n" " YY = XX\n") + enum = ir.module[0].type[0].enumeration + self.assertEqual( + [ + [ + error.error( + "m.emb", enum.value[0].source_location, "Dependency cycle\nXX" + ), + error.note("m.emb", enum.value[1].source_location, "YY"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_no_error_on_no_cycle(self): + ir = _parse_snippet("enum Foo:\n" " XX = 0\n" " YY = XX\n") + self.assertEqual([], dependency_checker.find_dependency_cycles(ir)) + + def test_error_on_cycle_nested(self): + ir = _parse_snippet( + "struct Foo:\n" + " struct Bar:\n" + " 0 [+field2] UInt field1\n" + " 0 [+field1] UInt field2\n" + " 0 [+1] UInt field\n" + ) + struct = ir.module[0].type[0].subtype[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[0].source_location, + "Dependency cycle\nfield1", + ), + error.note("m.emb", struct.field[1].source_location, "field2"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_import_cycle(self): + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader( + {"m.emb": 'import "n.emb" as n\n', "n.emb": 'import "m.emb" as m\n'} + ), + stop_before_step="find_dependency_cycles", + ) + assert not errors + self.assertEqual( + [ + [ + error.error( + "m.emb", + ir.module[0].source_location, + "Import dependency cycle\nm.emb", + ), + error.note("n.emb", ir.module[2].source_location, "n.emb"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_import_cycle_and_field_cycle(self): + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader( + { + "m.emb": 'import "n.emb" as n\n' + "struct Foo:\n" + " 0 [+field1] UInt field1\n", + "n.emb": 'import "m.emb" as m\n', + } + ), + stop_before_step="find_dependency_cycles", + ) + assert not errors + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + ir.module[0].source_location, + "Import dependency cycle\nm.emb", + ), + error.note("n.emb", ir.module[2].source_location, "n.emb"), + ], + [ + error.error( + "m.emb", + struct.field[0].source_location, + "Dependency cycle\nfield1", + ) + ], + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_field_existence_self_cycle(self): + ir = _parse_snippet("struct Foo:\n" " if x == 1:\n" " 0 [+1] UInt x\n") + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", struct.field[0].source_location, "Dependency cycle\nx" + ) + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_field_existence_cycle(self): + ir = _parse_snippet( + "struct Foo:\n" + " if y == 1:\n" + " 0 [+1] UInt x\n" + " if x == 0:\n" + " 1 [+1] UInt y\n" + ) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", struct.field[0].source_location, "Dependency cycle\nx" + ), + error.note("m.emb", struct.field[1].source_location, "y"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_virtual_field_cycle(self): + ir = _parse_snippet("struct Foo:\n" " let x = y\n" " let y = x\n") + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", struct.field[0].source_location, "Dependency cycle\nx" + ), + error.note("m.emb", struct.field[1].source_location, "y"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_virtual_non_virtual_field_cycle(self): + ir = _parse_snippet("struct Foo:\n" " let x = y\n" " x [+4] UInt y\n") + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", struct.field[0].source_location, "Dependency cycle\nx" + ), + error.note("m.emb", struct.field[1].source_location, "y"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_non_virtual_virtual_field_cycle(self): + ir = _parse_snippet("struct Foo:\n" " y [+4] UInt x\n" " let y = x\n") + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", struct.field[0].source_location, "Dependency cycle\nx" + ), + error.note("m.emb", struct.field[1].source_location, "y"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_error_on_cycle_involving_subfield(self): + ir = _parse_snippet( + "struct Bar:\n" + " foo_b.x [+4] Foo foo_a\n" + " foo_a.x [+4] Foo foo_b\n" + "struct Foo:\n" + " 0 [+4] UInt x\n" + ) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[0].source_location, + "Dependency cycle\nfoo_a", + ), + error.note("m.emb", struct.field[1].source_location, "foo_b"), + ] + ], + dependency_checker.find_dependency_cycles(ir), + ) + + def test_dependency_ordering_with_no_dependencies(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" " 0 [+4] UInt a\n" " 4 [+4] UInt b\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([0, 1], struct.fields_in_dependency_order[:2]) + + def test_dependency_ordering_with_dependency_in_order(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" " 0 [+4] UInt a\n" " a [+4] UInt b\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([0, 1], struct.fields_in_dependency_order[:2]) + + def test_dependency_ordering_with_dependency_in_reverse_order(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" " b [+4] UInt a\n" " 0 [+4] UInt b\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([1, 0], struct.fields_in_dependency_order[:2]) + + def test_dependency_ordering_with_extra_fields(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" + " d [+4] UInt a\n" + " 4 [+4] UInt b\n" + " 8 [+4] UInt c\n" + " 12 [+4] UInt d\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([1, 2, 3, 0], struct.fields_in_dependency_order[:4]) + + def test_dependency_ordering_scrambled(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" + " d [+4] UInt a\n" + " c [+4] UInt b\n" + " a [+4] UInt c\n" + " 12 [+4] UInt d\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([3, 0, 2, 1], struct.fields_in_dependency_order[:4]) + + def test_dependency_ordering_multiple_dependents(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" + " d [+4] UInt a\n" + " d [+4] UInt b\n" + " d [+4] UInt c\n" + " 12 [+4] UInt d\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([3, 0, 1, 2], struct.fields_in_dependency_order[:4]) + + def test_dependency_ordering_multiple_dependencies(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" + " b+c [+4] UInt a\n" + " 4 [+4] UInt b\n" + " 8 [+4] UInt c\n" + " a [+4] UInt d\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([1, 2, 0, 3], struct.fields_in_dependency_order[:4]) + + def test_dependency_ordering_with_parameter(self): + ir = _find_dependencies_for_snippet( + "struct Foo:\n" + " 0 [+1] Bar(x) b\n" + " 1 [+1] UInt x\n" + "struct Bar(x: UInt:8):\n" + " x [+1] UInt y\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([1, 0], struct.fields_in_dependency_order[:2]) + + def test_dependency_ordering_with_local_parameter(self): + ir = _find_dependencies_for_snippet( + "struct Foo(x: Int:13):\n" " 0 [+x] Int b\n" + ) + self.assertEqual([], dependency_checker.set_dependency_order(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual([0], struct.fields_in_dependency_order[:1]) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/docs_are_up_to_date_test.py b/compiler/front_end/docs_are_up_to_date_test.py index 70fd4b3..00ab32a 100644 --- a/compiler/front_end/docs_are_up_to_date_test.py +++ b/compiler/front_end/docs_are_up_to_date_test.py @@ -21,22 +21,22 @@ class DocsAreUpToDateTest(unittest.TestCase): - """Tests that auto-generated, checked-in documentation is up to date.""" - - def test_grammar_md(self): - doc_md = pkgutil.get_data("doc", "grammar.md").decode(encoding="UTF-8") - correct_md = generate_grammar_md.generate_grammar_md() - # If this fails, run: - # - # bazel run //compiler/front_end:generate_grammar_md > doc/grammar.md - # - # Be sure to check that the results look good before committing! - doc_md_lines = doc_md.splitlines() - correct_md_lines = correct_md.splitlines() - for i in range(len(doc_md_lines)): - self.assertEqual(correct_md_lines[i], doc_md_lines[i]) - self.assertEqual(correct_md, doc_md) + """Tests that auto-generated, checked-in documentation is up to date.""" + + def test_grammar_md(self): + doc_md = pkgutil.get_data("doc", "grammar.md").decode(encoding="UTF-8") + correct_md = generate_grammar_md.generate_grammar_md() + # If this fails, run: + # + # bazel run //compiler/front_end:generate_grammar_md > doc/grammar.md + # + # Be sure to check that the results look good before committing! + doc_md_lines = doc_md.splitlines() + correct_md_lines = correct_md.splitlines() + for i in range(len(doc_md_lines)): + self.assertEqual(correct_md_lines[i], doc_md_lines[i]) + self.assertEqual(correct_md, doc_md) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/emboss_front_end.py b/compiler/front_end/emboss_front_end.py index 1e30ded..c62638d 100644 --- a/compiler/front_end/emboss_front_end.py +++ b/compiler/front_end/emboss_front_end.py @@ -34,157 +34,180 @@ def _parse_command_line(argv): - """Parses the given command-line arguments.""" - parser = argparse.ArgumentParser(description="Emboss compiler front end.", - prog=argv[0]) - parser.add_argument("input_file", - type=str, - nargs=1, - help=".emb file to compile.") - parser.add_argument("--debug-show-tokenization", - action="store_true", - help="Show the tokenization of the main input file.") - parser.add_argument("--debug-show-parse-tree", - action="store_true", - help="Show the parse tree of the main input file.") - parser.add_argument("--debug-show-module-ir", - action="store_true", - help="Show the module-level IR of the main input file " - "before symbol resolution.") - parser.add_argument("--debug-show-full-ir", - action="store_true", - help="Show the final IR of the main input file.") - parser.add_argument("--debug-show-used-productions", - action="store_true", - help="Show all of the grammar productions used in " - "parsing the main input file.") - parser.add_argument("--debug-show-unused-productions", - action="store_true", - help="Show all of the grammar productions not used in " - "parsing the main input file.") - parser.add_argument("--output-ir-to-stdout", - action="store_true", - help="Dump serialized IR to stdout.") - parser.add_argument("--output-file", - type=str, - help="Write serialized IR to file.") - parser.add_argument("--no-debug-show-header-lines", - dest="debug_show_header_lines", - action="store_false", - help="Print header lines before output if true.") - parser.add_argument("--color-output", - default="if_tty", - choices=["always", "never", "if_tty", "auto"], - help="Print error messages using color. 'auto' is a " - "synonym for 'if_tty'.") - parser.add_argument("--import-dir", "-I", - dest="import_dirs", - action="append", - default=["."], - help="A directory to use when searching for imported " - "embs. If no import_dirs are specified, the " - "current directory will be used.") - return parser.parse_args(argv[1:]) + """Parses the given command-line arguments.""" + parser = argparse.ArgumentParser( + description="Emboss compiler front end.", prog=argv[0] + ) + parser.add_argument("input_file", type=str, nargs=1, help=".emb file to compile.") + parser.add_argument( + "--debug-show-tokenization", + action="store_true", + help="Show the tokenization of the main input file.", + ) + parser.add_argument( + "--debug-show-parse-tree", + action="store_true", + help="Show the parse tree of the main input file.", + ) + parser.add_argument( + "--debug-show-module-ir", + action="store_true", + help="Show the module-level IR of the main input file " + "before symbol resolution.", + ) + parser.add_argument( + "--debug-show-full-ir", + action="store_true", + help="Show the final IR of the main input file.", + ) + parser.add_argument( + "--debug-show-used-productions", + action="store_true", + help="Show all of the grammar productions used in " + "parsing the main input file.", + ) + parser.add_argument( + "--debug-show-unused-productions", + action="store_true", + help="Show all of the grammar productions not used in " + "parsing the main input file.", + ) + parser.add_argument( + "--output-ir-to-stdout", + action="store_true", + help="Dump serialized IR to stdout.", + ) + parser.add_argument("--output-file", type=str, help="Write serialized IR to file.") + parser.add_argument( + "--no-debug-show-header-lines", + dest="debug_show_header_lines", + action="store_false", + help="Print header lines before output if true.", + ) + parser.add_argument( + "--color-output", + default="if_tty", + choices=["always", "never", "if_tty", "auto"], + help="Print error messages using color. 'auto' is a " "synonym for 'if_tty'.", + ) + parser.add_argument( + "--import-dir", + "-I", + dest="import_dirs", + action="append", + default=["."], + help="A directory to use when searching for imported " + "embs. If no import_dirs are specified, the " + "current directory will be used.", + ) + return parser.parse_args(argv[1:]) def _show_errors(errors, ir, color_output): - """Prints errors with source code snippets.""" - source_codes = {} - if ir: - for module in ir.module: - source_codes[module.source_file_name] = module.source_text - use_color = (color_output == "always" or - (color_output in ("auto", "if_tty") and - os.isatty(sys.stderr.fileno()))) - print(error.format_errors(errors, source_codes, use_color), file=sys.stderr) + """Prints errors with source code snippets.""" + source_codes = {} + if ir: + for module in ir.module: + source_codes[module.source_file_name] = module.source_text + use_color = color_output == "always" or ( + color_output in ("auto", "if_tty") and os.isatty(sys.stderr.fileno()) + ) + print(error.format_errors(errors, source_codes, use_color), file=sys.stderr) def _find_in_dirs_and_read(import_dirs): - """Returns a function which will search import_dirs for a file.""" - - def _find_and_read(file_name): - """Searches import_dirs for file_name and returns the contents.""" - errors = [] - # *All* source files, including the one specified on the command line, will - # be searched for in the import_dirs. This may be surprising, especially if - # the current directory is *not* an import_dir. - # TODO(bolms): Determine if this is really the desired behavior. - for import_dir in import_dirs: - full_name = path.join(import_dir, file_name) - try: - with open(full_name) as f: - # As written, this follows the typical compiler convention of checking - # the include/import directories in the order specified by the user, - # and always reading the first matching file, even if other files - # might match in later directories. This lets files shadow other - # files, which can be useful in some cases (to override things), but - # can also cause accidental shadowing, which can be tricky to fix. - # - # TODO(bolms): Check if any other files with the same name are in the - # import path, and give a warning or error? - return f.read(), None - except IOError as e: - errors.append(str(e)) - return None, errors + ["import path " + ":".join(import_dirs)] - - return _find_and_read + """Returns a function which will search import_dirs for a file.""" + + def _find_and_read(file_name): + """Searches import_dirs for file_name and returns the contents.""" + errors = [] + # *All* source files, including the one specified on the command line, will + # be searched for in the import_dirs. This may be surprising, especially if + # the current directory is *not* an import_dir. + # TODO(bolms): Determine if this is really the desired behavior. + for import_dir in import_dirs: + full_name = path.join(import_dir, file_name) + try: + with open(full_name) as f: + # As written, this follows the typical compiler convention of checking + # the include/import directories in the order specified by the user, + # and always reading the first matching file, even if other files + # might match in later directories. This lets files shadow other + # files, which can be useful in some cases (to override things), but + # can also cause accidental shadowing, which can be tricky to fix. + # + # TODO(bolms): Check if any other files with the same name are in the + # import path, and give a warning or error? + return f.read(), None + except IOError as e: + errors.append(str(e)) + return None, errors + ["import path " + ":".join(import_dirs)] + + return _find_and_read + def parse_and_log_errors(input_file, import_dirs, color_output): - """Fully parses an .emb and logs any errors. + """Fully parses an .emb and logs any errors. + + Arguments: + input_file: The path of the module source file. + import_dirs: Directories to search for imported dependencies. + color_output: Used when logging errors: "always", "never", "if_tty", "auto" - Arguments: - input_file: The path of the module source file. - import_dirs: Directories to search for imported dependencies. - color_output: Used when logging errors: "always", "never", "if_tty", "auto" + Returns: + (ir, debug_info, errors) + """ + ir, debug_info, errors = glue.parse_emboss_file( + input_file, _find_in_dirs_and_read(import_dirs) + ) + if errors: + _show_errors(errors, ir, color_output) - Returns: - (ir, debug_info, errors) - """ - ir, debug_info, errors = glue.parse_emboss_file( - input_file, _find_in_dirs_and_read(import_dirs)) - if errors: - _show_errors(errors, ir, color_output) + return (ir, debug_info, errors) - return (ir, debug_info, errors) def main(flags): - ir, debug_info, errors = parse_and_log_errors( - flags.input_file[0], flags.import_dirs, flags.color_output) - if errors: - return 1 - main_module_debug_info = debug_info.modules[flags.input_file[0]] - if flags.debug_show_tokenization: - if flags.debug_show_header_lines: - print("Tokenization:") - print(main_module_debug_info.format_tokenization()) - if flags.debug_show_parse_tree: - if flags.debug_show_header_lines: - print("Parse Tree:") - print(main_module_debug_info.format_parse_tree()) - if flags.debug_show_module_ir: - if flags.debug_show_header_lines: - print("Module IR:") - print(main_module_debug_info.format_module_ir()) - if flags.debug_show_full_ir: - if flags.debug_show_header_lines: - print("Full IR:") - print(str(ir)) - if flags.debug_show_used_productions: - if flags.debug_show_header_lines: - print("Used Productions:") - print(glue.format_production_set(main_module_debug_info.used_productions)) - if flags.debug_show_unused_productions: - if flags.debug_show_header_lines: - print("Unused Productions:") - print(glue.format_production_set( - set(module_ir.PRODUCTIONS) - main_module_debug_info.used_productions)) - if flags.output_ir_to_stdout: - print(ir_data_utils.IrDataSerializer(ir).to_json()) - if flags.output_file: - with open(flags.output_file, "w") as f: - f.write(ir_data_utils.IrDataSerializer(ir).to_json()) - return 0 + ir, debug_info, errors = parse_and_log_errors( + flags.input_file[0], flags.import_dirs, flags.color_output + ) + if errors: + return 1 + main_module_debug_info = debug_info.modules[flags.input_file[0]] + if flags.debug_show_tokenization: + if flags.debug_show_header_lines: + print("Tokenization:") + print(main_module_debug_info.format_tokenization()) + if flags.debug_show_parse_tree: + if flags.debug_show_header_lines: + print("Parse Tree:") + print(main_module_debug_info.format_parse_tree()) + if flags.debug_show_module_ir: + if flags.debug_show_header_lines: + print("Module IR:") + print(main_module_debug_info.format_module_ir()) + if flags.debug_show_full_ir: + if flags.debug_show_header_lines: + print("Full IR:") + print(str(ir)) + if flags.debug_show_used_productions: + if flags.debug_show_header_lines: + print("Used Productions:") + print(glue.format_production_set(main_module_debug_info.used_productions)) + if flags.debug_show_unused_productions: + if flags.debug_show_header_lines: + print("Unused Productions:") + print( + glue.format_production_set( + set(module_ir.PRODUCTIONS) - main_module_debug_info.used_productions + ) + ) + if flags.output_ir_to_stdout: + print(ir_data_utils.IrDataSerializer(ir).to_json()) + if flags.output_file: + with open(flags.output_file, "w") as f: + f.write(ir_data_utils.IrDataSerializer(ir).to_json()) + return 0 if __name__ == "__main__": - sys.exit(main(_parse_command_line(sys.argv))) + sys.exit(main(_parse_command_line(sys.argv))) diff --git a/compiler/front_end/expression_bounds.py b/compiler/front_end/expression_bounds.py index e364399..cca36ee 100644 --- a/compiler/front_end/expression_bounds.py +++ b/compiler/front_end/expression_bounds.py @@ -26,702 +26,756 @@ # Create a local alias for math.gcd with a fallback to fractions.gcd if it is # not available. This can be dropped if pre-3.5 Python support is dropped. -if hasattr(math, 'gcd'): - _math_gcd = math.gcd +if hasattr(math, "gcd"): + _math_gcd = math.gcd else: - _math_gcd = fractions.gcd + _math_gcd = fractions.gcd def compute_constraints_of_expression(expression, ir): - """Adds appropriate bounding constraints to the given expression.""" - if ir_util.is_constant_type(expression.type): - return - expression_variety = expression.WhichOneof("expression") - if expression_variety == "constant": - _compute_constant_value_of_constant(expression) - elif expression_variety == "constant_reference": - _compute_constant_value_of_constant_reference(expression, ir) - elif expression_variety == "function": - _compute_constraints_of_function(expression, ir) - elif expression_variety == "field_reference": - _compute_constraints_of_field_reference(expression, ir) - elif expression_variety == "builtin_reference": - _compute_constraints_of_builtin_value(expression) - elif expression_variety == "boolean_constant": - _compute_constant_value_of_boolean_constant(expression) - else: - assert False, "Unknown expression variety {!r}".format(expression_variety) - if expression.type.WhichOneof("type") == "integer": - _assert_integer_constraints(expression) + """Adds appropriate bounding constraints to the given expression.""" + if ir_util.is_constant_type(expression.type): + return + expression_variety = expression.WhichOneof("expression") + if expression_variety == "constant": + _compute_constant_value_of_constant(expression) + elif expression_variety == "constant_reference": + _compute_constant_value_of_constant_reference(expression, ir) + elif expression_variety == "function": + _compute_constraints_of_function(expression, ir) + elif expression_variety == "field_reference": + _compute_constraints_of_field_reference(expression, ir) + elif expression_variety == "builtin_reference": + _compute_constraints_of_builtin_value(expression) + elif expression_variety == "boolean_constant": + _compute_constant_value_of_boolean_constant(expression) + else: + assert False, "Unknown expression variety {!r}".format(expression_variety) + if expression.type.WhichOneof("type") == "integer": + _assert_integer_constraints(expression) def _compute_constant_value_of_constant(expression): - value = expression.constant.value - expression.type.integer.modular_value = value - expression.type.integer.minimum_value = value - expression.type.integer.maximum_value = value - expression.type.integer.modulus = "infinity" + value = expression.constant.value + expression.type.integer.modular_value = value + expression.type.integer.minimum_value = value + expression.type.integer.maximum_value = value + expression.type.integer.modulus = "infinity" def _compute_constant_value_of_constant_reference(expression, ir): - referred_object = ir_util.find_object( - expression.constant_reference.canonical_name, ir) - expression = ir_data_utils.builder(expression) - if isinstance(referred_object, ir_data.EnumValue): - compute_constraints_of_expression(referred_object.value, ir) - assert ir_util.is_constant(referred_object.value) - new_value = str(ir_util.constant_value(referred_object.value)) - expression.type.enumeration.value = new_value - elif isinstance(referred_object, ir_data.Field): - assert ir_util.field_is_virtual(referred_object), ( - "Non-virtual non-enum-value constant reference should have been caught " - "in type_check.py") - compute_constraints_of_expression(referred_object.read_transform, ir) - expression.type.CopyFrom(referred_object.read_transform.type) - else: - assert False, "Unexpected constant reference type." + referred_object = ir_util.find_object( + expression.constant_reference.canonical_name, ir + ) + expression = ir_data_utils.builder(expression) + if isinstance(referred_object, ir_data.EnumValue): + compute_constraints_of_expression(referred_object.value, ir) + assert ir_util.is_constant(referred_object.value) + new_value = str(ir_util.constant_value(referred_object.value)) + expression.type.enumeration.value = new_value + elif isinstance(referred_object, ir_data.Field): + assert ir_util.field_is_virtual(referred_object), ( + "Non-virtual non-enum-value constant reference should have been caught " + "in type_check.py" + ) + compute_constraints_of_expression(referred_object.read_transform, ir) + expression.type.CopyFrom(referred_object.read_transform.type) + else: + assert False, "Unexpected constant reference type." def _compute_constraints_of_function(expression, ir): - """Computes the known constraints of the result of a function.""" - for arg in expression.function.args: - compute_constraints_of_expression(arg, ir) - op = expression.function.function - if op in (ir_data.FunctionMapping.ADDITION, ir_data.FunctionMapping.SUBTRACTION): - _compute_constraints_of_additive_operator(expression) - elif op == ir_data.FunctionMapping.MULTIPLICATION: - _compute_constraints_of_multiplicative_operator(expression) - elif op in (ir_data.FunctionMapping.EQUALITY, ir_data.FunctionMapping.INEQUALITY, - ir_data.FunctionMapping.LESS, ir_data.FunctionMapping.LESS_OR_EQUAL, - ir_data.FunctionMapping.GREATER, ir_data.FunctionMapping.GREATER_OR_EQUAL, - ir_data.FunctionMapping.AND, ir_data.FunctionMapping.OR): - _compute_constant_value_of_comparison_operator(expression) - elif op == ir_data.FunctionMapping.CHOICE: - _compute_constraints_of_choice_operator(expression) - elif op == ir_data.FunctionMapping.MAXIMUM: - _compute_constraints_of_maximum_function(expression) - elif op == ir_data.FunctionMapping.PRESENCE: - _compute_constraints_of_existence_function(expression, ir) - elif op in (ir_data.FunctionMapping.UPPER_BOUND, ir_data.FunctionMapping.LOWER_BOUND): - _compute_constraints_of_bound_function(expression) - else: - assert False, "Unknown operator {!r}".format(op) + """Computes the known constraints of the result of a function.""" + for arg in expression.function.args: + compute_constraints_of_expression(arg, ir) + op = expression.function.function + if op in (ir_data.FunctionMapping.ADDITION, ir_data.FunctionMapping.SUBTRACTION): + _compute_constraints_of_additive_operator(expression) + elif op == ir_data.FunctionMapping.MULTIPLICATION: + _compute_constraints_of_multiplicative_operator(expression) + elif op in ( + ir_data.FunctionMapping.EQUALITY, + ir_data.FunctionMapping.INEQUALITY, + ir_data.FunctionMapping.LESS, + ir_data.FunctionMapping.LESS_OR_EQUAL, + ir_data.FunctionMapping.GREATER, + ir_data.FunctionMapping.GREATER_OR_EQUAL, + ir_data.FunctionMapping.AND, + ir_data.FunctionMapping.OR, + ): + _compute_constant_value_of_comparison_operator(expression) + elif op == ir_data.FunctionMapping.CHOICE: + _compute_constraints_of_choice_operator(expression) + elif op == ir_data.FunctionMapping.MAXIMUM: + _compute_constraints_of_maximum_function(expression) + elif op == ir_data.FunctionMapping.PRESENCE: + _compute_constraints_of_existence_function(expression, ir) + elif op in ( + ir_data.FunctionMapping.UPPER_BOUND, + ir_data.FunctionMapping.LOWER_BOUND, + ): + _compute_constraints_of_bound_function(expression) + else: + assert False, "Unknown operator {!r}".format(op) def _compute_constraints_of_existence_function(expression, ir): - """Computes the constraints of a $has(field) expression.""" - field_path = expression.function.args[0].field_reference.path[-1] - field = ir_util.find_object(field_path, ir) - compute_constraints_of_expression(field.existence_condition, ir) - ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type) + """Computes the constraints of a $has(field) expression.""" + field_path = expression.function.args[0].field_reference.path[-1] + field = ir_util.find_object(field_path, ir) + compute_constraints_of_expression(field.existence_condition, ir) + ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type) def _compute_constraints_of_field_reference(expression, ir): - """Computes the constraints of a reference to a structure's field.""" - field_path = expression.field_reference.path[-1] - field = ir_util.find_object(field_path, ir) - if isinstance(field, ir_data.Field) and ir_util.field_is_virtual(field): - # References to virtual fields should have the virtual field's constraints - # copied over. - compute_constraints_of_expression(field.read_transform, ir) - ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type) - return - # Non-virtual non-integer fields do not (yet) have constraints. - if expression.type.WhichOneof("type") == "integer": - # TODO(bolms): These lines will need to change when support is added for - # fixed-point types. - expression.type.integer.modulus = "1" - expression.type.integer.modular_value = "0" - type_definition = ir_util.find_parent_object(field_path, ir) - if isinstance(field, ir_data.Field): - referrent_type = field.type - else: - referrent_type = field.physical_type_alias - if referrent_type.HasField("size_in_bits"): - type_size = ir_util.constant_value(referrent_type.size_in_bits) + """Computes the constraints of a reference to a structure's field.""" + field_path = expression.field_reference.path[-1] + field = ir_util.find_object(field_path, ir) + if isinstance(field, ir_data.Field) and ir_util.field_is_virtual(field): + # References to virtual fields should have the virtual field's constraints + # copied over. + compute_constraints_of_expression(field.read_transform, ir) + ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type) + return + # Non-virtual non-integer fields do not (yet) have constraints. + if expression.type.WhichOneof("type") == "integer": + # TODO(bolms): These lines will need to change when support is added for + # fixed-point types. + expression.type.integer.modulus = "1" + expression.type.integer.modular_value = "0" + type_definition = ir_util.find_parent_object(field_path, ir) + if isinstance(field, ir_data.Field): + referrent_type = field.type + else: + referrent_type = field.physical_type_alias + if referrent_type.HasField("size_in_bits"): + type_size = ir_util.constant_value(referrent_type.size_in_bits) + else: + field_size = ir_util.constant_value(field.location.size) + if field_size is None: + type_size = None + else: + type_size = field_size * type_definition.addressable_unit + assert referrent_type.HasField("atomic_type"), field + assert not referrent_type.atomic_type.reference.canonical_name.module_file + _set_integer_constraints_from_physical_type( + expression, referrent_type, type_size + ) + + +def _set_integer_constraints_from_physical_type(expression, physical_type, type_size): + """Copies the integer constraints of an expression from a physical type.""" + # SCAFFOLDING HACK: In order to keep changelists manageable, this hardcodes + # the ranges for all of the Emboss Prelude integer types. This would break + # any user-defined `external` integer types, but that feature isn't fully + # implemented in the C++ backend, so it doesn't matter for now. + # + # Adding the attribute(s) for integer bounds will require new operators: + # integer/flooring division, remainder, and exponentiation (2**N, 10**N). + # + # (Technically, there are a few sets of operators that would work: for + # example, just the choice operator `?:` is sufficient, but very ugly. + # Bitwise AND, bitshift, and exponentiation would also work, but `10**($bits + # >> 2) * 2**($bits & 0b11) - 1` isn't quite as clear as `10**($bits // 4) * + # 2**($bits % 4) - 1`, in my (bolms@) opinion.) + # + # TODO(bolms): Add a scheme for defining integer bounds on user-defined + # external types. + if type_size is None: + # If the type_size is unknown, then we can't actually say anything about the + # minimum and maximum values of the type. For UInt, Int, and Bcd, an error + # will be thrown during the constraints check stage. + expression.type.integer.minimum_value = "-infinity" + expression.type.integer.maximum_value = "infinity" + return + name = tuple(physical_type.atomic_type.reference.canonical_name.object_path) + if name == ("UInt",): + expression.type.integer.minimum_value = "0" + expression.type.integer.maximum_value = str(2**type_size - 1) + elif name == ("Int",): + expression.type.integer.minimum_value = str(-(2 ** (type_size - 1))) + expression.type.integer.maximum_value = str(2 ** (type_size - 1) - 1) + elif name == ("Bcd",): + expression.type.integer.minimum_value = "0" + expression.type.integer.maximum_value = str( + 10 ** (type_size // 4) * 2 ** (type_size % 4) - 1 + ) else: - field_size = ir_util.constant_value(field.location.size) - if field_size is None: - type_size = None - else: - type_size = field_size * type_definition.addressable_unit - assert referrent_type.HasField("atomic_type"), field - assert not referrent_type.atomic_type.reference.canonical_name.module_file - _set_integer_constraints_from_physical_type( - expression, referrent_type, type_size) - - -def _set_integer_constraints_from_physical_type( - expression, physical_type, type_size): - """Copies the integer constraints of an expression from a physical type.""" - # SCAFFOLDING HACK: In order to keep changelists manageable, this hardcodes - # the ranges for all of the Emboss Prelude integer types. This would break - # any user-defined `external` integer types, but that feature isn't fully - # implemented in the C++ backend, so it doesn't matter for now. - # - # Adding the attribute(s) for integer bounds will require new operators: - # integer/flooring division, remainder, and exponentiation (2**N, 10**N). - # - # (Technically, there are a few sets of operators that would work: for - # example, just the choice operator `?:` is sufficient, but very ugly. - # Bitwise AND, bitshift, and exponentiation would also work, but `10**($bits - # >> 2) * 2**($bits & 0b11) - 1` isn't quite as clear as `10**($bits // 4) * - # 2**($bits % 4) - 1`, in my (bolms@) opinion.) - # - # TODO(bolms): Add a scheme for defining integer bounds on user-defined - # external types. - if type_size is None: - # If the type_size is unknown, then we can't actually say anything about the - # minimum and maximum values of the type. For UInt, Int, and Bcd, an error - # will be thrown during the constraints check stage. - expression.type.integer.minimum_value = "-infinity" - expression.type.integer.maximum_value = "infinity" - return - name = tuple(physical_type.atomic_type.reference.canonical_name.object_path) - if name == ("UInt",): - expression.type.integer.minimum_value = "0" - expression.type.integer.maximum_value = str(2**type_size - 1) - elif name == ("Int",): - expression.type.integer.minimum_value = str(-(2**(type_size - 1))) - expression.type.integer.maximum_value = str(2**(type_size - 1) - 1) - elif name == ("Bcd",): - expression.type.integer.minimum_value = "0" - expression.type.integer.maximum_value = str( - 10**(type_size // 4) * 2**(type_size % 4) - 1) - else: - assert False, "Unknown integral type " + ".".join(name) + assert False, "Unknown integral type " + ".".join(name) def _compute_constraints_of_parameter(parameter): - if parameter.type.WhichOneof("type") == "integer": - type_size = ir_util.constant_value( - parameter.physical_type_alias.size_in_bits) - _set_integer_constraints_from_physical_type( - parameter, parameter.physical_type_alias, type_size) + if parameter.type.WhichOneof("type") == "integer": + type_size = ir_util.constant_value(parameter.physical_type_alias.size_in_bits) + _set_integer_constraints_from_physical_type( + parameter, parameter.physical_type_alias, type_size + ) def _compute_constraints_of_builtin_value(expression): - """Computes the constraints of a builtin (like $static_size_in_bits).""" - name = expression.builtin_reference.canonical_name.object_path[0] - if name == "$static_size_in_bits": - expression.type.integer.modulus = "1" - expression.type.integer.modular_value = "0" - expression.type.integer.minimum_value = "0" - # The maximum theoretically-supported size of something is 2**64 bytes, - # which is 2**64 * 8 bits. - # - # Really, $static_size_in_bits is only valid in expressions that have to be - # evaluated at compile time anyway, so it doesn't really matter if the - # bounds are excessive. - expression.type.integer.maximum_value = "infinity" - elif name == "$is_statically_sized": - # No bounds on a boolean variable. - pass - elif name == "$logical_value": - # $logical_value is the placeholder used in inferred write-through - # transformations. - # - # Only integers (currently) have "real" write-through transformations, but - # fields that would otherwise be straight aliases, but which have a - # [requires] attribute, are elevated to write-through fields, so that the - # [requires] clause can be checked in Write, CouldWriteValue, TryToWrite, - # Read, and Ok. - if expression.type.WhichOneof("type") == "integer": - assert expression.type.integer.modulus - assert expression.type.integer.modular_value - assert expression.type.integer.minimum_value - assert expression.type.integer.maximum_value - elif expression.type.WhichOneof("type") == "enumeration": - assert expression.type.enumeration.name - elif expression.type.WhichOneof("type") == "boolean": - pass + """Computes the constraints of a builtin (like $static_size_in_bits).""" + name = expression.builtin_reference.canonical_name.object_path[0] + if name == "$static_size_in_bits": + expression.type.integer.modulus = "1" + expression.type.integer.modular_value = "0" + expression.type.integer.minimum_value = "0" + # The maximum theoretically-supported size of something is 2**64 bytes, + # which is 2**64 * 8 bits. + # + # Really, $static_size_in_bits is only valid in expressions that have to be + # evaluated at compile time anyway, so it doesn't really matter if the + # bounds are excessive. + expression.type.integer.maximum_value = "infinity" + elif name == "$is_statically_sized": + # No bounds on a boolean variable. + pass + elif name == "$logical_value": + # $logical_value is the placeholder used in inferred write-through + # transformations. + # + # Only integers (currently) have "real" write-through transformations, but + # fields that would otherwise be straight aliases, but which have a + # [requires] attribute, are elevated to write-through fields, so that the + # [requires] clause can be checked in Write, CouldWriteValue, TryToWrite, + # Read, and Ok. + if expression.type.WhichOneof("type") == "integer": + assert expression.type.integer.modulus + assert expression.type.integer.modular_value + assert expression.type.integer.minimum_value + assert expression.type.integer.maximum_value + elif expression.type.WhichOneof("type") == "enumeration": + assert expression.type.enumeration.name + elif expression.type.WhichOneof("type") == "boolean": + pass + else: + assert False, "Unexpected type for $logical_value" else: - assert False, "Unexpected type for $logical_value" - else: - assert False, "Unknown builtin " + name + assert False, "Unknown builtin " + name def _compute_constant_value_of_boolean_constant(expression): - expression.type.boolean.value = expression.boolean_constant.value + expression.type.boolean.value = expression.boolean_constant.value def _add(a, b): - """Adds a and b, where a and b are ints, "infinity", or "-infinity".""" - if a in ("infinity", "-infinity"): - a, b = b, a - if b == "infinity": - assert a != "-infinity" - return "infinity" - if b == "-infinity": - assert a != "infinity" - return "-infinity" - return int(a) + int(b) + """Adds a and b, where a and b are ints, "infinity", or "-infinity".""" + if a in ("infinity", "-infinity"): + a, b = b, a + if b == "infinity": + assert a != "-infinity" + return "infinity" + if b == "-infinity": + assert a != "infinity" + return "-infinity" + return int(a) + int(b) def _sub(a, b): - """Subtracts b from a, where a and b are ints, "infinity", or "-infinity".""" - if b == "infinity": - return _add(a, "-infinity") - if b == "-infinity": - return _add(a, "infinity") - return _add(a, -int(b)) + """Subtracts b from a, where a and b are ints, "infinity", or "-infinity".""" + if b == "infinity": + return _add(a, "-infinity") + if b == "-infinity": + return _add(a, "infinity") + return _add(a, -int(b)) def _sign(a): - """Returns 1 if a > 0, 0 if a == 0, and -1 if a < 0.""" - if a == "infinity": - return 1 - if a == "-infinity": - return -1 - if int(a) > 0: - return 1 - if int(a) < 0: - return -1 - return 0 + """Returns 1 if a > 0, 0 if a == 0, and -1 if a < 0.""" + if a == "infinity": + return 1 + if a == "-infinity": + return -1 + if int(a) > 0: + return 1 + if int(a) < 0: + return -1 + return 0 def _mul(a, b): - """Multiplies a and b, where a and b are ints, "infinity", or "-infinity".""" - if _is_infinite(a): - a, b = b, a - if _is_infinite(b): - sign = _sign(a) * _sign(b) - if sign > 0: - return "infinity" - if sign < 0: - return "-infinity" - return 0 - return int(a) * int(b) + """Multiplies a and b, where a and b are ints, "infinity", or "-infinity".""" + if _is_infinite(a): + a, b = b, a + if _is_infinite(b): + sign = _sign(a) * _sign(b) + if sign > 0: + return "infinity" + if sign < 0: + return "-infinity" + return 0 + return int(a) * int(b) def _is_infinite(a): - return a in ("infinity", "-infinity") + return a in ("infinity", "-infinity") def _max(a): - """Returns max of a, where elements are ints, "infinity", or "-infinity".""" - if any(n == "infinity" for n in a): - return "infinity" - if all(n == "-infinity" for n in a): - return "-infinity" - return max(int(n) for n in a if not _is_infinite(n)) + """Returns max of a, where elements are ints, "infinity", or "-infinity".""" + if any(n == "infinity" for n in a): + return "infinity" + if all(n == "-infinity" for n in a): + return "-infinity" + return max(int(n) for n in a if not _is_infinite(n)) def _min(a): - """Returns min of a, where elements are ints, "infinity", or "-infinity".""" - if any(n == "-infinity" for n in a): - return "-infinity" - if all(n == "infinity" for n in a): - return "infinity" - return min(int(n) for n in a if not _is_infinite(n)) + """Returns min of a, where elements are ints, "infinity", or "-infinity".""" + if any(n == "-infinity" for n in a): + return "-infinity" + if all(n == "infinity" for n in a): + return "infinity" + return min(int(n) for n in a if not _is_infinite(n)) def _compute_constraints_of_additive_operator(expression): - """Computes the modular value of an additive expression.""" - funcs = { - ir_data.FunctionMapping.ADDITION: _add, - ir_data.FunctionMapping.SUBTRACTION: _sub, - } - func = funcs[expression.function.function] - args = expression.function.args - for arg in args: - assert arg.type.integer.modular_value, str(expression) - left, right = args - unadjusted_modular_value = func(left.type.integer.modular_value, - right.type.integer.modular_value) - new_modulus = _greatest_common_divisor(left.type.integer.modulus, - right.type.integer.modulus) - expression.type.integer.modulus = str(new_modulus) - if new_modulus == "infinity": - expression.type.integer.modular_value = str(unadjusted_modular_value) - else: - expression.type.integer.modular_value = str(unadjusted_modular_value % - new_modulus) - lmax = left.type.integer.maximum_value - lmin = left.type.integer.minimum_value - if expression.function.function == ir_data.FunctionMapping.SUBTRACTION: - rmax = right.type.integer.minimum_value - rmin = right.type.integer.maximum_value - else: - rmax = right.type.integer.maximum_value - rmin = right.type.integer.minimum_value - expression.type.integer.minimum_value = str(func(lmin, rmin)) - expression.type.integer.maximum_value = str(func(lmax, rmax)) + """Computes the modular value of an additive expression.""" + funcs = { + ir_data.FunctionMapping.ADDITION: _add, + ir_data.FunctionMapping.SUBTRACTION: _sub, + } + func = funcs[expression.function.function] + args = expression.function.args + for arg in args: + assert arg.type.integer.modular_value, str(expression) + left, right = args + unadjusted_modular_value = func( + left.type.integer.modular_value, right.type.integer.modular_value + ) + new_modulus = _greatest_common_divisor( + left.type.integer.modulus, right.type.integer.modulus + ) + expression.type.integer.modulus = str(new_modulus) + if new_modulus == "infinity": + expression.type.integer.modular_value = str(unadjusted_modular_value) + else: + expression.type.integer.modular_value = str( + unadjusted_modular_value % new_modulus + ) + lmax = left.type.integer.maximum_value + lmin = left.type.integer.minimum_value + if expression.function.function == ir_data.FunctionMapping.SUBTRACTION: + rmax = right.type.integer.minimum_value + rmin = right.type.integer.maximum_value + else: + rmax = right.type.integer.maximum_value + rmin = right.type.integer.minimum_value + expression.type.integer.minimum_value = str(func(lmin, rmin)) + expression.type.integer.maximum_value = str(func(lmax, rmax)) def _compute_constraints_of_multiplicative_operator(expression): - """Computes the modular value of a multiplicative expression.""" - bounds = [arg.type.integer for arg in expression.function.args] - - # The minimum and maximum values can come from any of the four pairings of - # (left min, left max) with (right min, right max), depending on the signs and - # magnitudes of the minima and maxima. E.g.: - # - # max = left max * right max: [ 2, 3] * [ 2, 3] - # max = left min * right min: [-3, -2] * [-3, -2] - # max = left max * right min: [-3, -2] * [ 2, 3] - # max = left min * right max: [ 2, 3] * [-3, -2] - # max = left max * right max: [-2, 3] * [-2, 3] - # max = left min * right min: [-3, 2] * [-3, 2] - # - # For uncorrelated multiplication, the minimum and maximum will always come - # from multiplying one extreme by another: if x is nonzero, then - # - # (y + e) * x > y * x || (y - e) * x > y * x - # - # for arbitrary nonzero e, so the extrema can only occur when we either cannot - # add or cannot subtract e. - # - # Correlated multiplication (e.g., `x * x`) can have tighter bounds, but - # Emboss is not currently trying to be that smart. - lmin, lmax = bounds[0].minimum_value, bounds[0].maximum_value - rmin, rmax = bounds[1].minimum_value, bounds[1].maximum_value - extrema = [_mul(lmax, rmax), _mul(lmin, rmax), # - _mul(lmax, rmin), _mul(lmin, rmin)] - expression.type.integer.minimum_value = str(_min(extrema)) - expression.type.integer.maximum_value = str(_max(extrema)) - - if all(bound.modulus == "infinity" for bound in bounds): - # If both sides are constant, the result is constant. - expression.type.integer.modulus = "infinity" - expression.type.integer.modular_value = str(int(bounds[0].modular_value) * - int(bounds[1].modular_value)) - return - - if any(bound.modulus == "infinity" for bound in bounds): - # If one side is constant and the other is not, then the non-constant - # modulus and modular_value can both be multiplied by the constant. E.g., - # if `a` is congruent to 3 mod 5, then `4 * a` will be congruent to 12 mod - # 20: + """Computes the modular value of a multiplicative expression.""" + bounds = [arg.type.integer for arg in expression.function.args] + + # The minimum and maximum values can come from any of the four pairings of + # (left min, left max) with (right min, right max), depending on the signs and + # magnitudes of the minima and maxima. E.g.: # - # a = ... | 4 * a = ... | 4 * a mod 20 = ... - # 3 | 12 | 12 - # 8 | 32 | 12 - # 13 | 52 | 12 - # 18 | 72 | 12 - # 23 | 92 | 12 - # 28 | 112 | 12 - # 33 | 132 | 12 + # max = left max * right max: [ 2, 3] * [ 2, 3] + # max = left min * right min: [-3, -2] * [-3, -2] + # max = left max * right min: [-3, -2] * [ 2, 3] + # max = left min * right max: [ 2, 3] * [-3, -2] + # max = left max * right max: [-2, 3] * [-2, 3] + # max = left min * right min: [-3, 2] * [-3, 2] # - # This is trivially shown by noting that the difference between consecutive - # possible values for `4 * a` always differ by 20. - if bounds[0].modulus == "infinity": - constant, variable = bounds - else: - variable, constant = bounds - if int(constant.modular_value) == 0: - # If the constant is 0, the result is 0, no matter what the variable side - # is. - expression.type.integer.modulus = "infinity" - expression.type.integer.modular_value = "0" - return - new_modulus = int(variable.modulus) * abs(int(constant.modular_value)) - expression.type.integer.modulus = str(new_modulus) - # The `% new_modulus` will force the `modular_value` to be positive, even - # when `constant.modular_value` is negative. + # For uncorrelated multiplication, the minimum and maximum will always come + # from multiplying one extreme by another: if x is nonzero, then + # + # (y + e) * x > y * x || (y - e) * x > y * x + # + # for arbitrary nonzero e, so the extrema can only occur when we either cannot + # add or cannot subtract e. + # + # Correlated multiplication (e.g., `x * x`) can have tighter bounds, but + # Emboss is not currently trying to be that smart. + lmin, lmax = bounds[0].minimum_value, bounds[0].maximum_value + rmin, rmax = bounds[1].minimum_value, bounds[1].maximum_value + extrema = [ + _mul(lmax, rmax), + _mul(lmin, rmax), # + _mul(lmax, rmin), + _mul(lmin, rmin), + ] + expression.type.integer.minimum_value = str(_min(extrema)) + expression.type.integer.maximum_value = str(_max(extrema)) + + if all(bound.modulus == "infinity" for bound in bounds): + # If both sides are constant, the result is constant. + expression.type.integer.modulus = "infinity" + expression.type.integer.modular_value = str( + int(bounds[0].modular_value) * int(bounds[1].modular_value) + ) + return + + if any(bound.modulus == "infinity" for bound in bounds): + # If one side is constant and the other is not, then the non-constant + # modulus and modular_value can both be multiplied by the constant. E.g., + # if `a` is congruent to 3 mod 5, then `4 * a` will be congruent to 12 mod + # 20: + # + # a = ... | 4 * a = ... | 4 * a mod 20 = ... + # 3 | 12 | 12 + # 8 | 32 | 12 + # 13 | 52 | 12 + # 18 | 72 | 12 + # 23 | 92 | 12 + # 28 | 112 | 12 + # 33 | 132 | 12 + # + # This is trivially shown by noting that the difference between consecutive + # possible values for `4 * a` always differ by 20. + if bounds[0].modulus == "infinity": + constant, variable = bounds + else: + variable, constant = bounds + if int(constant.modular_value) == 0: + # If the constant is 0, the result is 0, no matter what the variable side + # is. + expression.type.integer.modulus = "infinity" + expression.type.integer.modular_value = "0" + return + new_modulus = int(variable.modulus) * abs(int(constant.modular_value)) + expression.type.integer.modulus = str(new_modulus) + # The `% new_modulus` will force the `modular_value` to be positive, even + # when `constant.modular_value` is negative. + expression.type.integer.modular_value = str( + int(variable.modular_value) * int(constant.modular_value) % new_modulus + ) + return + + # If neither side is constant, then the result is more complex. Full proof is + # available in g3doc/modular_congruence_multiplication_proof.md + # + # Essentially, if: + # + # l == _ * l_mod + l_mv + # r == _ * r_mod + r_mv + # + # Then we find l_mod0 and r_mod0 in: + # + # l == (_ * l_mod_nz + l_mv_nz) * l_mod0 + # r == (_ * r_mod_nz + r_mv_nz) * r_mod0 + # + # And finally conclude: + # + # l * r == _ * GCD(l_mod_nz, r_mod_nz) * l_mod0 * r_mod0 + l_mv * r_mv + product_of_zero_congruence_moduli = 1 + product_of_modular_values = 1 + nonzero_congruence_moduli = [] + for bound in bounds: + zero_congruence_modulus = _greatest_common_divisor( + bound.modulus, bound.modular_value + ) + assert int(bound.modulus) % zero_congruence_modulus == 0 + product_of_zero_congruence_moduli *= zero_congruence_modulus + product_of_modular_values *= int(bound.modular_value) + nonzero_congruence_moduli.append(int(bound.modulus) // zero_congruence_modulus) + shared_nonzero_congruence_modulus = _greatest_common_divisor( + nonzero_congruence_moduli[0], nonzero_congruence_moduli[1] + ) + final_modulus = ( + shared_nonzero_congruence_modulus * product_of_zero_congruence_moduli + ) + expression.type.integer.modulus = str(final_modulus) expression.type.integer.modular_value = str( - int(variable.modular_value) * int(constant.modular_value) % new_modulus) - return - - # If neither side is constant, then the result is more complex. Full proof is - # available in g3doc/modular_congruence_multiplication_proof.md - # - # Essentially, if: - # - # l == _ * l_mod + l_mv - # r == _ * r_mod + r_mv - # - # Then we find l_mod0 and r_mod0 in: - # - # l == (_ * l_mod_nz + l_mv_nz) * l_mod0 - # r == (_ * r_mod_nz + r_mv_nz) * r_mod0 - # - # And finally conclude: - # - # l * r == _ * GCD(l_mod_nz, r_mod_nz) * l_mod0 * r_mod0 + l_mv * r_mv - product_of_zero_congruence_moduli = 1 - product_of_modular_values = 1 - nonzero_congruence_moduli = [] - for bound in bounds: - zero_congruence_modulus = _greatest_common_divisor(bound.modulus, - bound.modular_value) - assert int(bound.modulus) % zero_congruence_modulus == 0 - product_of_zero_congruence_moduli *= zero_congruence_modulus - product_of_modular_values *= int(bound.modular_value) - nonzero_congruence_moduli.append(int(bound.modulus) // - zero_congruence_modulus) - shared_nonzero_congruence_modulus = _greatest_common_divisor( - nonzero_congruence_moduli[0], nonzero_congruence_moduli[1]) - final_modulus = (shared_nonzero_congruence_modulus * - product_of_zero_congruence_moduli) - expression.type.integer.modulus = str(final_modulus) - expression.type.integer.modular_value = str(product_of_modular_values % - final_modulus) + product_of_modular_values % final_modulus + ) def _assert_integer_constraints(expression): - """Asserts that the integer bounds of expression are self-consistent. - - Asserts that `minimum_value` and `maximum_value` are congruent to - `modular_value` modulo `modulus`. - - If `modulus` is "infinity", asserts that `minimum_value`, `maximum_value`, and - `modular_value` are all equal. - - If `minimum_value` is equal to `maximum_value`, asserts that `modular_value` - is equal to both, and that `modulus` is "infinity". - - Arguments: - expression: an expression with type.integer - - Returns: - None - """ - bounds = expression.type.integer - if bounds.modulus == "infinity": - assert bounds.minimum_value == bounds.modular_value - assert bounds.maximum_value == bounds.modular_value - return - modulus = int(bounds.modulus) - assert modulus > 0 - if bounds.minimum_value != "-infinity": - assert int(bounds.minimum_value) % modulus == int(bounds.modular_value) - if bounds.maximum_value != "infinity": - assert int(bounds.maximum_value) % modulus == int(bounds.modular_value) - if bounds.minimum_value == bounds.maximum_value: - # TODO(bolms): I believe there are situations using the not-yet-implemented - # integer division operator that would trigger these asserts, so they should - # be turned into assignments (with corresponding tests) when implementing - # division. - assert bounds.modular_value == bounds.minimum_value - assert bounds.modulus == "infinity" - if bounds.minimum_value != "-infinity" and bounds.maximum_value != "infinity": - assert int(bounds.minimum_value) <= int(bounds.maximum_value) + """Asserts that the integer bounds of expression are self-consistent. + + Asserts that `minimum_value` and `maximum_value` are congruent to + `modular_value` modulo `modulus`. + + If `modulus` is "infinity", asserts that `minimum_value`, `maximum_value`, and + `modular_value` are all equal. + + If `minimum_value` is equal to `maximum_value`, asserts that `modular_value` + is equal to both, and that `modulus` is "infinity". + + Arguments: + expression: an expression with type.integer + + Returns: + None + """ + bounds = expression.type.integer + if bounds.modulus == "infinity": + assert bounds.minimum_value == bounds.modular_value + assert bounds.maximum_value == bounds.modular_value + return + modulus = int(bounds.modulus) + assert modulus > 0 + if bounds.minimum_value != "-infinity": + assert int(bounds.minimum_value) % modulus == int(bounds.modular_value) + if bounds.maximum_value != "infinity": + assert int(bounds.maximum_value) % modulus == int(bounds.modular_value) + if bounds.minimum_value == bounds.maximum_value: + # TODO(bolms): I believe there are situations using the not-yet-implemented + # integer division operator that would trigger these asserts, so they should + # be turned into assignments (with corresponding tests) when implementing + # division. + assert bounds.modular_value == bounds.minimum_value + assert bounds.modulus == "infinity" + if bounds.minimum_value != "-infinity" and bounds.maximum_value != "infinity": + assert int(bounds.minimum_value) <= int(bounds.maximum_value) def _compute_constant_value_of_comparison_operator(expression): - """Computes the constant value, if any, of a comparison operator.""" - args = expression.function.args - if all(ir_util.is_constant(arg) for arg in args): - functions = { - ir_data.FunctionMapping.EQUALITY: operator.eq, - ir_data.FunctionMapping.INEQUALITY: operator.ne, - ir_data.FunctionMapping.LESS: operator.lt, - ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le, - ir_data.FunctionMapping.GREATER: operator.gt, - ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge, - ir_data.FunctionMapping.AND: operator.and_, - ir_data.FunctionMapping.OR: operator.or_, - } - func = functions[expression.function.function] - expression.type.boolean.value = func( - *[ir_util.constant_value(arg) for arg in args]) + """Computes the constant value, if any, of a comparison operator.""" + args = expression.function.args + if all(ir_util.is_constant(arg) for arg in args): + functions = { + ir_data.FunctionMapping.EQUALITY: operator.eq, + ir_data.FunctionMapping.INEQUALITY: operator.ne, + ir_data.FunctionMapping.LESS: operator.lt, + ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le, + ir_data.FunctionMapping.GREATER: operator.gt, + ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge, + ir_data.FunctionMapping.AND: operator.and_, + ir_data.FunctionMapping.OR: operator.or_, + } + func = functions[expression.function.function] + expression.type.boolean.value = func( + *[ir_util.constant_value(arg) for arg in args] + ) def _compute_constraints_of_bound_function(expression): - """Computes the constraints of $upper_bound or $lower_bound.""" - if expression.function.function == ir_data.FunctionMapping.UPPER_BOUND: - value = expression.function.args[0].type.integer.maximum_value - elif expression.function.function == ir_data.FunctionMapping.LOWER_BOUND: - value = expression.function.args[0].type.integer.minimum_value - else: - assert False, "Non-bound function" - expression.type.integer.minimum_value = value - expression.type.integer.maximum_value = value - expression.type.integer.modular_value = value - expression.type.integer.modulus = "infinity" + """Computes the constraints of $upper_bound or $lower_bound.""" + if expression.function.function == ir_data.FunctionMapping.UPPER_BOUND: + value = expression.function.args[0].type.integer.maximum_value + elif expression.function.function == ir_data.FunctionMapping.LOWER_BOUND: + value = expression.function.args[0].type.integer.minimum_value + else: + assert False, "Non-bound function" + expression.type.integer.minimum_value = value + expression.type.integer.maximum_value = value + expression.type.integer.modular_value = value + expression.type.integer.modulus = "infinity" def _compute_constraints_of_maximum_function(expression): - """Computes the constraints of the $max function.""" - assert expression.type.WhichOneof("type") == "integer" - args = expression.function.args - assert args[0].type.WhichOneof("type") == "integer" - # The minimum value of the result occurs when every argument takes its minimum - # value, which means that the minimum result is the maximum-of-minimums. - expression.type.integer.minimum_value = str(_max( - [arg.type.integer.minimum_value for arg in args])) - # The maximum result is the maximum-of-maximums. - expression.type.integer.maximum_value = str(_max( - [arg.type.integer.maximum_value for arg in args])) - # If the expression is dominated by a constant factor, then the result is - # constant. I (bolms@) believe this is the only case where - # _compute_constraints_of_maximum_function might violate the assertions in - # _assert_integer_constraints. - if (expression.type.integer.minimum_value == - expression.type.integer.maximum_value): - expression.type.integer.modular_value = ( - expression.type.integer.minimum_value) - expression.type.integer.modulus = "infinity" - return - result_modulus = args[0].type.integer.modulus - result_modular_value = args[0].type.integer.modular_value - # The result of $max(a, b) could be either a or b, which means that the result - # of $max(a, b) uses the _shared_modular_value() of a and b, just like the - # choice operator '?:'. - # - # This also takes advantage of the fact that $max(a, b, c, d, ...) is - # equivalent to $max(a, $max(b, $max(c, $max(d, ...)))), so it is valid to - # call _shared_modular_value() in a loop. - for arg in args[1:]: - # TODO(bolms): I think the bounds could be tigher in some cases where - # arg.maximum_value is less than the new expression.minimum_value, and - # in some very specific cases where arg.maximum_value is greater than the - # new expression.minimum_value, but arg.maximum_value - arg.modulus is less - # than expression.minimum_value. - result_modulus, result_modular_value = _shared_modular_value( - (result_modulus, result_modular_value), - (arg.type.integer.modulus, arg.type.integer.modular_value)) - expression.type.integer.modulus = str(result_modulus) - expression.type.integer.modular_value = str(result_modular_value) + """Computes the constraints of the $max function.""" + assert expression.type.WhichOneof("type") == "integer" + args = expression.function.args + assert args[0].type.WhichOneof("type") == "integer" + # The minimum value of the result occurs when every argument takes its minimum + # value, which means that the minimum result is the maximum-of-minimums. + expression.type.integer.minimum_value = str( + _max([arg.type.integer.minimum_value for arg in args]) + ) + # The maximum result is the maximum-of-maximums. + expression.type.integer.maximum_value = str( + _max([arg.type.integer.maximum_value for arg in args]) + ) + # If the expression is dominated by a constant factor, then the result is + # constant. I (bolms@) believe this is the only case where + # _compute_constraints_of_maximum_function might violate the assertions in + # _assert_integer_constraints. + if expression.type.integer.minimum_value == expression.type.integer.maximum_value: + expression.type.integer.modular_value = expression.type.integer.minimum_value + expression.type.integer.modulus = "infinity" + return + result_modulus = args[0].type.integer.modulus + result_modular_value = args[0].type.integer.modular_value + # The result of $max(a, b) could be either a or b, which means that the result + # of $max(a, b) uses the _shared_modular_value() of a and b, just like the + # choice operator '?:'. + # + # This also takes advantage of the fact that $max(a, b, c, d, ...) is + # equivalent to $max(a, $max(b, $max(c, $max(d, ...)))), so it is valid to + # call _shared_modular_value() in a loop. + for arg in args[1:]: + # TODO(bolms): I think the bounds could be tigher in some cases where + # arg.maximum_value is less than the new expression.minimum_value, and + # in some very specific cases where arg.maximum_value is greater than the + # new expression.minimum_value, but arg.maximum_value - arg.modulus is less + # than expression.minimum_value. + result_modulus, result_modular_value = _shared_modular_value( + (result_modulus, result_modular_value), + (arg.type.integer.modulus, arg.type.integer.modular_value), + ) + expression.type.integer.modulus = str(result_modulus) + expression.type.integer.modular_value = str(result_modular_value) def _shared_modular_value(left, right): - """Returns the shared modulus and modular value of left and right. - - Arguments: - left: A tuple of (modulus, modular value) - right: A tuple of (modulus, modular value) - - Returns: - A tuple of (modulus, modular_value) such that: - - left.modulus % result.modulus == 0 - right.modulus % result.modulus == 0 - left.modular_value % result.modulus = result.modular_value - right.modular_value % result.modulus = result.modular_value - - That is, the result.modulus and result.modular_value will be compatible - with, but (possibly) less restrictive than both left.(modulus, - modular_value) and right.(modulus, modular_value). - """ - left_modulus, left_modular_value = left - right_modulus, right_modular_value = right - # The combined modulus is gcd(gcd(left_modulus, right_modulus), - # left_modular_value - right_modular_value). - # - # The inner gcd normalizes the left_modulus and right_modulus, but can leave - # incompatible modular_values. The outer gcd finds a modulus to which both - # modular_values are congruent. Some examples: - # - # left | right | res - # --------------+----------------+-------------------- - # l % 12 == 7 | r % 20 == 15 | res % 4 == 3 - # l == 35 | r % 20 == 15 | res % 20 == 15 - # l % 24 == 15 | r % 12 == 7 | res % 4 == 3 - # l % 20 == 15 | r % 20 == 10 | res % 5 == 0 - # l % 20 == 16 | r % 20 == 11 | res % 5 == 1 - # l == 10 | r == 7 | res % 3 == 1 - # l == 4 | r == 4 | res == 4 - # - # The cases where one side or the other are constant are handled - # automatically by the fact that _greatest_common_divisor("infinity", x) - # is x. - common_modulus = _greatest_common_divisor(left_modulus, right_modulus) - new_modulus = _greatest_common_divisor( - common_modulus, abs(int(left_modular_value) - int(right_modular_value))) - if new_modulus == "infinity": - # The only way for the new_modulus to come out as "infinity" *should* be - # if both if_true and if_false have the same constant value. - assert left_modular_value == right_modular_value - assert left_modulus == right_modulus == "infinity" - return new_modulus, left_modular_value - else: - assert (int(left_modular_value) % new_modulus == - int(right_modular_value) % new_modulus) - return new_modulus, int(left_modular_value) % new_modulus - - -def _compute_constraints_of_choice_operator(expression): - """Computes the constraints of a choice operation '?:'.""" - condition, if_true, if_false = ir_data_utils.reader(expression).function.args - expression = ir_data_utils.builder(expression) - if condition.type.boolean.HasField("value"): - # The generated expressions for $size_in_bits and $size_in_bytes look like + """Returns the shared modulus and modular value of left and right. + + Arguments: + left: A tuple of (modulus, modular value) + right: A tuple of (modulus, modular value) + + Returns: + A tuple of (modulus, modular_value) such that: + + left.modulus % result.modulus == 0 + right.modulus % result.modulus == 0 + left.modular_value % result.modulus = result.modular_value + right.modular_value % result.modulus = result.modular_value + + That is, the result.modulus and result.modular_value will be compatible + with, but (possibly) less restrictive than both left.(modulus, + modular_value) and right.(modulus, modular_value). + """ + left_modulus, left_modular_value = left + right_modulus, right_modular_value = right + # The combined modulus is gcd(gcd(left_modulus, right_modulus), + # left_modular_value - right_modular_value). # - # $max((field1_existence_condition ? field1_start + field1_size : 0), - # (field2_existence_condition ? field2_start + field2_size : 0), - # (field3_existence_condition ? field3_start + field3_size : 0), - # ...) + # The inner gcd normalizes the left_modulus and right_modulus, but can leave + # incompatible modular_values. The outer gcd finds a modulus to which both + # modular_values are congruent. Some examples: # - # Since most existence_conditions are just "true", it is important to select - # the tighter bounds in those cases -- otherwise, only zero-length - # structures could have a constant $size_in_bits or $size_in_bytes. - side = if_true if condition.type.boolean.value else if_false - expression.type.CopyFrom(side.type) - return - # The type.integer minimum_value/maximum_value bounding code is needed since - # constraints.check_constraints() will complain if minimum and maximum are not - # set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its - # weight, but for completeness I've left it in. - if if_true.type.WhichOneof("type") == "integer": - # The minimum value of the choice is the minimum value of either side, and - # the maximum is the maximum value of either side. - expression.type.integer.minimum_value = str(_min([ - if_true.type.integer.minimum_value, - if_false.type.integer.minimum_value])) - expression.type.integer.maximum_value = str(_max([ - if_true.type.integer.maximum_value, - if_false.type.integer.maximum_value])) - new_modulus, new_modular_value = _shared_modular_value( - (if_true.type.integer.modulus, if_true.type.integer.modular_value), - (if_false.type.integer.modulus, if_false.type.integer.modular_value)) - expression.type.integer.modulus = str(new_modulus) - expression.type.integer.modular_value = str(new_modular_value) - else: - assert if_true.type.WhichOneof("type") in ("boolean", "enumeration"), ( - "Unknown type {} for expression".format( - if_true.type.WhichOneof("type"))) - - -def _greatest_common_divisor(a, b): - """Returns the greatest common divisor of a and b. - - Arguments: - a: an integer, a stringified integer, or the string "infinity" - b: an integer, a stringified integer, or the string "infinity" - - Returns: - Conceptually, "infinity" is treated as the product of all integers. + # left | right | res + # --------------+----------------+-------------------- + # l % 12 == 7 | r % 20 == 15 | res % 4 == 3 + # l == 35 | r % 20 == 15 | res % 20 == 15 + # l % 24 == 15 | r % 12 == 7 | res % 4 == 3 + # l % 20 == 15 | r % 20 == 10 | res % 5 == 0 + # l % 20 == 16 | r % 20 == 11 | res % 5 == 1 + # l == 10 | r == 7 | res % 3 == 1 + # l == 4 | r == 4 | res == 4 + # + # The cases where one side or the other are constant are handled + # automatically by the fact that _greatest_common_divisor("infinity", x) + # is x. + common_modulus = _greatest_common_divisor(left_modulus, right_modulus) + new_modulus = _greatest_common_divisor( + common_modulus, abs(int(left_modular_value) - int(right_modular_value)) + ) + if new_modulus == "infinity": + # The only way for the new_modulus to come out as "infinity" *should* be + # if both if_true and if_false have the same constant value. + assert left_modular_value == right_modular_value + assert left_modulus == right_modulus == "infinity" + return new_modulus, left_modular_value + else: + assert ( + int(left_modular_value) % new_modulus + == int(right_modular_value) % new_modulus + ) + return new_modulus, int(left_modular_value) % new_modulus - If both a and b are 0, returns "infinity". - Otherwise, if either a or b are "infinity", and the other is 0, returns - "infinity". +def _compute_constraints_of_choice_operator(expression): + """Computes the constraints of a choice operation '?:'.""" + condition, if_true, if_false = ir_data_utils.reader(expression).function.args + expression = ir_data_utils.builder(expression) + if condition.type.boolean.HasField("value"): + # The generated expressions for $size_in_bits and $size_in_bytes look like + # + # $max((field1_existence_condition ? field1_start + field1_size : 0), + # (field2_existence_condition ? field2_start + field2_size : 0), + # (field3_existence_condition ? field3_start + field3_size : 0), + # ...) + # + # Since most existence_conditions are just "true", it is important to select + # the tighter bounds in those cases -- otherwise, only zero-length + # structures could have a constant $size_in_bits or $size_in_bytes. + side = if_true if condition.type.boolean.value else if_false + expression.type.CopyFrom(side.type) + return + # The type.integer minimum_value/maximum_value bounding code is needed since + # constraints.check_constraints() will complain if minimum and maximum are not + # set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its + # weight, but for completeness I've left it in. + if if_true.type.WhichOneof("type") == "integer": + # The minimum value of the choice is the minimum value of either side, and + # the maximum is the maximum value of either side. + expression.type.integer.minimum_value = str( + _min( + [ + if_true.type.integer.minimum_value, + if_false.type.integer.minimum_value, + ] + ) + ) + expression.type.integer.maximum_value = str( + _max( + [ + if_true.type.integer.maximum_value, + if_false.type.integer.maximum_value, + ] + ) + ) + new_modulus, new_modular_value = _shared_modular_value( + (if_true.type.integer.modulus, if_true.type.integer.modular_value), + (if_false.type.integer.modulus, if_false.type.integer.modular_value), + ) + expression.type.integer.modulus = str(new_modulus) + expression.type.integer.modular_value = str(new_modular_value) + else: + assert if_true.type.WhichOneof("type") in ( + "boolean", + "enumeration", + ), "Unknown type {} for expression".format(if_true.type.WhichOneof("type")) - Otherwise, if either a or b are "infinity", returns the other. - Otherwise, returns the greatest common divisor of a and b. - """ - if a != "infinity": a = int(a) - if b != "infinity": b = int(b) - assert a == "infinity" or a >= 0 - assert b == "infinity" or b >= 0 - if a == b == 0: return "infinity" - # GCD(0, x) is always x, so it's safe to shortcut when a == 0 or b == 0. - if a == 0: return b - if b == 0: return a - if a == "infinity": return b - if b == "infinity": return a - return _math_gcd(a, b) +def _greatest_common_divisor(a, b): + """Returns the greatest common divisor of a and b. + + Arguments: + a: an integer, a stringified integer, or the string "infinity" + b: an integer, a stringified integer, or the string "infinity" + + Returns: + Conceptually, "infinity" is treated as the product of all integers. + + If both a and b are 0, returns "infinity". + + Otherwise, if either a or b are "infinity", and the other is 0, returns + "infinity". + + Otherwise, if either a or b are "infinity", returns the other. + + Otherwise, returns the greatest common divisor of a and b. + """ + if a != "infinity": + a = int(a) + if b != "infinity": + b = int(b) + assert a == "infinity" or a >= 0 + assert b == "infinity" or b >= 0 + if a == b == 0: + return "infinity" + # GCD(0, x) is always x, so it's safe to shortcut when a == 0 or b == 0. + if a == 0: + return b + if b == 0: + return a + if a == "infinity": + return b + if b == "infinity": + return a + return _math_gcd(a, b) def compute_constants(ir): - """Computes constant values for all expressions in ir. - - compute_constants calculates all constant values and adds them to the type - information for each expression and subexpression. - - Arguments: - ir: an IR on which to compute constants - - Returns: - A (possibly empty) list of errors. - """ - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Expression], compute_constraints_of_expression, - skip_descendants_of={ir_data.Expression}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.RuntimeParameter], _compute_constraints_of_parameter, - skip_descendants_of={ir_data.Expression}) - return [] + """Computes constant values for all expressions in ir. + + compute_constants calculates all constant values and adds them to the type + information for each expression and subexpression. + + Arguments: + ir: an IR on which to compute constants + + Returns: + A (possibly empty) list of errors. + """ + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Expression], + compute_constraints_of_expression, + skip_descendants_of={ir_data.Expression}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.RuntimeParameter], + _compute_constraints_of_parameter, + skip_descendants_of={ir_data.Expression}, + ) + return [] diff --git a/compiler/front_end/expression_bounds_test.py b/compiler/front_end/expression_bounds_test.py index 54fa0ce..7af6836 100644 --- a/compiler/front_end/expression_bounds_test.py +++ b/compiler/front_end/expression_bounds_test.py @@ -22,1105 +22,1204 @@ class ComputeConstantsTest(unittest.TestCase): - def _make_ir(self, emb_text): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": emb_text}), - stop_before_step="compute_constants") - assert not errors, errors - return ir - - def test_constant_integer(self): - ir = self._make_ir("struct Foo:\n" - " 10 [+1] UInt x\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - start = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual("10", start.type.integer.minimum_value) - self.assertEqual("10", start.type.integer.maximum_value) - self.assertEqual("10", start.type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - - def test_boolean_constant(self): - ir = self._make_ir("struct Foo:\n" - " if true:\n" - " 0 [+1] UInt x\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expression = ir.module[0].type[0].structure.field[0].existence_condition - self.assertTrue(expression.type.boolean.HasField("value")) - self.assertTrue(expression.type.boolean.value) - - def test_constant_equality(self): - ir = self._make_ir("struct Foo:\n" - " if 5 == 5:\n" - " 0 [+1] UInt x\n" - " if 5 == 6:\n" - " 0 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - true_condition = structure.field[0].existence_condition - false_condition = structure.field[1].existence_condition - self.assertTrue(true_condition.type.boolean.HasField("value")) - self.assertTrue(true_condition.type.boolean.value) - self.assertTrue(false_condition.type.boolean.HasField("value")) - self.assertFalse(false_condition.type.boolean.value) - - def test_constant_inequality(self): - ir = self._make_ir("struct Foo:\n" - " if 5 != 5:\n" - " 0 [+1] UInt x\n" - " if 5 != 6:\n" - " 0 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - false_condition = structure.field[0].existence_condition - true_condition = structure.field[1].existence_condition - self.assertTrue(false_condition.type.boolean.HasField("value")) - self.assertFalse(false_condition.type.boolean.value) - self.assertTrue(true_condition.type.boolean.HasField("value")) - self.assertTrue(true_condition.type.boolean.value) - - def test_constant_less_than(self): - ir = self._make_ir("struct Foo:\n" - " if 5 < 4:\n" - " 0 [+1] UInt x\n" - " if 5 < 5:\n" - " 0 [+1] UInt y\n" - " if 5 < 6:\n" - " 0 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - greater_than_condition = structure.field[0].existence_condition - equal_condition = structure.field[1].existence_condition - less_than_condition = structure.field[2].existence_condition - self.assertTrue(greater_than_condition.type.boolean.HasField("value")) - self.assertFalse(greater_than_condition.type.boolean.value) - self.assertTrue(equal_condition.type.boolean.HasField("value")) - self.assertFalse(equal_condition.type.boolean.value) - self.assertTrue(less_than_condition.type.boolean.HasField("value")) - self.assertTrue(less_than_condition.type.boolean.value) - - def test_constant_less_than_or_equal(self): - ir = self._make_ir("struct Foo:\n" - " if 5 <= 4:\n" - " 0 [+1] UInt x\n" - " if 5 <= 5:\n" - " 0 [+1] UInt y\n" - " if 5 <= 6:\n" - " 0 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - greater_than_condition = structure.field[0].existence_condition - equal_condition = structure.field[1].existence_condition - less_than_condition = structure.field[2].existence_condition - self.assertTrue(greater_than_condition.type.boolean.HasField("value")) - self.assertFalse(greater_than_condition.type.boolean.value) - self.assertTrue(equal_condition.type.boolean.HasField("value")) - self.assertTrue(equal_condition.type.boolean.value) - self.assertTrue(less_than_condition.type.boolean.HasField("value")) - self.assertTrue(less_than_condition.type.boolean.value) - - def test_constant_greater_than(self): - ir = self._make_ir("struct Foo:\n" - " if 5 > 4:\n" - " 0 [+1] UInt x\n" - " if 5 > 5:\n" - " 0 [+1] UInt y\n" - " if 5 > 6:\n" - " 0 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - greater_than_condition = structure.field[0].existence_condition - equal_condition = structure.field[1].existence_condition - less_than_condition = structure.field[2].existence_condition - self.assertTrue(greater_than_condition.type.boolean.HasField("value")) - self.assertTrue(greater_than_condition.type.boolean.value) - self.assertTrue(equal_condition.type.boolean.HasField("value")) - self.assertFalse(equal_condition.type.boolean.value) - self.assertTrue(less_than_condition.type.boolean.HasField("value")) - self.assertFalse(less_than_condition.type.boolean.value) - - def test_constant_greater_than_or_equal(self): - ir = self._make_ir("struct Foo:\n" - " if 5 >= 4:\n" - " 0 [+1] UInt x\n" - " if 5 >= 5:\n" - " 0 [+1] UInt y\n" - " if 5 >= 6:\n" - " 0 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - greater_than_condition = structure.field[0].existence_condition - equal_condition = structure.field[1].existence_condition - less_than_condition = structure.field[2].existence_condition - self.assertTrue(greater_than_condition.type.boolean.HasField("value")) - self.assertTrue(greater_than_condition.type.boolean.value) - self.assertTrue(equal_condition.type.boolean.HasField("value")) - self.assertTrue(equal_condition.type.boolean.value) - self.assertTrue(less_than_condition.type.boolean.HasField("value")) - self.assertFalse(less_than_condition.type.boolean.value) - - def test_constant_and(self): - ir = self._make_ir("struct Foo:\n" - " if false && false:\n" - " 0 [+1] UInt x\n" - " if true && false:\n" - " 0 [+1] UInt y\n" - " if false && true:\n" - " 0 [+1] UInt z\n" - " if true && true:\n" - " 0 [+1] UInt w\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - false_false_condition = structure.field[0].existence_condition - true_false_condition = structure.field[1].existence_condition - false_true_condition = structure.field[2].existence_condition - true_true_condition = structure.field[3].existence_condition - self.assertTrue(false_false_condition.type.boolean.HasField("value")) - self.assertFalse(false_false_condition.type.boolean.value) - self.assertTrue(true_false_condition.type.boolean.HasField("value")) - self.assertFalse(true_false_condition.type.boolean.value) - self.assertTrue(false_true_condition.type.boolean.HasField("value")) - self.assertFalse(false_true_condition.type.boolean.value) - self.assertTrue(true_true_condition.type.boolean.HasField("value")) - self.assertTrue(true_true_condition.type.boolean.value) - - def test_constant_or(self): - ir = self._make_ir("struct Foo:\n" - " if false || false:\n" - " 0 [+1] UInt x\n" - " if true || false:\n" - " 0 [+1] UInt y\n" - " if false || true:\n" - " 0 [+1] UInt z\n" - " if true || true:\n" - " 0 [+1] UInt w\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - false_false_condition = structure.field[0].existence_condition - true_false_condition = structure.field[1].existence_condition - false_true_condition = structure.field[2].existence_condition - true_true_condition = structure.field[3].existence_condition - self.assertTrue(false_false_condition.type.boolean.HasField("value")) - self.assertFalse(false_false_condition.type.boolean.value) - self.assertTrue(true_false_condition.type.boolean.HasField("value")) - self.assertTrue(true_false_condition.type.boolean.value) - self.assertTrue(false_true_condition.type.boolean.HasField("value")) - self.assertTrue(false_true_condition.type.boolean.value) - self.assertTrue(true_true_condition.type.boolean.HasField("value")) - self.assertTrue(true_true_condition.type.boolean.value) - - def test_enum_constant(self): - ir = self._make_ir("struct Foo:\n" - " if Bar.QUX == Bar.QUX:\n" - " 0 [+1] Bar x\n" - "enum Bar:\n" - " QUX = 12\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - condition = ir.module[0].type[0].structure.field[0].existence_condition - left = condition.function.args[0] - self.assertEqual("12", left.type.enumeration.value) - - def test_non_constant_field_reference(self): - ir = self._make_ir("struct Foo:\n" - " y [+1] UInt x\n" - " 0 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - start = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual("0", start.type.integer.minimum_value) - self.assertEqual("255", start.type.integer.maximum_value) - self.assertEqual("0", start.type.integer.modular_value) - self.assertEqual("1", start.type.integer.modulus) - - def test_field_reference_bounds_are_uncomputable(self): - # Variable-sized UInt/Int/Bcd should not cause an error here: they are - # handled in the constraints pass. - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 0 [+x] UInt y\n" - " y [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - - def test_field_references_references_bounds_are_uncomputable(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 0 [+x] UInt y\n" - " 0 [+y] UInt z\n" - " z [+1] UInt q\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - - def test_non_constant_equality(self): - ir = self._make_ir("struct Foo:\n" - " if 5 == y:\n" - " 0 [+1] UInt x\n" - " 0 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - structure = ir.module[0].type[0].structure - condition = structure.field[0].existence_condition - self.assertFalse(condition.type.boolean.HasField("value")) - - def test_constant_addition(self): - ir = self._make_ir("struct Foo:\n" - " 7+5 [+1] UInt x\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - start = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual("12", start.type.integer.minimum_value) - self.assertEqual("12", start.type.integer.maximum_value) - self.assertEqual("12", start.type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - self.assertEqual("7", start.function.args[0].type.integer.minimum_value) - self.assertEqual("7", start.function.args[0].type.integer.maximum_value) - self.assertEqual("7", start.function.args[0].type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - self.assertEqual("5", start.function.args[1].type.integer.minimum_value) - self.assertEqual("5", start.function.args[1].type.integer.maximum_value) - self.assertEqual("5", start.function.args[1].type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - - def test_constant_subtraction(self): - ir = self._make_ir("struct Foo:\n" - " 7-5 [+1] UInt x\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - start = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual("2", start.type.integer.minimum_value) - self.assertEqual("2", start.type.integer.maximum_value) - self.assertEqual("2", start.type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - self.assertEqual("7", start.function.args[0].type.integer.minimum_value) - self.assertEqual("7", start.function.args[0].type.integer.maximum_value) - self.assertEqual("7", start.function.args[0].type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - self.assertEqual("5", start.function.args[1].type.integer.minimum_value) - self.assertEqual("5", start.function.args[1].type.integer.maximum_value) - self.assertEqual("5", start.function.args[1].type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - - def test_constant_multiplication(self): - ir = self._make_ir("struct Foo:\n" - " 7*5 [+1] UInt x\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - start = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual("35", start.type.integer.minimum_value) - self.assertEqual("35", start.type.integer.maximum_value) - self.assertEqual("35", start.type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - self.assertEqual("7", start.function.args[0].type.integer.minimum_value) - self.assertEqual("7", start.function.args[0].type.integer.maximum_value) - self.assertEqual("7", start.function.args[0].type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - self.assertEqual("5", start.function.args[1].type.integer.minimum_value) - self.assertEqual("5", start.function.args[1].type.integer.maximum_value) - self.assertEqual("5", start.function.args[1].type.integer.modular_value) - self.assertEqual("infinity", start.type.integer.modulus) - - def test_nested_constant_expression(self): - ir = self._make_ir("struct Foo:\n" - " if 7*(3+1) == 28:\n" - " 0 [+1] UInt x\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - condition = ir.module[0].type[0].structure.field[0].existence_condition - self.assertTrue(condition.type.boolean.value) - condition_left = condition.function.args[0] - self.assertEqual("28", condition_left.type.integer.minimum_value) - self.assertEqual("28", condition_left.type.integer.maximum_value) - self.assertEqual("28", condition_left.type.integer.modular_value) - self.assertEqual("infinity", condition_left.type.integer.modulus) - condition_left_left = condition_left.function.args[0] - self.assertEqual("7", condition_left_left.type.integer.minimum_value) - self.assertEqual("7", condition_left_left.type.integer.maximum_value) - self.assertEqual("7", condition_left_left.type.integer.modular_value) - self.assertEqual("infinity", condition_left_left.type.integer.modulus) - condition_left_right = condition_left.function.args[1] - self.assertEqual("4", condition_left_right.type.integer.minimum_value) - self.assertEqual("4", condition_left_right.type.integer.maximum_value) - self.assertEqual("4", condition_left_right.type.integer.modular_value) - self.assertEqual("infinity", condition_left_right.type.integer.modulus) - condition_left_right_left = condition_left_right.function.args[0] - self.assertEqual("3", condition_left_right_left.type.integer.minimum_value) - self.assertEqual("3", condition_left_right_left.type.integer.maximum_value) - self.assertEqual("3", condition_left_right_left.type.integer.modular_value) - self.assertEqual("infinity", condition_left_right_left.type.integer.modulus) - condition_left_right_right = condition_left_right.function.args[1] - self.assertEqual("1", condition_left_right_right.type.integer.minimum_value) - self.assertEqual("1", condition_left_right_right.type.integer.maximum_value) - self.assertEqual("1", condition_left_right_right.type.integer.modular_value) - self.assertEqual("infinity", - condition_left_right_right.type.integer.modulus) - - def test_constant_plus_non_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 5+(4*x) [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - y_start = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual("4", y_start.type.integer.modulus) - self.assertEqual("1", y_start.type.integer.modular_value) - self.assertEqual("5", y_start.type.integer.minimum_value) - self.assertEqual("1025", y_start.type.integer.maximum_value) - - def test_constant_minus_non_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 5-(4*x) [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - y_start = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual("4", y_start.type.integer.modulus) - self.assertEqual("1", y_start.type.integer.modular_value) - self.assertEqual("-1015", y_start.type.integer.minimum_value) - self.assertEqual("5", y_start.type.integer.maximum_value) - - def test_non_constant_minus_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " (4*x)-5 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - y_start = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual(str((4 * 0) - 5), y_start.type.integer.minimum_value) - self.assertEqual(str((4 * 255) - 5), y_start.type.integer.maximum_value) - self.assertEqual("4", y_start.type.integer.modulus) - self.assertEqual("3", y_start.type.integer.modular_value) - - def test_non_constant_plus_non_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n" - " (4*x)+(6*y+3) [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("3", z_start.type.integer.minimum_value) - self.assertEqual(str(4 * 255 + 6 * 255 + 3), - z_start.type.integer.maximum_value) - self.assertEqual("2", z_start.type.integer.modulus) - self.assertEqual("1", z_start.type.integer.modular_value) - - def test_non_constant_minus_non_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n" - " (x*3)-(y*3) [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("3", z_start.type.integer.modulus) - self.assertEqual("0", z_start.type.integer.modular_value) - self.assertEqual(str(-3 * 255), z_start.type.integer.minimum_value) - self.assertEqual(str(3 * 255), z_start.type.integer.maximum_value) - - def test_non_constant_times_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " (4*x+1)*5 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - y_start = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual("20", y_start.type.integer.modulus) - self.assertEqual("5", y_start.type.integer.modular_value) - self.assertEqual("5", y_start.type.integer.minimum_value) - self.assertEqual(str((4 * 255 + 1) * 5), y_start.type.integer.maximum_value) - - def test_non_constant_times_negative_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " (4*x+1)*-5 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - y_start = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual("20", y_start.type.integer.modulus) - self.assertEqual("15", y_start.type.integer.modular_value) - self.assertEqual(str((4 * 255 + 1) * -5), - y_start.type.integer.minimum_value) - self.assertEqual("-5", y_start.type.integer.maximum_value) - - def test_non_constant_times_zero(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " (4*x+1)*0 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - y_start = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual("infinity", y_start.type.integer.modulus) - self.assertEqual("0", y_start.type.integer.modular_value) - self.assertEqual("0", y_start.type.integer.minimum_value) - self.assertEqual("0", y_start.type.integer.maximum_value) - - def test_non_constant_times_non_constant_shared_modulus(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n" - " (4*x+3)*(4*y+3) [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("4", z_start.type.integer.modulus) - self.assertEqual("1", z_start.type.integer.modular_value) - self.assertEqual("9", z_start.type.integer.minimum_value) - self.assertEqual(str((4 * 255 + 3)**2), z_start.type.integer.maximum_value) - - def test_non_constant_times_non_constant_congruent_to_zero(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n" - " (4*x)*(4*y) [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("16", z_start.type.integer.modulus) - self.assertEqual("0", z_start.type.integer.modular_value) - self.assertEqual("0", z_start.type.integer.minimum_value) - self.assertEqual(str((4 * 255)**2), z_start.type.integer.maximum_value) - - def test_non_constant_times_non_constant_partially_shared_modulus(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n" - " (4*x+3)*(8*y+3) [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("4", z_start.type.integer.modulus) - self.assertEqual("1", z_start.type.integer.modular_value) - self.assertEqual("9", z_start.type.integer.minimum_value) - self.assertEqual(str((4 * 255 + 3) * (8 * 255 + 3)), - z_start.type.integer.maximum_value) - - def test_non_constant_times_non_constant_full_complexity(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n" - " (12*x+9)*(40*y+15) [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("60", z_start.type.integer.modulus) - self.assertEqual("15", z_start.type.integer.modular_value) - self.assertEqual(str(9 * 15), z_start.type.integer.minimum_value) - self.assertEqual(str((12 * 255 + 9) * (40 * 255 + 15)), - z_start.type.integer.maximum_value) - - def test_signed_non_constant_times_signed_non_constant_full_complexity(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Int x\n" - " 1 [+1] Int y\n" - " (12*x+9)*(40*y+15) [+1] Int z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("60", z_start.type.integer.modulus) - self.assertEqual("15", z_start.type.integer.modular_value) - # Max x/min y is slightly lower than min x/max y (-7825965 vs -7780065). - self.assertEqual(str((12 * 127 + 9) * (40 * -128 + 15)), - z_start.type.integer.minimum_value) - # Max x/max y is slightly higher than min x/min y (7810635 vs 7795335). - self.assertEqual(str((12 * 127 + 9) * (40 * 127 + 15)), - z_start.type.integer.maximum_value) - - def test_non_constant_times_non_constant_flipped_min_max(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n" - " (-x*3)*(y*3) [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("9", z_start.type.integer.modulus) - self.assertEqual("0", z_start.type.integer.modular_value) - self.assertEqual(str(-((3 * 255)**2)), z_start.type.integer.minimum_value) - self.assertEqual("0", z_start.type.integer.maximum_value) - - # Currently, only `$static_size_in_bits` has an infinite bound, so all of the - # examples below use `$static_size_in_bits`. Unfortunately, this also means - # that these tests rely on the fact that Emboss doesn't try to do any term - # rewriting or smart correlation between the arguments of various operators: - # for example, several tests rely on `$static_size_in_bits - - # $static_size_in_bits` having the range `-infinity` to `infinity`, when a - # trivial term rewrite would turn that expression into `0`. - # - # Unbounded expressions are only allowed at compile-time anyway, so these - # tests cover some fairly unlikely uses of the Emboss expression language. - def test_unbounded_plus_constant(self): - ir = self._make_ir("external Foo:\n" - " [requires: $static_size_in_bits + 2 > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("1", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("2", expr.type.integer.minimum_value) - self.assertEqual("infinity", expr.type.integer.maximum_value) - - def test_negative_unbounded_plus_constant(self): - ir = self._make_ir("external Foo:\n" - " [requires: -$static_size_in_bits + 2 > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("1", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("2", expr.type.integer.maximum_value) - - def test_negative_unbounded_plus_unbounded(self): - ir = self._make_ir( - "external Foo:\n" - " [requires: -$static_size_in_bits + $static_size_in_bits > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("1", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("infinity", expr.type.integer.maximum_value) - - def test_unbounded_minus_unbounded(self): - ir = self._make_ir( - "external Foo:\n" - " [requires: $static_size_in_bits - $static_size_in_bits > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("1", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("infinity", expr.type.integer.maximum_value) - - def test_unbounded_minus_negative_unbounded(self): - ir = self._make_ir( - "external Foo:\n" - " [requires: $static_size_in_bits - -$static_size_in_bits > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("1", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("0", expr.type.integer.minimum_value) - self.assertEqual("infinity", expr.type.integer.maximum_value) - - def test_unbounded_times_constant(self): - ir = self._make_ir("external Foo:\n" - " [requires: ($static_size_in_bits + 1) * 2 > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("2", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("2", expr.type.integer.minimum_value) - self.assertEqual("infinity", expr.type.integer.maximum_value) - - def test_unbounded_times_negative_constant(self): - ir = self._make_ir("external Foo:\n" - " [requires: ($static_size_in_bits + 1) * -2 > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("2", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("-2", expr.type.integer.maximum_value) - - def test_unbounded_times_negative_zero(self): - ir = self._make_ir("external Foo:\n" - " [requires: ($static_size_in_bits + 1) * 0 > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("infinity", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("0", expr.type.integer.minimum_value) - self.assertEqual("0", expr.type.integer.maximum_value) - - def test_negative_unbounded_times_constant(self): - ir = self._make_ir("external Foo:\n" - " [requires: (-$static_size_in_bits + 1) * 2 > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("2", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("2", expr.type.integer.maximum_value) - - def test_double_unbounded_minus_unbounded(self): - ir = self._make_ir( - "external Foo:\n" - " [requires: 2 * $static_size_in_bits - $static_size_in_bits > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("1", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("infinity", expr.type.integer.maximum_value) - - def test_double_unbounded_times_negative_unbounded(self): - ir = self._make_ir( - "external Foo:\n" - " [requires: 2 * $static_size_in_bits * -$static_size_in_bits > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("2", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("0", expr.type.integer.maximum_value) - - def test_upper_bound_of_field(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Int x\n" - " let u = $upper_bound(x)\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - u_type = ir.module[0].type[0].structure.field[1].read_transform.type - self.assertEqual("infinity", u_type.integer.modulus) - self.assertEqual("127", u_type.integer.maximum_value) - self.assertEqual("127", u_type.integer.minimum_value) - self.assertEqual("127", u_type.integer.modular_value) - - def test_lower_bound_of_field(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Int x\n" - " let l = $lower_bound(x)\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - l_type = ir.module[0].type[0].structure.field[1].read_transform.type - self.assertEqual("infinity", l_type.integer.modulus) - self.assertEqual("-128", l_type.integer.maximum_value) - self.assertEqual("-128", l_type.integer.minimum_value) - self.assertEqual("-128", l_type.integer.modular_value) - - def test_upper_bound_of_max(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Int x\n" - " 1 [+1] UInt y\n" - " let u = $upper_bound($max(x, y))\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - u_type = ir.module[0].type[0].structure.field[2].read_transform.type - self.assertEqual("infinity", u_type.integer.modulus) - self.assertEqual("255", u_type.integer.maximum_value) - self.assertEqual("255", u_type.integer.minimum_value) - self.assertEqual("255", u_type.integer.modular_value) - - def test_lower_bound_of_max(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Int x\n" - " 1 [+1] UInt y\n" - " let l = $lower_bound($max(x, y))\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - l_type = ir.module[0].type[0].structure.field[2].read_transform.type - self.assertEqual("infinity", l_type.integer.modulus) - self.assertEqual("0", l_type.integer.maximum_value) - self.assertEqual("0", l_type.integer.minimum_value) - self.assertEqual("0", l_type.integer.modular_value) - - def test_double_unbounded_both_ends_times_negative_unbounded(self): - ir = self._make_ir( - "external Foo:\n" - " [requires: (2 * ($static_size_in_bits - $static_size_in_bits) + 1) " - " * -$static_size_in_bits > 0]\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] - self.assertEqual("1", expr.type.integer.modulus) - self.assertEqual("0", expr.type.integer.modular_value) - self.assertEqual("-infinity", expr.type.integer.minimum_value) - self.assertEqual("infinity", expr.type.integer.maximum_value) - - def test_choice_two_non_constant_integers(self): - cases = [ - # t % 12 == 7 and f % 20 == 15 ==> r % 4 == 3 - (12, 7, 20, 15, 4, 3, -128 * 20 + 15, 127 * 20 + 15), - # t % 24 == 15 and f % 12 == 7 ==> r % 4 == 3 - (24, 15, 12, 7, 4, 3, -128 * 24 + 15, 127 * 24 + 15), - # t % 20 == 15 and f % 20 == 10 ==> r % 5 == 0 - (20, 15, 20, 10, 5, 0, -128 * 20 + 10, 127 * 20 + 15), - # t % 20 == 16 and f % 20 == 11 ==> r % 5 == 1 - (20, 16, 20, 11, 5, 1, -128 * 20 + 11, 127 * 20 + 16), - ] - for (t_mod, t_val, f_mod, f_val, r_mod, r_val, r_min, r_max) in cases: - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " if (x == 0 ? y * {} + {} : y * {} + {}) == 0:\n" - " 1 [+1] UInt z\n".format( - t_mod, t_val, f_mod, f_val)) - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[2] - expr = field.existence_condition.function.args[0] - self.assertEqual(str(r_mod), expr.type.integer.modulus) - self.assertEqual(str(r_val), expr.type.integer.modular_value) - self.assertEqual(str(r_min), expr.type.integer.minimum_value) - self.assertEqual(str(r_max), expr.type.integer.maximum_value) - - def test_choice_one_non_constant_integer(self): - cases = [ - # t == 35 and f % 20 == 15 ==> res % 20 == 15 - (35, 20, 15, 20, 15, -128 * 20 + 15, 127 * 20 + 15), - # t == 200035 and f % 20 == 15 ==> res % 20 == 15 - (200035, 20, 15, 20, 15, -128 * 20 + 15, 200035), - # t == 21 and f % 20 == 16 ==> res % 5 == 1 - (21, 20, 16, 5, 1, -128 * 20 + 16, 127 * 20 + 16), - ] - for (t_val, f_mod, f_val, r_mod, r_val, r_min, r_max) in cases: - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " if (x == 0 ? {0} : y * {1} + {2}) == 0:\n" - " 1 [+1] UInt z\n" - " if (x == 0 ? y * {1} + {2} : {0}) == 0:\n" - " 1 [+1] UInt q\n".format(t_val, f_mod, f_val)) - self.assertEqual([], expression_bounds.compute_constants(ir)) - field_constant_true = ir.module[0].type[0].structure.field[2] - constant_true = field_constant_true.existence_condition.function.args[0] - field_constant_false = ir.module[0].type[0].structure.field[3] - constant_false = field_constant_false.existence_condition.function.args[0] - self.assertEqual(str(r_mod), constant_true.type.integer.modulus) - self.assertEqual(str(r_val), constant_true.type.integer.modular_value) - self.assertEqual(str(r_min), constant_true.type.integer.minimum_value) - self.assertEqual(str(r_max), constant_true.type.integer.maximum_value) - self.assertEqual(str(r_mod), constant_false.type.integer.modulus) - self.assertEqual(str(r_val), constant_false.type.integer.modular_value) - self.assertEqual(str(r_min), constant_false.type.integer.minimum_value) - self.assertEqual(str(r_max), constant_false.type.integer.maximum_value) - - def test_choice_two_constant_integers(self): - cases = [ - # t == 10 and f == 7 ==> res % 3 == 1 - (10, 7, 3, 1, 7, 10), - # t == 4 and f == 4 ==> res == 4 - (4, 4, "infinity", 4, 4, 4), - ] - for (t_val, f_val, r_mod, r_val, r_min, r_max) in cases: - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " if (x == 0 ? {} : {}) == 0:\n" - " 1 [+1] UInt z\n".format(t_val, f_val)) - self.assertEqual([], expression_bounds.compute_constants(ir)) - field_constant_true = ir.module[0].type[0].structure.field[2] - constant_true = field_constant_true.existence_condition.function.args[0] - self.assertEqual(str(r_mod), constant_true.type.integer.modulus) - self.assertEqual(str(r_val), constant_true.type.integer.modular_value) - self.assertEqual(str(r_min), constant_true.type.integer.minimum_value) - self.assertEqual(str(r_max), constant_true.type.integer.maximum_value) - - def test_constant_true_has(self): - ir = self._make_ir("struct Foo:\n" - " if $present(x):\n" - " 1 [+1] UInt q\n" - " 0 [+1] UInt x\n" - " if x > 10:\n" - " 1 [+1] Int y\n" - " if false:\n" - " 2 [+1] Int z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[0] - has_func = field.existence_condition - self.assertTrue(has_func.type.boolean.value) - - def test_constant_false_has(self): - ir = self._make_ir("struct Foo:\n" - " if $present(z):\n" - " 1 [+1] UInt q\n" - " 0 [+1] UInt x\n" - " if x > 10:\n" - " 1 [+1] Int y\n" - " if false:\n" - " 2 [+1] Int z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[0] - has_func = field.existence_condition - self.assertTrue(has_func.type.boolean.HasField("value")) - self.assertFalse(has_func.type.boolean.value) - - def test_variable_has(self): - ir = self._make_ir("struct Foo:\n" - " if $present(y):\n" - " 1 [+1] UInt q\n" - " 0 [+1] UInt x\n" - " if x > 10:\n" - " 1 [+1] Int y\n" - " if false:\n" - " 2 [+1] Int z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[0] - has_func = field.existence_condition - self.assertFalse(has_func.type.boolean.HasField("value")) - - def test_max_of_constants(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " if $max(0, 1, 2) == 0:\n" - " 1 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[2] - max_func = field.existence_condition.function.args[0] - self.assertEqual("infinity", max_func.type.integer.modulus) - self.assertEqual("2", max_func.type.integer.modular_value) - self.assertEqual("2", max_func.type.integer.minimum_value) - self.assertEqual("2", max_func.type.integer.maximum_value) - - def test_max_dominated_by_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " if $max(x, y, 255) == 0:\n" - " 1 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[2] - max_func = field.existence_condition.function.args[0] - self.assertEqual("infinity", max_func.type.integer.modulus) - self.assertEqual("255", max_func.type.integer.modular_value) - self.assertEqual("255", max_func.type.integer.minimum_value) - self.assertEqual("255", max_func.type.integer.maximum_value) - - def test_max_of_variables(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " if $max(x, y) == 0:\n" - " 1 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[2] - max_func = field.existence_condition.function.args[0] - self.assertEqual("1", max_func.type.integer.modulus) - self.assertEqual("0", max_func.type.integer.modular_value) - self.assertEqual("0", max_func.type.integer.minimum_value) - self.assertEqual("255", max_func.type.integer.maximum_value) - - def test_max_of_variables_with_shared_modulus(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " if $max(x * 8 + 5, y * 4 + 3) == 0:\n" - " 1 [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[2] - max_func = field.existence_condition.function.args[0] - self.assertEqual("2", max_func.type.integer.modulus) - self.assertEqual("1", max_func.type.integer.modular_value) - self.assertEqual("5", max_func.type.integer.minimum_value) - self.assertEqual("2045", max_func.type.integer.maximum_value) - - def test_max_of_three_variables(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " 2 [+2] Int z\n" - " if $max(x, y, z) == 0:\n" - " 1 [+1] UInt q\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[3] - max_func = field.existence_condition.function.args[0] - self.assertEqual("1", max_func.type.integer.modulus) - self.assertEqual("0", max_func.type.integer.modular_value) - self.assertEqual("0", max_func.type.integer.minimum_value) - self.assertEqual("32767", max_func.type.integer.maximum_value) - - def test_max_of_one_variable(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " 2 [+2] Int z\n" - " if $max(x * 2 + 3) == 0:\n" - " 1 [+1] UInt q\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[3] - max_func = field.existence_condition.function.args[0] - self.assertEqual("2", max_func.type.integer.modulus) - self.assertEqual("1", max_func.type.integer.modular_value) - self.assertEqual("3", max_func.type.integer.minimum_value) - self.assertEqual("513", max_func.type.integer.maximum_value) - - def test_max_of_one_variable_and_one_constant(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] Int y\n" - " 2 [+2] Int z\n" - " if $max(x * 2 + 3, 311) == 0:\n" - " 1 [+1] UInt q\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field = ir.module[0].type[0].structure.field[3] - max_func = field.existence_condition.function.args[0] - self.assertEqual("2", max_func.type.integer.modulus) - self.assertEqual("1", max_func.type.integer.modular_value) - self.assertEqual("311", max_func.type.integer.minimum_value) - self.assertEqual("513", max_func.type.integer.maximum_value) - - def test_choice_non_integer_arguments(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " if x == 0 ? false : true:\n" - " 1 [+1] UInt y\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - expr = ir.module[0].type[0].structure.field[1].existence_condition - self.assertEqual("boolean", expr.type.WhichOneof("type")) - self.assertFalse(expr.type.boolean.HasField("value")) - - def test_uint_value_range_for_explicit_size(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+x] UInt:16 y\n" - " y [+1] UInt z\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("1", z_start.type.integer.modulus) - self.assertEqual("0", z_start.type.integer.modular_value) - self.assertEqual("0", z_start.type.integer.minimum_value) - self.assertEqual("65535", z_start.type.integer.maximum_value) - - def test_uint_value_ranges(self): - cases = [ - (1, 1), - (2, 3), - (3, 7), - (4, 15), - (8, 255), - (12, 4095), - (15, 32767), - (16, 65535), - (32, 4294967295), - (48, 281474976710655), - (64, 18446744073709551615), - ] - for bits, upper in cases: - ir = self._make_ir("struct Foo:\n" - " 0 [+8] bits:\n" - " 0 [+{}] UInt x\n" - " x [+1] UInt z\n".format(bits)) - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("1", z_start.type.integer.modulus) - self.assertEqual("0", z_start.type.integer.modular_value) - self.assertEqual("0", z_start.type.integer.minimum_value) - self.assertEqual(str(upper), z_start.type.integer.maximum_value) - - def test_int_value_ranges(self): - cases = [ - (1, -1, 0), - (2, -2, 1), - (3, -4, 3), - (4, -8, 7), - (8, -128, 127), - (12, -2048, 2047), - (15, -16384, 16383), - (16, -32768, 32767), - (32, -2147483648, 2147483647), - (48, -140737488355328, 140737488355327), - (64, -9223372036854775808, 9223372036854775807), - ] - for bits, lower, upper in cases: - ir = self._make_ir("struct Foo:\n" - " 0 [+8] bits:\n" - " 0 [+{}] Int x\n" - " x [+1] UInt z\n".format(bits)) - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("1", z_start.type.integer.modulus) - self.assertEqual("0", z_start.type.integer.modular_value) - self.assertEqual(str(lower), z_start.type.integer.minimum_value) - self.assertEqual(str(upper), z_start.type.integer.maximum_value) - - def test_bcd_value_ranges(self): - cases = [ - (1, 1), - (2, 3), - (3, 7), - (4, 9), - (8, 99), - (12, 999), - (15, 7999), - (16, 9999), - (32, 99999999), - (48, 999999999999), - (64, 9999999999999999), - ] - for bits, upper in cases: - ir = self._make_ir("struct Foo:\n" - " 0 [+8] bits:\n" - " 0 [+{}] Bcd x\n" - " x [+1] UInt z\n".format(bits)) - self.assertEqual([], expression_bounds.compute_constants(ir)) - z_start = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual("1", z_start.type.integer.modulus) - self.assertEqual("0", z_start.type.integer.modular_value) - self.assertEqual("0", z_start.type.integer.minimum_value) - self.assertEqual(str(upper), z_start.type.integer.maximum_value) - - def test_virtual_field_bounds(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = x + 10\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field_y = ir.module[0].type[0].structure.field[1] - self.assertEqual("1", field_y.read_transform.type.integer.modulus) - self.assertEqual("0", field_y.read_transform.type.integer.modular_value) - self.assertEqual("10", field_y.read_transform.type.integer.minimum_value) - self.assertEqual("265", field_y.read_transform.type.integer.maximum_value) - - def test_virtual_field_bounds_copied(self): - ir = self._make_ir("struct Foo:\n" - " let z = y + 100\n" - " let y = x + 10\n" - " 0 [+1] UInt x\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field_z = ir.module[0].type[0].structure.field[0] - self.assertEqual("1", field_z.read_transform.type.integer.modulus) - self.assertEqual("0", field_z.read_transform.type.integer.modular_value) - self.assertEqual("110", field_z.read_transform.type.integer.minimum_value) - self.assertEqual("365", field_z.read_transform.type.integer.maximum_value) - y_reference = field_z.read_transform.function.args[0] - self.assertEqual("1", y_reference.type.integer.modulus) - self.assertEqual("0", y_reference.type.integer.modular_value) - self.assertEqual("10", y_reference.type.integer.minimum_value) - self.assertEqual("265", y_reference.type.integer.maximum_value) - - def test_constant_reference_to_virtual_bounds_copied(self): - ir = self._make_ir("struct Foo:\n" - " let ten = Bar.ten\n" - " let truth = Bar.truth\n" - "struct Bar:\n" - " let ten = 10\n" - " let truth = true\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field_ten = ir.module[0].type[0].structure.field[0] - self.assertEqual("infinity", field_ten.read_transform.type.integer.modulus) - self.assertEqual("10", field_ten.read_transform.type.integer.modular_value) - self.assertEqual("10", field_ten.read_transform.type.integer.minimum_value) - self.assertEqual("10", field_ten.read_transform.type.integer.maximum_value) - field_truth = ir.module[0].type[0].structure.field[1] - self.assertTrue(field_truth.read_transform.type.boolean.value) - - def test_forward_reference_to_reference_to_enum_correctly_calculated(self): - ir = self._make_ir("struct Foo:\n" - " let ten = Bar.TEN\n" - "enum Bar:\n" - " TEN = TEN2\n" - " TEN2 = 5 + 5\n") - self.assertEqual([], expression_bounds.compute_constants(ir)) - field_ten = ir.module[0].type[0].structure.field[0] - self.assertEqual("10", field_ten.read_transform.type.enumeration.value) + def _make_ir(self, emb_text): + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader({"m.emb": emb_text}), + stop_before_step="compute_constants", + ) + assert not errors, errors + return ir + + def test_constant_integer(self): + ir = self._make_ir("struct Foo:\n" " 10 [+1] UInt x\n") + self.assertEqual([], expression_bounds.compute_constants(ir)) + start = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual("10", start.type.integer.minimum_value) + self.assertEqual("10", start.type.integer.maximum_value) + self.assertEqual("10", start.type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + + def test_boolean_constant(self): + ir = self._make_ir("struct Foo:\n" " if true:\n" " 0 [+1] UInt x\n") + self.assertEqual([], expression_bounds.compute_constants(ir)) + expression = ir.module[0].type[0].structure.field[0].existence_condition + self.assertTrue(expression.type.boolean.HasField("value")) + self.assertTrue(expression.type.boolean.value) + + def test_constant_equality(self): + ir = self._make_ir( + "struct Foo:\n" + " if 5 == 5:\n" + " 0 [+1] UInt x\n" + " if 5 == 6:\n" + " 0 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + true_condition = structure.field[0].existence_condition + false_condition = structure.field[1].existence_condition + self.assertTrue(true_condition.type.boolean.HasField("value")) + self.assertTrue(true_condition.type.boolean.value) + self.assertTrue(false_condition.type.boolean.HasField("value")) + self.assertFalse(false_condition.type.boolean.value) + + def test_constant_inequality(self): + ir = self._make_ir( + "struct Foo:\n" + " if 5 != 5:\n" + " 0 [+1] UInt x\n" + " if 5 != 6:\n" + " 0 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + false_condition = structure.field[0].existence_condition + true_condition = structure.field[1].existence_condition + self.assertTrue(false_condition.type.boolean.HasField("value")) + self.assertFalse(false_condition.type.boolean.value) + self.assertTrue(true_condition.type.boolean.HasField("value")) + self.assertTrue(true_condition.type.boolean.value) + + def test_constant_less_than(self): + ir = self._make_ir( + "struct Foo:\n" + " if 5 < 4:\n" + " 0 [+1] UInt x\n" + " if 5 < 5:\n" + " 0 [+1] UInt y\n" + " if 5 < 6:\n" + " 0 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + greater_than_condition = structure.field[0].existence_condition + equal_condition = structure.field[1].existence_condition + less_than_condition = structure.field[2].existence_condition + self.assertTrue(greater_than_condition.type.boolean.HasField("value")) + self.assertFalse(greater_than_condition.type.boolean.value) + self.assertTrue(equal_condition.type.boolean.HasField("value")) + self.assertFalse(equal_condition.type.boolean.value) + self.assertTrue(less_than_condition.type.boolean.HasField("value")) + self.assertTrue(less_than_condition.type.boolean.value) + + def test_constant_less_than_or_equal(self): + ir = self._make_ir( + "struct Foo:\n" + " if 5 <= 4:\n" + " 0 [+1] UInt x\n" + " if 5 <= 5:\n" + " 0 [+1] UInt y\n" + " if 5 <= 6:\n" + " 0 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + greater_than_condition = structure.field[0].existence_condition + equal_condition = structure.field[1].existence_condition + less_than_condition = structure.field[2].existence_condition + self.assertTrue(greater_than_condition.type.boolean.HasField("value")) + self.assertFalse(greater_than_condition.type.boolean.value) + self.assertTrue(equal_condition.type.boolean.HasField("value")) + self.assertTrue(equal_condition.type.boolean.value) + self.assertTrue(less_than_condition.type.boolean.HasField("value")) + self.assertTrue(less_than_condition.type.boolean.value) + + def test_constant_greater_than(self): + ir = self._make_ir( + "struct Foo:\n" + " if 5 > 4:\n" + " 0 [+1] UInt x\n" + " if 5 > 5:\n" + " 0 [+1] UInt y\n" + " if 5 > 6:\n" + " 0 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + greater_than_condition = structure.field[0].existence_condition + equal_condition = structure.field[1].existence_condition + less_than_condition = structure.field[2].existence_condition + self.assertTrue(greater_than_condition.type.boolean.HasField("value")) + self.assertTrue(greater_than_condition.type.boolean.value) + self.assertTrue(equal_condition.type.boolean.HasField("value")) + self.assertFalse(equal_condition.type.boolean.value) + self.assertTrue(less_than_condition.type.boolean.HasField("value")) + self.assertFalse(less_than_condition.type.boolean.value) + + def test_constant_greater_than_or_equal(self): + ir = self._make_ir( + "struct Foo:\n" + " if 5 >= 4:\n" + " 0 [+1] UInt x\n" + " if 5 >= 5:\n" + " 0 [+1] UInt y\n" + " if 5 >= 6:\n" + " 0 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + greater_than_condition = structure.field[0].existence_condition + equal_condition = structure.field[1].existence_condition + less_than_condition = structure.field[2].existence_condition + self.assertTrue(greater_than_condition.type.boolean.HasField("value")) + self.assertTrue(greater_than_condition.type.boolean.value) + self.assertTrue(equal_condition.type.boolean.HasField("value")) + self.assertTrue(equal_condition.type.boolean.value) + self.assertTrue(less_than_condition.type.boolean.HasField("value")) + self.assertFalse(less_than_condition.type.boolean.value) + + def test_constant_and(self): + ir = self._make_ir( + "struct Foo:\n" + " if false && false:\n" + " 0 [+1] UInt x\n" + " if true && false:\n" + " 0 [+1] UInt y\n" + " if false && true:\n" + " 0 [+1] UInt z\n" + " if true && true:\n" + " 0 [+1] UInt w\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + false_false_condition = structure.field[0].existence_condition + true_false_condition = structure.field[1].existence_condition + false_true_condition = structure.field[2].existence_condition + true_true_condition = structure.field[3].existence_condition + self.assertTrue(false_false_condition.type.boolean.HasField("value")) + self.assertFalse(false_false_condition.type.boolean.value) + self.assertTrue(true_false_condition.type.boolean.HasField("value")) + self.assertFalse(true_false_condition.type.boolean.value) + self.assertTrue(false_true_condition.type.boolean.HasField("value")) + self.assertFalse(false_true_condition.type.boolean.value) + self.assertTrue(true_true_condition.type.boolean.HasField("value")) + self.assertTrue(true_true_condition.type.boolean.value) + + def test_constant_or(self): + ir = self._make_ir( + "struct Foo:\n" + " if false || false:\n" + " 0 [+1] UInt x\n" + " if true || false:\n" + " 0 [+1] UInt y\n" + " if false || true:\n" + " 0 [+1] UInt z\n" + " if true || true:\n" + " 0 [+1] UInt w\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + false_false_condition = structure.field[0].existence_condition + true_false_condition = structure.field[1].existence_condition + false_true_condition = structure.field[2].existence_condition + true_true_condition = structure.field[3].existence_condition + self.assertTrue(false_false_condition.type.boolean.HasField("value")) + self.assertFalse(false_false_condition.type.boolean.value) + self.assertTrue(true_false_condition.type.boolean.HasField("value")) + self.assertTrue(true_false_condition.type.boolean.value) + self.assertTrue(false_true_condition.type.boolean.HasField("value")) + self.assertTrue(false_true_condition.type.boolean.value) + self.assertTrue(true_true_condition.type.boolean.HasField("value")) + self.assertTrue(true_true_condition.type.boolean.value) + + def test_enum_constant(self): + ir = self._make_ir( + "struct Foo:\n" + " if Bar.QUX == Bar.QUX:\n" + " 0 [+1] Bar x\n" + "enum Bar:\n" + " QUX = 12\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + condition = ir.module[0].type[0].structure.field[0].existence_condition + left = condition.function.args[0] + self.assertEqual("12", left.type.enumeration.value) + + def test_non_constant_field_reference(self): + ir = self._make_ir("struct Foo:\n" " y [+1] UInt x\n" " 0 [+1] UInt y\n") + self.assertEqual([], expression_bounds.compute_constants(ir)) + start = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual("0", start.type.integer.minimum_value) + self.assertEqual("255", start.type.integer.maximum_value) + self.assertEqual("0", start.type.integer.modular_value) + self.assertEqual("1", start.type.integer.modulus) + + def test_field_reference_bounds_are_uncomputable(self): + # Variable-sized UInt/Int/Bcd should not cause an error here: they are + # handled in the constraints pass. + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 0 [+x] UInt y\n" + " y [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + + def test_field_references_references_bounds_are_uncomputable(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 0 [+x] UInt y\n" + " 0 [+y] UInt z\n" + " z [+1] UInt q\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + + def test_non_constant_equality(self): + ir = self._make_ir( + "struct Foo:\n" + " if 5 == y:\n" + " 0 [+1] UInt x\n" + " 0 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + structure = ir.module[0].type[0].structure + condition = structure.field[0].existence_condition + self.assertFalse(condition.type.boolean.HasField("value")) + + def test_constant_addition(self): + ir = self._make_ir("struct Foo:\n" " 7+5 [+1] UInt x\n") + self.assertEqual([], expression_bounds.compute_constants(ir)) + start = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual("12", start.type.integer.minimum_value) + self.assertEqual("12", start.type.integer.maximum_value) + self.assertEqual("12", start.type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + self.assertEqual("7", start.function.args[0].type.integer.minimum_value) + self.assertEqual("7", start.function.args[0].type.integer.maximum_value) + self.assertEqual("7", start.function.args[0].type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + self.assertEqual("5", start.function.args[1].type.integer.minimum_value) + self.assertEqual("5", start.function.args[1].type.integer.maximum_value) + self.assertEqual("5", start.function.args[1].type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + + def test_constant_subtraction(self): + ir = self._make_ir("struct Foo:\n" " 7-5 [+1] UInt x\n") + self.assertEqual([], expression_bounds.compute_constants(ir)) + start = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual("2", start.type.integer.minimum_value) + self.assertEqual("2", start.type.integer.maximum_value) + self.assertEqual("2", start.type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + self.assertEqual("7", start.function.args[0].type.integer.minimum_value) + self.assertEqual("7", start.function.args[0].type.integer.maximum_value) + self.assertEqual("7", start.function.args[0].type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + self.assertEqual("5", start.function.args[1].type.integer.minimum_value) + self.assertEqual("5", start.function.args[1].type.integer.maximum_value) + self.assertEqual("5", start.function.args[1].type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + + def test_constant_multiplication(self): + ir = self._make_ir("struct Foo:\n" " 7*5 [+1] UInt x\n") + self.assertEqual([], expression_bounds.compute_constants(ir)) + start = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual("35", start.type.integer.minimum_value) + self.assertEqual("35", start.type.integer.maximum_value) + self.assertEqual("35", start.type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + self.assertEqual("7", start.function.args[0].type.integer.minimum_value) + self.assertEqual("7", start.function.args[0].type.integer.maximum_value) + self.assertEqual("7", start.function.args[0].type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + self.assertEqual("5", start.function.args[1].type.integer.minimum_value) + self.assertEqual("5", start.function.args[1].type.integer.maximum_value) + self.assertEqual("5", start.function.args[1].type.integer.modular_value) + self.assertEqual("infinity", start.type.integer.modulus) + + def test_nested_constant_expression(self): + ir = self._make_ir( + "struct Foo:\n" " if 7*(3+1) == 28:\n" " 0 [+1] UInt x\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + condition = ir.module[0].type[0].structure.field[0].existence_condition + self.assertTrue(condition.type.boolean.value) + condition_left = condition.function.args[0] + self.assertEqual("28", condition_left.type.integer.minimum_value) + self.assertEqual("28", condition_left.type.integer.maximum_value) + self.assertEqual("28", condition_left.type.integer.modular_value) + self.assertEqual("infinity", condition_left.type.integer.modulus) + condition_left_left = condition_left.function.args[0] + self.assertEqual("7", condition_left_left.type.integer.minimum_value) + self.assertEqual("7", condition_left_left.type.integer.maximum_value) + self.assertEqual("7", condition_left_left.type.integer.modular_value) + self.assertEqual("infinity", condition_left_left.type.integer.modulus) + condition_left_right = condition_left.function.args[1] + self.assertEqual("4", condition_left_right.type.integer.minimum_value) + self.assertEqual("4", condition_left_right.type.integer.maximum_value) + self.assertEqual("4", condition_left_right.type.integer.modular_value) + self.assertEqual("infinity", condition_left_right.type.integer.modulus) + condition_left_right_left = condition_left_right.function.args[0] + self.assertEqual("3", condition_left_right_left.type.integer.minimum_value) + self.assertEqual("3", condition_left_right_left.type.integer.maximum_value) + self.assertEqual("3", condition_left_right_left.type.integer.modular_value) + self.assertEqual("infinity", condition_left_right_left.type.integer.modulus) + condition_left_right_right = condition_left_right.function.args[1] + self.assertEqual("1", condition_left_right_right.type.integer.minimum_value) + self.assertEqual("1", condition_left_right_right.type.integer.maximum_value) + self.assertEqual("1", condition_left_right_right.type.integer.modular_value) + self.assertEqual("infinity", condition_left_right_right.type.integer.modulus) + + def test_constant_plus_non_constant(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " 5+(4*x) [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + y_start = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual("4", y_start.type.integer.modulus) + self.assertEqual("1", y_start.type.integer.modular_value) + self.assertEqual("5", y_start.type.integer.minimum_value) + self.assertEqual("1025", y_start.type.integer.maximum_value) + + def test_constant_minus_non_constant(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " 5-(4*x) [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + y_start = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual("4", y_start.type.integer.modulus) + self.assertEqual("1", y_start.type.integer.modular_value) + self.assertEqual("-1015", y_start.type.integer.minimum_value) + self.assertEqual("5", y_start.type.integer.maximum_value) + + def test_non_constant_minus_constant(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " (4*x)-5 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + y_start = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual(str((4 * 0) - 5), y_start.type.integer.minimum_value) + self.assertEqual(str((4 * 255) - 5), y_start.type.integer.maximum_value) + self.assertEqual("4", y_start.type.integer.modulus) + self.assertEqual("3", y_start.type.integer.modular_value) + + def test_non_constant_plus_non_constant(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + " (4*x)+(6*y+3) [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("3", z_start.type.integer.minimum_value) + self.assertEqual(str(4 * 255 + 6 * 255 + 3), z_start.type.integer.maximum_value) + self.assertEqual("2", z_start.type.integer.modulus) + self.assertEqual("1", z_start.type.integer.modular_value) + + def test_non_constant_minus_non_constant(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + " (x*3)-(y*3) [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("3", z_start.type.integer.modulus) + self.assertEqual("0", z_start.type.integer.modular_value) + self.assertEqual(str(-3 * 255), z_start.type.integer.minimum_value) + self.assertEqual(str(3 * 255), z_start.type.integer.maximum_value) + + def test_non_constant_times_constant(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " (4*x+1)*5 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + y_start = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual("20", y_start.type.integer.modulus) + self.assertEqual("5", y_start.type.integer.modular_value) + self.assertEqual("5", y_start.type.integer.minimum_value) + self.assertEqual(str((4 * 255 + 1) * 5), y_start.type.integer.maximum_value) + + def test_non_constant_times_negative_constant(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " (4*x+1)*-5 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + y_start = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual("20", y_start.type.integer.modulus) + self.assertEqual("15", y_start.type.integer.modular_value) + self.assertEqual(str((4 * 255 + 1) * -5), y_start.type.integer.minimum_value) + self.assertEqual("-5", y_start.type.integer.maximum_value) + + def test_non_constant_times_zero(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " (4*x+1)*0 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + y_start = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual("infinity", y_start.type.integer.modulus) + self.assertEqual("0", y_start.type.integer.modular_value) + self.assertEqual("0", y_start.type.integer.minimum_value) + self.assertEqual("0", y_start.type.integer.maximum_value) + + def test_non_constant_times_non_constant_shared_modulus(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + " (4*x+3)*(4*y+3) [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("4", z_start.type.integer.modulus) + self.assertEqual("1", z_start.type.integer.modular_value) + self.assertEqual("9", z_start.type.integer.minimum_value) + self.assertEqual(str((4 * 255 + 3) ** 2), z_start.type.integer.maximum_value) + + def test_non_constant_times_non_constant_congruent_to_zero(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + " (4*x)*(4*y) [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("16", z_start.type.integer.modulus) + self.assertEqual("0", z_start.type.integer.modular_value) + self.assertEqual("0", z_start.type.integer.minimum_value) + self.assertEqual(str((4 * 255) ** 2), z_start.type.integer.maximum_value) + + def test_non_constant_times_non_constant_partially_shared_modulus(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + " (4*x+3)*(8*y+3) [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("4", z_start.type.integer.modulus) + self.assertEqual("1", z_start.type.integer.modular_value) + self.assertEqual("9", z_start.type.integer.minimum_value) + self.assertEqual( + str((4 * 255 + 3) * (8 * 255 + 3)), z_start.type.integer.maximum_value + ) + + def test_non_constant_times_non_constant_full_complexity(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + " (12*x+9)*(40*y+15) [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("60", z_start.type.integer.modulus) + self.assertEqual("15", z_start.type.integer.modular_value) + self.assertEqual(str(9 * 15), z_start.type.integer.minimum_value) + self.assertEqual( + str((12 * 255 + 9) * (40 * 255 + 15)), z_start.type.integer.maximum_value + ) + + def test_signed_non_constant_times_signed_non_constant_full_complexity(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Int x\n" + " 1 [+1] Int y\n" + " (12*x+9)*(40*y+15) [+1] Int z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("60", z_start.type.integer.modulus) + self.assertEqual("15", z_start.type.integer.modular_value) + # Max x/min y is slightly lower than min x/max y (-7825965 vs -7780065). + self.assertEqual( + str((12 * 127 + 9) * (40 * -128 + 15)), z_start.type.integer.minimum_value + ) + # Max x/max y is slightly higher than min x/min y (7810635 vs 7795335). + self.assertEqual( + str((12 * 127 + 9) * (40 * 127 + 15)), z_start.type.integer.maximum_value + ) + + def test_non_constant_times_non_constant_flipped_min_max(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + " (-x*3)*(y*3) [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("9", z_start.type.integer.modulus) + self.assertEqual("0", z_start.type.integer.modular_value) + self.assertEqual(str(-((3 * 255) ** 2)), z_start.type.integer.minimum_value) + self.assertEqual("0", z_start.type.integer.maximum_value) + + # Currently, only `$static_size_in_bits` has an infinite bound, so all of the + # examples below use `$static_size_in_bits`. Unfortunately, this also means + # that these tests rely on the fact that Emboss doesn't try to do any term + # rewriting or smart correlation between the arguments of various operators: + # for example, several tests rely on `$static_size_in_bits - + # $static_size_in_bits` having the range `-infinity` to `infinity`, when a + # trivial term rewrite would turn that expression into `0`. + # + # Unbounded expressions are only allowed at compile-time anyway, so these + # tests cover some fairly unlikely uses of the Emboss expression language. + def test_unbounded_plus_constant(self): + ir = self._make_ir( + "external Foo:\n" " [requires: $static_size_in_bits + 2 > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("1", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("2", expr.type.integer.minimum_value) + self.assertEqual("infinity", expr.type.integer.maximum_value) + + def test_negative_unbounded_plus_constant(self): + ir = self._make_ir( + "external Foo:\n" " [requires: -$static_size_in_bits + 2 > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("1", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("2", expr.type.integer.maximum_value) + + def test_negative_unbounded_plus_unbounded(self): + ir = self._make_ir( + "external Foo:\n" + " [requires: -$static_size_in_bits + $static_size_in_bits > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("1", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("infinity", expr.type.integer.maximum_value) + + def test_unbounded_minus_unbounded(self): + ir = self._make_ir( + "external Foo:\n" + " [requires: $static_size_in_bits - $static_size_in_bits > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("1", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("infinity", expr.type.integer.maximum_value) + + def test_unbounded_minus_negative_unbounded(self): + ir = self._make_ir( + "external Foo:\n" + " [requires: $static_size_in_bits - -$static_size_in_bits > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("1", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("0", expr.type.integer.minimum_value) + self.assertEqual("infinity", expr.type.integer.maximum_value) + + def test_unbounded_times_constant(self): + ir = self._make_ir( + "external Foo:\n" " [requires: ($static_size_in_bits + 1) * 2 > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("2", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("2", expr.type.integer.minimum_value) + self.assertEqual("infinity", expr.type.integer.maximum_value) + + def test_unbounded_times_negative_constant(self): + ir = self._make_ir( + "external Foo:\n" " [requires: ($static_size_in_bits + 1) * -2 > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("2", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("-2", expr.type.integer.maximum_value) + + def test_unbounded_times_negative_zero(self): + ir = self._make_ir( + "external Foo:\n" " [requires: ($static_size_in_bits + 1) * 0 > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("infinity", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("0", expr.type.integer.minimum_value) + self.assertEqual("0", expr.type.integer.maximum_value) + + def test_negative_unbounded_times_constant(self): + ir = self._make_ir( + "external Foo:\n" " [requires: (-$static_size_in_bits + 1) * 2 > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("2", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("2", expr.type.integer.maximum_value) + + def test_double_unbounded_minus_unbounded(self): + ir = self._make_ir( + "external Foo:\n" + " [requires: 2 * $static_size_in_bits - $static_size_in_bits > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("1", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("infinity", expr.type.integer.maximum_value) + + def test_double_unbounded_times_negative_unbounded(self): + ir = self._make_ir( + "external Foo:\n" + " [requires: 2 * $static_size_in_bits * -$static_size_in_bits > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("2", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("0", expr.type.integer.maximum_value) + + def test_upper_bound_of_field(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] Int x\n" " let u = $upper_bound(x)\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + u_type = ir.module[0].type[0].structure.field[1].read_transform.type + self.assertEqual("infinity", u_type.integer.modulus) + self.assertEqual("127", u_type.integer.maximum_value) + self.assertEqual("127", u_type.integer.minimum_value) + self.assertEqual("127", u_type.integer.modular_value) + + def test_lower_bound_of_field(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] Int x\n" " let l = $lower_bound(x)\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + l_type = ir.module[0].type[0].structure.field[1].read_transform.type + self.assertEqual("infinity", l_type.integer.modulus) + self.assertEqual("-128", l_type.integer.maximum_value) + self.assertEqual("-128", l_type.integer.minimum_value) + self.assertEqual("-128", l_type.integer.modular_value) + + def test_upper_bound_of_max(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Int x\n" + " 1 [+1] UInt y\n" + " let u = $upper_bound($max(x, y))\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + u_type = ir.module[0].type[0].structure.field[2].read_transform.type + self.assertEqual("infinity", u_type.integer.modulus) + self.assertEqual("255", u_type.integer.maximum_value) + self.assertEqual("255", u_type.integer.minimum_value) + self.assertEqual("255", u_type.integer.modular_value) + + def test_lower_bound_of_max(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Int x\n" + " 1 [+1] UInt y\n" + " let l = $lower_bound($max(x, y))\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + l_type = ir.module[0].type[0].structure.field[2].read_transform.type + self.assertEqual("infinity", l_type.integer.modulus) + self.assertEqual("0", l_type.integer.maximum_value) + self.assertEqual("0", l_type.integer.minimum_value) + self.assertEqual("0", l_type.integer.modular_value) + + def test_double_unbounded_both_ends_times_negative_unbounded(self): + ir = self._make_ir( + "external Foo:\n" + " [requires: (2 * ($static_size_in_bits - $static_size_in_bits) + 1) " + " * -$static_size_in_bits > 0]\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0] + self.assertEqual("1", expr.type.integer.modulus) + self.assertEqual("0", expr.type.integer.modular_value) + self.assertEqual("-infinity", expr.type.integer.minimum_value) + self.assertEqual("infinity", expr.type.integer.maximum_value) + + def test_choice_two_non_constant_integers(self): + cases = [ + # t % 12 == 7 and f % 20 == 15 ==> r % 4 == 3 + (12, 7, 20, 15, 4, 3, -128 * 20 + 15, 127 * 20 + 15), + # t % 24 == 15 and f % 12 == 7 ==> r % 4 == 3 + (24, 15, 12, 7, 4, 3, -128 * 24 + 15, 127 * 24 + 15), + # t % 20 == 15 and f % 20 == 10 ==> r % 5 == 0 + (20, 15, 20, 10, 5, 0, -128 * 20 + 10, 127 * 20 + 15), + # t % 20 == 16 and f % 20 == 11 ==> r % 5 == 1 + (20, 16, 20, 11, 5, 1, -128 * 20 + 11, 127 * 20 + 16), + ] + for t_mod, t_val, f_mod, f_val, r_mod, r_val, r_min, r_max in cases: + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " if (x == 0 ? y * {} + {} : y * {} + {}) == 0:\n" + " 1 [+1] UInt z\n".format(t_mod, t_val, f_mod, f_val) + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[2] + expr = field.existence_condition.function.args[0] + self.assertEqual(str(r_mod), expr.type.integer.modulus) + self.assertEqual(str(r_val), expr.type.integer.modular_value) + self.assertEqual(str(r_min), expr.type.integer.minimum_value) + self.assertEqual(str(r_max), expr.type.integer.maximum_value) + + def test_choice_one_non_constant_integer(self): + cases = [ + # t == 35 and f % 20 == 15 ==> res % 20 == 15 + (35, 20, 15, 20, 15, -128 * 20 + 15, 127 * 20 + 15), + # t == 200035 and f % 20 == 15 ==> res % 20 == 15 + (200035, 20, 15, 20, 15, -128 * 20 + 15, 200035), + # t == 21 and f % 20 == 16 ==> res % 5 == 1 + (21, 20, 16, 5, 1, -128 * 20 + 16, 127 * 20 + 16), + ] + for t_val, f_mod, f_val, r_mod, r_val, r_min, r_max in cases: + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " if (x == 0 ? {0} : y * {1} + {2}) == 0:\n" + " 1 [+1] UInt z\n" + " if (x == 0 ? y * {1} + {2} : {0}) == 0:\n" + " 1 [+1] UInt q\n".format(t_val, f_mod, f_val) + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field_constant_true = ir.module[0].type[0].structure.field[2] + constant_true = field_constant_true.existence_condition.function.args[0] + field_constant_false = ir.module[0].type[0].structure.field[3] + constant_false = field_constant_false.existence_condition.function.args[0] + self.assertEqual(str(r_mod), constant_true.type.integer.modulus) + self.assertEqual(str(r_val), constant_true.type.integer.modular_value) + self.assertEqual(str(r_min), constant_true.type.integer.minimum_value) + self.assertEqual(str(r_max), constant_true.type.integer.maximum_value) + self.assertEqual(str(r_mod), constant_false.type.integer.modulus) + self.assertEqual(str(r_val), constant_false.type.integer.modular_value) + self.assertEqual(str(r_min), constant_false.type.integer.minimum_value) + self.assertEqual(str(r_max), constant_false.type.integer.maximum_value) + + def test_choice_two_constant_integers(self): + cases = [ + # t == 10 and f == 7 ==> res % 3 == 1 + (10, 7, 3, 1, 7, 10), + # t == 4 and f == 4 ==> res == 4 + (4, 4, "infinity", 4, 4, 4), + ] + for t_val, f_val, r_mod, r_val, r_min, r_max in cases: + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " if (x == 0 ? {} : {}) == 0:\n" + " 1 [+1] UInt z\n".format(t_val, f_val) + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field_constant_true = ir.module[0].type[0].structure.field[2] + constant_true = field_constant_true.existence_condition.function.args[0] + self.assertEqual(str(r_mod), constant_true.type.integer.modulus) + self.assertEqual(str(r_val), constant_true.type.integer.modular_value) + self.assertEqual(str(r_min), constant_true.type.integer.minimum_value) + self.assertEqual(str(r_max), constant_true.type.integer.maximum_value) + + def test_constant_true_has(self): + ir = self._make_ir( + "struct Foo:\n" + " if $present(x):\n" + " 1 [+1] UInt q\n" + " 0 [+1] UInt x\n" + " if x > 10:\n" + " 1 [+1] Int y\n" + " if false:\n" + " 2 [+1] Int z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[0] + has_func = field.existence_condition + self.assertTrue(has_func.type.boolean.value) + + def test_constant_false_has(self): + ir = self._make_ir( + "struct Foo:\n" + " if $present(z):\n" + " 1 [+1] UInt q\n" + " 0 [+1] UInt x\n" + " if x > 10:\n" + " 1 [+1] Int y\n" + " if false:\n" + " 2 [+1] Int z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[0] + has_func = field.existence_condition + self.assertTrue(has_func.type.boolean.HasField("value")) + self.assertFalse(has_func.type.boolean.value) + + def test_variable_has(self): + ir = self._make_ir( + "struct Foo:\n" + " if $present(y):\n" + " 1 [+1] UInt q\n" + " 0 [+1] UInt x\n" + " if x > 10:\n" + " 1 [+1] Int y\n" + " if false:\n" + " 2 [+1] Int z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[0] + has_func = field.existence_condition + self.assertFalse(has_func.type.boolean.HasField("value")) + + def test_max_of_constants(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " if $max(0, 1, 2) == 0:\n" + " 1 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[2] + max_func = field.existence_condition.function.args[0] + self.assertEqual("infinity", max_func.type.integer.modulus) + self.assertEqual("2", max_func.type.integer.modular_value) + self.assertEqual("2", max_func.type.integer.minimum_value) + self.assertEqual("2", max_func.type.integer.maximum_value) + + def test_max_dominated_by_constant(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " if $max(x, y, 255) == 0:\n" + " 1 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[2] + max_func = field.existence_condition.function.args[0] + self.assertEqual("infinity", max_func.type.integer.modulus) + self.assertEqual("255", max_func.type.integer.modular_value) + self.assertEqual("255", max_func.type.integer.minimum_value) + self.assertEqual("255", max_func.type.integer.maximum_value) + + def test_max_of_variables(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " if $max(x, y) == 0:\n" + " 1 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[2] + max_func = field.existence_condition.function.args[0] + self.assertEqual("1", max_func.type.integer.modulus) + self.assertEqual("0", max_func.type.integer.modular_value) + self.assertEqual("0", max_func.type.integer.minimum_value) + self.assertEqual("255", max_func.type.integer.maximum_value) + + def test_max_of_variables_with_shared_modulus(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " if $max(x * 8 + 5, y * 4 + 3) == 0:\n" + " 1 [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[2] + max_func = field.existence_condition.function.args[0] + self.assertEqual("2", max_func.type.integer.modulus) + self.assertEqual("1", max_func.type.integer.modular_value) + self.assertEqual("5", max_func.type.integer.minimum_value) + self.assertEqual("2045", max_func.type.integer.maximum_value) + + def test_max_of_three_variables(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " 2 [+2] Int z\n" + " if $max(x, y, z) == 0:\n" + " 1 [+1] UInt q\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[3] + max_func = field.existence_condition.function.args[0] + self.assertEqual("1", max_func.type.integer.modulus) + self.assertEqual("0", max_func.type.integer.modular_value) + self.assertEqual("0", max_func.type.integer.minimum_value) + self.assertEqual("32767", max_func.type.integer.maximum_value) + + def test_max_of_one_variable(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " 2 [+2] Int z\n" + " if $max(x * 2 + 3) == 0:\n" + " 1 [+1] UInt q\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[3] + max_func = field.existence_condition.function.args[0] + self.assertEqual("2", max_func.type.integer.modulus) + self.assertEqual("1", max_func.type.integer.modular_value) + self.assertEqual("3", max_func.type.integer.minimum_value) + self.assertEqual("513", max_func.type.integer.maximum_value) + + def test_max_of_one_variable_and_one_constant(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1] Int y\n" + " 2 [+2] Int z\n" + " if $max(x * 2 + 3, 311) == 0:\n" + " 1 [+1] UInt q\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field = ir.module[0].type[0].structure.field[3] + max_func = field.existence_condition.function.args[0] + self.assertEqual("2", max_func.type.integer.modulus) + self.assertEqual("1", max_func.type.integer.modular_value) + self.assertEqual("311", max_func.type.integer.minimum_value) + self.assertEqual("513", max_func.type.integer.maximum_value) + + def test_choice_non_integer_arguments(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " if x == 0 ? false : true:\n" + " 1 [+1] UInt y\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + expr = ir.module[0].type[0].structure.field[1].existence_condition + self.assertEqual("boolean", expr.type.WhichOneof("type")) + self.assertFalse(expr.type.boolean.HasField("value")) + + def test_uint_value_range_for_explicit_size(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+x] UInt:16 y\n" + " y [+1] UInt z\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("1", z_start.type.integer.modulus) + self.assertEqual("0", z_start.type.integer.modular_value) + self.assertEqual("0", z_start.type.integer.minimum_value) + self.assertEqual("65535", z_start.type.integer.maximum_value) + + def test_uint_value_ranges(self): + cases = [ + (1, 1), + (2, 3), + (3, 7), + (4, 15), + (8, 255), + (12, 4095), + (15, 32767), + (16, 65535), + (32, 4294967295), + (48, 281474976710655), + (64, 18446744073709551615), + ] + for bits, upper in cases: + ir = self._make_ir( + "struct Foo:\n" + " 0 [+8] bits:\n" + " 0 [+{}] UInt x\n" + " x [+1] UInt z\n".format(bits) + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("1", z_start.type.integer.modulus) + self.assertEqual("0", z_start.type.integer.modular_value) + self.assertEqual("0", z_start.type.integer.minimum_value) + self.assertEqual(str(upper), z_start.type.integer.maximum_value) + + def test_int_value_ranges(self): + cases = [ + (1, -1, 0), + (2, -2, 1), + (3, -4, 3), + (4, -8, 7), + (8, -128, 127), + (12, -2048, 2047), + (15, -16384, 16383), + (16, -32768, 32767), + (32, -2147483648, 2147483647), + (48, -140737488355328, 140737488355327), + (64, -9223372036854775808, 9223372036854775807), + ] + for bits, lower, upper in cases: + ir = self._make_ir( + "struct Foo:\n" + " 0 [+8] bits:\n" + " 0 [+{}] Int x\n" + " x [+1] UInt z\n".format(bits) + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("1", z_start.type.integer.modulus) + self.assertEqual("0", z_start.type.integer.modular_value) + self.assertEqual(str(lower), z_start.type.integer.minimum_value) + self.assertEqual(str(upper), z_start.type.integer.maximum_value) + + def test_bcd_value_ranges(self): + cases = [ + (1, 1), + (2, 3), + (3, 7), + (4, 9), + (8, 99), + (12, 999), + (15, 7999), + (16, 9999), + (32, 99999999), + (48, 999999999999), + (64, 9999999999999999), + ] + for bits, upper in cases: + ir = self._make_ir( + "struct Foo:\n" + " 0 [+8] bits:\n" + " 0 [+{}] Bcd x\n" + " x [+1] UInt z\n".format(bits) + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + z_start = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual("1", z_start.type.integer.modulus) + self.assertEqual("0", z_start.type.integer.modular_value) + self.assertEqual("0", z_start.type.integer.minimum_value) + self.assertEqual(str(upper), z_start.type.integer.maximum_value) + + def test_virtual_field_bounds(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " let y = x + 10\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field_y = ir.module[0].type[0].structure.field[1] + self.assertEqual("1", field_y.read_transform.type.integer.modulus) + self.assertEqual("0", field_y.read_transform.type.integer.modular_value) + self.assertEqual("10", field_y.read_transform.type.integer.minimum_value) + self.assertEqual("265", field_y.read_transform.type.integer.maximum_value) + + def test_virtual_field_bounds_copied(self): + ir = self._make_ir( + "struct Foo:\n" + " let z = y + 100\n" + " let y = x + 10\n" + " 0 [+1] UInt x\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field_z = ir.module[0].type[0].structure.field[0] + self.assertEqual("1", field_z.read_transform.type.integer.modulus) + self.assertEqual("0", field_z.read_transform.type.integer.modular_value) + self.assertEqual("110", field_z.read_transform.type.integer.minimum_value) + self.assertEqual("365", field_z.read_transform.type.integer.maximum_value) + y_reference = field_z.read_transform.function.args[0] + self.assertEqual("1", y_reference.type.integer.modulus) + self.assertEqual("0", y_reference.type.integer.modular_value) + self.assertEqual("10", y_reference.type.integer.minimum_value) + self.assertEqual("265", y_reference.type.integer.maximum_value) + + def test_constant_reference_to_virtual_bounds_copied(self): + ir = self._make_ir( + "struct Foo:\n" + " let ten = Bar.ten\n" + " let truth = Bar.truth\n" + "struct Bar:\n" + " let ten = 10\n" + " let truth = true\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field_ten = ir.module[0].type[0].structure.field[0] + self.assertEqual("infinity", field_ten.read_transform.type.integer.modulus) + self.assertEqual("10", field_ten.read_transform.type.integer.modular_value) + self.assertEqual("10", field_ten.read_transform.type.integer.minimum_value) + self.assertEqual("10", field_ten.read_transform.type.integer.maximum_value) + field_truth = ir.module[0].type[0].structure.field[1] + self.assertTrue(field_truth.read_transform.type.boolean.value) + + def test_forward_reference_to_reference_to_enum_correctly_calculated(self): + ir = self._make_ir( + "struct Foo:\n" + " let ten = Bar.TEN\n" + "enum Bar:\n" + " TEN = TEN2\n" + " TEN2 = 5 + 5\n" + ) + self.assertEqual([], expression_bounds.compute_constants(ir)) + field_ten = ir.module[0].type[0].structure.field[0] + self.assertEqual("10", field_ten.read_transform.type.enumeration.value) class InfinityAugmentedArithmeticTest(unittest.TestCase): - # TODO(bolms): Will there ever be any situations where all elements of the arg - # to _min would be "infinity"? - def test_min_of_infinities(self): - self.assertEqual("infinity", - expression_bounds._min(["infinity", "infinity"])) - - # TODO(bolms): Will there ever be any situations where all elements of the arg - # to _max would be "-infinity"? - def test_max_of_negative_infinities(self): - self.assertEqual("-infinity", - expression_bounds._max(["-infinity", "-infinity"])) - - def test_shared_modular_value_of_identical_modulus_and_value(self): - self.assertEqual((10, 8), - expression_bounds._shared_modular_value((10, 8), (10, 8))) - - def test_shared_modular_value_of_identical_modulus(self): - self.assertEqual((5, 3), - expression_bounds._shared_modular_value((10, 8), (10, 3))) - - def test_shared_modular_value_of_identical_value(self): - self.assertEqual((6, 2), - expression_bounds._shared_modular_value((18, 2), (12, 2))) - - def test_shared_modular_value_of_different_arguments(self): - self.assertEqual((7, 4), - expression_bounds._shared_modular_value((21, 11), (14, 4))) - - def test_shared_modular_value_of_infinity_and_non(self): - self.assertEqual((7, 4), - expression_bounds._shared_modular_value(("infinity", 25), - (14, 4))) - - def test_shared_modular_value_of_infinity_and_infinity(self): - self.assertEqual((14, 5), - expression_bounds._shared_modular_value(("infinity", 19), - ("infinity", 5))) - - def test_shared_modular_value_of_infinity_and_identical_value(self): - self.assertEqual(("infinity", 5), - expression_bounds._shared_modular_value(("infinity", 5), - ("infinity", 5))) + # TODO(bolms): Will there ever be any situations where all elements of the arg + # to _min would be "infinity"? + def test_min_of_infinities(self): + self.assertEqual("infinity", expression_bounds._min(["infinity", "infinity"])) + + # TODO(bolms): Will there ever be any situations where all elements of the arg + # to _max would be "-infinity"? + def test_max_of_negative_infinities(self): + self.assertEqual( + "-infinity", expression_bounds._max(["-infinity", "-infinity"]) + ) + + def test_shared_modular_value_of_identical_modulus_and_value(self): + self.assertEqual( + (10, 8), expression_bounds._shared_modular_value((10, 8), (10, 8)) + ) + + def test_shared_modular_value_of_identical_modulus(self): + self.assertEqual( + (5, 3), expression_bounds._shared_modular_value((10, 8), (10, 3)) + ) + + def test_shared_modular_value_of_identical_value(self): + self.assertEqual( + (6, 2), expression_bounds._shared_modular_value((18, 2), (12, 2)) + ) + + def test_shared_modular_value_of_different_arguments(self): + self.assertEqual( + (7, 4), expression_bounds._shared_modular_value((21, 11), (14, 4)) + ) + + def test_shared_modular_value_of_infinity_and_non(self): + self.assertEqual( + (7, 4), expression_bounds._shared_modular_value(("infinity", 25), (14, 4)) + ) + + def test_shared_modular_value_of_infinity_and_infinity(self): + self.assertEqual( + (14, 5), + expression_bounds._shared_modular_value(("infinity", 19), ("infinity", 5)), + ) + + def test_shared_modular_value_of_infinity_and_identical_value(self): + self.assertEqual( + ("infinity", 5), + expression_bounds._shared_modular_value(("infinity", 5), ("infinity", 5)), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/format.py b/compiler/front_end/format.py index 1e2389b..1b7de31 100644 --- a/compiler/front_end/format.py +++ b/compiler/front_end/format.py @@ -31,99 +31,112 @@ def _parse_command_line(argv): - """Parses the given command-line arguments.""" - argparser = argparse.ArgumentParser(description='Emboss compiler front end.', - prog=argv[0]) - argparser.add_argument('input_file', - type=str, - nargs='+', - help='.emb file to compile.') - argparser.add_argument('--no-check-result', - default=True, - action='store_false', - dest='check_result', - help='Verify that the resulting formatted text ' - 'contains only whitespace changes.') - argparser.add_argument('--debug-show-line-types', - default=False, - help='Show the computed type of each line.') - argparser.add_argument('--no-edit-in-place', - default=True, - action='store_false', - dest='edit_in_place', - help='Write the formatted text back to the input ' - 'file.') - argparser.add_argument('--indent', - type=int, - default=2, - help='Number of spaces to use for each level of ' - 'indentation.') - argparser.add_argument('--color-output', - default='if-tty', - choices=['always', 'never', 'if-tty', 'auto'], - help="Print error messages using color. 'auto' is a " - "synonym for 'if-tty'.") - return argparser.parse_args(argv[1:]) + """Parses the given command-line arguments.""" + argparser = argparse.ArgumentParser( + description="Emboss compiler front end.", prog=argv[0] + ) + argparser.add_argument( + "input_file", type=str, nargs="+", help=".emb file to compile." + ) + argparser.add_argument( + "--no-check-result", + default=True, + action="store_false", + dest="check_result", + help="Verify that the resulting formatted text " + "contains only whitespace changes.", + ) + argparser.add_argument( + "--debug-show-line-types", + default=False, + help="Show the computed type of each line.", + ) + argparser.add_argument( + "--no-edit-in-place", + default=True, + action="store_false", + dest="edit_in_place", + help="Write the formatted text back to the input " "file.", + ) + argparser.add_argument( + "--indent", + type=int, + default=2, + help="Number of spaces to use for each level of " "indentation.", + ) + argparser.add_argument( + "--color-output", + default="if-tty", + choices=["always", "never", "if-tty", "auto"], + help="Print error messages using color. 'auto' is a " "synonym for 'if-tty'.", + ) + return argparser.parse_args(argv[1:]) def _print_errors(errors, source_codes, flags): - use_color = (flags.color_output == 'always' or - (flags.color_output in ('auto', 'if-tty') and - os.isatty(sys.stderr.fileno()))) - print(error.format_errors(errors, source_codes, use_color), file=sys.stderr) + use_color = flags.color_output == "always" or ( + flags.color_output in ("auto", "if-tty") and os.isatty(sys.stderr.fileno()) + ) + print(error.format_errors(errors, source_codes, use_color), file=sys.stderr) def main(argv=()): - flags = _parse_command_line(argv) - - if not flags.edit_in_place and len(flags.input_file) > 1: - print('Multiple files may only be formatted without --no-edit-in-place.', - file=sys.stderr) - return 1 - - if flags.edit_in_place and flags.debug_show_line_types: - print('The flag --debug-show-line-types requires --no-edit-in-place.', - file=sys.stderr) - return 1 - - for file_name in flags.input_file: - with open(file_name) as f: - source_code = f.read() - - tokens, errors = tokenizer.tokenize(source_code, file_name) - if errors: - _print_errors(errors, {file_name: source_code}, flags) - continue - - parse_result = parser.parse_module(tokens) - if parse_result.error: - _print_errors( - [error.make_error_from_parse_error(file_name, parse_result.error)], - {file_name: source_code}, - flags) - continue - - formatted_text = format_emb.format_emboss_parse_tree( - parse_result.parse_tree, - format_emb.Config(show_line_types=flags.debug_show_line_types, - indent_width=flags.indent)) - - if flags.check_result and not flags.debug_show_line_types: - errors = format_emb.sanity_check_format_result(formatted_text, - source_code) - if errors: - for e in errors: - print(e, file=sys.stderr) - continue - - if flags.edit_in_place: - with open(file_name, 'w') as f: - f.write(formatted_text) - else: - sys.stdout.write(formatted_text) - - return 0 - - -if __name__ == '__main__': - sys.exit(main(sys.argv)) + flags = _parse_command_line(argv) + + if not flags.edit_in_place and len(flags.input_file) > 1: + print( + "Multiple files may only be formatted without --no-edit-in-place.", + file=sys.stderr, + ) + return 1 + + if flags.edit_in_place and flags.debug_show_line_types: + print( + "The flag --debug-show-line-types requires --no-edit-in-place.", + file=sys.stderr, + ) + return 1 + + for file_name in flags.input_file: + with open(file_name) as f: + source_code = f.read() + + tokens, errors = tokenizer.tokenize(source_code, file_name) + if errors: + _print_errors(errors, {file_name: source_code}, flags) + continue + + parse_result = parser.parse_module(tokens) + if parse_result.error: + _print_errors( + [error.make_error_from_parse_error(file_name, parse_result.error)], + {file_name: source_code}, + flags, + ) + continue + + formatted_text = format_emb.format_emboss_parse_tree( + parse_result.parse_tree, + format_emb.Config( + show_line_types=flags.debug_show_line_types, indent_width=flags.indent + ), + ) + + if flags.check_result and not flags.debug_show_line_types: + errors = format_emb.sanity_check_format_result(formatted_text, source_code) + if errors: + for e in errors: + print(e, file=sys.stderr) + continue + + if flags.edit_in_place: + with open(file_name, "w") as f: + f.write(formatted_text) + else: + sys.stdout.write(formatted_text) + + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/compiler/front_end/format_emb.py b/compiler/front_end/format_emb.py index e4b5c23..df9bcf3 100644 --- a/compiler/front_end/format_emb.py +++ b/compiler/front_end/format_emb.py @@ -28,27 +28,26 @@ from compiler.util import parser_types -class Config(collections.namedtuple('Config', - ['indent_width', 'show_line_types'])): - """Configuration for formatting.""" +class Config(collections.namedtuple("Config", ["indent_width", "show_line_types"])): + """Configuration for formatting.""" - def __new__(cls, indent_width=2, show_line_types=False): - return super(cls, Config).__new__(cls, indent_width, show_line_types) + def __new__(cls, indent_width=2, show_line_types=False): + return super(cls, Config).__new__(cls, indent_width, show_line_types) -class _Row(collections.namedtuple('Row', ['name', 'columns', 'indent'])): - """Structured contents of a single line.""" +class _Row(collections.namedtuple("Row", ["name", "columns", "indent"])): + """Structured contents of a single line.""" - def __new__(cls, name, columns=None, indent=0): - return super(cls, _Row).__new__(cls, name, tuple(columns or []), indent) + def __new__(cls, name, columns=None, indent=0): + return super(cls, _Row).__new__(cls, name, tuple(columns or []), indent) -class _Block(collections.namedtuple('Block', ['prefix', 'header', 'body'])): - """Structured block of multiple lines.""" +class _Block(collections.namedtuple("Block", ["prefix", "header", "body"])): + """Structured block of multiple lines.""" - def __new__(cls, prefix, header, body): - assert header - return super(cls, _Block).__new__(cls, prefix, header, body) + def __new__(cls, prefix, header, body): + assert header + return super(cls, _Block).__new__(cls, prefix, header, body) # Map of productions to their formatters. @@ -56,280 +55,295 @@ def __new__(cls, prefix, header, body): def format_emboss_parse_tree(parse_tree, config, used_productions=None): - """Formats Emboss source code. - - Arguments: - parse_tree: A parse tree of an Emboss source file. - config: A Config tuple with formatting options. - used_productions: An optional set to which all used productions will be - added. Intended for use by test code to ensure full production - coverage. - - Returns: - A string of the reformatted source text. - """ - if hasattr(parse_tree, 'children'): - parsed_children = [format_emboss_parse_tree(child, config, used_productions) - for child in parse_tree.children] - args = parsed_children + [config] - if used_productions is not None: - used_productions.add(parse_tree.production) - return _formatters[parse_tree.production](*args) - else: - assert isinstance(parse_tree, parser_types.Token), str(parse_tree) - return parse_tree.text + """Formats Emboss source code. + + Arguments: + parse_tree: A parse tree of an Emboss source file. + config: A Config tuple with formatting options. + used_productions: An optional set to which all used productions will be + added. Intended for use by test code to ensure full production + coverage. + + Returns: + A string of the reformatted source text. + """ + if hasattr(parse_tree, "children"): + parsed_children = [ + format_emboss_parse_tree(child, config, used_productions) + for child in parse_tree.children + ] + args = parsed_children + [config] + if used_productions is not None: + used_productions.add(parse_tree.production) + return _formatters[parse_tree.production](*args) + else: + assert isinstance(parse_tree, parser_types.Token), str(parse_tree) + return parse_tree.text def sanity_check_format_result(formatted_text, original_text): - """Checks that the given texts are equivalent.""" - # The texts are considered equivalent if they tokenize to the same token - # stream, except that: - # - # Multiple consecutive newline tokens are equivalent to a single newline - # token. - # - # Extra newline tokens at the start of the stream should be ignored. - # - # Whitespace at the start or end of a token should be ignored. This matters - # for documentation and comment tokens, which may have had trailing whitespace - # in the original text, and for indent tokens, which may contain a different - # number of space and/or tab characters. - original_tokens, errors = tokenizer.tokenize(original_text, '') - if errors: - return ['BUG: original text is not tokenizable: {!r}'.format(errors)] - - formatted_tokens, errors = tokenizer.tokenize(formatted_text, '') - if errors: - return ['BUG: formatted text is not tokenizable: {!r}'.format(errors)] - - o_tokens = _collapse_newline_tokens(original_tokens) - f_tokens = _collapse_newline_tokens(formatted_tokens) - for i in range(len(o_tokens)): - if (o_tokens[i].symbol != f_tokens[i].symbol or - o_tokens[i].text.strip() != f_tokens[i].text.strip()): - return ['BUG: Symbol {} differs: {!r} vs {!r}'.format(i, o_tokens[i], - f_tokens[i])] - return [] + """Checks that the given texts are equivalent.""" + # The texts are considered equivalent if they tokenize to the same token + # stream, except that: + # + # Multiple consecutive newline tokens are equivalent to a single newline + # token. + # + # Extra newline tokens at the start of the stream should be ignored. + # + # Whitespace at the start or end of a token should be ignored. This matters + # for documentation and comment tokens, which may have had trailing whitespace + # in the original text, and for indent tokens, which may contain a different + # number of space and/or tab characters. + original_tokens, errors = tokenizer.tokenize(original_text, "") + if errors: + return ["BUG: original text is not tokenizable: {!r}".format(errors)] + + formatted_tokens, errors = tokenizer.tokenize(formatted_text, "") + if errors: + return ["BUG: formatted text is not tokenizable: {!r}".format(errors)] + + o_tokens = _collapse_newline_tokens(original_tokens) + f_tokens = _collapse_newline_tokens(formatted_tokens) + for i in range(len(o_tokens)): + if ( + o_tokens[i].symbol != f_tokens[i].symbol + or o_tokens[i].text.strip() != f_tokens[i].text.strip() + ): + return [ + "BUG: Symbol {} differs: {!r} vs {!r}".format( + i, o_tokens[i], f_tokens[i] + ) + ] + return [] def _collapse_newline_tokens(token_list): - r"""Collapses multiple consecutive "\\n" tokens into a single newline.""" - result = [] - for symbol, group in itertools.groupby(token_list, lambda x: x.symbol): - if symbol == '"\\n"': - # Skip all newlines if they are at the start, otherwise add a single - # newline for each consecutive run of newlines. - if result: - result.append(list(group)[0]) - else: - result.extend(group) - return result + r"""Collapses multiple consecutive "\\n" tokens into a single newline.""" + result = [] + for symbol, group in itertools.groupby(token_list, lambda x: x.symbol): + if symbol == '"\\n"': + # Skip all newlines if they are at the start, otherwise add a single + # newline for each consecutive run of newlines. + if result: + result.append(list(group)[0]) + else: + result.extend(group) + return result def _indent_row(row): - """Adds one level of indent to the given row, returning a new row.""" - assert isinstance(row, _Row), repr(row) - return _Row(name=row.name, - columns=row.columns, - indent=row.indent + 1) + """Adds one level of indent to the given row, returning a new row.""" + assert isinstance(row, _Row), repr(row) + return _Row(name=row.name, columns=row.columns, indent=row.indent + 1) def _indent_rows(rows): - """Adds one level of indent to the given rows, returning a new list.""" - return list(map(_indent_row, rows)) + """Adds one level of indent to the given rows, returning a new list.""" + return list(map(_indent_row, rows)) def _indent_blocks(blocks): - """Adds one level of indent to the given blocks, returning a new list.""" - return [_Block(prefix=_indent_rows(block.prefix), - header=_indent_row(block.header), - body=_indent_rows(block.body)) - for block in blocks] + """Adds one level of indent to the given blocks, returning a new list.""" + return [ + _Block( + prefix=_indent_rows(block.prefix), + header=_indent_row(block.header), + body=_indent_rows(block.body), + ) + for block in blocks + ] def _intersperse(interspersed, sections): - """Intersperses `interspersed` between non-empty `sections`.""" - result = [] - for section in sections: - if section: - if result: - result.extend(interspersed) - result.extend(section) - return result + """Intersperses `interspersed` between non-empty `sections`.""" + result = [] + for section in sections: + if section: + if result: + result.extend(interspersed) + result.extend(section) + return result def _should_add_blank_lines(blocks): - """Returns true if blank lines should be added between blocks.""" - other_non_empty_lines = 0 - last_non_empty_lines = 0 - for block in blocks: - last_non_empty_lines = len([line for line in - block.body + block.prefix - if line.columns]) - other_non_empty_lines += last_non_empty_lines - # Vertical spaces should be added if there are more interior - # non-empty-non-header lines than header lines. - return len(blocks) <= other_non_empty_lines - last_non_empty_lines + """Returns true if blank lines should be added between blocks.""" + other_non_empty_lines = 0 + last_non_empty_lines = 0 + for block in blocks: + last_non_empty_lines = len( + [line for line in block.body + block.prefix if line.columns] + ) + other_non_empty_lines += last_non_empty_lines + # Vertical spaces should be added if there are more interior + # non-empty-non-header lines than header lines. + return len(blocks) <= other_non_empty_lines - last_non_empty_lines def _columnize(blocks, indent_width, indent_columns=1): - """Aligns columns in the header rows of the given blocks. - - The `indent_columns` argument is used to determine how many columns should be - indented. With `indent_columns == 1`, the result would be: - - AA BB CC - AAA BBB CCC - A B C - - With `indent_columns == 2`: - - AA BB CC - AAA BBB CCC - A B C - - With `indent_columns == 1`, only the first column is indented compared to - surrounding rows; with `indent_columns == 2`, both the first and second - columns are indented. - - Arguments: - blocks: A list of _Blocks to columnize. - indent_width: The number of spaces per level of indent. - indent_columns: The number of columns to indent. - - Returns: - A list of _Rows of the prefix, header, and body _Rows of each block, where - the header _Rows of each type have had their columns aligned. - """ - single_width_separators = {'enum-value': {0, 1}, 'field': {0}} - # For each type of row, figure out how many characters each column needs. - row_types = collections.defaultdict( - lambda: collections.defaultdict(lambda: 0)) - for block in blocks: - max_lengths = row_types[block.header.name] - for i in range(len(block.header.columns)): - if i == indent_columns - 1: - adjustment = block.header.indent * indent_width - else: - adjustment = 0 - max_lengths[i] = max(max_lengths[i], - len(block.header.columns[i]) + adjustment) - - assert len(row_types) < 3 - - # Then, for each row, actually columnize it. - result = [] - for block in blocks: - columns = [] - for i in range(len(block.header.columns)): - column_width = row_types[block.header.name][i] - if column_width == 0: - # Zero-width columns are entirely omitted, including their column - # separators. - pass - else: - if i == indent_columns - 1: - # This function only performs the right padding for each column. - # Since the left padding for indent will be added later, the - # corresponding space needs to be removed from the right padding of - # the first column. - column_width -= block.header.indent * indent_width - if i in single_width_separators.get(block.header.name, []): - # Only one space around the "=" in enum values and between the start - # and size in field locations. - column_width += 1 - else: - column_width += 2 - columns.append(block.header.columns[i].ljust(column_width)) - result.append(block.prefix + [_Row(block.header.name, - [''.join(columns).rstrip()], - block.header.indent)] + block.body) - return result + """Aligns columns in the header rows of the given blocks. + + The `indent_columns` argument is used to determine how many columns should be + indented. With `indent_columns == 1`, the result would be: + + AA BB CC + AAA BBB CCC + A B C + + With `indent_columns == 2`: + + AA BB CC + AAA BBB CCC + A B C + + With `indent_columns == 1`, only the first column is indented compared to + surrounding rows; with `indent_columns == 2`, both the first and second + columns are indented. + + Arguments: + blocks: A list of _Blocks to columnize. + indent_width: The number of spaces per level of indent. + indent_columns: The number of columns to indent. + + Returns: + A list of _Rows of the prefix, header, and body _Rows of each block, where + the header _Rows of each type have had their columns aligned. + """ + single_width_separators = {"enum-value": {0, 1}, "field": {0}} + # For each type of row, figure out how many characters each column needs. + row_types = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) + for block in blocks: + max_lengths = row_types[block.header.name] + for i in range(len(block.header.columns)): + if i == indent_columns - 1: + adjustment = block.header.indent * indent_width + else: + adjustment = 0 + max_lengths[i] = max( + max_lengths[i], len(block.header.columns[i]) + adjustment + ) + + assert len(row_types) < 3 + + # Then, for each row, actually columnize it. + result = [] + for block in blocks: + columns = [] + for i in range(len(block.header.columns)): + column_width = row_types[block.header.name][i] + if column_width == 0: + # Zero-width columns are entirely omitted, including their column + # separators. + pass + else: + if i == indent_columns - 1: + # This function only performs the right padding for each column. + # Since the left padding for indent will be added later, the + # corresponding space needs to be removed from the right padding of + # the first column. + column_width -= block.header.indent * indent_width + if i in single_width_separators.get(block.header.name, []): + # Only one space around the "=" in enum values and between the start + # and size in field locations. + column_width += 1 + else: + column_width += 2 + columns.append(block.header.columns[i].ljust(column_width)) + result.append( + block.prefix + + [ + _Row( + block.header.name, ["".join(columns).rstrip()], block.header.indent + ) + ] + + block.body + ) + return result def _indent_blanks_and_comments(rows): - """Indents blank and comment lines to match the next non-blank line.""" - result = [] - previous_indent = 0 - for row in reversed(rows): - if not ''.join(row.columns) or row.name == 'comment': - result.append(_Row(row.name, row.columns, previous_indent)) - else: - result.append(row) - previous_indent = row.indent - return reversed(result) + """Indents blank and comment lines to match the next non-blank line.""" + result = [] + previous_indent = 0 + for row in reversed(rows): + if not "".join(row.columns) or row.name == "comment": + result.append(_Row(row.name, row.columns, previous_indent)) + else: + result.append(row) + previous_indent = row.indent + return reversed(result) def _add_blank_rows_on_dedent(rows): - """Adds blank rows before dedented lines, where needed.""" - result = [] - previous_indent = 0 - previous_row_was_blank = True - for row in rows: - row_is_blank = not ''.join(row.columns) - found_dedent = previous_indent > row.indent - if found_dedent and not previous_row_was_blank and not row_is_blank: - result.append(_Row('dedent-space', [], row.indent)) - result.append(row) - previous_indent = row.indent - previous_row_was_blank = row_is_blank - return result + """Adds blank rows before dedented lines, where needed.""" + result = [] + previous_indent = 0 + previous_row_was_blank = True + for row in rows: + row_is_blank = not "".join(row.columns) + found_dedent = previous_indent > row.indent + if found_dedent and not previous_row_was_blank and not row_is_blank: + result.append(_Row("dedent-space", [], row.indent)) + result.append(row) + previous_indent = row.indent + previous_row_was_blank = row_is_blank + return result def _render_row_to_text(row, indent_width): - assert len(row.columns) < 2, '{!r}'.format(row) - text = ' ' * indent_width * row.indent - text += ''.join(row.columns) - return text.rstrip() + assert len(row.columns) < 2, "{!r}".format(row) + text = " " * indent_width * row.indent + text += "".join(row.columns) + return text.rstrip() def _render_rows_to_text(rows, indent_width, show_line_types): - max_row_name_len = max([0] + [len(row.name) for row in rows]) - flattened_rows = [] - for row in rows: - row_text = _render_row_to_text(row, indent_width) - if show_line_types: - row_text = row.name.ljust(max_row_name_len) + '|' + row_text - flattened_rows.append(row_text) - return '\n'.join(flattened_rows + ['']) + max_row_name_len = max([0] + [len(row.name) for row in rows]) + flattened_rows = [] + for row in rows: + row_text = _render_row_to_text(row, indent_width) + if show_line_types: + row_text = row.name.ljust(max_row_name_len) + "|" + row_text + flattened_rows.append(row_text) + return "\n".join(flattened_rows + [""]) def _check_productions(): - """Asserts that the productions in this module match those in module_ir.""" - productions_ok = True - for production in module_ir.PRODUCTIONS: - if production not in _formatters: - productions_ok = False - print('@_formats({!r})'.format(str(production))) + """Asserts that the productions in this module match those in module_ir.""" + productions_ok = True + for production in module_ir.PRODUCTIONS: + if production not in _formatters: + productions_ok = False + print("@_formats({!r})".format(str(production))) - for production in _formatters: - if production not in module_ir.PRODUCTIONS: - productions_ok = False - print('not @_formats({!r})'.format(str(production))) + for production in _formatters: + if production not in module_ir.PRODUCTIONS: + productions_ok = False + print("not @_formats({!r})".format(str(production))) - assert productions_ok, 'Grammar mismatch.' + assert productions_ok, "Grammar mismatch." def _formats_with_config(production_text): - """Marks a function as a formatter requiring a config argument.""" - production = parser_types.Production.parse(production_text) + """Marks a function as a formatter requiring a config argument.""" + production = parser_types.Production.parse(production_text) - def formats(f): - assert production not in _formatters, production - _formatters[production] = f - return f + def formats(f): + assert production not in _formatters, production + _formatters[production] = f + return f - return formats + return formats def _formats(production_text): - """Marks a function as the formatter for a particular production.""" + """Marks a function as the formatter for a particular production.""" - def strip_config_argument(f): - _formats_with_config(production_text)(lambda *a, **kw: f(*a[:-1], **kw)) - return f + def strip_config_argument(f): + _formats_with_config(production_text)(lambda *a, **kw: f(*a[:-1], **kw)) + return f - return strip_config_argument + return strip_config_argument ################################################################################ @@ -358,388 +372,490 @@ def strip_config_argument(f): # strings. -@_formats_with_config('module -> comment-line* doc-line* import-line*' - ' attribute-line* type-definition*') +@_formats_with_config( + "module -> comment-line* doc-line* import-line*" + " attribute-line* type-definition*" +) def _module(comments, docs, imports, attributes, types, config): - """Performs top-level formatting for an Emboss source file.""" - - # The top-level sections other than types should be separated by single lines. - header_rows = _intersperse( - [_Row('section-break')], - [_strip_empty_leading_trailing_comment_lines(comments), docs, imports, - attributes]) - - # Top-level types should be separated by double lines from themselves and from - # the header rows. - rows = _intersperse( - [_Row('top-type-separator'), _Row('top-type-separator')], - [header_rows] + types) - - # Final fixups. - rows = _indent_blanks_and_comments(rows) - rows = _add_blank_rows_on_dedent(rows) - return _render_rows_to_text(rows, config.indent_width, config.show_line_types) - - -@_formats('doc-line -> doc Comment? eol') + """Performs top-level formatting for an Emboss source file.""" + + # The top-level sections other than types should be separated by single lines. + header_rows = _intersperse( + [_Row("section-break")], + [ + _strip_empty_leading_trailing_comment_lines(comments), + docs, + imports, + attributes, + ], + ) + + # Top-level types should be separated by double lines from themselves and from + # the header rows. + rows = _intersperse( + [_Row("top-type-separator"), _Row("top-type-separator")], [header_rows] + types + ) + + # Final fixups. + rows = _indent_blanks_and_comments(rows) + rows = _add_blank_rows_on_dedent(rows) + return _render_rows_to_text(rows, config.indent_width, config.show_line_types) + + +@_formats("doc-line -> doc Comment? eol") def _doc_line(doc, comment, eol): - assert not comment, 'Comment should not be possible on the same line as doc.' - return [_Row('doc', [doc])] + eol + assert not comment, "Comment should not be possible on the same line as doc." + return [_Row("doc", [doc])] + eol -@_formats('import-line -> "import" string-constant "as" snake-word Comment?' - ' eol') +@_formats( + 'import-line -> "import" string-constant "as" snake-word Comment?' + " eol" +) def _import_line(import_, filename, as_, name, comment, eol): - return [_Row('import', ['{} {} {} {} {}'.format( - import_, filename, as_, name, comment)])] + eol + return [ + _Row( + "import", ["{} {} {} {} {}".format(import_, filename, as_, name, comment)] + ) + ] + eol -@_formats('attribute-line -> attribute Comment? eol') +@_formats("attribute-line -> attribute Comment? eol") def _attribute_line(attribute, comment, eol): - return [_Row('attribute', ['{} {}'.format(attribute, comment)])] + eol + return [_Row("attribute", ["{} {}".format(attribute, comment)])] + eol -@_formats('attribute -> "[" attribute-context? "$default"? snake-word ":"' - ' attribute-value "]"') +@_formats( + 'attribute -> "[" attribute-context? "$default"? snake-word ":"' + ' attribute-value "]"' +) def _attribute(open_, context, default, name, colon, value, close): - return ''.join([open_, - _concatenate_with_spaces(context, default, name + colon, - value), - close]) + return "".join( + [open_, _concatenate_with_spaces(context, default, name + colon, value), close] + ) @_formats('parameter-definition -> snake-name ":" type') def _parameter_definition(name, colon, type_specifier): - return '{}{} {}'.format(name, colon, type_specifier) + return "{}{} {}".format(name, colon, type_specifier) -@_formats('type-definition* -> type-definition type-definition*') +@_formats("type-definition* -> type-definition type-definition*") def _type_defitinions(definition, definitions): - return [definition] + definitions + return [definition] + definitions -@_formats('bits -> "bits" type-name delimited-parameter-definition-list? ":"' - ' Comment? eol bits-body') -@_formats('struct -> "struct" type-name delimited-parameter-definition-list?' - ' ":" Comment? eol struct-body') +@_formats( + 'bits -> "bits" type-name delimited-parameter-definition-list? ":"' + " Comment? eol bits-body" +) +@_formats( + 'struct -> "struct" type-name delimited-parameter-definition-list?' + ' ":" Comment? eol struct-body' +) def _structure_type(struct, name, parameters, colon, comment, eol, body): - return ([_Row('type-header', - ['{} {}{}{} {}'.format( - struct, name, parameters, colon, comment)])] + - eol + body) + return ( + [ + _Row( + "type-header", + ["{} {}{}{} {}".format(struct, name, parameters, colon, comment)], + ) + ] + + eol + + body + ) @_formats('enum -> "enum" type-name ":" Comment? eol enum-body') @_formats('external -> "external" type-name ":" Comment? eol external-body') def _type(struct, name, colon, comment, eol, body): - return ([_Row('type-header', - ['{} {}{} {}'.format(struct, name, colon, comment)])] + - eol + body) + return ( + [_Row("type-header", ["{} {}{} {}".format(struct, name, colon, comment)])] + + eol + + body + ) -@_formats_with_config('bits-body -> Indent doc-line* attribute-line*' - ' type-definition* bits-field-block Dedent') @_formats_with_config( - 'struct-body -> Indent doc-line* attribute-line*' - ' type-definition* struct-field-block Dedent') -def _structure_body(indent, docs, attributes, type_definitions, fields, dedent, - config): - del indent, dedent # Unused. - spacing = [_Row('field-separator')] if _should_add_blank_lines(fields) else [] - columnized_fields = _columnize(fields, config.indent_width, indent_columns=2) - return _indent_rows(_intersperse( - spacing, [docs, attributes] + type_definitions + columnized_fields)) + "bits-body -> Indent doc-line* attribute-line*" + " type-definition* bits-field-block Dedent" +) +@_formats_with_config( + "struct-body -> Indent doc-line* attribute-line*" + " type-definition* struct-field-block Dedent" +) +def _structure_body(indent, docs, attributes, type_definitions, fields, dedent, config): + del indent, dedent # Unused. + spacing = [_Row("field-separator")] if _should_add_blank_lines(fields) else [] + columnized_fields = _columnize(fields, config.indent_width, indent_columns=2) + return _indent_rows( + _intersperse(spacing, [docs, attributes] + type_definitions + columnized_fields) + ) @_formats('field-location -> expression "[" "+" expression "]"') def _field_location(start, open_bracket, plus, size, close_bracket): - return [start, open_bracket + plus + size + close_bracket] - - -@_formats('anonymous-bits-field-block -> conditional-anonymous-bits-field-block' - ' anonymous-bits-field-block') -@_formats('anonymous-bits-field-block -> unconditional-anonymous-bits-field' - ' anonymous-bits-field-block') -@_formats('bits-field-block -> conditional-bits-field-block bits-field-block') -@_formats('bits-field-block -> unconditional-bits-field bits-field-block') -@_formats('struct-field-block -> conditional-struct-field-block' - ' struct-field-block') -@_formats('struct-field-block -> unconditional-struct-field struct-field-block') -@_formats('unconditional-anonymous-bits-field* ->' - ' unconditional-anonymous-bits-field' - ' unconditional-anonymous-bits-field*') -@_formats('unconditional-anonymous-bits-field+ ->' - ' unconditional-anonymous-bits-field' - ' unconditional-anonymous-bits-field*') -@_formats('unconditional-bits-field* -> unconditional-bits-field' - ' unconditional-bits-field*') -@_formats('unconditional-bits-field+ -> unconditional-bits-field' - ' unconditional-bits-field*') -@_formats('unconditional-struct-field* -> unconditional-struct-field' - ' unconditional-struct-field*') -@_formats('unconditional-struct-field+ -> unconditional-struct-field' - ' unconditional-struct-field*') + return [start, open_bracket + plus + size + close_bracket] + + +@_formats( + "anonymous-bits-field-block -> conditional-anonymous-bits-field-block" + " anonymous-bits-field-block" +) +@_formats( + "anonymous-bits-field-block -> unconditional-anonymous-bits-field" + " anonymous-bits-field-block" +) +@_formats("bits-field-block -> conditional-bits-field-block bits-field-block") +@_formats("bits-field-block -> unconditional-bits-field bits-field-block") +@_formats( + "struct-field-block -> conditional-struct-field-block" + " struct-field-block" +) +@_formats("struct-field-block -> unconditional-struct-field struct-field-block") +@_formats( + "unconditional-anonymous-bits-field* ->" + " unconditional-anonymous-bits-field" + " unconditional-anonymous-bits-field*" +) +@_formats( + "unconditional-anonymous-bits-field+ ->" + " unconditional-anonymous-bits-field" + " unconditional-anonymous-bits-field*" +) +@_formats( + "unconditional-bits-field* -> unconditional-bits-field" + " unconditional-bits-field*" +) +@_formats( + "unconditional-bits-field+ -> unconditional-bits-field" + " unconditional-bits-field*" +) +@_formats( + "unconditional-struct-field* -> unconditional-struct-field" + " unconditional-struct-field*" +) +@_formats( + "unconditional-struct-field+ -> unconditional-struct-field" + " unconditional-struct-field*" +) def _structure_block(field, block): - """Prepends field to block.""" - return field + block + """Prepends field to block.""" + return field + block -@_formats('virtual-field -> "let" snake-name "=" expression Comment? eol' - ' field-body?') +@_formats( + 'virtual-field -> "let" snake-name "=" expression Comment? eol' + " field-body?" +) def _virtual_field(let_keyword, name, equals, value, comment, eol, body): - # This formatting doesn't look the best when there are blocks of several - # virtual fields next to each other, but works pretty well when they're - # intermixed with physical fields. It's probably good enough for now, since - # there aren't (yet) any virtual fields in real .embs, and will probably only - # be a few in the near future. - return [_Block([], - _Row('virtual-field', - [_concatenate_with( - ' ', - _concatenate_with_spaces(let_keyword, name, equals, - value), - comment)]), - eol + body)] - - -@_formats('field -> field-location type snake-name abbreviation?' - ' attribute* doc? Comment? eol field-body?') -def _unconditional_field(location, type_, name, abbreviation, attributes, doc, - comment, eol, body): - return [_Block([], - _Row('field', - location + [type_, - _concatenate_with_spaces(name, abbreviation), - attributes, doc, comment]), - eol + body)] - - -@_formats('field-body -> Indent doc-line* attribute-line* Dedent') + # This formatting doesn't look the best when there are blocks of several + # virtual fields next to each other, but works pretty well when they're + # intermixed with physical fields. It's probably good enough for now, since + # there aren't (yet) any virtual fields in real .embs, and will probably only + # be a few in the near future. + return [ + _Block( + [], + _Row( + "virtual-field", + [ + _concatenate_with( + " ", + _concatenate_with_spaces(let_keyword, name, equals, value), + comment, + ) + ], + ), + eol + body, + ) + ] + + +@_formats( + "field -> field-location type snake-name abbreviation?" + " attribute* doc? Comment? eol field-body?" +) +def _unconditional_field( + location, type_, name, abbreviation, attributes, doc, comment, eol, body +): + return [ + _Block( + [], + _Row( + "field", + location + + [ + type_, + _concatenate_with_spaces(name, abbreviation), + attributes, + doc, + comment, + ], + ), + eol + body, + ) + ] + + +@_formats("field-body -> Indent doc-line* attribute-line* Dedent") def _field_body(indent, docs, attributes, dedent): - del indent, dedent # Unused - return _indent_rows(docs + attributes) + del indent, dedent # Unused + return _indent_rows(docs + attributes) -@_formats('anonymous-bits-field-definition ->' - ' field-location "bits" ":" Comment? eol anonymous-bits-body') +@_formats( + "anonymous-bits-field-definition ->" + ' field-location "bits" ":" Comment? eol anonymous-bits-body' +) def _inline_bits(location, bits, colon, comment, eol, body): - # Even though an anonymous bits field technically defines a new, anonymous - # type, conceptually it's more like defining a bunch of fields on the - # surrounding type, so it is treated as an inline list of blocks, instead of - # being separately formatted. - header_row = _Row('field', [location[0], location[1] + ' ' + bits + colon, - '', '', '', '', comment]) - return ([_Block([], header_row, eol + body.header_lines)] + - body.field_blocks) + # Even though an anonymous bits field technically defines a new, anonymous + # type, conceptually it's more like defining a bunch of fields on the + # surrounding type, so it is treated as an inline list of blocks, instead of + # being separately formatted. + header_row = _Row( + "field", + [location[0], location[1] + " " + bits + colon, "", "", "", "", comment], + ) + return [_Block([], header_row, eol + body.header_lines)] + body.field_blocks -@_formats('inline-enum-field-definition ->' - ' field-location "enum" snake-name abbreviation? ":" Comment? eol' - ' enum-body') @_formats( - 'inline-struct-field-definition ->' + "inline-enum-field-definition ->" + ' field-location "enum" snake-name abbreviation? ":" Comment? eol' + " enum-body" +) +@_formats( + "inline-struct-field-definition ->" ' field-location "struct" snake-name abbreviation? ":" Comment? eol' - ' struct-body') -@_formats('inline-bits-field-definition ->' - ' field-location "bits" snake-name abbreviation? ":" Comment? eol' - ' bits-body') -def _inline_type(location, keyword, name, abbreviation, colon, comment, eol, - body): - """Formats an inline type in a struct or bits.""" - header_row = _Row( - 'field', location + [keyword, - _concatenate_with_spaces(name, abbreviation) + colon, - '', '', comment]) - return [_Block([], header_row, eol + body)] - - -@_formats('conditional-struct-field-block -> "if" expression ":" Comment? eol' - ' Indent unconditional-struct-field+' - ' Dedent') -@_formats('conditional-bits-field-block -> "if" expression ":" Comment? eol' - ' Indent unconditional-bits-field+' - ' Dedent') -@_formats('conditional-anonymous-bits-field-block ->' - ' "if" expression ":" Comment? eol' - ' Indent unconditional-anonymous-bits-field+ Dedent') -def _conditional_field(if_, condition, colon, comment, eol, indent, body, - dedent): - """Formats an `if` construct.""" - del indent, dedent # Unused - # The body of an 'if' should be columnized with the surrounding blocks, so - # much like an inline 'bits', its body is treated as an inline list of blocks. - header_row = _Row('if', - ['{} {}{} {}'.format(if_, condition, colon, comment)]) - indented_body = _indent_blocks(body) - assert indented_body, 'Expected body of if condition.' - return [_Block([header_row] + eol + indented_body[0].prefix, - indented_body[0].header, - indented_body[0].body)] + indented_body[1:] - - -_InlineBitsBodyType = collections.namedtuple('InlineBitsBodyType', - ['header_lines', 'field_blocks']) - - -@_formats('anonymous-bits-body ->' - ' Indent attribute-line* anonymous-bits-field-block Dedent') + " struct-body" +) +@_formats( + "inline-bits-field-definition ->" + ' field-location "bits" snake-name abbreviation? ":" Comment? eol' + " bits-body" +) +def _inline_type(location, keyword, name, abbreviation, colon, comment, eol, body): + """Formats an inline type in a struct or bits.""" + header_row = _Row( + "field", + location + + [ + keyword, + _concatenate_with_spaces(name, abbreviation) + colon, + "", + "", + comment, + ], + ) + return [_Block([], header_row, eol + body)] + + +@_formats( + 'conditional-struct-field-block -> "if" expression ":" Comment? eol' + " Indent unconditional-struct-field+" + " Dedent" +) +@_formats( + 'conditional-bits-field-block -> "if" expression ":" Comment? eol' + " Indent unconditional-bits-field+" + " Dedent" +) +@_formats( + "conditional-anonymous-bits-field-block ->" + ' "if" expression ":" Comment? eol' + " Indent unconditional-anonymous-bits-field+ Dedent" +) +def _conditional_field(if_, condition, colon, comment, eol, indent, body, dedent): + """Formats an `if` construct.""" + del indent, dedent # Unused + # The body of an 'if' should be columnized with the surrounding blocks, so + # much like an inline 'bits', its body is treated as an inline list of blocks. + header_row = _Row("if", ["{} {}{} {}".format(if_, condition, colon, comment)]) + indented_body = _indent_blocks(body) + assert indented_body, "Expected body of if condition." + return [ + _Block( + [header_row] + eol + indented_body[0].prefix, + indented_body[0].header, + indented_body[0].body, + ) + ] + indented_body[1:] + + +_InlineBitsBodyType = collections.namedtuple( + "InlineBitsBodyType", ["header_lines", "field_blocks"] +) + + +@_formats( + "anonymous-bits-body ->" + " Indent attribute-line* anonymous-bits-field-block Dedent" +) def _inline_bits_body(indent, attributes, fields, dedent): - del indent, dedent # Unused - return _InlineBitsBodyType(header_lines=_indent_rows(attributes), - field_blocks=_indent_blocks(fields)) + del indent, dedent # Unused + return _InlineBitsBodyType( + header_lines=_indent_rows(attributes), field_blocks=_indent_blocks(fields) + ) @_formats_with_config( - 'enum-body -> Indent doc-line* attribute-line* enum-value+' - ' Dedent') + "enum-body -> Indent doc-line* attribute-line* enum-value+" " Dedent" +) def _enum_body(indent, docs, attributes, values, dedent, config): - del indent, dedent # Unused - spacing = [_Row('value-separator')] if _should_add_blank_lines(values) else [] - columnized_values = _columnize(values, config.indent_width) - return _indent_rows(_intersperse(spacing, - [docs, attributes] + columnized_values)) + del indent, dedent # Unused + spacing = [_Row("value-separator")] if _should_add_blank_lines(values) else [] + columnized_values = _columnize(values, config.indent_width) + return _indent_rows(_intersperse(spacing, [docs, attributes] + columnized_values)) -@_formats('enum-value* -> enum-value enum-value*') -@_formats('enum-value+ -> enum-value enum-value*') +@_formats("enum-value* -> enum-value enum-value*") +@_formats("enum-value+ -> enum-value enum-value*") def _enum_values(value, block): - return value + block + return value + block -@_formats('enum-value -> constant-name "=" expression attribute* doc? Comment? eol' - ' enum-value-body?') +@_formats( + 'enum-value -> constant-name "=" expression attribute* doc? Comment? eol' + " enum-value-body?" +) def _enum_value(name, equals, value, attributes, docs, comment, eol, body): - return [_Block([], _Row('enum-value', [name, equals, value, attributes, docs, comment]), - eol + body)] + return [ + _Block( + [], + _Row("enum-value", [name, equals, value, attributes, docs, comment]), + eol + body, + ) + ] -@_formats('enum-value-body -> Indent doc-line* attribute-line* Dedent') +@_formats("enum-value-body -> Indent doc-line* attribute-line* Dedent") def _enum_value_body(indent, docs, attributes, dedent): - del indent, dedent # Unused - return _indent_rows(docs + attributes) + del indent, dedent # Unused + return _indent_rows(docs + attributes) -@_formats('external-body -> Indent doc-line* attribute-line* Dedent') +@_formats("external-body -> Indent doc-line* attribute-line* Dedent") def _external_body(indent, docs, attributes, dedent): - del indent, dedent # Unused - return _indent_rows(_intersperse([_Row('section-break')], [docs, attributes])) + del indent, dedent # Unused + return _indent_rows(_intersperse([_Row("section-break")], [docs, attributes])) @_formats('comment-line -> Comment? "\\n"') def _comment_line(comment, eol): - del eol # Unused - if comment: - return [_Row('comment', [comment])] - else: - return [_Row('comment')] + del eol # Unused + if comment: + return [_Row("comment", [comment])] + else: + return [_Row("comment")] @_formats('eol -> "\\n" comment-line*') def _eol(eol, comments): - del eol # Unused - return _strip_empty_leading_trailing_comment_lines(comments) + del eol # Unused + return _strip_empty_leading_trailing_comment_lines(comments) def _strip_empty_leading_trailing_comment_lines(comments): - first_non_empty_line = None - last_non_empty_line = None - for i in range(len(comments)): - if comments[i].columns: - if first_non_empty_line is None: - first_non_empty_line = i - last_non_empty_line = i - if first_non_empty_line is None: - return [] - else: - return comments[first_non_empty_line:last_non_empty_line + 1] - - -@_formats('attribute-line* -> ') -@_formats('anonymous-bits-field-block -> ') -@_formats('bits-field-block -> ') -@_formats('comment-line* -> ') -@_formats('doc-line* -> ') -@_formats('enum-value* -> ') -@_formats('enum-value-body? -> ') -@_formats('field-body? -> ') -@_formats('import-line* -> ') -@_formats('struct-field-block -> ') -@_formats('type-definition* -> ') -@_formats('unconditional-anonymous-bits-field* -> ') -@_formats('unconditional-bits-field* -> ') -@_formats('unconditional-struct-field* -> ') + first_non_empty_line = None + last_non_empty_line = None + for i in range(len(comments)): + if comments[i].columns: + if first_non_empty_line is None: + first_non_empty_line = i + last_non_empty_line = i + if first_non_empty_line is None: + return [] + else: + return comments[first_non_empty_line : last_non_empty_line + 1] + + +@_formats("attribute-line* -> ") +@_formats("anonymous-bits-field-block -> ") +@_formats("bits-field-block -> ") +@_formats("comment-line* -> ") +@_formats("doc-line* -> ") +@_formats("enum-value* -> ") +@_formats("enum-value-body? -> ") +@_formats("field-body? -> ") +@_formats("import-line* -> ") +@_formats("struct-field-block -> ") +@_formats("type-definition* -> ") +@_formats("unconditional-anonymous-bits-field* -> ") +@_formats("unconditional-bits-field* -> ") +@_formats("unconditional-struct-field* -> ") def _empty_list(): - return [] - - -@_formats('abbreviation? -> ') -@_formats('additive-expression-right* -> ') -@_formats('and-expression-right* -> ') -@_formats('argument-list -> ') -@_formats('array-length-specifier* -> ') -@_formats('attribute* -> ') -@_formats('attribute-context? -> ') -@_formats('comma-then-expression* -> ') -@_formats('Comment? -> ') + return [] + + +@_formats("abbreviation? -> ") +@_formats("additive-expression-right* -> ") +@_formats("and-expression-right* -> ") +@_formats("argument-list -> ") +@_formats("array-length-specifier* -> ") +@_formats("attribute* -> ") +@_formats("attribute-context? -> ") +@_formats("comma-then-expression* -> ") +@_formats("Comment? -> ") @_formats('"$default"? -> ') -@_formats('delimited-argument-list? -> ') -@_formats('delimited-parameter-definition-list? -> ') -@_formats('doc? -> ') -@_formats('equality-expression-right* -> ') -@_formats('equality-or-greater-expression-right* -> ') -@_formats('equality-or-less-expression-right* -> ') -@_formats('field-reference-tail* -> ') -@_formats('or-expression-right* -> ') -@_formats('parameter-definition-list -> ') -@_formats('parameter-definition-list-tail* -> ') -@_formats('times-expression-right* -> ') -@_formats('type-size-specifier? -> ') +@_formats("delimited-argument-list? -> ") +@_formats("delimited-parameter-definition-list? -> ") +@_formats("doc? -> ") +@_formats("equality-expression-right* -> ") +@_formats("equality-or-greater-expression-right* -> ") +@_formats("equality-or-less-expression-right* -> ") +@_formats("field-reference-tail* -> ") +@_formats("or-expression-right* -> ") +@_formats("parameter-definition-list -> ") +@_formats("parameter-definition-list-tail* -> ") +@_formats("times-expression-right* -> ") +@_formats("type-size-specifier? -> ") def _empty_string(): - return '' + return "" -@_formats('abbreviation? -> abbreviation') +@_formats("abbreviation? -> abbreviation") @_formats('additive-operator -> "-"') @_formats('additive-operator -> "+"') @_formats('and-operator -> "&&"') -@_formats('attribute-context? -> attribute-context') -@_formats('attribute-value -> expression') -@_formats('attribute-value -> string-constant') -@_formats('boolean-constant -> BooleanConstant') -@_formats('bottom-expression -> boolean-constant') -@_formats('bottom-expression -> builtin-reference') -@_formats('bottom-expression -> constant-reference') -@_formats('bottom-expression -> field-reference') -@_formats('bottom-expression -> numeric-constant') +@_formats("attribute-context? -> attribute-context") +@_formats("attribute-value -> expression") +@_formats("attribute-value -> string-constant") +@_formats("boolean-constant -> BooleanConstant") +@_formats("bottom-expression -> boolean-constant") +@_formats("bottom-expression -> builtin-reference") +@_formats("bottom-expression -> constant-reference") +@_formats("bottom-expression -> field-reference") +@_formats("bottom-expression -> numeric-constant") @_formats('builtin-field-word -> "$max_size_in_bits"') @_formats('builtin-field-word -> "$max_size_in_bytes"') @_formats('builtin-field-word -> "$min_size_in_bits"') @_formats('builtin-field-word -> "$min_size_in_bytes"') @_formats('builtin-field-word -> "$size_in_bits"') @_formats('builtin-field-word -> "$size_in_bytes"') -@_formats('builtin-reference -> builtin-word') +@_formats("builtin-reference -> builtin-word") @_formats('builtin-word -> "$is_statically_sized"') @_formats('builtin-word -> "$next"') @_formats('builtin-word -> "$static_size_in_bits"') -@_formats('choice-expression -> logical-expression') -@_formats('Comment? -> Comment') -@_formats('comparison-expression -> additive-expression') -@_formats('constant-name -> constant-word') -@_formats('constant-reference -> constant-reference-tail') -@_formats('constant-reference-tail -> constant-word') -@_formats('constant-word -> ShoutyWord') +@_formats("choice-expression -> logical-expression") +@_formats("Comment? -> Comment") +@_formats("comparison-expression -> additive-expression") +@_formats("constant-name -> constant-word") +@_formats("constant-reference -> constant-reference-tail") +@_formats("constant-reference-tail -> constant-word") +@_formats("constant-word -> ShoutyWord") @_formats('"$default"? -> "$default"') -@_formats('delimited-argument-list? -> delimited-argument-list') -@_formats('doc? -> doc') -@_formats('doc -> Documentation') -@_formats('enum-value-body? -> enum-value-body') +@_formats("delimited-argument-list? -> delimited-argument-list") +@_formats("doc? -> doc") +@_formats("doc -> Documentation") +@_formats("enum-value-body? -> enum-value-body") @_formats('equality-operator -> "=="') -@_formats('equality-or-greater-expression-right -> equality-expression-right') -@_formats('equality-or-greater-expression-right -> greater-expression-right') -@_formats('equality-or-less-expression-right -> equality-expression-right') -@_formats('equality-or-less-expression-right -> less-expression-right') -@_formats('expression -> choice-expression') -@_formats('field-body? -> field-body') +@_formats("equality-or-greater-expression-right -> equality-expression-right") +@_formats("equality-or-greater-expression-right -> greater-expression-right") +@_formats("equality-or-less-expression-right -> equality-expression-right") +@_formats("equality-or-less-expression-right -> less-expression-right") +@_formats("expression -> choice-expression") +@_formats("field-body? -> field-body") @_formats('function-name -> "$lower_bound"') @_formats('function-name -> "$present"') @_formats('function-name -> "$max"') @@ -749,146 +865,186 @@ def _empty_string(): @_formats('inequality-operator -> "!="') @_formats('less-operator -> "<="') @_formats('less-operator -> "<"') -@_formats('logical-expression -> and-expression') -@_formats('logical-expression -> comparison-expression') -@_formats('logical-expression -> or-expression') +@_formats("logical-expression -> and-expression") +@_formats("logical-expression -> comparison-expression") +@_formats("logical-expression -> or-expression") @_formats('multiplicative-operator -> "*"') -@_formats('negation-expression -> bottom-expression') -@_formats('numeric-constant -> Number') +@_formats("negation-expression -> bottom-expression") +@_formats("numeric-constant -> Number") @_formats('or-operator -> "||"') -@_formats('snake-name -> snake-word') -@_formats('snake-reference -> builtin-field-word') -@_formats('snake-reference -> snake-word') -@_formats('snake-word -> SnakeWord') -@_formats('string-constant -> String') -@_formats('type-definition -> bits') -@_formats('type-definition -> enum') -@_formats('type-definition -> external') -@_formats('type-definition -> struct') -@_formats('type-name -> type-word') -@_formats('type-reference-tail -> type-word') -@_formats('type-reference -> type-reference-tail') -@_formats('type-size-specifier? -> type-size-specifier') -@_formats('type-word -> CamelWord') -@_formats('unconditional-anonymous-bits-field -> field') -@_formats('unconditional-anonymous-bits-field -> inline-bits-field-definition') -@_formats('unconditional-anonymous-bits-field -> inline-enum-field-definition') -@_formats('unconditional-bits-field -> unconditional-anonymous-bits-field') -@_formats('unconditional-bits-field -> virtual-field') -@_formats('unconditional-struct-field -> anonymous-bits-field-definition') -@_formats('unconditional-struct-field -> field') -@_formats('unconditional-struct-field -> inline-bits-field-definition') -@_formats('unconditional-struct-field -> inline-enum-field-definition') -@_formats('unconditional-struct-field -> inline-struct-field-definition') -@_formats('unconditional-struct-field -> virtual-field') +@_formats("snake-name -> snake-word") +@_formats("snake-reference -> builtin-field-word") +@_formats("snake-reference -> snake-word") +@_formats("snake-word -> SnakeWord") +@_formats("string-constant -> String") +@_formats("type-definition -> bits") +@_formats("type-definition -> enum") +@_formats("type-definition -> external") +@_formats("type-definition -> struct") +@_formats("type-name -> type-word") +@_formats("type-reference-tail -> type-word") +@_formats("type-reference -> type-reference-tail") +@_formats("type-size-specifier? -> type-size-specifier") +@_formats("type-word -> CamelWord") +@_formats("unconditional-anonymous-bits-field -> field") +@_formats("unconditional-anonymous-bits-field -> inline-bits-field-definition") +@_formats("unconditional-anonymous-bits-field -> inline-enum-field-definition") +@_formats("unconditional-bits-field -> unconditional-anonymous-bits-field") +@_formats("unconditional-bits-field -> virtual-field") +@_formats("unconditional-struct-field -> anonymous-bits-field-definition") +@_formats("unconditional-struct-field -> field") +@_formats("unconditional-struct-field -> inline-bits-field-definition") +@_formats("unconditional-struct-field -> inline-enum-field-definition") +@_formats("unconditional-struct-field -> inline-struct-field-definition") +@_formats("unconditional-struct-field -> virtual-field") def _identity(x): - return x + return x -@_formats('argument-list -> expression comma-then-expression*') -@_formats('times-expression -> negation-expression times-expression-right*') -@_formats('type -> type-reference delimited-argument-list? type-size-specifier?' - ' array-length-specifier*') +@_formats("argument-list -> expression comma-then-expression*") +@_formats("times-expression -> negation-expression times-expression-right*") +@_formats( + "type -> type-reference delimited-argument-list? type-size-specifier?" + " array-length-specifier*" +) @_formats('array-length-specifier -> "[" expression "]"') -@_formats('array-length-specifier* -> array-length-specifier' - ' array-length-specifier*') +@_formats( + "array-length-specifier* -> array-length-specifier" + " array-length-specifier*" +) @_formats('type-size-specifier -> ":" numeric-constant') @_formats('attribute-context -> "(" snake-word ")"') @_formats('constant-reference -> snake-reference "." constant-reference-tail') @_formats('constant-reference-tail -> type-word "." constant-reference-tail') @_formats('constant-reference-tail -> type-word "." snake-reference') @_formats('type-reference-tail -> type-word "." type-reference-tail') -@_formats('field-reference -> snake-reference field-reference-tail*') +@_formats("field-reference -> snake-reference field-reference-tail*") @_formats('abbreviation -> "(" snake-word ")"') -@_formats('additive-expression-right -> additive-operator times-expression') -@_formats('additive-expression-right* -> additive-expression-right' - ' additive-expression-right*') -@_formats('additive-expression -> times-expression additive-expression-right*') +@_formats("additive-expression-right -> additive-operator times-expression") +@_formats( + "additive-expression-right* -> additive-expression-right" + " additive-expression-right*" +) +@_formats("additive-expression -> times-expression additive-expression-right*") @_formats('array-length-specifier -> "[" "]"') @_formats('delimited-argument-list -> "(" argument-list ")"') -@_formats('delimited-parameter-definition-list? ->' - ' delimited-parameter-definition-list') -@_formats('delimited-parameter-definition-list ->' - ' "(" parameter-definition-list ")"') -@_formats('parameter-definition-list -> parameter-definition' - ' parameter-definition-list-tail*') -@_formats('parameter-definition-list-tail* -> parameter-definition-list-tail' - ' parameter-definition-list-tail*') -@_formats('times-expression-right -> multiplicative-operator' - ' negation-expression') -@_formats('times-expression-right* -> times-expression-right' - ' times-expression-right*') +@_formats( + "delimited-parameter-definition-list? ->" " delimited-parameter-definition-list" +) +@_formats( + "delimited-parameter-definition-list ->" ' "(" parameter-definition-list ")"' +) +@_formats( + "parameter-definition-list -> parameter-definition" + " parameter-definition-list-tail*" +) +@_formats( + "parameter-definition-list-tail* -> parameter-definition-list-tail" + " parameter-definition-list-tail*" +) +@_formats( + "times-expression-right -> multiplicative-operator" + " negation-expression" +) +@_formats( + "times-expression-right* -> times-expression-right" + " times-expression-right*" +) @_formats('field-reference-tail -> "." snake-reference') -@_formats('field-reference-tail* -> field-reference-tail field-reference-tail*') -@_formats('negation-expression -> additive-operator bottom-expression') +@_formats("field-reference-tail* -> field-reference-tail field-reference-tail*") +@_formats("negation-expression -> additive-operator bottom-expression") @_formats('type-reference -> snake-word "." type-reference-tail') @_formats('bottom-expression -> "(" expression ")"') @_formats('bottom-expression -> function-name "(" argument-list ")"') -@_formats('comma-then-expression* -> comma-then-expression' - ' comma-then-expression*') -@_formats('or-expression-right* -> or-expression-right or-expression-right*') -@_formats('less-expression-right-list -> equality-expression-right*' - ' less-expression-right' - ' equality-or-less-expression-right*') -@_formats('or-expression-right+ -> or-expression-right or-expression-right*') -@_formats('and-expression -> comparison-expression and-expression-right+') -@_formats('comparison-expression -> additive-expression' - ' greater-expression-right-list') -@_formats('comparison-expression -> additive-expression' - ' equality-expression-right+') -@_formats('or-expression -> comparison-expression or-expression-right+') -@_formats('equality-expression-right+ -> equality-expression-right' - ' equality-expression-right*') -@_formats('and-expression-right* -> and-expression-right and-expression-right*') -@_formats('equality-or-greater-expression-right* ->' - ' equality-or-greater-expression-right' - ' equality-or-greater-expression-right*') -@_formats('and-expression-right+ -> and-expression-right and-expression-right*') -@_formats('equality-or-less-expression-right* ->' - ' equality-or-less-expression-right' - ' equality-or-less-expression-right*') -@_formats('equality-expression-right* -> equality-expression-right' - ' equality-expression-right*') -@_formats('greater-expression-right-list ->' - ' equality-expression-right* greater-expression-right' - ' equality-or-greater-expression-right*') -@_formats('comparison-expression -> additive-expression' - ' less-expression-right-list') +@_formats( + "comma-then-expression* -> comma-then-expression" + " comma-then-expression*" +) +@_formats("or-expression-right* -> or-expression-right or-expression-right*") +@_formats( + "less-expression-right-list -> equality-expression-right*" + " less-expression-right" + " equality-or-less-expression-right*" +) +@_formats("or-expression-right+ -> or-expression-right or-expression-right*") +@_formats("and-expression -> comparison-expression and-expression-right+") +@_formats( + "comparison-expression -> additive-expression" + " greater-expression-right-list" +) +@_formats( + "comparison-expression -> additive-expression" + " equality-expression-right+" +) +@_formats("or-expression -> comparison-expression or-expression-right+") +@_formats( + "equality-expression-right+ -> equality-expression-right" + " equality-expression-right*" +) +@_formats("and-expression-right* -> and-expression-right and-expression-right*") +@_formats( + "equality-or-greater-expression-right* ->" + " equality-or-greater-expression-right" + " equality-or-greater-expression-right*" +) +@_formats("and-expression-right+ -> and-expression-right and-expression-right*") +@_formats( + "equality-or-less-expression-right* ->" + " equality-or-less-expression-right" + " equality-or-less-expression-right*" +) +@_formats( + "equality-expression-right* -> equality-expression-right" + " equality-expression-right*" +) +@_formats( + "greater-expression-right-list ->" + " equality-expression-right* greater-expression-right" + " equality-or-greater-expression-right*" +) +@_formats( + "comparison-expression -> additive-expression" + " less-expression-right-list" +) def _concatenate(*elements): - """Concatenates all arguments with no delimiters.""" - return ''.join(elements) + """Concatenates all arguments with no delimiters.""" + return "".join(elements) -@_formats('equality-expression-right -> equality-operator additive-expression') -@_formats('less-expression-right -> less-operator additive-expression') -@_formats('greater-expression-right -> greater-operator additive-expression') -@_formats('or-expression-right -> or-operator comparison-expression') -@_formats('and-expression-right -> and-operator comparison-expression') +@_formats("equality-expression-right -> equality-operator additive-expression") +@_formats("less-expression-right -> less-operator additive-expression") +@_formats("greater-expression-right -> greater-operator additive-expression") +@_formats("or-expression-right -> or-operator comparison-expression") +@_formats("and-expression-right -> and-operator comparison-expression") def _concatenate_with_prefix_spaces(*elements): - return ''.join(' ' + element for element in elements if element) + return "".join(" " + element for element in elements if element) -@_formats('attribute* -> attribute attribute*') +@_formats("attribute* -> attribute attribute*") @_formats('comma-then-expression -> "," expression') -@_formats('comparison-expression -> additive-expression inequality-operator' - ' additive-expression') -@_formats('choice-expression -> logical-expression "?" logical-expression' - ' ":" logical-expression') +@_formats( + "comparison-expression -> additive-expression inequality-operator" + " additive-expression" +) +@_formats( + 'choice-expression -> logical-expression "?" logical-expression' + ' ":" logical-expression' +) @_formats('parameter-definition-list-tail -> "," parameter-definition') def _concatenate_with_spaces(*elements): - return _concatenate_with(' ', *elements) + return _concatenate_with(" ", *elements) def _concatenate_with(joiner, *elements): - return joiner.join(element for element in elements if element) + return joiner.join(element for element in elements if element) -@_formats('attribute-line* -> attribute-line attribute-line*') -@_formats('comment-line* -> comment-line comment-line*') -@_formats('doc-line* -> doc-line doc-line*') -@_formats('import-line* -> import-line import-line*') +@_formats("attribute-line* -> attribute-line attribute-line*") +@_formats("comment-line* -> comment-line comment-line*") +@_formats("doc-line* -> doc-line doc-line*") +@_formats("import-line* -> import-line import-line*") def _concatenate_lists(head, tail): - return head + tail + return head + tail _check_productions() diff --git a/compiler/front_end/format_emb_test.py b/compiler/front_end/format_emb_test.py index e633f59..3d5331e 100644 --- a/compiler/front_end/format_emb_test.py +++ b/compiler/front_end/format_emb_test.py @@ -31,163 +31,180 @@ class SanityCheckerTest(unittest.TestCase): - def test_text_does_not_tokenize(self): - self.assertTrue(format_emb.sanity_check_format_result("-- doc", "~ bad")) - - def test_original_text_does_not_tokenize(self): - self.assertTrue(format_emb.sanity_check_format_result("~ bad", "-- doc")) - - def test_text_matches(self): - self.assertFalse(format_emb.sanity_check_format_result("-- doc", "-- doc")) - - def test_text_has_extra_eols(self): - self.assertFalse( - format_emb.sanity_check_format_result("-- doc\n\n-- doc", - "-- doc\n\n\n-- doc")) - - def test_text_has_fewer_eols(self): - self.assertFalse(format_emb.sanity_check_format_result("-- doc\n\n-- doc", - "-- doc\n-- doc")) - - def test_original_text_has_leading_eols(self): - self.assertFalse(format_emb.sanity_check_format_result("\n\n-- doc\n", - "-- doc\n")) - - def test_original_text_has_extra_doc_whitespace(self): - self.assertFalse(format_emb.sanity_check_format_result("-- doc \n", - "-- doc\n")) - - def test_comments_differ(self): - self.assertTrue(format_emb.sanity_check_format_result("#c\n-- doc\n", - "#d\n-- doc\n")) - - def test_comment_missing(self): - self.assertTrue(format_emb.sanity_check_format_result("#c\n-- doc\n", - "\n-- doc\n")) - - def test_comment_added(self): - self.assertTrue(format_emb.sanity_check_format_result("\n-- doc\n", - "#d\n-- doc\n")) - - def test_token_text_differs(self): - self.assertTrue(format_emb.sanity_check_format_result("-- doc\n", - "-- bad doc\n")) - - def test_token_type_differs(self): - self.assertTrue(format_emb.sanity_check_format_result("-- doc\n", - "abc\n")) - - def test_eol_missing(self): - self.assertTrue(format_emb.sanity_check_format_result("abc\n-- doc\n", - "abc -- doc\n")) + def test_text_does_not_tokenize(self): + self.assertTrue(format_emb.sanity_check_format_result("-- doc", "~ bad")) + + def test_original_text_does_not_tokenize(self): + self.assertTrue(format_emb.sanity_check_format_result("~ bad", "-- doc")) + + def test_text_matches(self): + self.assertFalse(format_emb.sanity_check_format_result("-- doc", "-- doc")) + + def test_text_has_extra_eols(self): + self.assertFalse( + format_emb.sanity_check_format_result( + "-- doc\n\n-- doc", "-- doc\n\n\n-- doc" + ) + ) + + def test_text_has_fewer_eols(self): + self.assertFalse( + format_emb.sanity_check_format_result("-- doc\n\n-- doc", "-- doc\n-- doc") + ) + + def test_original_text_has_leading_eols(self): + self.assertFalse( + format_emb.sanity_check_format_result("\n\n-- doc\n", "-- doc\n") + ) + + def test_original_text_has_extra_doc_whitespace(self): + self.assertFalse( + format_emb.sanity_check_format_result("-- doc \n", "-- doc\n") + ) + + def test_comments_differ(self): + self.assertTrue( + format_emb.sanity_check_format_result("#c\n-- doc\n", "#d\n-- doc\n") + ) + + def test_comment_missing(self): + self.assertTrue( + format_emb.sanity_check_format_result("#c\n-- doc\n", "\n-- doc\n") + ) + + def test_comment_added(self): + self.assertTrue( + format_emb.sanity_check_format_result("\n-- doc\n", "#d\n-- doc\n") + ) + + def test_token_text_differs(self): + self.assertTrue( + format_emb.sanity_check_format_result("-- doc\n", "-- bad doc\n") + ) + + def test_token_type_differs(self): + self.assertTrue(format_emb.sanity_check_format_result("-- doc\n", "abc\n")) + + def test_eol_missing(self): + self.assertTrue( + format_emb.sanity_check_format_result("abc\n-- doc\n", "abc -- doc\n") + ) class FormatEmbTest(unittest.TestCase): - pass + pass def _make_golden_file_tests(): - """Generates test cases from the golden files in the resource bundle.""" - - package = "testdata.format" - path_prefix = "" - - def make_test_case(name, unformatted_text, expected_text, indent_width): - - def test_case(self): - self.maxDiff = 100000 - unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, name) - self.assertFalse(errors) - parsed_unformatted = parser.parse_module(unformatted_tokens) - self.assertFalse(parsed_unformatted.error) - formatted_text = format_emb.format_emboss_parse_tree( - parsed_unformatted.parse_tree, - format_emb.Config(indent_width=indent_width)) - self.assertEqual(expected_text, formatted_text) - annotated_text = format_emb.format_emboss_parse_tree( - parsed_unformatted.parse_tree, - format_emb.Config(indent_width=indent_width, show_line_types=True)) - self.assertEqual(expected_text, re.sub(r"^.*?\|", "", annotated_text, - flags=re.MULTILINE)) - self.assertFalse(re.search("^[^|]+$", annotated_text, flags=re.MULTILINE)) - - return test_case - - all_unformatted_texts = [] - - for filename in ( - "abbreviations", - "anonymous_bits_formatting", - "arithmetic_expressions", - "array_length", - "attributes", - "choice_expression", - "comparison_expressions", - "conditional_field_formatting", - "conditional_inline_bits_formatting", - "dotted_names", - "empty", - "enum_value_attributes", - "enum_value_bodies", - "enum_values_aligned", - "equality_expressions", - "external", - "extra_newlines", - "fields_aligned", - "functions", - "header_and_type", - "indent", - "inline_attributes_get_a_column", - "inline_bits", - "inline_documentation_gets_a_column", - "inline_enum", - "inline_struct", - "lines_not_spaced_out_with_excess_trailing_noise_lines", - "lines_not_spaced_out_with_not_enough_noise_lines", - "lines_spaced_out_with_noise_lines", - "logical_expressions", - "multiline_ifs", - "multiple_header_sections", - "nested_types_are_columnized_independently", - "one_type", - "parameterized_struct", - "sanity_check", - "spacing_between_types", - "trailing_spaces", - "virtual_fields"): - for suffix, width in ((".emb.formatted", 2), - (".emb.formatted_indent_4", 4)): - unformatted_name = path_prefix + filename + ".emb" - expected_name = path_prefix + filename + suffix - unformatted_text = pkgutil.get_data(package, - unformatted_name).decode("utf-8") - expected_text = pkgutil.get_data(package, expected_name).decode("utf-8") - setattr(FormatEmbTest, "test {} indent {}".format(filename, width), - make_test_case(filename, unformatted_text, expected_text, width)) - - all_unformatted_texts.append(unformatted_text) - - def test_all_productions_used(self): - used_productions = set() - for unformatted_text in all_unformatted_texts: - unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, "") - self.assertFalse(errors) - parsed_unformatted = parser.parse_module(unformatted_tokens) - self.assertFalse(parsed_unformatted.error) - format_emb.format_emboss_parse_tree(parsed_unformatted.parse_tree, - format_emb.Config(), used_productions) - unused_productions = set(module_ir.PRODUCTIONS) - used_productions - if unused_productions: - print("Used production total:", len(used_productions), file=sys.stderr) - for production in unused_productions: - print("Unused production:", str(production), file=sys.stderr) - print("Total:", len(unused_productions), file=sys.stderr) - self.assertEqual(set(module_ir.PRODUCTIONS), used_productions) - - FormatEmbTest.testAllProductionsUsed = test_all_productions_used + """Generates test cases from the golden files in the resource bundle.""" + + package = "testdata.format" + path_prefix = "" + + def make_test_case(name, unformatted_text, expected_text, indent_width): + + def test_case(self): + self.maxDiff = 100000 + unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, name) + self.assertFalse(errors) + parsed_unformatted = parser.parse_module(unformatted_tokens) + self.assertFalse(parsed_unformatted.error) + formatted_text = format_emb.format_emboss_parse_tree( + parsed_unformatted.parse_tree, + format_emb.Config(indent_width=indent_width), + ) + self.assertEqual(expected_text, formatted_text) + annotated_text = format_emb.format_emboss_parse_tree( + parsed_unformatted.parse_tree, + format_emb.Config(indent_width=indent_width, show_line_types=True), + ) + self.assertEqual( + expected_text, re.sub(r"^.*?\|", "", annotated_text, flags=re.MULTILINE) + ) + self.assertFalse(re.search("^[^|]+$", annotated_text, flags=re.MULTILINE)) + + return test_case + + all_unformatted_texts = [] + + for filename in ( + "abbreviations", + "anonymous_bits_formatting", + "arithmetic_expressions", + "array_length", + "attributes", + "choice_expression", + "comparison_expressions", + "conditional_field_formatting", + "conditional_inline_bits_formatting", + "dotted_names", + "empty", + "enum_value_attributes", + "enum_value_bodies", + "enum_values_aligned", + "equality_expressions", + "external", + "extra_newlines", + "fields_aligned", + "functions", + "header_and_type", + "indent", + "inline_attributes_get_a_column", + "inline_bits", + "inline_documentation_gets_a_column", + "inline_enum", + "inline_struct", + "lines_not_spaced_out_with_excess_trailing_noise_lines", + "lines_not_spaced_out_with_not_enough_noise_lines", + "lines_spaced_out_with_noise_lines", + "logical_expressions", + "multiline_ifs", + "multiple_header_sections", + "nested_types_are_columnized_independently", + "one_type", + "parameterized_struct", + "sanity_check", + "spacing_between_types", + "trailing_spaces", + "virtual_fields", + ): + for suffix, width in ((".emb.formatted", 2), (".emb.formatted_indent_4", 4)): + unformatted_name = path_prefix + filename + ".emb" + expected_name = path_prefix + filename + suffix + unformatted_text = pkgutil.get_data(package, unformatted_name).decode( + "utf-8" + ) + expected_text = pkgutil.get_data(package, expected_name).decode("utf-8") + setattr( + FormatEmbTest, + "test {} indent {}".format(filename, width), + make_test_case(filename, unformatted_text, expected_text, width), + ) + + all_unformatted_texts.append(unformatted_text) + + def test_all_productions_used(self): + used_productions = set() + for unformatted_text in all_unformatted_texts: + unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, "") + self.assertFalse(errors) + parsed_unformatted = parser.parse_module(unformatted_tokens) + self.assertFalse(parsed_unformatted.error) + format_emb.format_emboss_parse_tree( + parsed_unformatted.parse_tree, format_emb.Config(), used_productions + ) + unused_productions = set(module_ir.PRODUCTIONS) - used_productions + if unused_productions: + print("Used production total:", len(used_productions), file=sys.stderr) + for production in unused_productions: + print("Unused production:", str(production), file=sys.stderr) + print("Total:", len(unused_productions), file=sys.stderr) + self.assertEqual(set(module_ir.PRODUCTIONS), used_productions) + + FormatEmbTest.testAllProductionsUsed = test_all_productions_used _make_golden_file_tests() if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/generate_grammar_md.py b/compiler/front_end/generate_grammar_md.py index 6cb2c58..53c9e50 100644 --- a/compiler/front_end/generate_grammar_md.py +++ b/compiler/front_end/generate_grammar_md.py @@ -65,172 +65,179 @@ def _sort_productions(productions, start_symbol): - """Sorts the given productions in a human-friendly order.""" - productions_by_lhs = {} - for p in productions: - if p.lhs not in productions_by_lhs: - productions_by_lhs[p.lhs] = set() - productions_by_lhs[p.lhs].add(p) - - queue = [start_symbol] - previously_queued_symbols = set(queue) - main_production_list = [] - # This sorts productions depth-first. I'm not sure if it is better to sort - # them breadth-first or depth-first, or with some hybrid. - while queue: - symbol = queue.pop(-1) - if symbol not in productions_by_lhs: - continue - for production in sorted(productions_by_lhs[symbol]): - main_production_list.append(production) - for symbol in production.rhs: - # Skip boilerplate productions for now, but include their base - # production. - if symbol and symbol[-1] in "*+?": - symbol = symbol[0:-1] - if symbol not in previously_queued_symbols: - queue.append(symbol) - previously_queued_symbols.add(symbol) - - # It's not particularly important to put boilerplate productions in any - # particular order. - boilerplate_production_list = sorted( - set(productions) - set(main_production_list)) - for production in boilerplate_production_list: - assert production.lhs[-1] in "*+?", "Found orphaned production {}".format( - production.lhs) - assert set(productions) == set( - main_production_list + boilerplate_production_list) - assert len(productions) == len(main_production_list) + len( - boilerplate_production_list) - return main_production_list, boilerplate_production_list + """Sorts the given productions in a human-friendly order.""" + productions_by_lhs = {} + for p in productions: + if p.lhs not in productions_by_lhs: + productions_by_lhs[p.lhs] = set() + productions_by_lhs[p.lhs].add(p) + + queue = [start_symbol] + previously_queued_symbols = set(queue) + main_production_list = [] + # This sorts productions depth-first. I'm not sure if it is better to sort + # them breadth-first or depth-first, or with some hybrid. + while queue: + symbol = queue.pop(-1) + if symbol not in productions_by_lhs: + continue + for production in sorted(productions_by_lhs[symbol]): + main_production_list.append(production) + for symbol in production.rhs: + # Skip boilerplate productions for now, but include their base + # production. + if symbol and symbol[-1] in "*+?": + symbol = symbol[0:-1] + if symbol not in previously_queued_symbols: + queue.append(symbol) + previously_queued_symbols.add(symbol) + + # It's not particularly important to put boilerplate productions in any + # particular order. + boilerplate_production_list = sorted(set(productions) - set(main_production_list)) + for production in boilerplate_production_list: + assert production.lhs[-1] in "*+?", "Found orphaned production {}".format( + production.lhs + ) + assert set(productions) == set(main_production_list + boilerplate_production_list) + assert len(productions) == len(main_production_list) + len( + boilerplate_production_list + ) + return main_production_list, boilerplate_production_list def _word_wrap_at_column(words, width): - """Wraps words to the specified width, and returns a list of wrapped lines.""" - result = [] - in_progress = [] - for word in words: - if len(" ".join(in_progress + [word])) > width: - result.append(" ".join(in_progress)) - assert len(result[-1]) <= width - in_progress = [] - in_progress.append(word) - result.append(" ".join(in_progress)) - assert len(result[-1]) <= width - return result + """Wraps words to the specified width, and returns a list of wrapped lines.""" + result = [] + in_progress = [] + for word in words: + if len(" ".join(in_progress + [word])) > width: + result.append(" ".join(in_progress)) + assert len(result[-1]) <= width + in_progress = [] + in_progress.append(word) + result.append(" ".join(in_progress)) + assert len(result[-1]) <= width + return result def _format_productions(productions): - """Formats a list of productions for inclusion in a Markdown document.""" - max_lhs_len = max([len(production.lhs) for production in productions]) - - # TODO(bolms): This highlighting is close for now, but not actually right. - result = ["```shell\n"] - last_lhs = None - for production in productions: - if last_lhs == production.lhs: - lhs = "" - delimiter = " |" - else: - lhs = production.lhs - delimiter = "->" - leader = "{lhs:{width}} {delimiter}".format( - lhs=lhs, - width=max_lhs_len, - delimiter=delimiter) - for rhs_block in _word_wrap_at_column( - production.rhs or [""], _MAX_OUTPUT_WIDTH - len(leader)): - result.append("{leader} {rhs}\n".format(leader=leader, rhs=rhs_block)) - leader = " " * len(leader) - last_lhs = production.lhs - result.append("```\n") - return "".join(result) + """Formats a list of productions for inclusion in a Markdown document.""" + max_lhs_len = max([len(production.lhs) for production in productions]) + + # TODO(bolms): This highlighting is close for now, but not actually right. + result = ["```shell\n"] + last_lhs = None + for production in productions: + if last_lhs == production.lhs: + lhs = "" + delimiter = " |" + else: + lhs = production.lhs + delimiter = "->" + leader = "{lhs:{width}} {delimiter}".format( + lhs=lhs, width=max_lhs_len, delimiter=delimiter + ) + for rhs_block in _word_wrap_at_column( + production.rhs or [""], _MAX_OUTPUT_WIDTH - len(leader) + ): + result.append("{leader} {rhs}\n".format(leader=leader, rhs=rhs_block)) + leader = " " * len(leader) + last_lhs = production.lhs + result.append("```\n") + return "".join(result) def _normalize_literal_patterns(literals): - """Normalizes a list of strings to a list of (regex, symbol) pairs.""" - return [(re.sub(r"(\W)", r"\\\1", literal), '"' + literal + '"') - for literal in literals] + """Normalizes a list of strings to a list of (regex, symbol) pairs.""" + return [ + (re.sub(r"(\W)", r"\\\1", literal), '"' + literal + '"') for literal in literals + ] def _normalize_regex_patterns(regexes): - """Normalizes a list of tokenizer regexes to a list of (regex, symbol).""" - # g3doc breaks up patterns containing '|' when they are inserted into a table, - # unless they're preceded by '\'. Note that other special characters, - # including '\', should *not* be escaped with '\'. - return [(re.sub(r"\|", r"\\|", r.regex.pattern), r.symbol) for r in regexes] + """Normalizes a list of tokenizer regexes to a list of (regex, symbol).""" + # g3doc breaks up patterns containing '|' when they are inserted into a table, + # unless they're preceded by '\'. Note that other special characters, + # including '\', should *not* be escaped with '\'. + return [(re.sub(r"\|", r"\\|", r.regex.pattern), r.symbol) for r in regexes] def _normalize_reserved_word_list(reserved_words): - """Returns words that would be allowed as names if they were not reserved.""" - interesting_reserved_words = [] - for word in reserved_words: - tokens, errors = tokenizer.tokenize(word, "") - assert tokens and not errors, "Failed to tokenize " + word - if tokens[0].symbol in ["SnakeWord", "CamelWord", "ShoutyWord"]: - interesting_reserved_words.append(word) - return sorted(interesting_reserved_words) + """Returns words that would be allowed as names if they were not reserved.""" + interesting_reserved_words = [] + for word in reserved_words: + tokens, errors = tokenizer.tokenize(word, "") + assert tokens and not errors, "Failed to tokenize " + word + if tokens[0].symbol in ["SnakeWord", "CamelWord", "ShoutyWord"]: + interesting_reserved_words.append(word) + return sorted(interesting_reserved_words) def _format_token_rules(token_rules): - """Formats a list of (pattern, symbol) pairs as a table.""" - pattern_width = max([len(rule[0]) for rule in token_rules]) - pattern_width += 2 # For the `` characters. - result = ["{pat_header:{width}} | Symbol\n" - "{empty:-<{width}} | {empty:-<30}\n".format(pat_header="Pattern", - width=pattern_width, - empty="")] - for rule in token_rules: - if rule[1]: - symbol_name = "`" + rule[1] + "`" - else: - symbol_name = "*no symbol emitted*" - result.append( - "{pattern:{width}} | {symbol}\n".format(pattern="`" + rule[0] + "`", - width=pattern_width, - symbol=symbol_name)) - return "".join(result) + """Formats a list of (pattern, symbol) pairs as a table.""" + pattern_width = max([len(rule[0]) for rule in token_rules]) + pattern_width += 2 # For the `` characters. + result = [ + "{pat_header:{width}} | Symbol\n" + "{empty:-<{width}} | {empty:-<30}\n".format( + pat_header="Pattern", width=pattern_width, empty="" + ) + ] + for rule in token_rules: + if rule[1]: + symbol_name = "`" + rule[1] + "`" + else: + symbol_name = "*no symbol emitted*" + result.append( + "{pattern:{width}} | {symbol}\n".format( + pattern="`" + rule[0] + "`", width=pattern_width, symbol=symbol_name + ) + ) + return "".join(result) def _format_keyword_list(reserved_words): - """formats a list of reserved words.""" - lines = [] - current_line = "" - for word in reserved_words: - if len(current_line) + len(word) + 2 > 80: - lines.append(current_line) - current_line = "" - current_line += "`{}` ".format(word) - return "".join([line[:-1] + "\n" for line in lines]) + """formats a list of reserved words.""" + lines = [] + current_line = "" + for word in reserved_words: + if len(current_line) + len(word) + 2 > 80: + lines.append(current_line) + current_line = "" + current_line += "`{}` ".format(word) + return "".join([line[:-1] + "\n" for line in lines]) def generate_grammar_md(): - """Generates up-to-date text for grammar.md.""" - main_productions, boilerplate_productions = _sort_productions( - module_ir.PRODUCTIONS, module_ir.START_SYMBOL) - result = [_HEADER, _format_productions(main_productions), - _BOILERPLATE_PRODUCTION_HEADER, - _format_productions(boilerplate_productions)] + """Generates up-to-date text for grammar.md.""" + main_productions, boilerplate_productions = _sort_productions( + module_ir.PRODUCTIONS, module_ir.START_SYMBOL + ) + result = [ + _HEADER, + _format_productions(main_productions), + _BOILERPLATE_PRODUCTION_HEADER, + _format_productions(boilerplate_productions), + ] - main_tokens = _normalize_literal_patterns(tokenizer.LITERAL_TOKEN_PATTERNS) - main_tokens += _normalize_regex_patterns(tokenizer.REGEX_TOKEN_PATTERNS) - result.append(_TOKENIZER_RULE_HEADER) - result.append(_format_token_rules(main_tokens)) + main_tokens = _normalize_literal_patterns(tokenizer.LITERAL_TOKEN_PATTERNS) + main_tokens += _normalize_regex_patterns(tokenizer.REGEX_TOKEN_PATTERNS) + result.append(_TOKENIZER_RULE_HEADER) + result.append(_format_token_rules(main_tokens)) - reserved_words = _normalize_reserved_word_list( - constraints.get_reserved_word_list()) - result.append(_KEYWORDS_HEADER.format(len(reserved_words))) - result.append(_format_keyword_list(reserved_words)) + reserved_words = _normalize_reserved_word_list(constraints.get_reserved_word_list()) + result.append(_KEYWORDS_HEADER.format(len(reserved_words))) + result.append(_format_keyword_list(reserved_words)) - return "".join(result) + return "".join(result) def main(argv): - del argv # Unused. - print(generate_grammar_md(), end="") - return 0 + del argv # Unused. + print(generate_grammar_md(), end="") + return 0 if __name__ == "__main__": - sys.exit(main(sys.argv)) + sys.exit(main(sys.argv)) diff --git a/compiler/front_end/glue.py b/compiler/front_end/glue.py index a1e1a5b..2744086 100644 --- a/compiler/front_end/glue.py +++ b/compiler/front_end/glue.py @@ -38,330 +38,347 @@ from compiler.util import parser_types from compiler.util import resources -_IrDebugInfo = collections.namedtuple("IrDebugInfo", ["ir", "debug_info", - "errors"]) +_IrDebugInfo = collections.namedtuple("IrDebugInfo", ["ir", "debug_info", "errors"]) class DebugInfo(object): - """Debug information about Emboss parsing.""" - __slots__ = ("modules") + """Debug information about Emboss parsing.""" - def __init__(self): - self.modules = {} + __slots__ = "modules" - def __eq__(self, other): - return self.modules == other.modules + def __init__(self): + self.modules = {} - def __ne__(self, other): - return not self == other + def __eq__(self, other): + return self.modules == other.modules + def __ne__(self, other): + return not self == other -class ModuleDebugInfo(object): - """Debug information about the parse of a single file. - - Attributes: - file_name: The name of the file from which this module came. - tokens: The tokenization of this module's source text. - parse_tree: The raw parse tree for this module. - ir: The intermediate representation of this module, before additional - processing such as symbol resolution. - used_productions: The set of grammar productions used when parsing this - module. - source_code: The source text of the module. - """ - __slots__ = ("file_name", "tokens", "parse_tree", "ir", "used_productions", - "source_code") - - def __init__(self, file_name): - self.file_name = file_name - self.tokens = None - self.parse_tree = None - self.ir = None - self.used_productions = None - self.source_code = None - - def __eq__(self, other): - return (self.file_name == other.file_name and self.tokens == other.tokens - and self.parse_tree == other.parse_tree and self.ir == other.ir and - self.used_productions == other.used_productions and - self.source_code == other.source_code) - - def __ne__(self, other): - return not self == other - - def format_tokenization(self): - """Renders self.tokens in a human-readable format.""" - return "\n".join([str(token) for token in self.tokens]) - - def format_parse_tree(self, parse_tree=None, indent=""): - """Renders self.parse_tree in a human-readable format.""" - if parse_tree is None: - parse_tree = self.parse_tree - result = [] - if isinstance(parse_tree, lr1.Reduction): - result.append(indent + parse_tree.symbol) - if parse_tree.children: - result.append(":\n") - for child in parse_tree.children: - result.append(self.format_parse_tree(child, indent + " ")) - else: - result.append("\n") - else: - result.append("{}{}\n".format(indent, parse_tree)) - return "".join(result) - def format_module_ir(self): - """Renders self.ir in a human-readable format.""" - return ir_data_utils.IrDataSerializer(self.ir).to_json(indent=2) +class ModuleDebugInfo(object): + """Debug information about the parse of a single file. + + Attributes: + file_name: The name of the file from which this module came. + tokens: The tokenization of this module's source text. + parse_tree: The raw parse tree for this module. + ir: The intermediate representation of this module, before additional + processing such as symbol resolution. + used_productions: The set of grammar productions used when parsing this + module. + source_code: The source text of the module. + """ + + __slots__ = ( + "file_name", + "tokens", + "parse_tree", + "ir", + "used_productions", + "source_code", + ) + + def __init__(self, file_name): + self.file_name = file_name + self.tokens = None + self.parse_tree = None + self.ir = None + self.used_productions = None + self.source_code = None + + def __eq__(self, other): + return ( + self.file_name == other.file_name + and self.tokens == other.tokens + and self.parse_tree == other.parse_tree + and self.ir == other.ir + and self.used_productions == other.used_productions + and self.source_code == other.source_code + ) + + def __ne__(self, other): + return not self == other + + def format_tokenization(self): + """Renders self.tokens in a human-readable format.""" + return "\n".join([str(token) for token in self.tokens]) + + def format_parse_tree(self, parse_tree=None, indent=""): + """Renders self.parse_tree in a human-readable format.""" + if parse_tree is None: + parse_tree = self.parse_tree + result = [] + if isinstance(parse_tree, lr1.Reduction): + result.append(indent + parse_tree.symbol) + if parse_tree.children: + result.append(":\n") + for child in parse_tree.children: + result.append(self.format_parse_tree(child, indent + " ")) + else: + result.append("\n") + else: + result.append("{}{}\n".format(indent, parse_tree)) + return "".join(result) + + def format_module_ir(self): + """Renders self.ir in a human-readable format.""" + return ir_data_utils.IrDataSerializer(self.ir).to_json(indent=2) def format_production_set(productions): - """Renders a set of productions in a human-readable format.""" - return "\n".join([str(production) for production in sorted(productions)]) + """Renders a set of productions in a human-readable format.""" + return "\n".join([str(production) for production in sorted(productions)]) _cached_modules = {} def parse_module_text(source_code, file_name): - """Parses the text of a module, returning a module-level IR. - - Arguments: - source_code: The text of the module to parse. - file_name: The name of the module's source file (will be included in the - resulting IR). - - Returns: - A module-level intermediate representation (IR), prior to import and symbol - resolution, and a corresponding ModuleDebugInfo, for debugging the parser. - - Raises: - FrontEndFailure: An error occurred while parsing the module. str(error) - will give a human-readable error message. - """ - # This is strictly an optimization to speed up tests, mostly by avoiding the - # need to re-parse the prelude for every test .emb. - if (source_code, file_name) in _cached_modules: - debug_info = _cached_modules[source_code, file_name] - ir = ir_data_utils.copy(debug_info.ir) - else: - debug_info = ModuleDebugInfo(file_name) - debug_info.source_code = source_code - tokens, errors = tokenizer.tokenize(source_code, file_name) - if errors: - return _IrDebugInfo(None, debug_info, errors) - debug_info.tokens = tokens - parse_result = parser.parse_module(tokens) - if parse_result.error: - return _IrDebugInfo( - None, - debug_info, - [error.make_error_from_parse_error(file_name, parse_result.error)]) - debug_info.parse_tree = parse_result.parse_tree - used_productions = set() - ir = module_ir.build_ir(parse_result.parse_tree, used_productions) - ir.source_text = source_code - debug_info.used_productions = used_productions - debug_info.ir = ir_data_utils.copy(ir) - _cached_modules[source_code, file_name] = debug_info - ir.source_file_name = file_name - return _IrDebugInfo(ir, debug_info, []) + """Parses the text of a module, returning a module-level IR. + + Arguments: + source_code: The text of the module to parse. + file_name: The name of the module's source file (will be included in the + resulting IR). + + Returns: + A module-level intermediate representation (IR), prior to import and symbol + resolution, and a corresponding ModuleDebugInfo, for debugging the parser. + + Raises: + FrontEndFailure: An error occurred while parsing the module. str(error) + will give a human-readable error message. + """ + # This is strictly an optimization to speed up tests, mostly by avoiding the + # need to re-parse the prelude for every test .emb. + if (source_code, file_name) in _cached_modules: + debug_info = _cached_modules[source_code, file_name] + ir = ir_data_utils.copy(debug_info.ir) + else: + debug_info = ModuleDebugInfo(file_name) + debug_info.source_code = source_code + tokens, errors = tokenizer.tokenize(source_code, file_name) + if errors: + return _IrDebugInfo(None, debug_info, errors) + debug_info.tokens = tokens + parse_result = parser.parse_module(tokens) + if parse_result.error: + return _IrDebugInfo( + None, + debug_info, + [error.make_error_from_parse_error(file_name, parse_result.error)], + ) + debug_info.parse_tree = parse_result.parse_tree + used_productions = set() + ir = module_ir.build_ir(parse_result.parse_tree, used_productions) + ir.source_text = source_code + debug_info.used_productions = used_productions + debug_info.ir = ir_data_utils.copy(ir) + _cached_modules[source_code, file_name] = debug_info + ir.source_file_name = file_name + return _IrDebugInfo(ir, debug_info, []) def parse_module(file_name, file_reader): - """Parses a module, returning a module-level IR. - - Arguments: - file_name: The name of the module's source file. - file_reader: A callable that returns either: - (file_contents, None) or - (None, list_of_error_detail_strings) - - Returns: - (ir, debug_info, errors), where ir is a module-level intermediate - representation (IR), debug_info is a ModuleDebugInfo containing the - tokenization, parse tree, and original source text of all modules, and - errors is a list of tokenization or parse errors. If errors is not an empty - list, ir will be None. - - Raises: - FrontEndFailure: An error occurred while reading or parsing the module. - str(error) will give a human-readable error message. - """ - source_code, errors = file_reader(file_name) - if errors: - location = parser_types.make_location((1, 1), (1, 1)) - return None, None, [ - [error.error(file_name, location, "Unable to read file.")] + - [error.note(file_name, location, e) for e in errors] - ] - return parse_module_text(source_code, file_name) + """Parses a module, returning a module-level IR. + + Arguments: + file_name: The name of the module's source file. + file_reader: A callable that returns either: + (file_contents, None) or + (None, list_of_error_detail_strings) + + Returns: + (ir, debug_info, errors), where ir is a module-level intermediate + representation (IR), debug_info is a ModuleDebugInfo containing the + tokenization, parse tree, and original source text of all modules, and + errors is a list of tokenization or parse errors. If errors is not an empty + list, ir will be None. + + Raises: + FrontEndFailure: An error occurred while reading or parsing the module. + str(error) will give a human-readable error message. + """ + source_code, errors = file_reader(file_name) + if errors: + location = parser_types.make_location((1, 1), (1, 1)) + return ( + None, + None, + [ + [error.error(file_name, location, "Unable to read file.")] + + [error.note(file_name, location, e) for e in errors] + ], + ) + return parse_module_text(source_code, file_name) def get_prelude(): - """Returns the module IR and debug info of the Emboss Prelude.""" - return parse_module_text( - resources.load("compiler.front_end", "prelude.emb"), "") + """Returns the module IR and debug info of the Emboss Prelude.""" + return parse_module_text(resources.load("compiler.front_end", "prelude.emb"), "") def parse_emboss_file(file_name, file_reader, stop_before_step=None): - """Fully parses an .emb, and returns an IR suitable for passing to a back end. - - parse_emboss_file is a convenience function which calls only_parse_emboss_file - and process_ir. - - Arguments: - file_name: The name of the module's source file. - file_reader: A callable that returns the contents of files, or raises - IOError. - stop_before_step: If set, parse_emboss_file will stop normalizing the IR - just before the specified step. This parameter should be None for - non-test code. - - Returns: - (ir, debug_info, errors), where ir is a complete IR, ready for consumption - by an Emboss back end, debug_info is a DebugInfo containing the - tokenization, parse tree, and original source text of all modules, and - errors is a list of tokenization or parse errors. If errors is not an empty - list, ir will be None. - """ - ir, debug_info, errors = only_parse_emboss_file(file_name, file_reader) - if errors: - return _IrDebugInfo(None, debug_info, errors) - ir, errors = process_ir(ir, stop_before_step) - if errors: - return _IrDebugInfo(None, debug_info, errors) - return _IrDebugInfo(ir, debug_info, errors) + """Fully parses an .emb, and returns an IR suitable for passing to a back end. + + parse_emboss_file is a convenience function which calls only_parse_emboss_file + and process_ir. + + Arguments: + file_name: The name of the module's source file. + file_reader: A callable that returns the contents of files, or raises + IOError. + stop_before_step: If set, parse_emboss_file will stop normalizing the IR + just before the specified step. This parameter should be None for + non-test code. + + Returns: + (ir, debug_info, errors), where ir is a complete IR, ready for consumption + by an Emboss back end, debug_info is a DebugInfo containing the + tokenization, parse tree, and original source text of all modules, and + errors is a list of tokenization or parse errors. If errors is not an empty + list, ir will be None. + """ + ir, debug_info, errors = only_parse_emboss_file(file_name, file_reader) + if errors: + return _IrDebugInfo(None, debug_info, errors) + ir, errors = process_ir(ir, stop_before_step) + if errors: + return _IrDebugInfo(None, debug_info, errors) + return _IrDebugInfo(ir, debug_info, errors) def only_parse_emboss_file(file_name, file_reader): - """Parses an .emb, and returns an IR suitable for process_ir. - - only_parse_emboss_file parses the given file and all of its transitive - imports, and returns a first-stage intermediate representation, which can be - passed to process_ir. - - Arguments: - file_name: The name of the module's source file. - file_reader: A callable that returns the contents of files, or raises - IOError. - - Returns: - (ir, debug_info, errors), where ir is an intermediate representation (IR), - debug_info is a DebugInfo containing the tokenization, parse tree, and - original source text of all modules, and errors is a list of tokenization or - parse errors. If errors is not an empty list, ir will be None. - """ - file_queue = [file_name] - files = {file_name} - debug_info = DebugInfo() - ir = ir_data.EmbossIr(module=[]) - while file_queue: - file_to_parse = file_queue[0] - del file_queue[0] - if file_to_parse: - module, module_debug_info, errors = parse_module(file_to_parse, - file_reader) - else: - module, module_debug_info, errors = get_prelude() - if module_debug_info: - debug_info.modules[file_to_parse] = module_debug_info - if errors: - return _IrDebugInfo(None, debug_info, errors) - ir.module.extend([module]) # Proto supports extend but not append here. - for import_ in module.foreign_import: - if import_.file_name.text not in files: - file_queue.append(import_.file_name.text) - files.add(import_.file_name.text) - return _IrDebugInfo(ir, debug_info, []) + """Parses an .emb, and returns an IR suitable for process_ir. + + only_parse_emboss_file parses the given file and all of its transitive + imports, and returns a first-stage intermediate representation, which can be + passed to process_ir. + + Arguments: + file_name: The name of the module's source file. + file_reader: A callable that returns the contents of files, or raises + IOError. + + Returns: + (ir, debug_info, errors), where ir is an intermediate representation (IR), + debug_info is a DebugInfo containing the tokenization, parse tree, and + original source text of all modules, and errors is a list of tokenization or + parse errors. If errors is not an empty list, ir will be None. + """ + file_queue = [file_name] + files = {file_name} + debug_info = DebugInfo() + ir = ir_data.EmbossIr(module=[]) + while file_queue: + file_to_parse = file_queue[0] + del file_queue[0] + if file_to_parse: + module, module_debug_info, errors = parse_module(file_to_parse, file_reader) + else: + module, module_debug_info, errors = get_prelude() + if module_debug_info: + debug_info.modules[file_to_parse] = module_debug_info + if errors: + return _IrDebugInfo(None, debug_info, errors) + ir.module.extend([module]) # Proto supports extend but not append here. + for import_ in module.foreign_import: + if import_.file_name.text not in files: + file_queue.append(import_.file_name.text) + files.add(import_.file_name.text) + return _IrDebugInfo(ir, debug_info, []) def process_ir(ir, stop_before_step): - """Turns a first-stage IR into a fully-processed IR. - - process_ir performs all of the semantic processing steps on `ir`: resolving - symbols, checking dependencies, adding type annotations, normalizing - attributes, etc. process_ir is generally meant to be called with the result - of parse_emboss_file(), but in theory could be called with a first-stage - intermediate representation (IR) from another source. - - Arguments: - ir: The IR to process. This structure will be modified during processing. - stop_before_step: If set, process_ir will stop normalizing the IR just - before the specified step. This parameter should be None for non-test - code. - - Returns: - (ir, errors), where ir is a complete IR, ready for consumption by an Emboss - back end, and errors is a list of compilation errors. If errors is not an - empty list, ir will be None. - """ - passes = (synthetics.desugar, - symbol_resolver.resolve_symbols, - dependency_checker.find_dependency_cycles, - dependency_checker.set_dependency_order, - symbol_resolver.resolve_field_references, - type_check.annotate_types, - type_check.check_types, - expression_bounds.compute_constants, - attribute_checker.normalize_and_verify, - constraints.check_constraints, - write_inference.set_write_methods) - assert stop_before_step in [None] + [f.__name__ for f in passes], ( - "Bad value for stop_before_step.") - # Some parts of the IR are synthesized from "natural" parts of the IR, before - # the natural parts have been fully error checked. Because of this, the - # synthesized parts can have errors; in a couple of cases, they can have - # errors that show up in an earlier pass than the errors in the natural parts - # of the IR. As an example: - # - # struct Foo: - # 0 [+1] bits: - # 0 [+1] Flag flag - # 1 [+flag] UInt:8 field - # - # In this case, the use of `flag` as the size of `field` is incorrect, because - # `flag` is a boolean, but the size of a field must be an integer. - # - # Type checking occurs in two passes: in the first pass, expressions are - # checked for internal consistency. In the second pass, expression types are - # checked against their location. The use of `flag` would be caught in the - # second pass. - # - # However, the generated_fields pass will synthesize a $size_in_bytes virtual - # field that would look like: - # - # struct Foo: - # 0 [+1] bits: - # 0 [+1] Flag flag - # 1 [+flag] UInt:8 field - # let $size_in_bytes = $max(true ? 0 + 1 : 0, true ? 1 + flag : 0) - # - # Since `1 + flag` is not internally consistent, this type error would be - # caught in the first pass, and the user would see a very strange error - # message that "the right-hand argument of operator `+` must be an integer." - # - # In order to avoid showing these kinds of errors to the user, we defer any - # errors in synthetic parts of the IR. Unless there is a compiler bug, those - # errors will show up as errors in the natural parts of the IR, which should - # be much more comprehensible to end users. - # - # If, for some reason, there is an error in the synthetic IR, but no error in - # the natural IR, the synthetic errors will be shown. In this case, the - # formatting for the synthetic errors will show '[compiler bug]' for the - # error location, which (hopefully) will provide the end user with a cue that - # the error is a compiler bug. - deferred_errors = [] - for function in passes: - if stop_before_step == function.__name__: - return (ir, []) - errors, hidden_errors = error.split_errors(function(ir)) - if errors: - return (None, errors) - deferred_errors.extend(hidden_errors) - - if deferred_errors: - return (None, deferred_errors) - - assert stop_before_step is None, "Bad value for stop_before_step." - return (ir, []) + """Turns a first-stage IR into a fully-processed IR. + + process_ir performs all of the semantic processing steps on `ir`: resolving + symbols, checking dependencies, adding type annotations, normalizing + attributes, etc. process_ir is generally meant to be called with the result + of parse_emboss_file(), but in theory could be called with a first-stage + intermediate representation (IR) from another source. + + Arguments: + ir: The IR to process. This structure will be modified during processing. + stop_before_step: If set, process_ir will stop normalizing the IR just + before the specified step. This parameter should be None for non-test + code. + + Returns: + (ir, errors), where ir is a complete IR, ready for consumption by an Emboss + back end, and errors is a list of compilation errors. If errors is not an + empty list, ir will be None. + """ + passes = ( + synthetics.desugar, + symbol_resolver.resolve_symbols, + dependency_checker.find_dependency_cycles, + dependency_checker.set_dependency_order, + symbol_resolver.resolve_field_references, + type_check.annotate_types, + type_check.check_types, + expression_bounds.compute_constants, + attribute_checker.normalize_and_verify, + constraints.check_constraints, + write_inference.set_write_methods, + ) + assert stop_before_step in [None] + [ + f.__name__ for f in passes + ], "Bad value for stop_before_step." + # Some parts of the IR are synthesized from "natural" parts of the IR, before + # the natural parts have been fully error checked. Because of this, the + # synthesized parts can have errors; in a couple of cases, they can have + # errors that show up in an earlier pass than the errors in the natural parts + # of the IR. As an example: + # + # struct Foo: + # 0 [+1] bits: + # 0 [+1] Flag flag + # 1 [+flag] UInt:8 field + # + # In this case, the use of `flag` as the size of `field` is incorrect, because + # `flag` is a boolean, but the size of a field must be an integer. + # + # Type checking occurs in two passes: in the first pass, expressions are + # checked for internal consistency. In the second pass, expression types are + # checked against their location. The use of `flag` would be caught in the + # second pass. + # + # However, the generated_fields pass will synthesize a $size_in_bytes virtual + # field that would look like: + # + # struct Foo: + # 0 [+1] bits: + # 0 [+1] Flag flag + # 1 [+flag] UInt:8 field + # let $size_in_bytes = $max(true ? 0 + 1 : 0, true ? 1 + flag : 0) + # + # Since `1 + flag` is not internally consistent, this type error would be + # caught in the first pass, and the user would see a very strange error + # message that "the right-hand argument of operator `+` must be an integer." + # + # In order to avoid showing these kinds of errors to the user, we defer any + # errors in synthetic parts of the IR. Unless there is a compiler bug, those + # errors will show up as errors in the natural parts of the IR, which should + # be much more comprehensible to end users. + # + # If, for some reason, there is an error in the synthetic IR, but no error in + # the natural IR, the synthetic errors will be shown. In this case, the + # formatting for the synthetic errors will show '[compiler bug]' for the + # error location, which (hopefully) will provide the end user with a cue that + # the error is a compiler bug. + deferred_errors = [] + for function in passes: + if stop_before_step == function.__name__: + return (ir, []) + errors, hidden_errors = error.split_errors(function(ir)) + if errors: + return (None, errors) + deferred_errors.extend(hidden_errors) + + if deferred_errors: + return (None, deferred_errors) + + assert stop_before_step is None, "Bad value for stop_before_step." + return (ir, []) diff --git a/compiler/front_end/glue_test.py b/compiler/front_end/glue_test.py index 2f2ddc5..1decae3 100644 --- a/compiler/front_end/glue_test.py +++ b/compiler/front_end/glue_test.py @@ -30,272 +30,328 @@ _GOLDEN_PATH = "" _SPAN_SE_LOG_FILE_PATH = _GOLDEN_PATH + "span_se_log_file_status.emb" -_SPAN_SE_LOG_FILE_EMB = pkgutil.get_data( - _ROOT_PACKAGE, _SPAN_SE_LOG_FILE_PATH).decode(encoding="UTF-8") +_SPAN_SE_LOG_FILE_EMB = pkgutil.get_data(_ROOT_PACKAGE, _SPAN_SE_LOG_FILE_PATH).decode( + encoding="UTF-8" +) _SPAN_SE_LOG_FILE_READER = test_util.dict_file_reader( - {_SPAN_SE_LOG_FILE_PATH: _SPAN_SE_LOG_FILE_EMB}) -_SPAN_SE_LOG_FILE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.Module, + {_SPAN_SE_LOG_FILE_PATH: _SPAN_SE_LOG_FILE_EMB} +) +_SPAN_SE_LOG_FILE_IR = ir_data_utils.IrDataSerializer.from_json( + ir_data.Module, pkgutil.get_data( - _ROOT_PACKAGE, - _GOLDEN_PATH + "span_se_log_file_status.ir.txt" - ).decode(encoding="UTF-8")) + _ROOT_PACKAGE, _GOLDEN_PATH + "span_se_log_file_status.ir.txt" + ).decode(encoding="UTF-8"), +) _SPAN_SE_LOG_FILE_PARSE_TREE_TEXT = pkgutil.get_data( - _ROOT_PACKAGE, - _GOLDEN_PATH + "span_se_log_file_status.parse_tree.txt" + _ROOT_PACKAGE, _GOLDEN_PATH + "span_se_log_file_status.parse_tree.txt" ).decode(encoding="UTF-8") _SPAN_SE_LOG_FILE_TOKENIZATION_TEXT = pkgutil.get_data( - _ROOT_PACKAGE, - _GOLDEN_PATH + "span_se_log_file_status.tokens.txt" + _ROOT_PACKAGE, _GOLDEN_PATH + "span_se_log_file_status.tokens.txt" ).decode(encoding="UTF-8") class FrontEndGlueTest(unittest.TestCase): - """Tests for front_end.glue.""" - - def test_parse_module(self): - # parse_module(file) should return the same thing as - # parse_module_text(text), assuming file can be read. - main_module, debug_info, errors = glue.parse_module( - _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER) - main_module2, debug_info2, errors2 = glue.parse_module_text( - _SPAN_SE_LOG_FILE_EMB, _SPAN_SE_LOG_FILE_PATH) - self.assertEqual([], errors) - self.assertEqual([], errors2) - self.assertEqual(main_module, main_module2) - self.assertEqual(debug_info, debug_info2) - - def test_parse_module_no_such_file(self): - file_name = "nonexistent.emb" - ir, debug_info, errors = glue.parse_emboss_file( - file_name, test_util.dict_file_reader({})) - self.assertEqual([[ - error.error("nonexistent.emb", _location((1, 1), (1, 1)), - "Unable to read file."), - error.note("nonexistent.emb", _location((1, 1), (1, 1)), - "File 'nonexistent.emb' not found."), - ]], errors) - self.assertFalse(file_name in debug_info.modules) - self.assertFalse(ir) - - def test_parse_module_tokenization_error(self): - file_name = "tokens.emb" - ir, debug_info, errors = glue.parse_emboss_file( - file_name, test_util.dict_file_reader({file_name: "@"})) - self.assertTrue(debug_info.modules[file_name].source_code) - self.assertTrue(errors) - self.assertEqual("Unrecognized token", errors[0][0].message) - self.assertFalse(ir) - - def test_parse_module_indentation_error(self): - file_name = "indent.emb" - ir, debug_info, errors = glue.parse_emboss_file( - file_name, test_util.dict_file_reader( - {file_name: "struct Foo:\n" - " 1 [+1] Int x\n" - " 2 [+1] Int y\n"})) - self.assertTrue(debug_info.modules[file_name].source_code) - self.assertTrue(errors) - self.assertEqual("Bad indentation", errors[0][0].message) - self.assertFalse(ir) - - def test_parse_module_parse_error(self): - file_name = "parse.emb" - ir, debug_info, errors = glue.parse_emboss_file( - file_name, test_util.dict_file_reader( - {file_name: "struct foo:\n" - " 1 [+1] Int x\n" - " 3 [+1] Int y\n"})) - self.assertTrue(debug_info.modules[file_name].source_code) - self.assertEqual([[ - error.error(file_name, _location((1, 8), (1, 11)), - "A type name must be CamelCase.\n" - "Found 'foo' (SnakeWord), expected CamelWord.") - ]], errors) - self.assertFalse(ir) - - def test_parse_error(self): - file_name = "parse.emb" - ir, debug_info, errors = glue.parse_emboss_file( - file_name, test_util.dict_file_reader( - {file_name: "struct foo:\n" - " 1 [+1] Int x\n" - " 2 [+1] Int y\n"})) - self.assertTrue(debug_info.modules[file_name].source_code) - self.assertEqual([[ - error.error(file_name, _location((1, 8), (1, 11)), - "A type name must be CamelCase.\n" - "Found 'foo' (SnakeWord), expected CamelWord.") - ]], errors) - self.assertFalse(ir) - - def test_circular_dependency_error(self): - file_name = "cycle.emb" - ir, debug_info, errors = glue.parse_emboss_file( - file_name, test_util.dict_file_reader({ - file_name: "struct Foo:\n" - " 0 [+field1] UInt field1\n" - })) - self.assertTrue(debug_info.modules[file_name].source_code) - self.assertTrue(errors) - self.assertEqual("Dependency cycle\nfield1", errors[0][0].message) - self.assertFalse(ir) - - def test_ir_from_parse_module(self): - log_file_path_ir = ir_data_utils.copy(_SPAN_SE_LOG_FILE_IR) - log_file_path_ir.source_file_name = _SPAN_SE_LOG_FILE_PATH - self.assertEqual(log_file_path_ir, glue.parse_module( - _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER).ir) - - def test_debug_info_from_parse_module(self): - debug_info = glue.parse_module(_SPAN_SE_LOG_FILE_PATH, - _SPAN_SE_LOG_FILE_READER).debug_info - self.maxDiff = 200000 # pylint:disable=invalid-name - self.assertEqual(_SPAN_SE_LOG_FILE_TOKENIZATION_TEXT.strip(), - debug_info.format_tokenization().strip()) - self.assertEqual(_SPAN_SE_LOG_FILE_PARSE_TREE_TEXT.strip(), - debug_info.format_parse_tree().strip()) - self.assertEqual(_SPAN_SE_LOG_FILE_IR, debug_info.ir) - self.assertEqual(ir_data_utils.IrDataSerializer(_SPAN_SE_LOG_FILE_IR).to_json(indent=2), - debug_info.format_module_ir()) - - def test_parse_emboss_file(self): - # parse_emboss_file calls parse_module, wraps its results, and calls - # symbol_resolver.resolve_symbols() on the resulting IR. - ir, debug_info, errors = glue.parse_emboss_file(_SPAN_SE_LOG_FILE_PATH, - _SPAN_SE_LOG_FILE_READER) - module_ir, module_debug_info, module_errors = glue.parse_module( - _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER) - self.assertEqual([], errors) - self.assertEqual([], module_errors) - self.assertTrue(test_util.proto_is_superset(ir.module[0], module_ir)) - self.assertEqual(module_debug_info, - debug_info.modules[_SPAN_SE_LOG_FILE_PATH]) - self.assertEqual(2, len(debug_info.modules)) - self.assertEqual(2, len(ir.module)) - self.assertEqual(_SPAN_SE_LOG_FILE_PATH, ir.module[0].source_file_name) - self.assertEqual("", ir.module[1].source_file_name) - - def test_synthetic_error(self): - file_name = "missing_byte_order_attribute.emb" - ir, unused_debug_info, errors = glue.only_parse_emboss_file( - file_name, test_util.dict_file_reader({ - file_name: "struct Foo:\n" - " 0 [+8] UInt field\n" - })) - self.assertFalse(errors) - # Artificially mark the first field as is_synthetic. - first_field = ir.module[0].type[0].structure.field[0] - first_field.source_location.is_synthetic = True - ir, errors = glue.process_ir(ir, None) - self.assertTrue(errors) - self.assertEqual("Attribute 'byte_order' required on field which is byte " - "order dependent.", errors[0][0].message) - self.assertTrue(errors[0][0].location.is_synthetic) - self.assertFalse(ir) - - def test_suppressed_synthetic_error(self): - file_name = "triplicate_symbol.emb" - ir, unused_debug_info, errors = glue.only_parse_emboss_file( - file_name, test_util.dict_file_reader({ - file_name: "struct Foo:\n" - " 0 [+1] UInt field\n" - " 1 [+1] UInt field\n" - " 2 [+1] UInt field\n" - })) - self.assertFalse(errors) - # Artificially mark the name of the second field as is_synthetic. - second_field = ir.module[0].type[0].structure.field[1] - second_field.name.source_location.is_synthetic = True - second_field.name.name.source_location.is_synthetic = True - ir, errors = glue.process_ir(ir, None) - self.assertEqual(1, len(errors)) - self.assertEqual("Duplicate name 'field'", errors[0][0].message) - self.assertFalse(errors[0][0].location.is_synthetic) - self.assertFalse(errors[0][1].location.is_synthetic) - self.assertFalse(ir) + """Tests for front_end.glue.""" + + def test_parse_module(self): + # parse_module(file) should return the same thing as + # parse_module_text(text), assuming file can be read. + main_module, debug_info, errors = glue.parse_module( + _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER + ) + main_module2, debug_info2, errors2 = glue.parse_module_text( + _SPAN_SE_LOG_FILE_EMB, _SPAN_SE_LOG_FILE_PATH + ) + self.assertEqual([], errors) + self.assertEqual([], errors2) + self.assertEqual(main_module, main_module2) + self.assertEqual(debug_info, debug_info2) + + def test_parse_module_no_such_file(self): + file_name = "nonexistent.emb" + ir, debug_info, errors = glue.parse_emboss_file( + file_name, test_util.dict_file_reader({}) + ) + self.assertEqual( + [ + [ + error.error( + "nonexistent.emb", + _location((1, 1), (1, 1)), + "Unable to read file.", + ), + error.note( + "nonexistent.emb", + _location((1, 1), (1, 1)), + "File 'nonexistent.emb' not found.", + ), + ] + ], + errors, + ) + self.assertFalse(file_name in debug_info.modules) + self.assertFalse(ir) + + def test_parse_module_tokenization_error(self): + file_name = "tokens.emb" + ir, debug_info, errors = glue.parse_emboss_file( + file_name, test_util.dict_file_reader({file_name: "@"}) + ) + self.assertTrue(debug_info.modules[file_name].source_code) + self.assertTrue(errors) + self.assertEqual("Unrecognized token", errors[0][0].message) + self.assertFalse(ir) + + def test_parse_module_indentation_error(self): + file_name = "indent.emb" + ir, debug_info, errors = glue.parse_emboss_file( + file_name, + test_util.dict_file_reader( + {file_name: "struct Foo:\n" " 1 [+1] Int x\n" " 2 [+1] Int y\n"} + ), + ) + self.assertTrue(debug_info.modules[file_name].source_code) + self.assertTrue(errors) + self.assertEqual("Bad indentation", errors[0][0].message) + self.assertFalse(ir) + + def test_parse_module_parse_error(self): + file_name = "parse.emb" + ir, debug_info, errors = glue.parse_emboss_file( + file_name, + test_util.dict_file_reader( + {file_name: "struct foo:\n" " 1 [+1] Int x\n" " 3 [+1] Int y\n"} + ), + ) + self.assertTrue(debug_info.modules[file_name].source_code) + self.assertEqual( + [ + [ + error.error( + file_name, + _location((1, 8), (1, 11)), + "A type name must be CamelCase.\n" + "Found 'foo' (SnakeWord), expected CamelWord.", + ) + ] + ], + errors, + ) + self.assertFalse(ir) + + def test_parse_error(self): + file_name = "parse.emb" + ir, debug_info, errors = glue.parse_emboss_file( + file_name, + test_util.dict_file_reader( + {file_name: "struct foo:\n" " 1 [+1] Int x\n" " 2 [+1] Int y\n"} + ), + ) + self.assertTrue(debug_info.modules[file_name].source_code) + self.assertEqual( + [ + [ + error.error( + file_name, + _location((1, 8), (1, 11)), + "A type name must be CamelCase.\n" + "Found 'foo' (SnakeWord), expected CamelWord.", + ) + ] + ], + errors, + ) + self.assertFalse(ir) + + def test_circular_dependency_error(self): + file_name = "cycle.emb" + ir, debug_info, errors = glue.parse_emboss_file( + file_name, + test_util.dict_file_reader( + {file_name: "struct Foo:\n" " 0 [+field1] UInt field1\n"} + ), + ) + self.assertTrue(debug_info.modules[file_name].source_code) + self.assertTrue(errors) + self.assertEqual("Dependency cycle\nfield1", errors[0][0].message) + self.assertFalse(ir) + + def test_ir_from_parse_module(self): + log_file_path_ir = ir_data_utils.copy(_SPAN_SE_LOG_FILE_IR) + log_file_path_ir.source_file_name = _SPAN_SE_LOG_FILE_PATH + self.assertEqual( + log_file_path_ir, + glue.parse_module(_SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER).ir, + ) + + def test_debug_info_from_parse_module(self): + debug_info = glue.parse_module( + _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER + ).debug_info + self.maxDiff = 200000 # pylint:disable=invalid-name + self.assertEqual( + _SPAN_SE_LOG_FILE_TOKENIZATION_TEXT.strip(), + debug_info.format_tokenization().strip(), + ) + self.assertEqual( + _SPAN_SE_LOG_FILE_PARSE_TREE_TEXT.strip(), + debug_info.format_parse_tree().strip(), + ) + self.assertEqual(_SPAN_SE_LOG_FILE_IR, debug_info.ir) + self.assertEqual( + ir_data_utils.IrDataSerializer(_SPAN_SE_LOG_FILE_IR).to_json(indent=2), + debug_info.format_module_ir(), + ) + + def test_parse_emboss_file(self): + # parse_emboss_file calls parse_module, wraps its results, and calls + # symbol_resolver.resolve_symbols() on the resulting IR. + ir, debug_info, errors = glue.parse_emboss_file( + _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER + ) + module_ir, module_debug_info, module_errors = glue.parse_module( + _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER + ) + self.assertEqual([], errors) + self.assertEqual([], module_errors) + self.assertTrue(test_util.proto_is_superset(ir.module[0], module_ir)) + self.assertEqual(module_debug_info, debug_info.modules[_SPAN_SE_LOG_FILE_PATH]) + self.assertEqual(2, len(debug_info.modules)) + self.assertEqual(2, len(ir.module)) + self.assertEqual(_SPAN_SE_LOG_FILE_PATH, ir.module[0].source_file_name) + self.assertEqual("", ir.module[1].source_file_name) + + def test_synthetic_error(self): + file_name = "missing_byte_order_attribute.emb" + ir, unused_debug_info, errors = glue.only_parse_emboss_file( + file_name, + test_util.dict_file_reader( + {file_name: "struct Foo:\n" " 0 [+8] UInt field\n"} + ), + ) + self.assertFalse(errors) + # Artificially mark the first field as is_synthetic. + first_field = ir.module[0].type[0].structure.field[0] + first_field.source_location.is_synthetic = True + ir, errors = glue.process_ir(ir, None) + self.assertTrue(errors) + self.assertEqual( + "Attribute 'byte_order' required on field which is byte " + "order dependent.", + errors[0][0].message, + ) + self.assertTrue(errors[0][0].location.is_synthetic) + self.assertFalse(ir) + + def test_suppressed_synthetic_error(self): + file_name = "triplicate_symbol.emb" + ir, unused_debug_info, errors = glue.only_parse_emboss_file( + file_name, + test_util.dict_file_reader( + { + file_name: "struct Foo:\n" + " 0 [+1] UInt field\n" + " 1 [+1] UInt field\n" + " 2 [+1] UInt field\n" + } + ), + ) + self.assertFalse(errors) + # Artificially mark the name of the second field as is_synthetic. + second_field = ir.module[0].type[0].structure.field[1] + second_field.name.source_location.is_synthetic = True + second_field.name.name.source_location.is_synthetic = True + ir, errors = glue.process_ir(ir, None) + self.assertEqual(1, len(errors)) + self.assertEqual("Duplicate name 'field'", errors[0][0].message) + self.assertFalse(errors[0][0].location.is_synthetic) + self.assertFalse(errors[0][1].location.is_synthetic) + self.assertFalse(ir) class DebugInfoTest(unittest.TestCase): - """Tests for DebugInfo and ModuleDebugInfo classes.""" - - def test_debug_info_initialization(self): - debug_info = glue.DebugInfo() - self.assertEqual({}, debug_info.modules) - - def test_debug_info_invalid_attribute_set(self): - debug_info = glue.DebugInfo() - with self.assertRaises(AttributeError): - debug_info.foo = "foo" - - def test_debug_info_equality(self): - debug_info = glue.DebugInfo() - debug_info2 = glue.DebugInfo() - self.assertEqual(debug_info, debug_info2) - debug_info.modules["foo"] = glue.ModuleDebugInfo("foo") - self.assertNotEqual(debug_info, debug_info2) - debug_info2.modules["foo"] = glue.ModuleDebugInfo("foo") - self.assertEqual(debug_info, debug_info2) - - def test_module_debug_info_initialization(self): - module_info = glue.ModuleDebugInfo("bar.emb") - self.assertEqual("bar.emb", module_info.file_name) - self.assertEqual(None, module_info.tokens) - self.assertEqual(None, module_info.parse_tree) - self.assertEqual(None, module_info.ir) - self.assertEqual(None, module_info.used_productions) - - def test_module_debug_info_attribute_set(self): - module_info = glue.ModuleDebugInfo("bar.emb") - module_info.tokens = "a" - module_info.parse_tree = "b" - module_info.ir = "c" - module_info.used_productions = "d" - module_info.source_code = "e" - self.assertEqual("a", module_info.tokens) - self.assertEqual("b", module_info.parse_tree) - self.assertEqual("c", module_info.ir) - self.assertEqual("d", module_info.used_productions) - self.assertEqual("e", module_info.source_code) - - def test_module_debug_info_bad_attribute_set(self): - module_info = glue.ModuleDebugInfo("bar.emb") - with self.assertRaises(AttributeError): - module_info.foo = "foo" - - def test_module_debug_info_equality(self): - module_info = glue.ModuleDebugInfo("foo") - module_info2 = glue.ModuleDebugInfo("foo") - module_info_bar = glue.ModuleDebugInfo("bar") - self.assertEqual(module_info, module_info2) - module_info_bar = glue.ModuleDebugInfo("bar") - self.assertNotEqual(module_info, module_info_bar) - module_info.tokens = [] - self.assertNotEqual(module_info, module_info2) - module_info2.tokens = [] - self.assertEqual(module_info, module_info2) - module_info.parse_tree = [] - self.assertNotEqual(module_info, module_info2) - module_info2.parse_tree = [] - self.assertEqual(module_info, module_info2) - module_info.ir = [] - self.assertNotEqual(module_info, module_info2) - module_info2.ir = [] - self.assertEqual(module_info, module_info2) - module_info.used_productions = [] - self.assertNotEqual(module_info, module_info2) - module_info2.used_productions = [] - self.assertEqual(module_info, module_info2) + """Tests for DebugInfo and ModuleDebugInfo classes.""" + + def test_debug_info_initialization(self): + debug_info = glue.DebugInfo() + self.assertEqual({}, debug_info.modules) + + def test_debug_info_invalid_attribute_set(self): + debug_info = glue.DebugInfo() + with self.assertRaises(AttributeError): + debug_info.foo = "foo" + + def test_debug_info_equality(self): + debug_info = glue.DebugInfo() + debug_info2 = glue.DebugInfo() + self.assertEqual(debug_info, debug_info2) + debug_info.modules["foo"] = glue.ModuleDebugInfo("foo") + self.assertNotEqual(debug_info, debug_info2) + debug_info2.modules["foo"] = glue.ModuleDebugInfo("foo") + self.assertEqual(debug_info, debug_info2) + + def test_module_debug_info_initialization(self): + module_info = glue.ModuleDebugInfo("bar.emb") + self.assertEqual("bar.emb", module_info.file_name) + self.assertEqual(None, module_info.tokens) + self.assertEqual(None, module_info.parse_tree) + self.assertEqual(None, module_info.ir) + self.assertEqual(None, module_info.used_productions) + + def test_module_debug_info_attribute_set(self): + module_info = glue.ModuleDebugInfo("bar.emb") + module_info.tokens = "a" + module_info.parse_tree = "b" + module_info.ir = "c" + module_info.used_productions = "d" + module_info.source_code = "e" + self.assertEqual("a", module_info.tokens) + self.assertEqual("b", module_info.parse_tree) + self.assertEqual("c", module_info.ir) + self.assertEqual("d", module_info.used_productions) + self.assertEqual("e", module_info.source_code) + + def test_module_debug_info_bad_attribute_set(self): + module_info = glue.ModuleDebugInfo("bar.emb") + with self.assertRaises(AttributeError): + module_info.foo = "foo" + + def test_module_debug_info_equality(self): + module_info = glue.ModuleDebugInfo("foo") + module_info2 = glue.ModuleDebugInfo("foo") + module_info_bar = glue.ModuleDebugInfo("bar") + self.assertEqual(module_info, module_info2) + module_info_bar = glue.ModuleDebugInfo("bar") + self.assertNotEqual(module_info, module_info_bar) + module_info.tokens = [] + self.assertNotEqual(module_info, module_info2) + module_info2.tokens = [] + self.assertEqual(module_info, module_info2) + module_info.parse_tree = [] + self.assertNotEqual(module_info, module_info2) + module_info2.parse_tree = [] + self.assertEqual(module_info, module_info2) + module_info.ir = [] + self.assertNotEqual(module_info, module_info2) + module_info2.ir = [] + self.assertEqual(module_info, module_info2) + module_info.used_productions = [] + self.assertNotEqual(module_info, module_info2) + module_info2.used_productions = [] + self.assertEqual(module_info, module_info2) class TestFormatProductionSet(unittest.TestCase): - """Tests for format_production_set.""" + """Tests for format_production_set.""" - def test_format_production_set(self): - production_texts = ["A -> B", "B -> C", "A -> C", "C -> A"] - productions = [parser_types.Production.parse(p) for p in production_texts] - self.assertEqual("\n".join(sorted(production_texts)), - glue.format_production_set(set(productions))) + def test_format_production_set(self): + production_texts = ["A -> B", "B -> C", "A -> C", "C -> A"] + productions = [parser_types.Production.parse(p) for p in production_texts] + self.assertEqual( + "\n".join(sorted(production_texts)), + glue.format_production_set(set(productions)), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/lr1.py b/compiler/front_end/lr1.py index 41111d1..be99f95 100644 --- a/compiler/front_end/lr1.py +++ b/compiler/front_end/lr1.py @@ -31,81 +31,97 @@ from compiler.util import parser_types -class Item(collections.namedtuple("Item", ["production", "dot", "terminal", - "next_symbol"])): - """An Item is an LR(1) Item: a production, a cursor location, and a terminal. - - An Item represents a partially-parsed production, and a lookahead symbol. The - position of the dot indicates what portion of the production has been parsed. - Generally, Items are an internal implementation detail, but they can be useful - elsewhere, particularly for debugging. - - Attributes: - production: The Production this Item covers. - dot: The index of the "dot" in production's rhs. - terminal: The terminal lookahead symbol that follows the production in the - input stream. - """ - - def __str__(self): - """__str__ generates ASLU notation.""" - return (str(self.production.lhs) + " -> " + " ".join( - [str(r) for r in self.production.rhs[0:self.dot] + (".",) + - self.production.rhs[self.dot:]]) + ", " + str(self.terminal)) - - @staticmethod - def parse(text): - """Parses an Item in ALSU notation. - - Parses an Item from notation like: - - symbol -> foo . bar baz, qux - - where "symbol -> foo bar baz" will be taken as the production, the position - of the "." is taken as "dot" (in this case 1), and the symbol after "," is - taken as the "terminal". The following are also valid items: - - sym -> ., foo - sym -> . foo bar, baz - sym -> foo bar ., baz - - Symbols on the right-hand side of the production should be separated by - whitespace. - - Arguments: - text: The text to parse into an Item. - - Returns: - An Item. +class Item( + collections.namedtuple("Item", ["production", "dot", "terminal", "next_symbol"]) +): + """An Item is an LR(1) Item: a production, a cursor location, and a terminal. + + An Item represents a partially-parsed production, and a lookahead symbol. The + position of the dot indicates what portion of the production has been parsed. + Generally, Items are an internal implementation detail, but they can be useful + elsewhere, particularly for debugging. + + Attributes: + production: The Production this Item covers. + dot: The index of the "dot" in production's rhs. + terminal: The terminal lookahead symbol that follows the production in the + input stream. """ - production, terminal = text.split(",") - terminal = terminal.strip() - if terminal == "$": - terminal = END_OF_INPUT - lhs, rhs = production.split("->") - lhs = lhs.strip() - if lhs == "S'": - lhs = START_PRIME - before_dot, after_dot = rhs.split(".") - handle = before_dot.split() - tail = after_dot.split() - return make_item(parser_types.Production(lhs, tuple(handle + tail)), - len(handle), terminal) + + def __str__(self): + """__str__ generates ASLU notation.""" + return ( + str(self.production.lhs) + + " -> " + + " ".join( + [ + str(r) + for r in self.production.rhs[0 : self.dot] + + (".",) + + self.production.rhs[self.dot :] + ] + ) + + ", " + + str(self.terminal) + ) + + @staticmethod + def parse(text): + """Parses an Item in ALSU notation. + + Parses an Item from notation like: + + symbol -> foo . bar baz, qux + + where "symbol -> foo bar baz" will be taken as the production, the position + of the "." is taken as "dot" (in this case 1), and the symbol after "," is + taken as the "terminal". The following are also valid items: + + sym -> ., foo + sym -> . foo bar, baz + sym -> foo bar ., baz + + Symbols on the right-hand side of the production should be separated by + whitespace. + + Arguments: + text: The text to parse into an Item. + + Returns: + An Item. + """ + production, terminal = text.split(",") + terminal = terminal.strip() + if terminal == "$": + terminal = END_OF_INPUT + lhs, rhs = production.split("->") + lhs = lhs.strip() + if lhs == "S'": + lhs = START_PRIME + before_dot, after_dot = rhs.split(".") + handle = before_dot.split() + tail = after_dot.split() + return make_item( + parser_types.Production(lhs, tuple(handle + tail)), len(handle), terminal + ) def make_item(production, dot, symbol): - return Item(production, dot, symbol, - None if dot >= len(production.rhs) else production.rhs[dot]) + return Item( + production, + dot, + symbol, + None if dot >= len(production.rhs) else production.rhs[dot], + ) -class Conflict( - collections.namedtuple("Conflict", ["state", "symbol", "actions"]) -): - """Conflict represents a parse conflict.""" +class Conflict(collections.namedtuple("Conflict", ["state", "symbol", "actions"])): + """Conflict represents a parse conflict.""" - def __str__(self): - return "Conflict for {} in state {}: ".format( - self.symbol, self.state) + " vs ".join([str(a) for a in self.actions]) + def __str__(self): + return "Conflict for {} in state {}: ".format( + self.symbol, self.state + ) + " vs ".join([str(a) for a in self.actions]) Shift = collections.namedtuple("Shift", ["state", "items"]) @@ -123,637 +139,695 @@ def __str__(self): # ANY_TOKEN is used by mark_error as a "wildcard" token that should be replaced # by every other token. -ANY_TOKEN = parser_types.Token(object(), "*", - parser_types.parse_location("0:0-0:0")) +ANY_TOKEN = parser_types.Token(object(), "*", parser_types.parse_location("0:0-0:0")) -class Reduction(collections.namedtuple("Reduction", - ["symbol", "children", "production", - "source_location"])): - """A Reduction is a non-leaf node in a parse tree. +class Reduction( + collections.namedtuple( + "Reduction", ["symbol", "children", "production", "source_location"] + ) +): + """A Reduction is a non-leaf node in a parse tree. + + Attributes: + symbol: The name of this element in the parse. + children: The child elements of this parse. + production: The grammar production to which this reduction corresponds. + source_location: If known, the range in the source text corresponding to the + tokens from which this reduction was parsed. May be 'None' if this + reduction was produced from no symbols, or if the tokens fed to `parse` + did not include source_location. + """ - Attributes: - symbol: The name of this element in the parse. - children: The child elements of this parse. - production: The grammar production to which this reduction corresponds. - source_location: If known, the range in the source text corresponding to the - tokens from which this reduction was parsed. May be 'None' if this - reduction was produced from no symbols, or if the tokens fed to `parse` - did not include source_location. - """ - pass + pass class Grammar(object): - """Grammar is an LR(1) context-free grammar. - - Attributes: - start: The start symbol for the grammar. - productions: A list of productions in the grammar, including the S' -> start - production. - symbols: A set of all symbols in the grammar, including $ and S'. - nonterminals: A set of all nonterminal symbols in the grammar, including S'. - terminals: A set of all terminal symbols in the grammar, including $. - """ - - def __init__(self, start_symbol, productions): - """Constructs a Grammar object. - - Arguments: - start_symbol: The start symbol for the grammar. - productions: A list of productions (not including the "S' -> start_symbol" - production). - """ - object.__init__(self) - self.start = start_symbol - self._seed_production = parser_types.Production(START_PRIME, (self.start,)) - self.productions = productions + [self._seed_production] - - self._single_level_closure_of_item_cache = {} - self._closure_of_item_cache = {} - self._compute_symbols() - self._compute_seed_firsts() - self._set_productions_by_lhs() - self._populate_item_cache() - - def _set_productions_by_lhs(self): - # Prepopulating _productions_by_lhs speeds up _closure_of_item by about 30%, - # which is significant on medium-to-large grammars. - self._productions_by_lhs = {} - for production in self.productions: - self._productions_by_lhs.setdefault(production.lhs, list()).append( - production) - - def _populate_item_cache(self): - # There are a relatively small number of possible Items for a grammar, and - # the algorithm needs to get Items from their constituent components very - # frequently. As it turns out, pre-caching all possible Items results in a - # ~35% overall speedup to Grammar.parser(). - self._item_cache = {} - for symbol in self.terminals: - for production in self.productions: - for dot in range(len(production.rhs) + 1): - self._item_cache[production, dot, symbol] = make_item( - production, dot, symbol) - - def _compute_symbols(self): - """Finds all grammar symbols, and sorts them into terminal and non-terminal. - - Nonterminal symbols are those which appear on the left side of any - production. Terminal symbols are those which do not. - - _compute_symbols is used during __init__. - """ - self.symbols = {END_OF_INPUT} - self.nonterminals = set() - for production in self.productions: - self.symbols.add(production.lhs) - self.nonterminals.add(production.lhs) - for symbol in production.rhs: - self.symbols.add(symbol) - self.terminals = self.symbols - self.nonterminals - - def _compute_seed_firsts(self): - """Computes FIRST (ALSU p221) for all terminal and nonterminal symbols. - - The algorithm for computing FIRST is an iterative one that terminates when - it reaches a fixed point (that is, when further iterations stop changing - state). _compute_seed_firsts computes the fixed point for all single-symbol - strings, by repeatedly calling _first and updating the internal _firsts - table with the results. - - Once _compute_seed_firsts has completed, _first will return correct results - for both single- and multi-symbol strings. - - _compute_seed_firsts is used during __init__. - """ - self.firsts = {} - # FIRST for a terminal symbol is always just that terminal symbol. - for terminal in self.terminals: - self.firsts[terminal] = set([terminal]) - for nonterminal in self.nonterminals: - self.firsts[nonterminal] = set() - while True: - # The first iteration picks up all the productions that start with - # terminal symbols. The second iteration picks up productions that start - # with nonterminals that the first iteration picked up. The third - # iteration picks up nonterminals that the first and second picked up, and - # so on. - # - # This is guaranteed to end, in the worst case, when every terminal - # symbol and epsilon has been added to the _firsts set for every - # nonterminal symbol. This would be slow, but requires a pathological - # grammar; useful grammars should complete in only a few iterations. - firsts_to_add = {} - for production in self.productions: - for first in self._first(production.rhs): - if first not in self.firsts[production.lhs]: - if production.lhs not in firsts_to_add: - firsts_to_add[production.lhs] = set() - firsts_to_add[production.lhs].add(first) - if not firsts_to_add: - break - for symbol in firsts_to_add: - self.firsts[symbol].update(firsts_to_add[symbol]) - - def _first(self, symbols): - """The FIRST function from ALSU p221. - - _first takes a string of symbols (both terminals and nonterminals) and - returns the set of terminal symbols which could be the first terminal symbol - of a string produced by the given list of symbols. - - _first will not give fully-correct results until _compute_seed_firsts - finishes, but is called by _compute_seed_firsts, and must provide partial - results during that method's execution. - - Args: - symbols: A list of symbols. - - Returns: - A set of terminals which could be the first terminal in "symbols." - """ - result = set() - all_contain_epsilon = True - for symbol in symbols: - for first in self.firsts[symbol]: - if first: - result.add(first) - if None not in self.firsts[symbol]: - all_contain_epsilon = False - break - if all_contain_epsilon: - # "None" seems like a Pythonic way of representing epsilon (no symbol). - result.add(None) - return result - - def _closure_of_item(self, root_item): - """Modified implementation of CLOSURE from ALSU p261. - - _closure_of_item performs the CLOSURE function with a single seed item, with - memoization. In the algorithm as presented in ALSU, CLOSURE is called with - a different set of items every time, which is unhelpful for memoization. - Instead, we let _parallel_goto merge the sets returned by _closure_of_item, - which results in a ~40% speedup. - - CLOSURE, roughly, computes the set of LR(1) Items which might be active when - a "seed" set of Items is active. - - Technically, it is the epsilon-closure of the NFA states represented by - "items," where an epsilon transition (a transition that does not consume any - symbols) occurs from a->Z.bY,q to b->.X,p when p is in FIRST(Yq). (a and b - are nonterminals, X, Y, and Z are arbitrary strings of symbols, and p and q - are terminals.) That is, it is the set of all NFA states which can be - reached from "items" without consuming any input. This set corresponds to a - single DFA state. - - Args: - root_item: The initial LR(1) Item. - - Returns: - A set of LR(1) items which may be active at the time when the provided - item is active. + """Grammar is an LR(1) context-free grammar. + + Attributes: + start: The start symbol for the grammar. + productions: A list of productions in the grammar, including the S' -> start + production. + symbols: A set of all symbols in the grammar, including $ and S'. + nonterminals: A set of all nonterminal symbols in the grammar, including S'. + terminals: A set of all terminal symbols in the grammar, including $. """ - if root_item in self._closure_of_item_cache: - return self._closure_of_item_cache[root_item] - item_set = set([root_item]) - item_list = [root_item] - i = 0 - # Each newly-added Item may trigger the addition of further Items, so - # iterate until no new Items are added. In the worst case, a new Item will - # be added for each production. - # - # This algorithm is really looking for "next" nonterminals in the existing - # items, and adding new items corresponding to their productions. - while i < len(item_list): - item = item_list[i] - i += 1 - if not item.next_symbol: - continue - # If _closure_of_item_cache contains the full closure of item, then we can - # add its full closure to the result set, and skip checking any of its - # items: any item that would be added by any item in the cached result - # will already be in the _closure_of_item_cache entry. - if item in self._closure_of_item_cache: - item_set |= self._closure_of_item_cache[item] - continue - # Even if we don't have the full closure of item, we may have the - # immediate closure of item. It turns out that memoizing just this step - # speeds up this function by about 50%, even after the - # _closure_of_item_cache check. - if item not in self._single_level_closure_of_item_cache: - new_items = set() - for production in self._productions_by_lhs.get(item.next_symbol, []): - for terminal in self._first(item.production.rhs[item.dot + 1:] + - (item.terminal,)): - new_items.add(self._item_cache[production, 0, terminal]) - self._single_level_closure_of_item_cache[item] = new_items - for new_item in self._single_level_closure_of_item_cache[item]: - if new_item not in item_set: - item_set.add(new_item) - item_list.append(new_item) - self._closure_of_item_cache[root_item] = item_set - # Typically, _closure_of_item() will be called on items whose closures - # bring in the greatest number of additional items, then on items which - # close over fewer and fewer other items. Since items are not added to - # _closure_of_item_cache unless _closure_of_item() is called directly on - # them, this means that it is unlikely that items brought in will (without - # intervention) have entries in _closure_of_item_cache, which slows down the - # computation of the larger closures. - # - # Although it is not guaranteed, items added to item_list last will tend to - # close over fewer items, and therefore be easier to compute. By forcibly - # re-calculating closures from last to first, and adding the results to - # _closure_of_item_cache at each step, we get a modest performance - # improvement: roughly 50% less time spent in _closure_of_item, which - # translates to about 5% less time in parser(). - for item in item_list[::-1]: - self._closure_of_item(item) - return item_set - - def _parallel_goto(self, items): - """The GOTO function from ALSU p261, executed on all symbols. - - _parallel_goto takes a set of Items, and returns a dict from every symbol in - self.symbols to the set of Items that would be active after a shift - operation (if symbol is a terminal) or after a reduction operation (if - symbol is a nonterminal). - - _parallel_goto is used in lieu of the single-symbol GOTO from ALSU because - it eliminates the outer loop over self.terminals, and thereby reduces the - number of next_symbol calls by a factor of len(self.terminals). - - Args: - items: The set of items representing the initial DFA state. - - Returns: - A dict from symbols to sets of items representing the new DFA states. - """ - results = collections.defaultdict(set) - for item in items: - next_symbol = item.next_symbol - if next_symbol is None: - continue - item = self._item_cache[item.production, item.dot + 1, item.terminal] - # Inlining the cache check results in a ~25% speedup in this function, and - # about 10% overall speedup to parser(). - if item in self._closure_of_item_cache: - closure = self._closure_of_item_cache[item] - else: - closure = self._closure_of_item(item) - # _closure will add newly-started Items (Items with dot=0) to the result - # set. After this operation, the result set will correspond to the new - # state. - results[next_symbol].update(closure) - return results - - def _items(self): - """The items function from ALSU p261. - - _items computes the set of sets of LR(1) items for a shift-reduce parser - that matches the grammar. Each set of LR(1) items corresponds to a single - DFA state. - - Returns: - A tuple. - - The first element of the tuple is a list of sets of LR(1) items (each set - corresponding to a DFA state). - - The second element of the tuple is a dictionary from (int, symbol) pairs - to ints, where all the ints are indexes into the list of sets of LR(1) - items. This dictionary is based on the results of the _Goto function, - where item_sets[dict[i, sym]] == self._Goto(item_sets[i], sym). - """ - # The list of states is seeded with the marker S' production. - item_list = [ - frozenset(self._closure_of_item( - self._item_cache[self._seed_production, 0, END_OF_INPUT])) - ] - items = {item_list[0]: 0} - goto_table = {} - i = 0 - # For each state, figure out what the new state when each symbol is added to - # the top of the parsing stack (see the comments in parser._parse). See - # _Goto for an explanation of how that is actually computed. - while i < len(item_list): - item_set = item_list[i] - gotos = self._parallel_goto(item_set) - for symbol, goto in gotos.items(): - goto = frozenset(goto) - if goto not in items: - items[goto] = len(item_list) - item_list.append(goto) - goto_table[i, symbol] = items[goto] - i += 1 - return item_list, goto_table - - def parser(self): - """parser returns an LR(1) parser for the Grammar. - - This implements the Canonical LR(1) ("LR(1)") parser algorithm ("Algorithm - 4.56", ALSU p265), rather than the more common Lookahead LR(1) ("LALR(1)") - algorithm. LALR(1) produces smaller tables, but is more complex and does - not cover all LR(1) grammars. When the LR(1) and LALR(1) algorithms were - invented, table sizes were an important consideration; now, the difference - between a few hundred and a few thousand entries is unlikely to matter. - - At this time, Grammar does not handle ambiguous grammars, which are commonly - used to handle precedence, associativity, and the "dangling else" problem. - Formally, these can always be handled by an unambiguous grammar, though - doing so can be cumbersome, particularly for expression languages with many - levels of precedence. ALSU section 4.8 (pp278-287) contains some techniques - for handling these kinds of ambiguity. - - Returns: - A Parser. - """ - item_sets, goto = self._items() - action = {} - conflicts = set() - end_item = self._item_cache[self._seed_production, 1, END_OF_INPUT] - for i in range(len(item_sets)): - for item in item_sets[i]: - new_action = None - if (item.next_symbol is None and - item.production != self._seed_production): - terminal = item.terminal - new_action = Reduce(item.production) - elif item.next_symbol in self.terminals: - terminal = item.next_symbol - assert goto[i, terminal] is not None - new_action = Shift(goto[i, terminal], item_sets[goto[i, terminal]]) - if new_action: - if (i, terminal) in action and action[i, terminal] != new_action: - conflicts.add( - Conflict(i, terminal, - frozenset([action[i, terminal], new_action]))) - action[i, terminal] = new_action - if item == end_item: - new_action = Accept() - assert (i, END_OF_INPUT - ) not in action or action[i, END_OF_INPUT] == new_action - action[i, END_OF_INPUT] = new_action - trimmed_goto = {} - for k in goto: - if k[1] in self.nonterminals: - trimmed_goto[k] = goto[k] - expected = {} - for state, terminal in action: - if state not in expected: - expected[state] = set() - expected[state].add(terminal) - return Parser(item_sets, trimmed_goto, action, expected, conflicts, - self.terminals, self.nonterminals, self.productions) - - -ParseError = collections.namedtuple("ParseError", ["code", "index", "token", - "state", "expected_tokens"]) -ParseResult = collections.namedtuple("ParseResult", ["parse_tree", "error"]) - -class Parser(object): - """Parser is a shift-reduce LR(1) parser. - - Generally, clients will want to get a Parser from a Grammar, rather than - directly instantiating one. - - Parser exposes the raw tables needed to feed into a Shift-Reduce parser, - but can also be used directly for parsing. - - Attributes: - item_sets: A list of item sets which correspond to the state numbers in - the action and goto tables. This is not necessary for parsing, but is - useful for debugging parsers. - goto: The GOTO table for this parser. - action: The ACTION table for this parser. - expected: A table of terminal symbols that are expected (that is, that - have a non-Error action) for each state. This can be used to provide - more helpful error messages for parse errors. - conflicts: A set of unresolved conflicts found during table generation. - terminals: A set of terminal symbols in the grammar. - nonterminals: A set of nonterminal symbols in the grammar. - productions: A list of productions in the grammar. - default_errors: A dict of states to default error codes to use when - encountering an error in that state, when a more-specific Error for the - state/terminal pair has not been set. - """ - - def __init__(self, item_sets, goto, action, expected, conflicts, terminals, - nonterminals, productions): - super(Parser, self).__init__() - self.item_sets = item_sets - self.goto = goto - self.action = action - self.expected = expected - self.conflicts = conflicts - self.terminals = terminals - self.nonterminals = nonterminals - self.productions = productions - self.default_errors = {} - - def _parse(self, tokens): - """_parse implements Shift-Reduce parsing algorithm. - - _parse implements the standard shift-reduce algorithm outlined on ASLU - pp236-237. - - Arguments: - tokens: the list of token objects to parse. - - Returns: - A ParseResult. - """ - # The END_OF_INPUT token is explicitly added to avoid explicit "cursor < - # len(tokens)" checks. - tokens = list(tokens) + [Symbol(END_OF_INPUT)] - - # Each element of stack is a parse state and a (possibly partial) parse - # tree. The state at the top of the stack encodes which productions are - # "active" (that is, which ones the parser has seen partial input which - # matches some prefix of the production, in a place where that production - # might be valid), and, for each active production, how much of the - # production has been completed. - stack = [(0, None)] - - def state(): - return stack[-1][0] - - cursor = 0 - - # On each iteration, look at the next symbol and the current state, and - # perform the corresponding action. - while True: - if (state(), tokens[cursor].symbol) not in self.action: - # Most state/symbol entries would be Errors, so rather than exhaustively - # adding error entries, we just check here. - if state() in self.default_errors: - next_action = Error(self.default_errors[state()]) - else: - next_action = Error(None) - else: - next_action = self.action[state(), tokens[cursor].symbol] - - if isinstance(next_action, Shift): - # Shift means that there are no "complete" productions on the stack, - # and so the current token should be shifted onto the stack, with a new - # state indicating the new set of "active" productions. - stack.append((next_action.state, tokens[cursor])) - cursor += 1 - elif isinstance(next_action, Accept): - # Accept means that parsing is over, successfully. - assert len(stack) == 2, "Accepted incompletely-reduced input." - assert tokens[cursor].symbol == END_OF_INPUT, ("Accepted parse before " - "end of input.") - return ParseResult(stack[-1][1], None) - elif isinstance(next_action, Reduce): - # Reduce means that there is a complete production on the stack, and - # that the next symbol implies that the completed production is the - # correct production. + def __init__(self, start_symbol, productions): + """Constructs a Grammar object. + + Arguments: + start_symbol: The start symbol for the grammar. + productions: A list of productions (not including the "S' -> start_symbol" + production). + """ + object.__init__(self) + self.start = start_symbol + self._seed_production = parser_types.Production(START_PRIME, (self.start,)) + self.productions = productions + [self._seed_production] + + self._single_level_closure_of_item_cache = {} + self._closure_of_item_cache = {} + self._compute_symbols() + self._compute_seed_firsts() + self._set_productions_by_lhs() + self._populate_item_cache() + + def _set_productions_by_lhs(self): + # Prepopulating _productions_by_lhs speeds up _closure_of_item by about 30%, + # which is significant on medium-to-large grammars. + self._productions_by_lhs = {} + for production in self.productions: + self._productions_by_lhs.setdefault(production.lhs, list()).append( + production + ) + + def _populate_item_cache(self): + # There are a relatively small number of possible Items for a grammar, and + # the algorithm needs to get Items from their constituent components very + # frequently. As it turns out, pre-caching all possible Items results in a + # ~35% overall speedup to Grammar.parser(). + self._item_cache = {} + for symbol in self.terminals: + for production in self.productions: + for dot in range(len(production.rhs) + 1): + self._item_cache[production, dot, symbol] = make_item( + production, dot, symbol + ) + + def _compute_symbols(self): + """Finds all grammar symbols, and sorts them into terminal and non-terminal. + + Nonterminal symbols are those which appear on the left side of any + production. Terminal symbols are those which do not. + + _compute_symbols is used during __init__. + """ + self.symbols = {END_OF_INPUT} + self.nonterminals = set() + for production in self.productions: + self.symbols.add(production.lhs) + self.nonterminals.add(production.lhs) + for symbol in production.rhs: + self.symbols.add(symbol) + self.terminals = self.symbols - self.nonterminals + + def _compute_seed_firsts(self): + """Computes FIRST (ALSU p221) for all terminal and nonterminal symbols. + + The algorithm for computing FIRST is an iterative one that terminates when + it reaches a fixed point (that is, when further iterations stop changing + state). _compute_seed_firsts computes the fixed point for all single-symbol + strings, by repeatedly calling _first and updating the internal _firsts + table with the results. + + Once _compute_seed_firsts has completed, _first will return correct results + for both single- and multi-symbol strings. + + _compute_seed_firsts is used during __init__. + """ + self.firsts = {} + # FIRST for a terminal symbol is always just that terminal symbol. + for terminal in self.terminals: + self.firsts[terminal] = set([terminal]) + for nonterminal in self.nonterminals: + self.firsts[nonterminal] = set() + while True: + # The first iteration picks up all the productions that start with + # terminal symbols. The second iteration picks up productions that start + # with nonterminals that the first iteration picked up. The third + # iteration picks up nonterminals that the first and second picked up, and + # so on. + # + # This is guaranteed to end, in the worst case, when every terminal + # symbol and epsilon has been added to the _firsts set for every + # nonterminal symbol. This would be slow, but requires a pathological + # grammar; useful grammars should complete in only a few iterations. + firsts_to_add = {} + for production in self.productions: + for first in self._first(production.rhs): + if first not in self.firsts[production.lhs]: + if production.lhs not in firsts_to_add: + firsts_to_add[production.lhs] = set() + firsts_to_add[production.lhs].add(first) + if not firsts_to_add: + break + for symbol in firsts_to_add: + self.firsts[symbol].update(firsts_to_add[symbol]) + + def _first(self, symbols): + """The FIRST function from ALSU p221. + + _first takes a string of symbols (both terminals and nonterminals) and + returns the set of terminal symbols which could be the first terminal symbol + of a string produced by the given list of symbols. + + _first will not give fully-correct results until _compute_seed_firsts + finishes, but is called by _compute_seed_firsts, and must provide partial + results during that method's execution. + + Args: + symbols: A list of symbols. + + Returns: + A set of terminals which could be the first terminal in "symbols." + """ + result = set() + all_contain_epsilon = True + for symbol in symbols: + for first in self.firsts[symbol]: + if first: + result.add(first) + if None not in self.firsts[symbol]: + all_contain_epsilon = False + break + if all_contain_epsilon: + # "None" seems like a Pythonic way of representing epsilon (no symbol). + result.add(None) + return result + + def _closure_of_item(self, root_item): + """Modified implementation of CLOSURE from ALSU p261. + + _closure_of_item performs the CLOSURE function with a single seed item, with + memoization. In the algorithm as presented in ALSU, CLOSURE is called with + a different set of items every time, which is unhelpful for memoization. + Instead, we let _parallel_goto merge the sets returned by _closure_of_item, + which results in a ~40% speedup. + + CLOSURE, roughly, computes the set of LR(1) Items which might be active when + a "seed" set of Items is active. + + Technically, it is the epsilon-closure of the NFA states represented by + "items," where an epsilon transition (a transition that does not consume any + symbols) occurs from a->Z.bY,q to b->.X,p when p is in FIRST(Yq). (a and b + are nonterminals, X, Y, and Z are arbitrary strings of symbols, and p and q + are terminals.) That is, it is the set of all NFA states which can be + reached from "items" without consuming any input. This set corresponds to a + single DFA state. + + Args: + root_item: The initial LR(1) Item. + + Returns: + A set of LR(1) items which may be active at the time when the provided + item is active. + """ + if root_item in self._closure_of_item_cache: + return self._closure_of_item_cache[root_item] + item_set = set([root_item]) + item_list = [root_item] + i = 0 + # Each newly-added Item may trigger the addition of further Items, so + # iterate until no new Items are added. In the worst case, a new Item will + # be added for each production. # - # Per ALSU, we would simply pop an element off the state stack for each - # symbol on the rhs of the production, and then push a new state by - # looking up the (post-pop) current state and the lhs of the production - # in GOTO. The GOTO table, in some sense, is equivalent to shift - # actions for nonterminal symbols. + # This algorithm is really looking for "next" nonterminals in the existing + # items, and adding new items corresponding to their productions. + while i < len(item_list): + item = item_list[i] + i += 1 + if not item.next_symbol: + continue + # If _closure_of_item_cache contains the full closure of item, then we can + # add its full closure to the result set, and skip checking any of its + # items: any item that would be added by any item in the cached result + # will already be in the _closure_of_item_cache entry. + if item in self._closure_of_item_cache: + item_set |= self._closure_of_item_cache[item] + continue + # Even if we don't have the full closure of item, we may have the + # immediate closure of item. It turns out that memoizing just this step + # speeds up this function by about 50%, even after the + # _closure_of_item_cache check. + if item not in self._single_level_closure_of_item_cache: + new_items = set() + for production in self._productions_by_lhs.get(item.next_symbol, []): + for terminal in self._first( + item.production.rhs[item.dot + 1 :] + (item.terminal,) + ): + new_items.add(self._item_cache[production, 0, terminal]) + self._single_level_closure_of_item_cache[item] = new_items + for new_item in self._single_level_closure_of_item_cache[item]: + if new_item not in item_set: + item_set.add(new_item) + item_list.append(new_item) + self._closure_of_item_cache[root_item] = item_set + # Typically, _closure_of_item() will be called on items whose closures + # bring in the greatest number of additional items, then on items which + # close over fewer and fewer other items. Since items are not added to + # _closure_of_item_cache unless _closure_of_item() is called directly on + # them, this means that it is unlikely that items brought in will (without + # intervention) have entries in _closure_of_item_cache, which slows down the + # computation of the larger closures. # - # Here, we attach a new partial parse tree, with the production lhs as - # the "name" of the tree, and the popped trees as the "children" of the - # new tree. - children = [ - item[1] for item in stack[len(stack) - len(next_action.rule.rhs):] + # Although it is not guaranteed, items added to item_list last will tend to + # close over fewer items, and therefore be easier to compute. By forcibly + # re-calculating closures from last to first, and adding the results to + # _closure_of_item_cache at each step, we get a modest performance + # improvement: roughly 50% less time spent in _closure_of_item, which + # translates to about 5% less time in parser(). + for item in item_list[::-1]: + self._closure_of_item(item) + return item_set + + def _parallel_goto(self, items): + """The GOTO function from ALSU p261, executed on all symbols. + + _parallel_goto takes a set of Items, and returns a dict from every symbol in + self.symbols to the set of Items that would be active after a shift + operation (if symbol is a terminal) or after a reduction operation (if + symbol is a nonterminal). + + _parallel_goto is used in lieu of the single-symbol GOTO from ALSU because + it eliminates the outer loop over self.terminals, and thereby reduces the + number of next_symbol calls by a factor of len(self.terminals). + + Args: + items: The set of items representing the initial DFA state. + + Returns: + A dict from symbols to sets of items representing the new DFA states. + """ + results = collections.defaultdict(set) + for item in items: + next_symbol = item.next_symbol + if next_symbol is None: + continue + item = self._item_cache[item.production, item.dot + 1, item.terminal] + # Inlining the cache check results in a ~25% speedup in this function, and + # about 10% overall speedup to parser(). + if item in self._closure_of_item_cache: + closure = self._closure_of_item_cache[item] + else: + closure = self._closure_of_item(item) + # _closure will add newly-started Items (Items with dot=0) to the result + # set. After this operation, the result set will correspond to the new + # state. + results[next_symbol].update(closure) + return results + + def _items(self): + """The items function from ALSU p261. + + _items computes the set of sets of LR(1) items for a shift-reduce parser + that matches the grammar. Each set of LR(1) items corresponds to a single + DFA state. + + Returns: + A tuple. + + The first element of the tuple is a list of sets of LR(1) items (each set + corresponding to a DFA state). + + The second element of the tuple is a dictionary from (int, symbol) pairs + to ints, where all the ints are indexes into the list of sets of LR(1) + items. This dictionary is based on the results of the _Goto function, + where item_sets[dict[i, sym]] == self._Goto(item_sets[i], sym). + """ + # The list of states is seeded with the marker S' production. + item_list = [ + frozenset( + self._closure_of_item( + self._item_cache[self._seed_production, 0, END_OF_INPUT] + ) + ) ] - # Attach source_location, if known. The source location will not be - # known if the reduction consumes no symbols (empty rhs) or if the - # client did not specify source_locations for tokens. - # - # It is necessary to loop in order to handle cases like: - # - # C -> c D - # D -> - # - # The D child of the C reduction will not have a source location - # (because it is not produced from any source), so it is necessary to - # scan backwards through C's children to find the end position. The - # opposite is required in the case where initial children have no - # source. - # - # These loops implicitly handle the case where the reduction has no - # children, setting the source_location to None in that case. - start_position = None - end_position = None - for child in children: - if hasattr(child, - "source_location") and child.source_location is not None: - start_position = child.source_location.start - break - for child in reversed(children): - if hasattr(child, - "source_location") and child.source_location is not None: - end_position = child.source_location.end - break - if start_position is None: - source_location = None - else: - source_location = parser_types.make_location(start_position, - end_position) - reduction = Reduction(next_action.rule.lhs, children, next_action.rule, - source_location) - del stack[len(stack) - len(next_action.rule.rhs):] - stack.append((self.goto[state(), next_action.rule.lhs], reduction)) - elif isinstance(next_action, Error): - # Error means that the parse is impossible. For typical grammars and - # texts, this usually happens within a few tokens after the mistake in - # the input stream, which is convenient (though imperfect) for error - # reporting. - return ParseResult(None, - ParseError(next_action.code, cursor, tokens[cursor], - state(), self.expected[state()])) - else: - assert False, "Shouldn't be here." - - def mark_error(self, tokens, error_token, error_code): - """Marks an error state with the given error code. - - mark_error implements the equivalent of the "Merr" system presented in - "Generating LR Syntax error Messages from Examples" (Jeffery, 2003). - This system has limitations, but has the primary advantage that error - messages can be specified by giving an example of the error and the - message itself. - - Arguments: - tokens: a list of tokens to parse. - error_token: the token where the parse should fail, or None if the parse - should fail at the implicit end-of-input token. - - If the error_token is the special ANY_TOKEN, then the error will be - recorded as the default error for the error state. - error_code: a value to record for the error state reached by parsing - tokens. - - Returns: - None if error_code was successfully recorded, or an error message if there - was a problem. + items = {item_list[0]: 0} + goto_table = {} + i = 0 + # For each state, figure out what the new state when each symbol is added to + # the top of the parsing stack (see the comments in parser._parse). See + # _Goto for an explanation of how that is actually computed. + while i < len(item_list): + item_set = item_list[i] + gotos = self._parallel_goto(item_set) + for symbol, goto in gotos.items(): + goto = frozenset(goto) + if goto not in items: + items[goto] = len(item_list) + item_list.append(goto) + goto_table[i, symbol] = items[goto] + i += 1 + return item_list, goto_table + + def parser(self): + """parser returns an LR(1) parser for the Grammar. + + This implements the Canonical LR(1) ("LR(1)") parser algorithm ("Algorithm + 4.56", ALSU p265), rather than the more common Lookahead LR(1) ("LALR(1)") + algorithm. LALR(1) produces smaller tables, but is more complex and does + not cover all LR(1) grammars. When the LR(1) and LALR(1) algorithms were + invented, table sizes were an important consideration; now, the difference + between a few hundred and a few thousand entries is unlikely to matter. + + At this time, Grammar does not handle ambiguous grammars, which are commonly + used to handle precedence, associativity, and the "dangling else" problem. + Formally, these can always be handled by an unambiguous grammar, though + doing so can be cumbersome, particularly for expression languages with many + levels of precedence. ALSU section 4.8 (pp278-287) contains some techniques + for handling these kinds of ambiguity. + + Returns: + A Parser. + """ + item_sets, goto = self._items() + action = {} + conflicts = set() + end_item = self._item_cache[self._seed_production, 1, END_OF_INPUT] + for i in range(len(item_sets)): + for item in item_sets[i]: + new_action = None + if ( + item.next_symbol is None + and item.production != self._seed_production + ): + terminal = item.terminal + new_action = Reduce(item.production) + elif item.next_symbol in self.terminals: + terminal = item.next_symbol + assert goto[i, terminal] is not None + new_action = Shift(goto[i, terminal], item_sets[goto[i, terminal]]) + if new_action: + if (i, terminal) in action and action[i, terminal] != new_action: + conflicts.add( + Conflict( + i, + terminal, + frozenset([action[i, terminal], new_action]), + ) + ) + action[i, terminal] = new_action + if item == end_item: + new_action = Accept() + assert (i, END_OF_INPUT) not in action or action[ + i, END_OF_INPUT + ] == new_action + action[i, END_OF_INPUT] = new_action + trimmed_goto = {} + for k in goto: + if k[1] in self.nonterminals: + trimmed_goto[k] = goto[k] + expected = {} + for state, terminal in action: + if state not in expected: + expected[state] = set() + expected[state].add(terminal) + return Parser( + item_sets, + trimmed_goto, + action, + expected, + conflicts, + self.terminals, + self.nonterminals, + self.productions, + ) + + +ParseError = collections.namedtuple( + "ParseError", ["code", "index", "token", "state", "expected_tokens"] +) +ParseResult = collections.namedtuple("ParseResult", ["parse_tree", "error"]) + + +class Parser(object): + """Parser is a shift-reduce LR(1) parser. + + Generally, clients will want to get a Parser from a Grammar, rather than + directly instantiating one. + + Parser exposes the raw tables needed to feed into a Shift-Reduce parser, + but can also be used directly for parsing. + + Attributes: + item_sets: A list of item sets which correspond to the state numbers in + the action and goto tables. This is not necessary for parsing, but is + useful for debugging parsers. + goto: The GOTO table for this parser. + action: The ACTION table for this parser. + expected: A table of terminal symbols that are expected (that is, that + have a non-Error action) for each state. This can be used to provide + more helpful error messages for parse errors. + conflicts: A set of unresolved conflicts found during table generation. + terminals: A set of terminal symbols in the grammar. + nonterminals: A set of nonterminal symbols in the grammar. + productions: A list of productions in the grammar. + default_errors: A dict of states to default error codes to use when + encountering an error in that state, when a more-specific Error for the + state/terminal pair has not been set. """ - result = self._parse(tokens) - - # There is no error state to mark on a successful parse. - if not result.error: - return "Input successfully parsed." - - # Check if the error occurred at the specified token; if not, then this was - # not the expected error. - if error_token is None: - error_symbol = END_OF_INPUT - if result.error.token.symbol != END_OF_INPUT: - return "error occurred on {} token, not end of input.".format( - result.error.token.symbol) - else: - error_symbol = error_token.symbol - if result.error.token != error_token: - return "error occurred on {} token, not {} token.".format( - result.error.token.symbol, error_token.symbol) - - # If the expected error was found, attempt to mark it. It is acceptable if - # the given error_code is already set as the error code for the given parse, - # but not if a different code is set. - if result.error.token == ANY_TOKEN: - # For ANY_TOKEN, mark it as a default error. - if result.error.state in self.default_errors: - if self.default_errors[result.error.state] == error_code: - return None + + def __init__( + self, + item_sets, + goto, + action, + expected, + conflicts, + terminals, + nonterminals, + productions, + ): + super(Parser, self).__init__() + self.item_sets = item_sets + self.goto = goto + self.action = action + self.expected = expected + self.conflicts = conflicts + self.terminals = terminals + self.nonterminals = nonterminals + self.productions = productions + self.default_errors = {} + + def _parse(self, tokens): + """_parse implements Shift-Reduce parsing algorithm. + + _parse implements the standard shift-reduce algorithm outlined on ASLU + pp236-237. + + Arguments: + tokens: the list of token objects to parse. + + Returns: + A ParseResult. + """ + # The END_OF_INPUT token is explicitly added to avoid explicit "cursor < + # len(tokens)" checks. + tokens = list(tokens) + [Symbol(END_OF_INPUT)] + + # Each element of stack is a parse state and a (possibly partial) parse + # tree. The state at the top of the stack encodes which productions are + # "active" (that is, which ones the parser has seen partial input which + # matches some prefix of the production, in a place where that production + # might be valid), and, for each active production, how much of the + # production has been completed. + stack = [(0, None)] + + def state(): + return stack[-1][0] + + cursor = 0 + + # On each iteration, look at the next symbol and the current state, and + # perform the corresponding action. + while True: + if (state(), tokens[cursor].symbol) not in self.action: + # Most state/symbol entries would be Errors, so rather than exhaustively + # adding error entries, we just check here. + if state() in self.default_errors: + next_action = Error(self.default_errors[state()]) + else: + next_action = Error(None) + else: + next_action = self.action[state(), tokens[cursor].symbol] + + if isinstance(next_action, Shift): + # Shift means that there are no "complete" productions on the stack, + # and so the current token should be shifted onto the stack, with a new + # state indicating the new set of "active" productions. + stack.append((next_action.state, tokens[cursor])) + cursor += 1 + elif isinstance(next_action, Accept): + # Accept means that parsing is over, successfully. + assert len(stack) == 2, "Accepted incompletely-reduced input." + assert tokens[cursor].symbol == END_OF_INPUT, ( + "Accepted parse before " "end of input." + ) + return ParseResult(stack[-1][1], None) + elif isinstance(next_action, Reduce): + # Reduce means that there is a complete production on the stack, and + # that the next symbol implies that the completed production is the + # correct production. + # + # Per ALSU, we would simply pop an element off the state stack for each + # symbol on the rhs of the production, and then push a new state by + # looking up the (post-pop) current state and the lhs of the production + # in GOTO. The GOTO table, in some sense, is equivalent to shift + # actions for nonterminal symbols. + # + # Here, we attach a new partial parse tree, with the production lhs as + # the "name" of the tree, and the popped trees as the "children" of the + # new tree. + children = [ + item[1] for item in stack[len(stack) - len(next_action.rule.rhs) :] + ] + # Attach source_location, if known. The source location will not be + # known if the reduction consumes no symbols (empty rhs) or if the + # client did not specify source_locations for tokens. + # + # It is necessary to loop in order to handle cases like: + # + # C -> c D + # D -> + # + # The D child of the C reduction will not have a source location + # (because it is not produced from any source), so it is necessary to + # scan backwards through C's children to find the end position. The + # opposite is required in the case where initial children have no + # source. + # + # These loops implicitly handle the case where the reduction has no + # children, setting the source_location to None in that case. + start_position = None + end_position = None + for child in children: + if ( + hasattr(child, "source_location") + and child.source_location is not None + ): + start_position = child.source_location.start + break + for child in reversed(children): + if ( + hasattr(child, "source_location") + and child.source_location is not None + ): + end_position = child.source_location.end + break + if start_position is None: + source_location = None + else: + source_location = parser_types.make_location( + start_position, end_position + ) + reduction = Reduction( + next_action.rule.lhs, children, next_action.rule, source_location + ) + del stack[len(stack) - len(next_action.rule.rhs) :] + stack.append((self.goto[state(), next_action.rule.lhs], reduction)) + elif isinstance(next_action, Error): + # Error means that the parse is impossible. For typical grammars and + # texts, this usually happens within a few tokens after the mistake in + # the input stream, which is convenient (though imperfect) for error + # reporting. + return ParseResult( + None, + ParseError( + next_action.code, + cursor, + tokens[cursor], + state(), + self.expected[state()], + ), + ) + else: + assert False, "Shouldn't be here." + + def mark_error(self, tokens, error_token, error_code): + """Marks an error state with the given error code. + + mark_error implements the equivalent of the "Merr" system presented in + "Generating LR Syntax error Messages from Examples" (Jeffery, 2003). + This system has limitations, but has the primary advantage that error + messages can be specified by giving an example of the error and the + message itself. + + Arguments: + tokens: a list of tokens to parse. + error_token: the token where the parse should fail, or None if the parse + should fail at the implicit end-of-input token. + + If the error_token is the special ANY_TOKEN, then the error will be + recorded as the default error for the error state. + error_code: a value to record for the error state reached by parsing + tokens. + + Returns: + None if error_code was successfully recorded, or an error message if there + was a problem. + """ + result = self._parse(tokens) + + # There is no error state to mark on a successful parse. + if not result.error: + return "Input successfully parsed." + + # Check if the error occurred at the specified token; if not, then this was + # not the expected error. + if error_token is None: + error_symbol = END_OF_INPUT + if result.error.token.symbol != END_OF_INPUT: + return "error occurred on {} token, not end of input.".format( + result.error.token.symbol + ) else: - return ("Attempted to overwrite existing default error code {!r} " - "with new error code {!r} for state {}".format( - self.default_errors[result.error.state], error_code, - result.error.state)) - else: - self.default_errors[result.error.state] = error_code - return None - else: - if (result.error.state, error_symbol) in self.action: - existing_error = self.action[result.error.state, error_symbol] - assert isinstance(existing_error, Error), "Bug" - if existing_error.code == error_code: - return None + error_symbol = error_token.symbol + if result.error.token != error_token: + return "error occurred on {} token, not {} token.".format( + result.error.token.symbol, error_token.symbol + ) + + # If the expected error was found, attempt to mark it. It is acceptable if + # the given error_code is already set as the error code for the given parse, + # but not if a different code is set. + if result.error.token == ANY_TOKEN: + # For ANY_TOKEN, mark it as a default error. + if result.error.state in self.default_errors: + if self.default_errors[result.error.state] == error_code: + return None + else: + return ( + "Attempted to overwrite existing default error code {!r} " + "with new error code {!r} for state {}".format( + self.default_errors[result.error.state], + error_code, + result.error.state, + ) + ) + else: + self.default_errors[result.error.state] = error_code + return None else: - return ("Attempted to overwrite existing error code {!r} with new " - "error code {!r} for state {}, terminal {}".format( - existing_error.code, error_code, result.error.state, - error_symbol)) - else: - self.action[result.error.state, error_symbol] = Error(error_code) - return None - assert False, "All other paths should lead to return." - - def parse(self, tokens): - """Parses a list of tokens. - - Arguments: - tokens: a list of tokens to parse. - - Returns: - A ParseResult. - """ - result = self._parse(tokens) - return result + if (result.error.state, error_symbol) in self.action: + existing_error = self.action[result.error.state, error_symbol] + assert isinstance(existing_error, Error), "Bug" + if existing_error.code == error_code: + return None + else: + return ( + "Attempted to overwrite existing error code {!r} with new " + "error code {!r} for state {}, terminal {}".format( + existing_error.code, + error_code, + result.error.state, + error_symbol, + ) + ) + else: + self.action[result.error.state, error_symbol] = Error(error_code) + return None + assert False, "All other paths should lead to return." + + def parse(self, tokens): + """Parses a list of tokens. + + Arguments: + tokens: a list of tokens to parse. + + Returns: + A ParseResult. + """ + result = self._parse(tokens) + return result diff --git a/compiler/front_end/lr1_test.py b/compiler/front_end/lr1_test.py index 44ffa75..ae03e2d 100644 --- a/compiler/front_end/lr1_test.py +++ b/compiler/front_end/lr1_test.py @@ -22,58 +22,77 @@ def _make_items(text): - """Makes a list of lr1.Items from the lines in text.""" - return frozenset([lr1.Item.parse(line.strip()) for line in text.splitlines()]) + """Makes a list of lr1.Items from the lines in text.""" + return frozenset([lr1.Item.parse(line.strip()) for line in text.splitlines()]) Token = collections.namedtuple("Token", ["symbol", "source_location"]) def _tokenize(text): - """"Tokenizes" text by making each character into a token.""" - result = [] - for i in range(len(text)): - result.append(Token(text[i], parser_types.make_location( - (1, i + 1), (1, i + 2)))) - return result + """ "Tokenizes" text by making each character into a token.""" + result = [] + for i in range(len(text)): + result.append( + Token(text[i], parser_types.make_location((1, i + 1), (1, i + 2))) + ) + return result def _parse_productions(text): - """Parses text into a grammar by calling Production.parse on each line.""" - return [parser_types.Production.parse(line) for line in text.splitlines()] + """Parses text into a grammar by calling Production.parse on each line.""" + return [parser_types.Production.parse(line) for line in text.splitlines()] + # Example grammar 4.54 from Aho, Sethi, Lam, Ullman (ASLU) p263. -_alsu_grammar = lr1.Grammar("S", _parse_productions("""S -> C C +_alsu_grammar = lr1.Grammar( + "S", + _parse_productions( + """S -> C C C -> c C - C -> d""")) + C -> d""" + ), +) # Item sets corresponding to the above grammar, ASLU pp263-264. _alsu_items = [ - _make_items("""S' -> . S, $ + _make_items( + """S' -> . S, $ S -> . C C, $ C -> . c C, c C -> . c C, d C -> . d, c - C -> . d, d"""), + C -> . d, d""" + ), _make_items("""S' -> S ., $"""), - _make_items("""S -> C . C, $ + _make_items( + """S -> C . C, $ C -> . c C, $ - C -> . d, $"""), - _make_items("""C -> c . C, c + C -> . d, $""" + ), + _make_items( + """C -> c . C, c C -> c . C, d C -> . c C, c C -> . c C, d C -> . d, c - C -> . d, d"""), - _make_items("""C -> d ., c - C -> d ., d"""), + C -> . d, d""" + ), + _make_items( + """C -> d ., c + C -> d ., d""" + ), _make_items("""S -> C C ., $"""), - _make_items("""C -> c . C, $ + _make_items( + """C -> c . C, $ C -> . c C, $ - C -> . d, $"""), + C -> . d, $""" + ), _make_items("""C -> d ., $"""), - _make_items("""C -> c C ., c - C -> c C ., d"""), + _make_items( + """C -> c C ., c + C -> c C ., d""" + ), _make_items("""C -> c C ., $"""), ] @@ -98,220 +117,315 @@ def _parse_productions(text): } # GOTO table corresponding to the above grammar, ASLU p266. -_alsu_goto = {(0, "S"): 1, (0, "C"): 2, (2, "C"): 5, (3, "C"): 8, (6, "C"): 9,} +_alsu_goto = { + (0, "S"): 1, + (0, "C"): 2, + (2, "C"): 5, + (3, "C"): 8, + (6, "C"): 9, +} def _normalize_table(items, table): - """Returns a canonical-form version of items and table, for comparisons.""" - item_to_original_index = {} - for i in range(len(items)): - item_to_original_index[items[i]] = i - sorted_items = items[0:1] + sorted(items[1:], key=sorted) - original_index_to_index = {} - for i in range(len(sorted_items)): - original_index_to_index[item_to_original_index[sorted_items[i]]] = i - updated_table = {} - for k in table: - new_k = original_index_to_index[k[0]], k[1] - new_value = table[k] - if isinstance(new_value, int): - new_value = original_index_to_index[new_value] - elif isinstance(new_value, lr1.Shift): - new_value = lr1.Shift(original_index_to_index[new_value.state], - new_value.items) - updated_table[new_k] = new_value - return sorted_items, updated_table + """Returns a canonical-form version of items and table, for comparisons.""" + item_to_original_index = {} + for i in range(len(items)): + item_to_original_index[items[i]] = i + sorted_items = items[0:1] + sorted(items[1:], key=sorted) + original_index_to_index = {} + for i in range(len(sorted_items)): + original_index_to_index[item_to_original_index[sorted_items[i]]] = i + updated_table = {} + for k in table: + new_k = original_index_to_index[k[0]], k[1] + new_value = table[k] + if isinstance(new_value, int): + new_value = original_index_to_index[new_value] + elif isinstance(new_value, lr1.Shift): + new_value = lr1.Shift( + original_index_to_index[new_value.state], new_value.items + ) + updated_table[new_k] = new_value + return sorted_items, updated_table class Lr1Test(unittest.TestCase): - """Tests for lr1.""" - - def test_parse_lr1item(self): - self.assertEqual(lr1.Item.parse("S' -> . S, $"), - lr1.Item(parser_types.Production(lr1.START_PRIME, ("S",)), - 0, lr1.END_OF_INPUT, "S")) - - def test_symbol_extraction(self): - self.assertEqual(_alsu_grammar.terminals, set(["c", "d", lr1.END_OF_INPUT])) - self.assertEqual(_alsu_grammar.nonterminals, set(["S", "C", - lr1.START_PRIME])) - self.assertEqual(_alsu_grammar.symbols, - set(["c", "d", "S", "C", lr1.END_OF_INPUT, - lr1.START_PRIME])) - - def test_items(self): - self.assertEqual(set(_alsu_grammar._items()[0]), frozenset(_alsu_items)) - - def test_terminal_nonterminal_production_tables(self): - parser = _alsu_grammar.parser() - self.assertEqual(parser.terminals, _alsu_grammar.terminals) - self.assertEqual(parser.nonterminals, _alsu_grammar.nonterminals) - self.assertEqual(parser.productions, _alsu_grammar.productions) - - def test_action_table(self): - parser = _alsu_grammar.parser() - norm_items, norm_action = _normalize_table(parser.item_sets, parser.action) - test_items, test_action = _normalize_table(_alsu_items, _alsu_action) - self.assertEqual(norm_items, test_items) - self.assertEqual(norm_action, test_action) - - def test_goto_table(self): - parser = _alsu_grammar.parser() - norm_items, norm_goto = _normalize_table(parser.item_sets, parser.goto) - test_items, test_goto = _normalize_table(_alsu_items, _alsu_goto) - self.assertEqual(norm_items, test_items) - self.assertEqual(norm_goto, test_goto) - - def test_successful_parse(self): - parser = _alsu_grammar.parser() - loc = parser_types.parse_location - s_to_c_c = parser_types.Production.parse("S -> C C") - c_to_c_c = parser_types.Production.parse("C -> c C") - c_to_d = parser_types.Production.parse("C -> d") - self.assertEqual( - lr1.Reduction("S", [lr1.Reduction("C", [ - Token("c", loc("1:1-1:2")), lr1.Reduction( - "C", [Token("c", loc("1:2-1:3")), - lr1.Reduction("C", - [Token("c", loc("1:3-1:4")), lr1.Reduction( - "C", [Token("d", loc("1:4-1:5"))], - c_to_d, loc("1:4-1:5"))], c_to_c_c, - loc("1:3-1:5"))], c_to_c_c, loc("1:2-1:5")) - ], c_to_c_c, loc("1:1-1:5")), lr1.Reduction( - "C", [Token("c", loc("1:5-1:6")), - lr1.Reduction("C", [Token("d", loc("1:6-1:7"))], c_to_d, - loc("1:6-1:7"))], c_to_c_c, loc("1:5-1:7"))], - s_to_c_c, loc("1:1-1:7")), - parser.parse(_tokenize("cccdcd")).parse_tree) - self.assertEqual( - lr1.Reduction("S", [ - lr1.Reduction("C", [Token("d", loc("1:1-1:2"))], c_to_d, loc( - "1:1-1:2")), lr1.Reduction("C", [Token("d", loc("1:2-1:3"))], - c_to_d, loc("1:2-1:3")) - ], s_to_c_c, loc("1:1-1:3")), parser.parse(_tokenize("dd")).parse_tree) - - def test_parse_with_no_source_information(self): - parser = _alsu_grammar.parser() - s_to_c_c = parser_types.Production.parse("S -> C C") - c_to_d = parser_types.Production.parse("C -> d") - self.assertEqual( - lr1.Reduction("S", [ - lr1.Reduction("C", [Token("d", None)], c_to_d, None), - lr1.Reduction("C", [Token("d", None)], c_to_d, None) - ], s_to_c_c, None), - parser.parse([Token("d", None), Token("d", None)]).parse_tree) - - def test_failed_parses(self): - parser = _alsu_grammar.parser() - self.assertEqual(None, parser.parse(_tokenize("d")).parse_tree) - self.assertEqual(None, parser.parse(_tokenize("cccd")).parse_tree) - self.assertEqual(None, parser.parse(_tokenize("")).parse_tree) - self.assertEqual(None, parser.parse(_tokenize("cccdc")).parse_tree) - - def test_mark_error(self): - parser = _alsu_grammar.parser() - self.assertIsNone(parser.mark_error(_tokenize("cccdc"), None, - "missing last d")) - self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C")) - # Marking an already-marked error with the same error code should succeed. - self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C")) - # Marking an already-marked error with a different error code should fail. - self.assertRegexpMatches( - parser.mark_error(_tokenize("d"), None, "different message"), - r"^Attempted to overwrite existing error code 'missing last C' with " - r"new error code 'different message' for state \d+, terminal \$$") - self.assertEqual( - "Input successfully parsed.", - parser.mark_error(_tokenize("dd"), None, "good parse")) - self.assertEqual( - parser.mark_error(_tokenize("x"), None, "wrong location"), - "error occurred on x token, not end of input.") - self.assertEqual( - parser.mark_error([], _tokenize("x")[0], "wrong location"), - "error occurred on $ token, not x token.") - self.assertIsNone( - parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error")) - # Marking an already-marked error with the same error code should succeed. - self.assertIsNone( - parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error")) - # Marking an already-marked error with a different error code should fail. - self.assertRegexpMatches( - parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error 2"), - r"^Attempted to overwrite existing default error code 'default error' " - r"with new error code 'default error 2' for state \d+$") - - self.assertEqual( - "missing last d", parser.parse(_tokenize("cccdc")).error.code) - self.assertEqual("missing last d", parser.parse(_tokenize("dc")).error.code) - self.assertEqual("missing last C", parser.parse(_tokenize("d")).error.code) - self.assertEqual("default error", parser.parse(_tokenize("z")).error.code) - self.assertEqual( - "missing last C", parser.parse(_tokenize("ccccd")).error.code) - self.assertEqual(None, parser.parse(_tokenize("ccc")).error.code) - - def test_grammar_with_empty_rhs(self): - grammar = lr1.Grammar("S", _parse_productions("""S -> A B + """Tests for lr1.""" + + def test_parse_lr1item(self): + self.assertEqual( + lr1.Item.parse("S' -> . S, $"), + lr1.Item( + parser_types.Production(lr1.START_PRIME, ("S",)), + 0, + lr1.END_OF_INPUT, + "S", + ), + ) + + def test_symbol_extraction(self): + self.assertEqual(_alsu_grammar.terminals, set(["c", "d", lr1.END_OF_INPUT])) + self.assertEqual(_alsu_grammar.nonterminals, set(["S", "C", lr1.START_PRIME])) + self.assertEqual( + _alsu_grammar.symbols, + set(["c", "d", "S", "C", lr1.END_OF_INPUT, lr1.START_PRIME]), + ) + + def test_items(self): + self.assertEqual(set(_alsu_grammar._items()[0]), frozenset(_alsu_items)) + + def test_terminal_nonterminal_production_tables(self): + parser = _alsu_grammar.parser() + self.assertEqual(parser.terminals, _alsu_grammar.terminals) + self.assertEqual(parser.nonterminals, _alsu_grammar.nonterminals) + self.assertEqual(parser.productions, _alsu_grammar.productions) + + def test_action_table(self): + parser = _alsu_grammar.parser() + norm_items, norm_action = _normalize_table(parser.item_sets, parser.action) + test_items, test_action = _normalize_table(_alsu_items, _alsu_action) + self.assertEqual(norm_items, test_items) + self.assertEqual(norm_action, test_action) + + def test_goto_table(self): + parser = _alsu_grammar.parser() + norm_items, norm_goto = _normalize_table(parser.item_sets, parser.goto) + test_items, test_goto = _normalize_table(_alsu_items, _alsu_goto) + self.assertEqual(norm_items, test_items) + self.assertEqual(norm_goto, test_goto) + + def test_successful_parse(self): + parser = _alsu_grammar.parser() + loc = parser_types.parse_location + s_to_c_c = parser_types.Production.parse("S -> C C") + c_to_c_c = parser_types.Production.parse("C -> c C") + c_to_d = parser_types.Production.parse("C -> d") + self.assertEqual( + lr1.Reduction( + "S", + [ + lr1.Reduction( + "C", + [ + Token("c", loc("1:1-1:2")), + lr1.Reduction( + "C", + [ + Token("c", loc("1:2-1:3")), + lr1.Reduction( + "C", + [ + Token("c", loc("1:3-1:4")), + lr1.Reduction( + "C", + [Token("d", loc("1:4-1:5"))], + c_to_d, + loc("1:4-1:5"), + ), + ], + c_to_c_c, + loc("1:3-1:5"), + ), + ], + c_to_c_c, + loc("1:2-1:5"), + ), + ], + c_to_c_c, + loc("1:1-1:5"), + ), + lr1.Reduction( + "C", + [ + Token("c", loc("1:5-1:6")), + lr1.Reduction( + "C", + [Token("d", loc("1:6-1:7"))], + c_to_d, + loc("1:6-1:7"), + ), + ], + c_to_c_c, + loc("1:5-1:7"), + ), + ], + s_to_c_c, + loc("1:1-1:7"), + ), + parser.parse(_tokenize("cccdcd")).parse_tree, + ) + self.assertEqual( + lr1.Reduction( + "S", + [ + lr1.Reduction( + "C", [Token("d", loc("1:1-1:2"))], c_to_d, loc("1:1-1:2") + ), + lr1.Reduction( + "C", [Token("d", loc("1:2-1:3"))], c_to_d, loc("1:2-1:3") + ), + ], + s_to_c_c, + loc("1:1-1:3"), + ), + parser.parse(_tokenize("dd")).parse_tree, + ) + + def test_parse_with_no_source_information(self): + parser = _alsu_grammar.parser() + s_to_c_c = parser_types.Production.parse("S -> C C") + c_to_d = parser_types.Production.parse("C -> d") + self.assertEqual( + lr1.Reduction( + "S", + [ + lr1.Reduction("C", [Token("d", None)], c_to_d, None), + lr1.Reduction("C", [Token("d", None)], c_to_d, None), + ], + s_to_c_c, + None, + ), + parser.parse([Token("d", None), Token("d", None)]).parse_tree, + ) + + def test_failed_parses(self): + parser = _alsu_grammar.parser() + self.assertEqual(None, parser.parse(_tokenize("d")).parse_tree) + self.assertEqual(None, parser.parse(_tokenize("cccd")).parse_tree) + self.assertEqual(None, parser.parse(_tokenize("")).parse_tree) + self.assertEqual(None, parser.parse(_tokenize("cccdc")).parse_tree) + + def test_mark_error(self): + parser = _alsu_grammar.parser() + self.assertIsNone(parser.mark_error(_tokenize("cccdc"), None, "missing last d")) + self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C")) + # Marking an already-marked error with the same error code should succeed. + self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C")) + # Marking an already-marked error with a different error code should fail. + self.assertRegexpMatches( + parser.mark_error(_tokenize("d"), None, "different message"), + r"^Attempted to overwrite existing error code 'missing last C' with " + r"new error code 'different message' for state \d+, terminal \$$", + ) + self.assertEqual( + "Input successfully parsed.", + parser.mark_error(_tokenize("dd"), None, "good parse"), + ) + self.assertEqual( + parser.mark_error(_tokenize("x"), None, "wrong location"), + "error occurred on x token, not end of input.", + ) + self.assertEqual( + parser.mark_error([], _tokenize("x")[0], "wrong location"), + "error occurred on $ token, not x token.", + ) + self.assertIsNone( + parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error") + ) + # Marking an already-marked error with the same error code should succeed. + self.assertIsNone( + parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error") + ) + # Marking an already-marked error with a different error code should fail. + self.assertRegexpMatches( + parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error 2"), + r"^Attempted to overwrite existing default error code 'default error' " + r"with new error code 'default error 2' for state \d+$", + ) + + self.assertEqual("missing last d", parser.parse(_tokenize("cccdc")).error.code) + self.assertEqual("missing last d", parser.parse(_tokenize("dc")).error.code) + self.assertEqual("missing last C", parser.parse(_tokenize("d")).error.code) + self.assertEqual("default error", parser.parse(_tokenize("z")).error.code) + self.assertEqual("missing last C", parser.parse(_tokenize("ccccd")).error.code) + self.assertEqual(None, parser.parse(_tokenize("ccc")).error.code) + + def test_grammar_with_empty_rhs(self): + grammar = lr1.Grammar( + "S", + _parse_productions( + """S -> A B A -> a A A -> - B -> b""")) - parser = grammar.parser() - self.assertFalse(parser.conflicts) - self.assertTrue(parser.parse(_tokenize("ab")).parse_tree) - self.assertTrue(parser.parse(_tokenize("b")).parse_tree) - self.assertTrue(parser.parse(_tokenize("aab")).parse_tree) - - def test_grammar_with_reduce_reduce_conflicts(self): - grammar = lr1.Grammar("S", _parse_productions("""S -> A c + B -> b""" + ), + ) + parser = grammar.parser() + self.assertFalse(parser.conflicts) + self.assertTrue(parser.parse(_tokenize("ab")).parse_tree) + self.assertTrue(parser.parse(_tokenize("b")).parse_tree) + self.assertTrue(parser.parse(_tokenize("aab")).parse_tree) + + def test_grammar_with_reduce_reduce_conflicts(self): + grammar = lr1.Grammar( + "S", + _parse_productions( + """S -> A c S -> B c A -> a - B -> a""")) - parser = grammar.parser() - self.assertEqual(len(parser.conflicts), 1) - # parser.conflicts is a set - for conflict in parser.conflicts: - for action in conflict.actions: - self.assertTrue(isinstance(action, lr1.Reduce)) - - def test_grammar_with_shift_reduce_conflicts(self): - grammar = lr1.Grammar("S", _parse_productions("""S -> A B + B -> a""" + ), + ) + parser = grammar.parser() + self.assertEqual(len(parser.conflicts), 1) + # parser.conflicts is a set + for conflict in parser.conflicts: + for action in conflict.actions: + self.assertTrue(isinstance(action, lr1.Reduce)) + + def test_grammar_with_shift_reduce_conflicts(self): + grammar = lr1.Grammar( + "S", + _parse_productions( + """S -> A B A -> a A -> B -> a - B ->""")) - parser = grammar.parser() - self.assertEqual(len(parser.conflicts), 1) - # parser.conflicts is a set - for conflict in parser.conflicts: - reduces = 0 - shifts = 0 - for action in conflict.actions: - if isinstance(action, lr1.Reduce): - reduces += 1 - elif isinstance(action, lr1.Shift): - shifts += 1 - self.assertEqual(1, reduces) - self.assertEqual(1, shifts) - - def test_item_str(self): - self.assertEqual( - "a -> b c ., d", - str(lr1.make_item(parser_types.Production.parse("a -> b c"), 2, "d"))) - self.assertEqual( - "a -> b . c, d", - str(lr1.make_item(parser_types.Production.parse("a -> b c"), 1, "d"))) - self.assertEqual( - "a -> . b c, d", - str(lr1.make_item(parser_types.Production.parse("a -> b c"), 0, "d"))) - self.assertEqual( - "a -> ., d", - str(lr1.make_item(parser_types.Production.parse("a ->"), 0, "d"))) - - def test_conflict_str(self): - self.assertEqual("Conflict for 'A' in state 12: R vs S", - str(lr1.Conflict(12, "'A'", ["R", "S"]))) - self.assertEqual("Conflict for 'A' in state 12: R vs S vs T", - str(lr1.Conflict(12, "'A'", ["R", "S", "T"]))) + B ->""" + ), + ) + parser = grammar.parser() + self.assertEqual(len(parser.conflicts), 1) + # parser.conflicts is a set + for conflict in parser.conflicts: + reduces = 0 + shifts = 0 + for action in conflict.actions: + if isinstance(action, lr1.Reduce): + reduces += 1 + elif isinstance(action, lr1.Shift): + shifts += 1 + self.assertEqual(1, reduces) + self.assertEqual(1, shifts) + + def test_item_str(self): + self.assertEqual( + "a -> b c ., d", + str(lr1.make_item(parser_types.Production.parse("a -> b c"), 2, "d")), + ) + self.assertEqual( + "a -> b . c, d", + str(lr1.make_item(parser_types.Production.parse("a -> b c"), 1, "d")), + ) + self.assertEqual( + "a -> . b c, d", + str(lr1.make_item(parser_types.Production.parse("a -> b c"), 0, "d")), + ) + self.assertEqual( + "a -> ., d", + str(lr1.make_item(parser_types.Production.parse("a ->"), 0, "d")), + ) + + def test_conflict_str(self): + self.assertEqual( + "Conflict for 'A' in state 12: R vs S", + str(lr1.Conflict(12, "'A'", ["R", "S"])), + ) + self.assertEqual( + "Conflict for 'A' in state 12: R vs S vs T", + str(lr1.Conflict(12, "'A'", ["R", "S", "T"])), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py index c9ba765..bd27c8a 100644 --- a/compiler/front_end/module_ir.py +++ b/compiler/front_end/module_ir.py @@ -33,51 +33,54 @@ # Intermediate types; should not be found in the final IR. class _List(object): - """A list with source location information.""" - __slots__ = ('list', 'source_location') + """A list with source location information.""" - def __init__(self, l): - assert isinstance(l, list), "_List object must wrap list, not '%r'" % l - self.list = l - self.source_location = ir_data.Location() + __slots__ = ("list", "source_location") + + def __init__(self, l): + assert isinstance(l, list), "_List object must wrap list, not '%r'" % l + self.list = l + self.source_location = ir_data.Location() class _ExpressionTail(object): - """A fragment of an expression with an operator and right-hand side. + """A fragment of an expression with an operator and right-hand side. - _ExpressionTail is the tail of an expression, consisting of an operator and - the right-hand argument to the operator; for example, in the expression (6+8), - the _ExpressionTail would be "+8". + _ExpressionTail is the tail of an expression, consisting of an operator and + the right-hand argument to the operator; for example, in the expression (6+8), + the _ExpressionTail would be "+8". - This is used as a temporary object while converting the right-recursive - "expression" and "times-expression" productions into left-associative - Expressions. + This is used as a temporary object while converting the right-recursive + "expression" and "times-expression" productions into left-associative + Expressions. - Attributes: - operator: An ir_data.Word of the operator's name. - expression: The expression on the right side of the operator. - source_location: The source location of the operation fragment. - """ - __slots__ = ('operator', 'expression', 'source_location') + Attributes: + operator: An ir_data.Word of the operator's name. + expression: The expression on the right side of the operator. + source_location: The source location of the operation fragment. + """ - def __init__(self, operator, expression): - self.operator = operator - self.expression = expression - self.source_location = ir_data.Location() + __slots__ = ("operator", "expression", "source_location") + + def __init__(self, operator, expression): + self.operator = operator + self.expression = expression + self.source_location = ir_data.Location() class _FieldWithType(object): - """A field with zero or more types defined inline with that field.""" - __slots__ = ('field', 'subtypes', 'source_location') + """A field with zero or more types defined inline with that field.""" + + __slots__ = ("field", "subtypes", "source_location") - def __init__(self, field, subtypes=None): - self.field = field - self.subtypes = subtypes or [] - self.source_location = ir_data.Location() + def __init__(self, field, subtypes=None): + self.field = field + self.subtypes = subtypes or [] + self.source_location = ir_data.Location() def build_ir(parse_tree, used_productions=None): - r"""Builds a module-level intermediate representation from a valid parse tree. + r"""Builds a module-level intermediate representation from a valid parse tree. The parse tree is precisely dictated by the exact productions in the grammar used by the parser, with no semantic information. _really_build_ir transforms @@ -129,37 +132,39 @@ def build_ir(parse_tree, used_productions=None): a forest of module IRs so that names from other modules can be resolved. """ - # TODO(b/140259131): Refactor _really_build_ir to be less recursive/use an - # explicit stack. - old_recursion_limit = sys.getrecursionlimit() - sys.setrecursionlimit(16 * 1024) # ~8000 top-level entities in one module. - try: - result = _really_build_ir(parse_tree, used_productions) - finally: - sys.setrecursionlimit(old_recursion_limit) - return result + # TODO(b/140259131): Refactor _really_build_ir to be less recursive/use an + # explicit stack. + old_recursion_limit = sys.getrecursionlimit() + sys.setrecursionlimit(16 * 1024) # ~8000 top-level entities in one module. + try: + result = _really_build_ir(parse_tree, used_productions) + finally: + sys.setrecursionlimit(old_recursion_limit) + return result def _really_build_ir(parse_tree, used_productions): - """Real implementation of build_ir().""" - if used_productions is None: - used_productions = set() - if hasattr(parse_tree, 'children'): - parsed_children = [_really_build_ir(child, used_productions) - for child in parse_tree.children] - used_productions.add(parse_tree.production) - result = _handlers[parse_tree.production](*parsed_children) - if parse_tree.source_location is not None: - if result.source_location: - ir_data_utils.update(result.source_location, parse_tree.source_location) - else: - result.source_location = ir_data_utils.copy(parse_tree.source_location) - return result - else: - # For leaf nodes, the temporary "IR" is just the token. Higher-level rules - # will translate it to a real IR. - assert isinstance(parse_tree, parser_types.Token), str(parse_tree) - return parse_tree + """Real implementation of build_ir().""" + if used_productions is None: + used_productions = set() + if hasattr(parse_tree, "children"): + parsed_children = [ + _really_build_ir(child, used_productions) for child in parse_tree.children + ] + used_productions.add(parse_tree.production) + result = _handlers[parse_tree.production](*parsed_children) + if parse_tree.source_location is not None: + if result.source_location: + ir_data_utils.update(result.source_location, parse_tree.source_location) + else: + result.source_location = ir_data_utils.copy(parse_tree.source_location) + return result + else: + # For leaf nodes, the temporary "IR" is just the token. Higher-level rules + # will translate it to a real IR. + assert isinstance(parse_tree, parser_types.Token), str(parse_tree) + return parse_tree + # Map of productions to their handlers. _handlers = {} @@ -168,58 +173,59 @@ def _really_build_ir(parse_tree, used_productions): def _get_anonymous_field_name(): - global _anonymous_name_counter - _anonymous_name_counter += 1 - return 'emboss_reserved_anonymous_field_{}'.format(_anonymous_name_counter) + global _anonymous_name_counter + _anonymous_name_counter += 1 + return "emboss_reserved_anonymous_field_{}".format(_anonymous_name_counter) def _handles(production_text): - """_handles marks a function as the handler for a particular production.""" - production = parser_types.Production.parse(production_text) + """_handles marks a function as the handler for a particular production.""" + production = parser_types.Production.parse(production_text) - def handles(f): - _handlers[production] = f - return f + def handles(f): + _handlers[production] = f + return f - return handles + return handles def _make_prelude_import(position): - """Helper function to construct a synthetic ir_data.Import for the prelude.""" - location = parser_types.make_location(position, position) - return ir_data.Import( - file_name=ir_data.String(text='', source_location=location), - local_name=ir_data.Word(text='', source_location=location), - source_location=location) + """Helper function to construct a synthetic ir_data.Import for the prelude.""" + location = parser_types.make_location(position, position) + return ir_data.Import( + file_name=ir_data.String(text="", source_location=location), + local_name=ir_data.Word(text="", source_location=location), + source_location=location, + ) def _text_to_operator(text): - """Converts an operator's textual name to its corresponding enum.""" - operations = { - '+': ir_data.FunctionMapping.ADDITION, - '-': ir_data.FunctionMapping.SUBTRACTION, - '*': ir_data.FunctionMapping.MULTIPLICATION, - '==': ir_data.FunctionMapping.EQUALITY, - '!=': ir_data.FunctionMapping.INEQUALITY, - '&&': ir_data.FunctionMapping.AND, - '||': ir_data.FunctionMapping.OR, - '>': ir_data.FunctionMapping.GREATER, - '>=': ir_data.FunctionMapping.GREATER_OR_EQUAL, - '<': ir_data.FunctionMapping.LESS, - '<=': ir_data.FunctionMapping.LESS_OR_EQUAL, - } - return operations[text] + """Converts an operator's textual name to its corresponding enum.""" + operations = { + "+": ir_data.FunctionMapping.ADDITION, + "-": ir_data.FunctionMapping.SUBTRACTION, + "*": ir_data.FunctionMapping.MULTIPLICATION, + "==": ir_data.FunctionMapping.EQUALITY, + "!=": ir_data.FunctionMapping.INEQUALITY, + "&&": ir_data.FunctionMapping.AND, + "||": ir_data.FunctionMapping.OR, + ">": ir_data.FunctionMapping.GREATER, + ">=": ir_data.FunctionMapping.GREATER_OR_EQUAL, + "<": ir_data.FunctionMapping.LESS, + "<=": ir_data.FunctionMapping.LESS_OR_EQUAL, + } + return operations[text] def _text_to_function(text): - """Converts a function's textual name to its corresponding enum.""" - functions = { - '$max': ir_data.FunctionMapping.MAXIMUM, - '$present': ir_data.FunctionMapping.PRESENCE, - '$upper_bound': ir_data.FunctionMapping.UPPER_BOUND, - '$lower_bound': ir_data.FunctionMapping.LOWER_BOUND, - } - return functions[text] + """Converts a function's textual name to its corresponding enum.""" + functions = { + "$max": ir_data.FunctionMapping.MAXIMUM, + "$present": ir_data.FunctionMapping.PRESENCE, + "$upper_bound": ir_data.FunctionMapping.UPPER_BOUND, + "$lower_bound": ir_data.FunctionMapping.LOWER_BOUND, + } + return functions[text] ################################################################################ @@ -245,136 +251,153 @@ def _text_to_function(text): # A module file is a list of documentation, then imports, then top-level # attributes, then type definitions. Any section may be missing. # TODO(bolms): Should Emboss disallow completely empty files? -@_handles('module -> comment-line* doc-line* import-line* attribute-line*' - ' type-definition*') +@_handles( + "module -> comment-line* doc-line* import-line* attribute-line*" + " type-definition*" +) def _file(leading_newlines, docs, imports, attributes, type_definitions): - """Assembles the top-level IR for a module.""" - del leading_newlines # Unused. - # Figure out the best synthetic source_location for the synthesized prelude - # import. - if imports.list: - position = imports.list[0].source_location.start - elif docs.list: - position = docs.list[0].source_location.end - elif attributes.list: - position = attributes.list[0].source_location.start - elif type_definitions.list: - position = type_definitions.list[0].source_location.start - else: - position = 1, 1 - - # If the source file is completely empty, build_ir won't automatically - # populate the source_location attribute for the module. - if (not docs.list and not imports.list and not attributes.list and - not type_definitions.list): - module_source_location = parser_types.make_location((1, 1), (1, 1)) - else: - module_source_location = None - - return ir_data.Module( - documentation=docs.list, - foreign_import=[_make_prelude_import(position)] + imports.list, - attribute=attributes.list, - type=type_definitions.list, - source_location=module_source_location) - - -@_handles('import-line ->' - ' "import" string-constant "as" snake-word Comment? eol') + """Assembles the top-level IR for a module.""" + del leading_newlines # Unused. + # Figure out the best synthetic source_location for the synthesized prelude + # import. + if imports.list: + position = imports.list[0].source_location.start + elif docs.list: + position = docs.list[0].source_location.end + elif attributes.list: + position = attributes.list[0].source_location.start + elif type_definitions.list: + position = type_definitions.list[0].source_location.start + else: + position = 1, 1 + + # If the source file is completely empty, build_ir won't automatically + # populate the source_location attribute for the module. + if ( + not docs.list + and not imports.list + and not attributes.list + and not type_definitions.list + ): + module_source_location = parser_types.make_location((1, 1), (1, 1)) + else: + module_source_location = None + + return ir_data.Module( + documentation=docs.list, + foreign_import=[_make_prelude_import(position)] + imports.list, + attribute=attributes.list, + type=type_definitions.list, + source_location=module_source_location, + ) + + +@_handles("import-line ->" ' "import" string-constant "as" snake-word Comment? eol') def _import(import_, file_name, as_, local_name, comment, eol): - del import_, as_, comment, eol # Unused - return ir_data.Import(file_name=file_name, local_name=local_name) + del import_, as_, comment, eol # Unused + return ir_data.Import(file_name=file_name, local_name=local_name) -@_handles('doc-line -> doc Comment? eol') +@_handles("doc-line -> doc Comment? eol") def _doc_line(doc, comment, eol): - del comment, eol # Unused. - return doc + del comment, eol # Unused. + return doc -@_handles('doc -> Documentation') +@_handles("doc -> Documentation") def _doc(documentation): - # As a special case, an empty documentation string may omit the trailing - # space. - if documentation.text == '--': - doc_text = '-- ' - else: - doc_text = documentation.text - assert doc_text[0:3] == '-- ', ( - "Documentation token '{}' in unknown format.".format( - documentation.text)) - return ir_data.Documentation(text=doc_text[3:]) + # As a special case, an empty documentation string may omit the trailing + # space. + if documentation.text == "--": + doc_text = "-- " + else: + doc_text = documentation.text + assert doc_text[0:3] == "-- ", "Documentation token '{}' in unknown format.".format( + documentation.text + ) + return ir_data.Documentation(text=doc_text[3:]) # A attribute-line is just a attribute on its own line. -@_handles('attribute-line -> attribute Comment? eol') +@_handles("attribute-line -> attribute Comment? eol") def _attribute_line(attr, comment, eol): - del comment, eol # Unused. - return attr + del comment, eol # Unused. + return attr # A attribute is [name = value]. -@_handles('attribute -> "[" attribute-context? "$default"?' - ' snake-word ":" attribute-value "]"') -def _attribute(open_bracket, context_specifier, default_specifier, name, colon, - attribute_value, close_bracket): - del open_bracket, colon, close_bracket # Unused. - if context_specifier.list: - return ir_data.Attribute(name=name, - value=attribute_value, - is_default=bool(default_specifier.list), - back_end=context_specifier.list[0]) - else: - return ir_data.Attribute(name=name, - value=attribute_value, - is_default=bool(default_specifier.list)) +@_handles( + 'attribute -> "[" attribute-context? "$default"?' + ' snake-word ":" attribute-value "]"' +) +def _attribute( + open_bracket, + context_specifier, + default_specifier, + name, + colon, + attribute_value, + close_bracket, +): + del open_bracket, colon, close_bracket # Unused. + if context_specifier.list: + return ir_data.Attribute( + name=name, + value=attribute_value, + is_default=bool(default_specifier.list), + back_end=context_specifier.list[0], + ) + else: + return ir_data.Attribute( + name=name, value=attribute_value, is_default=bool(default_specifier.list) + ) @_handles('attribute-context -> "(" snake-word ")"') def _attribute_context(open_paren, context_name, close_paren): - del open_paren, close_paren # Unused. - return context_name + del open_paren, close_paren # Unused. + return context_name -@_handles('attribute-value -> expression') +@_handles("attribute-value -> expression") def _attribute_value_expression(expression): - return ir_data.AttributeValue(expression=expression) + return ir_data.AttributeValue(expression=expression) -@_handles('attribute-value -> string-constant') +@_handles("attribute-value -> string-constant") def _attribute_value_string(string): - return ir_data.AttributeValue(string_constant=string) + return ir_data.AttributeValue(string_constant=string) -@_handles('boolean-constant -> BooleanConstant') +@_handles("boolean-constant -> BooleanConstant") def _boolean_constant(boolean): - return ir_data.BooleanConstant(value=(boolean.text == 'true')) + return ir_data.BooleanConstant(value=(boolean.text == "true")) -@_handles('string-constant -> String') +@_handles("string-constant -> String") def _string_constant(string): - """Turns a String token into an ir_data.String, with proper unescaping. - - Arguments: - string: A String token. - - Returns: - An ir_data.String with the "text" field set to the unescaped value of - string.text. - """ - # TODO(bolms): If/when this logic becomes more complex (e.g., to handle \NNN - # or \xNN escapes), extract this into a separate module with separate tests. - assert string.text[0] == '"' - assert string.text[-1] == '"' - assert len(string.text) >= 2 - result = [] - for substring in re.split(r'(\\.)', string.text[1:-1]): - if substring and substring[0] == '\\': - assert len(substring) == 2 - result.append({'\\': '\\', '"': '"', 'n': '\n'}[substring[1]]) - else: - result.append(substring) - return ir_data.String(text=''.join(result)) + """Turns a String token into an ir_data.String, with proper unescaping. + + Arguments: + string: A String token. + + Returns: + An ir_data.String with the "text" field set to the unescaped value of + string.text. + """ + # TODO(bolms): If/when this logic becomes more complex (e.g., to handle \NNN + # or \xNN escapes), extract this into a separate module with separate tests. + assert string.text[0] == '"' + assert string.text[-1] == '"' + assert len(string.text) >= 2 + result = [] + for substring in re.split(r"(\\.)", string.text[1:-1]): + if substring and substring[0] == "\\": + assert len(substring) == 2 + result.append({"\\": "\\", '"': '"', "n": "\n"}[substring[1]]) + else: + result.append(substring) + return ir_data.String(text="".join(result)) # In Emboss, '&&' and '||' may not be mixed without parentheses. These are all @@ -419,179 +442,205 @@ def _string_constant(string): # and-expression-right -> '&&' equality-expression # # In either case, explicit parenthesization is handled elsewhere in the grammar. -@_handles('logical-expression -> and-expression') -@_handles('logical-expression -> or-expression') -@_handles('logical-expression -> comparison-expression') -@_handles('choice-expression -> logical-expression') -@_handles('expression -> choice-expression') +@_handles("logical-expression -> and-expression") +@_handles("logical-expression -> or-expression") +@_handles("logical-expression -> comparison-expression") +@_handles("choice-expression -> logical-expression") +@_handles("expression -> choice-expression") def _expression(expression): - return expression + return expression # The `logical-expression`s here means that ?: can't be chained without # parentheses. `x < 0 ? -1 : (x == 0 ? 0 : 1)` is OK, but `x < 0 ? -1 : x == 0 # ? 0 : 1` is not. Parentheses are also needed in the middle: `x <= 0 ? x < 0 ? # -1 : 0 : 1` is not syntactically valid. -@_handles('choice-expression -> logical-expression "?" logical-expression' - ' ":" logical-expression') +@_handles( + 'choice-expression -> logical-expression "?" logical-expression' + ' ":" logical-expression' +) def _choice_expression(condition, question, if_true, colon, if_false): - location = parser_types.make_location( - condition.source_location.start, if_false.source_location.end) - operator_location = parser_types.make_location( - question.source_location.start, colon.source_location.end) - # The function_name is a bit weird, but should suffice for any error messages - # that might need it. - return ir_data.Expression( - function=ir_data.Function(function=ir_data.FunctionMapping.CHOICE, - args=[condition, if_true, if_false], - function_name=ir_data.Word( - text='?:', - source_location=operator_location), - source_location=location)) - - -@_handles('comparison-expression -> additive-expression') -def _no_op_comparative_expression(expression): - return expression - - -@_handles('comparison-expression ->' - ' additive-expression inequality-operator additive-expression') -def _comparative_expression(left, operator, right): - location = parser_types.make_location( - left.source_location.start, right.source_location.end) - return ir_data.Expression( - function=ir_data.Function(function=_text_to_operator(operator.text), - args=[left, right], - function_name=operator, - source_location=location)) - - -@_handles('additive-expression -> times-expression additive-expression-right*') -@_handles('times-expression -> negation-expression times-expression-right*') -@_handles('and-expression -> comparison-expression and-expression-right+') -@_handles('or-expression -> comparison-expression or-expression-right+') -def _binary_operator_expression(expression, expression_right): - """Builds the IR for a chain of equal-precedence left-associative operations. - - _binary_operator_expression transforms a right-recursive list of expression - tails into a left-associative Expression tree. For example, given the - arguments: - - 6, (Tail("+", 7), Tail("-", 8), Tail("+", 10)) - - _expression produces a structure like: - - Expression(Expression(Expression(6, "+", 7), "-", 8), "+", 10) - - This transformation is necessary because strict LR(1) grammars do not allow - left recursion. - - Note that this method is used for several productions; each of those - productions handles a different precedence level, but are identical in form. - - Arguments: - expression: An ir_data.Expression which is the head of the (expr, operator, - expr, operator, expr, ...) list. - expression_right: A list of _ExpressionTails corresponding to the (operator, - expr, operator, expr, ...) list that comes after expression. - - Returns: - An ir_data.Expression with the correct recursive structure to represent a - list of left-associative operations. - """ - e = expression - for right in expression_right.list: location = parser_types.make_location( - e.source_location.start, right.source_location.end) - e = ir_data.Expression( + condition.source_location.start, if_false.source_location.end + ) + operator_location = parser_types.make_location( + question.source_location.start, colon.source_location.end + ) + # The function_name is a bit weird, but should suffice for any error messages + # that might need it. + return ir_data.Expression( function=ir_data.Function( - function=_text_to_operator(right.operator.text), - args=[e, right.expression], - function_name=right.operator, - source_location=location), - source_location=location) - return e - - -@_handles('comparison-expression ->' - ' additive-expression equality-expression-right+') -@_handles('comparison-expression ->' - ' additive-expression less-expression-right-list') -@_handles('comparison-expression ->' - ' additive-expression greater-expression-right-list') -def _chained_comparison_expression(expression, expression_right): - """Builds the IR for a chain of comparisons, like a == b == c. - - Like _binary_operator_expression, _chained_comparison_expression transforms a - right-recursive list of expression tails into a left-associative Expression - tree. Unlike _binary_operator_expression, extra AND nodes are added. For - example, the following expression: + function=ir_data.FunctionMapping.CHOICE, + args=[condition, if_true, if_false], + function_name=ir_data.Word(text="?:", source_location=operator_location), + source_location=location, + ) + ) - 0 <= b <= 64 - must be translated to the conceptually-equivalent expression: - - 0 <= b && b <= 64 - - (The middle subexpression is duplicated -- this would be a problem in a - programming language like C where expressions like `x++` have side effects, - but side effects do not make sense in a data definition language like Emboss.) - - _chained_comparison_expression receives a left-hand head expression and a list - of tails, like: - - 6, (Tail("<=", b), Tail("<=", 64)) - - which it translates to a structure like: - - Expression(Expression(6, "<=", b), "&&", Expression(b, "<=", 64)) - - The Emboss grammar is constructed such that sequences of "<", "<=", and "==" - comparisons may be chained, and sequences of ">", ">=", and "==" can be - chained, but greater and less-than comparisons may not; e.g., "b < 64 > a" is - not allowed. +@_handles("comparison-expression -> additive-expression") +def _no_op_comparative_expression(expression): + return expression - Arguments: - expression: An ir_data.Expression which is the head of the (expr, operator, - expr, operator, expr, ...) list. - expression_right: A list of _ExpressionTails corresponding to the (operator, - expr, operator, expr, ...) list that comes after expression. - Returns: - An ir_data.Expression with the correct recursive structure to represent a - chain of left-associative comparison operations. - """ - sequence = [expression] - for right in expression_right.list: - sequence.append(right.operator) - sequence.append(right.expression) - comparisons = [] - for i in range(0, len(sequence) - 1, 2): - left, operator, right = sequence[i:i+3] +@_handles( + "comparison-expression ->" + " additive-expression inequality-operator additive-expression" +) +def _comparative_expression(left, operator, right): location = parser_types.make_location( - left.source_location.start, right.source_location.end) - comparisons.append(ir_data.Expression( + left.source_location.start, right.source_location.end + ) + return ir_data.Expression( function=ir_data.Function( function=_text_to_operator(operator.text), args=[left, right], function_name=operator, - source_location=location), - source_location=location)) - e = comparisons[0] - for comparison in comparisons[1:]: - location = parser_types.make_location( - e.source_location.start, comparison.source_location.end) - e = ir_data.Expression( - function=ir_data.Function( - function=ir_data.FunctionMapping.AND, - args=[e, comparison], - function_name=ir_data.Word( - text='&&', - source_location=comparison.function.args[0].source_location), - source_location=location), - source_location=location) - return e + source_location=location, + ) + ) + + +@_handles("additive-expression -> times-expression additive-expression-right*") +@_handles("times-expression -> negation-expression times-expression-right*") +@_handles("and-expression -> comparison-expression and-expression-right+") +@_handles("or-expression -> comparison-expression or-expression-right+") +def _binary_operator_expression(expression, expression_right): + """Builds the IR for a chain of equal-precedence left-associative operations. + + _binary_operator_expression transforms a right-recursive list of expression + tails into a left-associative Expression tree. For example, given the + arguments: + + 6, (Tail("+", 7), Tail("-", 8), Tail("+", 10)) + + _expression produces a structure like: + + Expression(Expression(Expression(6, "+", 7), "-", 8), "+", 10) + + This transformation is necessary because strict LR(1) grammars do not allow + left recursion. + + Note that this method is used for several productions; each of those + productions handles a different precedence level, but are identical in form. + + Arguments: + expression: An ir_data.Expression which is the head of the (expr, operator, + expr, operator, expr, ...) list. + expression_right: A list of _ExpressionTails corresponding to the (operator, + expr, operator, expr, ...) list that comes after expression. + + Returns: + An ir_data.Expression with the correct recursive structure to represent a + list of left-associative operations. + """ + e = expression + for right in expression_right.list: + location = parser_types.make_location( + e.source_location.start, right.source_location.end + ) + e = ir_data.Expression( + function=ir_data.Function( + function=_text_to_operator(right.operator.text), + args=[e, right.expression], + function_name=right.operator, + source_location=location, + ), + source_location=location, + ) + return e + + +@_handles( + "comparison-expression ->" " additive-expression equality-expression-right+" +) +@_handles( + "comparison-expression ->" " additive-expression less-expression-right-list" +) +@_handles( + "comparison-expression ->" " additive-expression greater-expression-right-list" +) +def _chained_comparison_expression(expression, expression_right): + """Builds the IR for a chain of comparisons, like a == b == c. + + Like _binary_operator_expression, _chained_comparison_expression transforms a + right-recursive list of expression tails into a left-associative Expression + tree. Unlike _binary_operator_expression, extra AND nodes are added. For + example, the following expression: + + 0 <= b <= 64 + + must be translated to the conceptually-equivalent expression: + + 0 <= b && b <= 64 + + (The middle subexpression is duplicated -- this would be a problem in a + programming language like C where expressions like `x++` have side effects, + but side effects do not make sense in a data definition language like Emboss.) + + _chained_comparison_expression receives a left-hand head expression and a list + of tails, like: + + 6, (Tail("<=", b), Tail("<=", 64)) + + which it translates to a structure like: + + Expression(Expression(6, "<=", b), "&&", Expression(b, "<=", 64)) + + The Emboss grammar is constructed such that sequences of "<", "<=", and "==" + comparisons may be chained, and sequences of ">", ">=", and "==" can be + chained, but greater and less-than comparisons may not; e.g., "b < 64 > a" is + not allowed. + + Arguments: + expression: An ir_data.Expression which is the head of the (expr, operator, + expr, operator, expr, ...) list. + expression_right: A list of _ExpressionTails corresponding to the (operator, + expr, operator, expr, ...) list that comes after expression. + + Returns: + An ir_data.Expression with the correct recursive structure to represent a + chain of left-associative comparison operations. + """ + sequence = [expression] + for right in expression_right.list: + sequence.append(right.operator) + sequence.append(right.expression) + comparisons = [] + for i in range(0, len(sequence) - 1, 2): + left, operator, right = sequence[i : i + 3] + location = parser_types.make_location( + left.source_location.start, right.source_location.end + ) + comparisons.append( + ir_data.Expression( + function=ir_data.Function( + function=_text_to_operator(operator.text), + args=[left, right], + function_name=operator, + source_location=location, + ), + source_location=location, + ) + ) + e = comparisons[0] + for comparison in comparisons[1:]: + location = parser_types.make_location( + e.source_location.start, comparison.source_location.end + ) + e = ir_data.Expression( + function=ir_data.Function( + function=ir_data.FunctionMapping.AND, + args=[e, comparison], + function_name=ir_data.Word( + text="&&", + source_location=comparison.function.args[0].source_location, + ), + source_location=location, + ), + source_location=location, + ) + return e # _chained_comparison_expression, above, handles three types of chains: `a == b @@ -629,279 +678,315 @@ def _chained_comparison_expression(expression, expression_right): # # By using `equality-expression-right*` for the first symbol, only the first # parse is possible. -@_handles('greater-expression-right-list ->' - ' equality-expression-right* greater-expression-right' - ' equality-or-greater-expression-right*') -@_handles('less-expression-right-list ->' - ' equality-expression-right* less-expression-right' - ' equality-or-less-expression-right*') +@_handles( + "greater-expression-right-list ->" + " equality-expression-right* greater-expression-right" + " equality-or-greater-expression-right*" +) +@_handles( + "less-expression-right-list ->" + " equality-expression-right* less-expression-right" + " equality-or-less-expression-right*" +) def _chained_comparison_tails(start, middle, end): - return _List(start.list + [middle] + end.list) + return _List(start.list + [middle] + end.list) -@_handles('equality-or-greater-expression-right -> equality-expression-right') -@_handles('equality-or-greater-expression-right -> greater-expression-right') -@_handles('equality-or-less-expression-right -> equality-expression-right') -@_handles('equality-or-less-expression-right -> less-expression-right') +@_handles("equality-or-greater-expression-right -> equality-expression-right") +@_handles("equality-or-greater-expression-right -> greater-expression-right") +@_handles("equality-or-less-expression-right -> equality-expression-right") +@_handles("equality-or-less-expression-right -> less-expression-right") def _equality_or_less_or_greater(right): - return right + return right -@_handles('and-expression-right -> and-operator comparison-expression') -@_handles('or-expression-right -> or-operator comparison-expression') -@_handles('additive-expression-right -> additive-operator times-expression') -@_handles('equality-expression-right -> equality-operator additive-expression') -@_handles('greater-expression-right -> greater-operator additive-expression') -@_handles('less-expression-right -> less-operator additive-expression') -@_handles('times-expression-right ->' - ' multiplicative-operator negation-expression') +@_handles("and-expression-right -> and-operator comparison-expression") +@_handles("or-expression-right -> or-operator comparison-expression") +@_handles("additive-expression-right -> additive-operator times-expression") +@_handles("equality-expression-right -> equality-operator additive-expression") +@_handles("greater-expression-right -> greater-operator additive-expression") +@_handles("less-expression-right -> less-operator additive-expression") +@_handles("times-expression-right ->" " multiplicative-operator negation-expression") def _expression_right_production(operator, expression): - return _ExpressionTail(operator, expression) + return _ExpressionTail(operator, expression) # This supports a single layer of unary plus/minus, so "+5" and "-value" are # allowed, but "+-5" or "-+-something" are not. -@_handles('negation-expression -> additive-operator bottom-expression') +@_handles("negation-expression -> additive-operator bottom-expression") def _negation_expression_with_operator(operator, expression): - phantom_zero_location = ir_data.Location(start=operator.source_location.start, - end=operator.source_location.start) - return ir_data.Expression( - function=ir_data.Function( - function=_text_to_operator(operator.text), - args=[ir_data.Expression( - constant=ir_data.NumericConstant( - value='0', - source_location=phantom_zero_location), - source_location=phantom_zero_location), expression], - function_name=operator, - source_location=ir_data.Location( - start=operator.source_location.start, - end=expression.source_location.end))) - - -@_handles('negation-expression -> bottom-expression') + phantom_zero_location = ir_data.Location( + start=operator.source_location.start, end=operator.source_location.start + ) + return ir_data.Expression( + function=ir_data.Function( + function=_text_to_operator(operator.text), + args=[ + ir_data.Expression( + constant=ir_data.NumericConstant( + value="0", source_location=phantom_zero_location + ), + source_location=phantom_zero_location, + ), + expression, + ], + function_name=operator, + source_location=ir_data.Location( + start=operator.source_location.start, end=expression.source_location.end + ), + ) + ) + + +@_handles("negation-expression -> bottom-expression") def _negation_expression(expression): - return expression + return expression @_handles('bottom-expression -> "(" expression ")"') def _bottom_expression_parentheses(open_paren, expression, close_paren): - del open_paren, close_paren # Unused. - return expression + del open_paren, close_paren # Unused. + return expression @_handles('bottom-expression -> function-name "(" argument-list ")"') def _bottom_expression_function(function, open_paren, arguments, close_paren): - del open_paren # Unused. - return ir_data.Expression( - function=ir_data.Function( - function=_text_to_function(function.text), - args=arguments.list, - function_name=function, - source_location=ir_data.Location( - start=function.source_location.start, - end=close_paren.source_location.end))) + del open_paren # Unused. + return ir_data.Expression( + function=ir_data.Function( + function=_text_to_function(function.text), + args=arguments.list, + function_name=function, + source_location=ir_data.Location( + start=function.source_location.start, + end=close_paren.source_location.end, + ), + ) + ) @_handles('comma-then-expression -> "," expression') def _comma_then_expression(comma, expression): - del comma # Unused. - return expression + del comma # Unused. + return expression -@_handles('argument-list -> expression comma-then-expression*') +@_handles("argument-list -> expression comma-then-expression*") def _argument_list(head, tail): - tail.list.insert(0, head) - return tail + tail.list.insert(0, head) + return tail -@_handles('argument-list ->') +@_handles("argument-list ->") def _empty_argument_list(): - return _List([]) + return _List([]) -@_handles('bottom-expression -> numeric-constant') +@_handles("bottom-expression -> numeric-constant") def _bottom_expression_from_numeric_constant(constant): - return ir_data.Expression(constant=constant) + return ir_data.Expression(constant=constant) -@_handles('bottom-expression -> constant-reference') +@_handles("bottom-expression -> constant-reference") def _bottom_expression_from_constant_reference(reference): - return ir_data.Expression(constant_reference=reference) + return ir_data.Expression(constant_reference=reference) -@_handles('bottom-expression -> builtin-reference') +@_handles("bottom-expression -> builtin-reference") def _bottom_expression_from_builtin(reference): - return ir_data.Expression(builtin_reference=reference) + return ir_data.Expression(builtin_reference=reference) -@_handles('bottom-expression -> boolean-constant') +@_handles("bottom-expression -> boolean-constant") def _bottom_expression_from_boolean_constant(boolean): - return ir_data.Expression(boolean_constant=boolean) + return ir_data.Expression(boolean_constant=boolean) -@_handles('bottom-expression -> field-reference') +@_handles("bottom-expression -> field-reference") def _bottom_expression_from_reference(reference): - return reference + return reference -@_handles('field-reference -> snake-reference field-reference-tail*') +@_handles("field-reference -> snake-reference field-reference-tail*") def _indirect_field_reference(field_reference, field_references): - if field_references.source_location.HasField('end'): - end_location = field_references.source_location.end - else: - end_location = field_reference.source_location.end - return ir_data.Expression(field_reference=ir_data.FieldReference( - path=[field_reference] + field_references.list, - source_location=parser_types.make_location( - field_reference.source_location.start, end_location))) + if field_references.source_location.HasField("end"): + end_location = field_references.source_location.end + else: + end_location = field_reference.source_location.end + return ir_data.Expression( + field_reference=ir_data.FieldReference( + path=[field_reference] + field_references.list, + source_location=parser_types.make_location( + field_reference.source_location.start, end_location + ), + ) + ) # If "Type.field" ever becomes syntactically valid, it will be necessary to # check that enum values are compile-time constants. @_handles('field-reference-tail -> "." snake-reference') def _field_reference_tail(dot, reference): - del dot # Unused. - return reference + del dot # Unused. + return reference -@_handles('numeric-constant -> Number') +@_handles("numeric-constant -> Number") def _numeric_constant(number): - # All types of numeric constant tokenize to the same symbol, because they are - # interchangeable in source code. - if number.text[0:2] == '0b': - n = int(number.text.replace('_', '')[2:], 2) - elif number.text[0:2] == '0x': - n = int(number.text.replace('_', '')[2:], 16) - else: - n = int(number.text.replace('_', ''), 10) - return ir_data.NumericConstant(value=str(n)) - - -@_handles('type-definition -> struct') -@_handles('type-definition -> bits') -@_handles('type-definition -> enum') -@_handles('type-definition -> external') + # All types of numeric constant tokenize to the same symbol, because they are + # interchangeable in source code. + if number.text[0:2] == "0b": + n = int(number.text.replace("_", "")[2:], 2) + elif number.text[0:2] == "0x": + n = int(number.text.replace("_", "")[2:], 16) + else: + n = int(number.text.replace("_", ""), 10) + return ir_data.NumericConstant(value=str(n)) + + +@_handles("type-definition -> struct") +@_handles("type-definition -> bits") +@_handles("type-definition -> enum") +@_handles("type-definition -> external") def _type_definition(type_definition): - return type_definition + return type_definition # struct StructureName: # ... fields ... # bits BitName: # ... fields ... -@_handles('struct -> "struct" type-name delimited-parameter-definition-list?' - ' ":" Comment? eol struct-body') -@_handles('bits -> "bits" type-name delimited-parameter-definition-list? ":"' - ' Comment? eol bits-body') +@_handles( + 'struct -> "struct" type-name delimited-parameter-definition-list?' + ' ":" Comment? eol struct-body' +) +@_handles( + 'bits -> "bits" type-name delimited-parameter-definition-list? ":"' + " Comment? eol bits-body" +) def _structure(struct, name, parameters, colon, comment, newline, struct_body): - """Composes the top-level IR for an Emboss structure.""" - del colon, comment, newline # Unused. - ir_data_utils.builder(struct_body.structure).source_location.start.CopyFrom( - struct.source_location.start) - ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom( - struct_body.source_location.end) - if struct_body.name: - ir_data_utils.update(struct_body.name, name) - else: - struct_body.name = ir_data_utils.copy(name) - if parameters.list: - struct_body.runtime_parameter.extend(parameters.list[0].list) - return struct_body - - -@_handles('delimited-parameter-definition-list ->' - ' "(" parameter-definition-list ")"') + """Composes the top-level IR for an Emboss structure.""" + del colon, comment, newline # Unused. + ir_data_utils.builder(struct_body.structure).source_location.start.CopyFrom( + struct.source_location.start + ) + ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom( + struct_body.source_location.end + ) + if struct_body.name: + ir_data_utils.update(struct_body.name, name) + else: + struct_body.name = ir_data_utils.copy(name) + if parameters.list: + struct_body.runtime_parameter.extend(parameters.list[0].list) + return struct_body + + +@_handles( + "delimited-parameter-definition-list ->" ' "(" parameter-definition-list ")"' +) def _delimited_parameter_definition_list(open_paren, parameters, close_paren): - del open_paren, close_paren # Unused - return parameters + del open_paren, close_paren # Unused + return parameters @_handles('parameter-definition -> snake-name ":" type') def _parameter_definition(name, double_colon, parameter_type): - del double_colon # Unused - return ir_data.RuntimeParameter(name=name, physical_type_alias=parameter_type) + del double_colon # Unused + return ir_data.RuntimeParameter(name=name, physical_type_alias=parameter_type) @_handles('parameter-definition-list-tail -> "," parameter-definition') def _parameter_definition_list_tail(comma, parameter): - del comma # Unused. - return parameter + del comma # Unused. + return parameter -@_handles('parameter-definition-list -> parameter-definition' - ' parameter-definition-list-tail*') +@_handles( + "parameter-definition-list -> parameter-definition" + " parameter-definition-list-tail*" +) def _parameter_definition_list(head, tail): - tail.list.insert(0, head) - return tail + tail.list.insert(0, head) + return tail -@_handles('parameter-definition-list ->') +@_handles("parameter-definition-list ->") def _empty_parameter_definition_list(): - return _List([]) + return _List([]) # The body of a struct: basically, the part after the first line. -@_handles('struct-body -> Indent doc-line* attribute-line*' - ' type-definition* struct-field-block Dedent') +@_handles( + "struct-body -> Indent doc-line* attribute-line*" + " type-definition* struct-field-block Dedent" +) def _struct_body(indent, docs, attributes, types, fields, dedent): - del indent, dedent # Unused. - return _structure_body(docs, attributes, types, fields, - ir_data.AddressableUnit.BYTE) + del indent, dedent # Unused. + return _structure_body( + docs, attributes, types, fields, ir_data.AddressableUnit.BYTE + ) def _structure_body(docs, attributes, types, fields, addressable_unit): - """Constructs the body of a structure (bits or struct) definition.""" - return ir_data.TypeDefinition( - structure=ir_data.Structure(field=[field.field for field in fields.list]), - documentation=docs.list, - attribute=attributes.list, - subtype=types.list + [subtype for field in fields.list for subtype in - field.subtypes], - addressable_unit=addressable_unit) - - -@_handles('struct-field-block ->') -@_handles('bits-field-block ->') -@_handles('anonymous-bits-field-block ->') + """Constructs the body of a structure (bits or struct) definition.""" + return ir_data.TypeDefinition( + structure=ir_data.Structure(field=[field.field for field in fields.list]), + documentation=docs.list, + attribute=attributes.list, + subtype=types.list + + [subtype for field in fields.list for subtype in field.subtypes], + addressable_unit=addressable_unit, + ) + + +@_handles("struct-field-block ->") +@_handles("bits-field-block ->") +@_handles("anonymous-bits-field-block ->") def _empty_field_block(): - return _List([]) + return _List([]) -@_handles('struct-field-block ->' - ' conditional-struct-field-block struct-field-block') -@_handles('bits-field-block ->' - ' conditional-bits-field-block bits-field-block') -@_handles('anonymous-bits-field-block -> conditional-anonymous-bits-field-block' - ' anonymous-bits-field-block') +@_handles( + "struct-field-block ->" " conditional-struct-field-block struct-field-block" +) +@_handles("bits-field-block ->" " conditional-bits-field-block bits-field-block") +@_handles( + "anonymous-bits-field-block -> conditional-anonymous-bits-field-block" + " anonymous-bits-field-block" +) def _conditional_block_plus_field_block(conditional_block, block): - return _List(conditional_block.list + block.list) + return _List(conditional_block.list + block.list) -@_handles('struct-field-block ->' - ' unconditional-struct-field struct-field-block') -@_handles('bits-field-block ->' - ' unconditional-bits-field bits-field-block') -@_handles('anonymous-bits-field-block ->' - ' unconditional-anonymous-bits-field anonymous-bits-field-block') +@_handles("struct-field-block ->" " unconditional-struct-field struct-field-block") +@_handles("bits-field-block ->" " unconditional-bits-field bits-field-block") +@_handles( + "anonymous-bits-field-block ->" + " unconditional-anonymous-bits-field anonymous-bits-field-block" +) def _unconditional_block_plus_field_block(field, block): - """Prepends an unconditional field to block.""" - ir_data_utils.builder(field.field).existence_condition.source_location.CopyFrom( - field.source_location) - ir_data_utils.builder(field.field).existence_condition.boolean_constant.source_location.CopyFrom( - field.source_location) - ir_data_utils.builder(field.field).existence_condition.boolean_constant.value = True - return _List([field] + block.list) + """Prepends an unconditional field to block.""" + ir_data_utils.builder(field.field).existence_condition.source_location.CopyFrom( + field.source_location + ) + ir_data_utils.builder( + field.field + ).existence_condition.boolean_constant.source_location.CopyFrom( + field.source_location + ) + ir_data_utils.builder(field.field).existence_condition.boolean_constant.value = True + return _List([field] + block.list) # Struct "fields" are regular fields, inline enums, bits, or structs, anonymous # inline bits, or virtual fields. -@_handles('unconditional-struct-field -> field') -@_handles('unconditional-struct-field -> inline-enum-field-definition') -@_handles('unconditional-struct-field -> inline-bits-field-definition') -@_handles('unconditional-struct-field -> inline-struct-field-definition') -@_handles('unconditional-struct-field -> anonymous-bits-field-definition') -@_handles('unconditional-struct-field -> virtual-field') +@_handles("unconditional-struct-field -> field") +@_handles("unconditional-struct-field -> inline-enum-field-definition") +@_handles("unconditional-struct-field -> inline-bits-field-definition") +@_handles("unconditional-struct-field -> inline-struct-field-definition") +@_handles("unconditional-struct-field -> anonymous-bits-field-definition") +@_handles("unconditional-struct-field -> virtual-field") # Bits fields are "regular" fields, inline enums or bits, or virtual fields. # # Inline structs and anonymous inline bits are not allowed inside of bits: @@ -910,55 +995,66 @@ def _unconditional_block_plus_field_block(field, block): # # Anonymous inline bits may not include virtual fields; instead, the virtual # field should be a direct part of the enclosing structure. -@_handles('unconditional-anonymous-bits-field -> field') -@_handles('unconditional-anonymous-bits-field -> inline-enum-field-definition') -@_handles('unconditional-anonymous-bits-field -> inline-bits-field-definition') -@_handles('unconditional-bits-field -> unconditional-anonymous-bits-field') -@_handles('unconditional-bits-field -> virtual-field') +@_handles("unconditional-anonymous-bits-field -> field") +@_handles("unconditional-anonymous-bits-field -> inline-enum-field-definition") +@_handles("unconditional-anonymous-bits-field -> inline-bits-field-definition") +@_handles("unconditional-bits-field -> unconditional-anonymous-bits-field") +@_handles("unconditional-bits-field -> virtual-field") def _unconditional_field(field): - """Handles the unifying grammar production for a struct or bits field.""" - return field + """Handles the unifying grammar production for a struct or bits field.""" + return field # TODO(bolms): Add 'elif' and 'else' support. # TODO(bolms): Should nested 'if' blocks be allowed? -@_handles('conditional-struct-field-block ->' - ' "if" expression ":" Comment? eol' - ' Indent unconditional-struct-field+ Dedent') -@_handles('conditional-bits-field-block ->' - ' "if" expression ":" Comment? eol' - ' Indent unconditional-bits-field+ Dedent') -@_handles('conditional-anonymous-bits-field-block ->' - ' "if" expression ":" Comment? eol' - ' Indent unconditional-anonymous-bits-field+ Dedent') -def _conditional_field_block(if_keyword, expression, colon, comment, newline, - indent, fields, dedent): - """Applies an existence_condition to each element of fields.""" - del if_keyword, newline, colon, comment, indent, dedent # Unused. - for field in fields.list: - condition = ir_data_utils.builder(field.field).existence_condition - condition.CopyFrom(expression) - condition.source_location.is_disjoint_from_parent = True - return fields +@_handles( + "conditional-struct-field-block ->" + ' "if" expression ":" Comment? eol' + " Indent unconditional-struct-field+ Dedent" +) +@_handles( + "conditional-bits-field-block ->" + ' "if" expression ":" Comment? eol' + " Indent unconditional-bits-field+ Dedent" +) +@_handles( + "conditional-anonymous-bits-field-block ->" + ' "if" expression ":" Comment? eol' + " Indent unconditional-anonymous-bits-field+ Dedent" +) +def _conditional_field_block( + if_keyword, expression, colon, comment, newline, indent, fields, dedent +): + """Applies an existence_condition to each element of fields.""" + del if_keyword, newline, colon, comment, indent, dedent # Unused. + for field in fields.list: + condition = ir_data_utils.builder(field.field).existence_condition + condition.CopyFrom(expression) + condition.source_location.is_disjoint_from_parent = True + return fields # The body of a bit field definition: basically, the part after the first line. -@_handles('bits-body -> Indent doc-line* attribute-line*' - ' type-definition* bits-field-block Dedent') +@_handles( + "bits-body -> Indent doc-line* attribute-line*" + " type-definition* bits-field-block Dedent" +) def _bits_body(indent, docs, attributes, types, fields, dedent): - del indent, dedent # Unused. - return _structure_body(docs, attributes, types, fields, - ir_data.AddressableUnit.BIT) + del indent, dedent # Unused. + return _structure_body(docs, attributes, types, fields, ir_data.AddressableUnit.BIT) # Inline bits (defined as part of a field) are more restricted than standalone # bits. -@_handles('anonymous-bits-body ->' - ' Indent attribute-line* anonymous-bits-field-block Dedent') +@_handles( + "anonymous-bits-body ->" + " Indent attribute-line* anonymous-bits-field-block Dedent" +) def _anonymous_bits_body(indent, attributes, fields, dedent): - del indent, dedent # Unused. - return _structure_body(_List([]), attributes, _List([]), fields, - ir_data.AddressableUnit.BIT) + del indent, dedent # Unused. + return _structure_body( + _List([]), attributes, _List([]), fields, ir_data.AddressableUnit.BIT + ) # A field is: @@ -967,30 +1063,43 @@ def _anonymous_bits_body(indent, attributes, fields, dedent): # -- doc # [attr3: value] # [attr4: value] -@_handles('field ->' - ' field-location type snake-name abbreviation? attribute* doc?' - ' Comment? eol field-body?') -def _field(location, field_type, name, abbreviation, attributes, doc, comment, - newline, field_body): - """Constructs an ir_data.Field from the given components.""" - del comment # Unused - field_ir = ir_data.Field(location=location, - type=field_type, - name=name, - attribute=attributes.list, - documentation=doc.list) - field = ir_data_utils.builder(field_ir) - if field_body.list: - field.attribute.extend(field_body.list[0].attribute) - field.documentation.extend(field_body.list[0].documentation) - if abbreviation.list: - field.abbreviation.CopyFrom(abbreviation.list[0]) - field.source_location.start.CopyFrom(location.source_location.start) - if field_body.source_location.HasField('end'): - field.source_location.end.CopyFrom(field_body.source_location.end) - else: - field.source_location.end.CopyFrom(newline.source_location.end) - return _FieldWithType(field=field_ir) +@_handles( + "field ->" + " field-location type snake-name abbreviation? attribute* doc?" + " Comment? eol field-body?" +) +def _field( + location, + field_type, + name, + abbreviation, + attributes, + doc, + comment, + newline, + field_body, +): + """Constructs an ir_data.Field from the given components.""" + del comment # Unused + field_ir = ir_data.Field( + location=location, + type=field_type, + name=name, + attribute=attributes.list, + documentation=doc.list, + ) + field = ir_data_utils.builder(field_ir) + if field_body.list: + field.attribute.extend(field_body.list[0].attribute) + field.documentation.extend(field_body.list[0].documentation) + if abbreviation.list: + field.abbreviation.CopyFrom(abbreviation.list[0]) + field.source_location.start.CopyFrom(location.source_location.start) + if field_body.source_location.HasField("end"): + field.source_location.end.CopyFrom(field_body.source_location.end) + else: + field.source_location.end.CopyFrom(newline.source_location.end) + return _FieldWithType(field=field_ir) # A "virtual field" is: @@ -999,22 +1108,23 @@ def _field(location, field_type, name, abbreviation, attributes, doc, comment, # -- doc # [attr1: value] # [attr2: value] -@_handles('virtual-field ->' - ' "let" snake-name "=" expression Comment? eol field-body?') +@_handles( + "virtual-field ->" ' "let" snake-name "=" expression Comment? eol field-body?' +) def _virtual_field(let, name, equals, value, comment, newline, field_body): - """Constructs an ir_data.Field from the given components.""" - del equals, comment # Unused - field_ir = ir_data.Field(read_transform=value, name=name) - field = ir_data_utils.builder(field_ir) - if field_body.list: - field.attribute.extend(field_body.list[0].attribute) - field.documentation.extend(field_body.list[0].documentation) - field.source_location.start.CopyFrom(let.source_location.start) - if field_body.source_location.HasField('end'): - field.source_location.end.CopyFrom(field_body.source_location.end) - else: - field.source_location.end.CopyFrom(newline.source_location.end) - return _FieldWithType(field=field_ir) + """Constructs an ir_data.Field from the given components.""" + del equals, comment # Unused + field_ir = ir_data.Field(read_transform=value, name=name) + field = ir_data_utils.builder(field_ir) + if field_body.list: + field.attribute.extend(field_body.list[0].attribute) + field.documentation.extend(field_body.list[0].documentation) + field.source_location.start.CopyFrom(let.source_location.start) + if field_body.source_location.HasField("end"): + field.source_location.end.CopyFrom(field_body.source_location.end) + else: + field.source_location.end.CopyFrom(newline.source_location.end) + return _FieldWithType(field=field_ir) # An inline enum is: @@ -1025,252 +1135,292 @@ def _virtual_field(let, name, equals, value, comment, newline, field_body): # [attr4: value] # NAME = 10 # NAME2 = 20 -@_handles('inline-enum-field-definition ->' - ' field-location "enum" snake-name abbreviation? ":" Comment? eol' - ' enum-body') -def _inline_enum_field(location, enum, name, abbreviation, colon, comment, - newline, enum_body): - """Constructs an ir_data.Field for an inline enum field.""" - del enum, colon, comment, newline # Unused. - return _inline_type_field(location, name, abbreviation, enum_body) +@_handles( + "inline-enum-field-definition ->" + ' field-location "enum" snake-name abbreviation? ":" Comment? eol' + " enum-body" +) +def _inline_enum_field( + location, enum, name, abbreviation, colon, comment, newline, enum_body +): + """Constructs an ir_data.Field for an inline enum field.""" + del enum, colon, comment, newline # Unused. + return _inline_type_field(location, name, abbreviation, enum_body) @_handles( - 'inline-struct-field-definition ->' + "inline-struct-field-definition ->" ' field-location "struct" snake-name abbreviation? ":" Comment? eol' - ' struct-body') -def _inline_struct_field(location, struct, name, abbreviation, colon, comment, - newline, struct_body): - del struct, colon, comment, newline # Unused. - return _inline_type_field(location, name, abbreviation, struct_body) + " struct-body" +) +def _inline_struct_field( + location, struct, name, abbreviation, colon, comment, newline, struct_body +): + del struct, colon, comment, newline # Unused. + return _inline_type_field(location, name, abbreviation, struct_body) -@_handles('inline-bits-field-definition ->' - ' field-location "bits" snake-name abbreviation? ":" Comment? eol' - ' bits-body') -def _inline_bits_field(location, bits, name, abbreviation, colon, comment, - newline, bits_body): - del bits, colon, comment, newline # Unused. - return _inline_type_field(location, name, abbreviation, bits_body) +@_handles( + "inline-bits-field-definition ->" + ' field-location "bits" snake-name abbreviation? ":" Comment? eol' + " bits-body" +) +def _inline_bits_field( + location, bits, name, abbreviation, colon, comment, newline, bits_body +): + del bits, colon, comment, newline # Unused. + return _inline_type_field(location, name, abbreviation, bits_body) def _inline_type_field(location, name, abbreviation, body): - """Shared implementation of _inline_enum_field and _anonymous_bit_field.""" - field_ir = ir_data.Field(location=location, - name=name, - attribute=body.attribute, - documentation=body.documentation) - field = ir_data_utils.builder(field_ir) - # All attributes should be attached to the field, not the type definition: if - # the user wants to use type attributes, they should create a separate type - # definition and reference it. - del body.attribute[:] - type_name = ir_data_utils.copy(name) - ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(type_name.name.text) - field.type.atomic_type.reference.source_name.extend([type_name.name]) - field.type.atomic_type.reference.source_location.CopyFrom( - type_name.source_location) - field.type.atomic_type.reference.is_local_name = True - field.type.atomic_type.source_location.CopyFrom(type_name.source_location) - field.type.source_location.CopyFrom(type_name.source_location) - if abbreviation.list: - field.abbreviation.CopyFrom(abbreviation.list[0]) - field.source_location.start.CopyFrom(location.source_location.start) - ir_data_utils.builder(body.source_location).start.CopyFrom(location.source_location.start) - if body.HasField('enumeration'): - ir_data_utils.builder(body.enumeration).source_location.CopyFrom(body.source_location) - else: - assert body.HasField('structure') - ir_data_utils.builder(body.structure).source_location.CopyFrom(body.source_location) - ir_data_utils.builder(body).name.CopyFrom(type_name) - field.source_location.end.CopyFrom(body.source_location.end) - subtypes = [body] + list(body.subtype) - del body.subtype[:] - return _FieldWithType(field=field_ir, subtypes=subtypes) - - -@_handles('anonymous-bits-field-definition ->' - ' field-location "bits" ":" Comment? eol anonymous-bits-body') -def _anonymous_bit_field(location, bits_keyword, colon, comment, newline, - bits_body): - """Constructs an ir_data.Field for an anonymous bit field.""" - del colon, comment, newline # Unused. - name = ir_data.NameDefinition( - name=ir_data.Word( - text=_get_anonymous_field_name(), - source_location=bits_keyword.source_location), - source_location=bits_keyword.source_location, - is_anonymous=True) - return _inline_type_field(location, name, _List([]), bits_body) - - -@_handles('field-body -> Indent doc-line* attribute-line* Dedent') + """Shared implementation of _inline_enum_field and _anonymous_bit_field.""" + field_ir = ir_data.Field( + location=location, + name=name, + attribute=body.attribute, + documentation=body.documentation, + ) + field = ir_data_utils.builder(field_ir) + # All attributes should be attached to the field, not the type definition: if + # the user wants to use type attributes, they should create a separate type + # definition and reference it. + del body.attribute[:] + type_name = ir_data_utils.copy(name) + ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel( + type_name.name.text + ) + field.type.atomic_type.reference.source_name.extend([type_name.name]) + field.type.atomic_type.reference.source_location.CopyFrom(type_name.source_location) + field.type.atomic_type.reference.is_local_name = True + field.type.atomic_type.source_location.CopyFrom(type_name.source_location) + field.type.source_location.CopyFrom(type_name.source_location) + if abbreviation.list: + field.abbreviation.CopyFrom(abbreviation.list[0]) + field.source_location.start.CopyFrom(location.source_location.start) + ir_data_utils.builder(body.source_location).start.CopyFrom( + location.source_location.start + ) + if body.HasField("enumeration"): + ir_data_utils.builder(body.enumeration).source_location.CopyFrom( + body.source_location + ) + else: + assert body.HasField("structure") + ir_data_utils.builder(body.structure).source_location.CopyFrom( + body.source_location + ) + ir_data_utils.builder(body).name.CopyFrom(type_name) + field.source_location.end.CopyFrom(body.source_location.end) + subtypes = [body] + list(body.subtype) + del body.subtype[:] + return _FieldWithType(field=field_ir, subtypes=subtypes) + + +@_handles( + "anonymous-bits-field-definition ->" + ' field-location "bits" ":" Comment? eol anonymous-bits-body' +) +def _anonymous_bit_field(location, bits_keyword, colon, comment, newline, bits_body): + """Constructs an ir_data.Field for an anonymous bit field.""" + del colon, comment, newline # Unused. + name = ir_data.NameDefinition( + name=ir_data.Word( + text=_get_anonymous_field_name(), + source_location=bits_keyword.source_location, + ), + source_location=bits_keyword.source_location, + is_anonymous=True, + ) + return _inline_type_field(location, name, _List([]), bits_body) + + +@_handles("field-body -> Indent doc-line* attribute-line* Dedent") def _field_body(indent, docs, attributes, dedent): - del indent, dedent # Unused. - return ir_data.Field(documentation=docs.list, attribute=attributes.list) + del indent, dedent # Unused. + return ir_data.Field(documentation=docs.list, attribute=attributes.list) # A parenthetically-denoted abbreviation. @_handles('abbreviation -> "(" snake-word ")"') def _abbreviation(open_paren, word, close_paren): - del open_paren, close_paren # Unused. - return word + del open_paren, close_paren # Unused. + return word # enum EnumName: # ... values ... @_handles('enum -> "enum" type-name ":" Comment? eol enum-body') def _enum(enum, name, colon, comment, newline, enum_body): - del colon, comment, newline # Unused. - ir_data_utils.builder(enum_body.enumeration).source_location.start.CopyFrom( - enum.source_location.start) - ir_data_utils.builder(enum_body.enumeration).source_location.end.CopyFrom( - enum_body.source_location.end) - ir_data_utils.builder(enum_body).name.CopyFrom(name) - return enum_body + del colon, comment, newline # Unused. + ir_data_utils.builder(enum_body.enumeration).source_location.start.CopyFrom( + enum.source_location.start + ) + ir_data_utils.builder(enum_body.enumeration).source_location.end.CopyFrom( + enum_body.source_location.end + ) + ir_data_utils.builder(enum_body).name.CopyFrom(name) + return enum_body # [enum Foo:] # name = value # name = value -@_handles('enum-body -> Indent doc-line* attribute-line* enum-value+ Dedent') +@_handles("enum-body -> Indent doc-line* attribute-line* enum-value+ Dedent") def _enum_body(indent, docs, attributes, values, dedent): - del indent, dedent # Unused. - return ir_data.TypeDefinition( - enumeration=ir_data.Enum(value=values.list), - documentation=docs.list, - attribute=attributes.list, - addressable_unit=ir_data.AddressableUnit.BIT) + del indent, dedent # Unused. + return ir_data.TypeDefinition( + enumeration=ir_data.Enum(value=values.list), + documentation=docs.list, + attribute=attributes.list, + addressable_unit=ir_data.AddressableUnit.BIT, + ) # name = value -@_handles('enum-value -> ' - ' constant-name "=" expression attribute* doc? Comment? eol enum-value-body?') -def _enum_value(name, equals, expression, attribute, documentation, comment, newline, - body): - del equals, comment, newline # Unused. - result = ir_data.EnumValue(name=name, - value=expression, - documentation=documentation.list, - attribute=attribute.list) - if body.list: - result.documentation.extend(body.list[0].documentation) - result.attribute.extend(body.list[0].attribute) - return result - - -@_handles('enum-value-body -> Indent doc-line* attribute-line* Dedent') +@_handles( + "enum-value -> " + ' constant-name "=" expression attribute* doc? Comment? eol enum-value-body?' +) +def _enum_value( + name, equals, expression, attribute, documentation, comment, newline, body +): + del equals, comment, newline # Unused. + result = ir_data.EnumValue( + name=name, + value=expression, + documentation=documentation.list, + attribute=attribute.list, + ) + if body.list: + result.documentation.extend(body.list[0].documentation) + result.attribute.extend(body.list[0].attribute) + return result + + +@_handles("enum-value-body -> Indent doc-line* attribute-line* Dedent") def _enum_value_body(indent, docs, attributes, dedent): - del indent, dedent # Unused. - return ir_data.EnumValue(documentation=docs.list, attribute=attributes.list) + del indent, dedent # Unused. + return ir_data.EnumValue(documentation=docs.list, attribute=attributes.list) # An external is just a declaration that a type exists and has certain # attributes. @_handles('external -> "external" type-name ":" Comment? eol external-body') def _external(external, name, colon, comment, newline, external_body): - del colon, comment, newline # Unused. - ir_data_utils.builder(external_body.source_location).start.CopyFrom(external.source_location.start) - if external_body.name: - ir_data_utils.update(external_body.name, name) - else: - external_body.name = ir_data_utils.copy(name) - return external_body + del colon, comment, newline # Unused. + ir_data_utils.builder(external_body.source_location).start.CopyFrom( + external.source_location.start + ) + if external_body.name: + ir_data_utils.update(external_body.name, name) + else: + external_body.name = ir_data_utils.copy(name) + return external_body # This syntax implicitly requires either a documentation line or a attribute # line, or it won't parse (because no Indent/Dedent tokens will be emitted). -@_handles('external-body -> Indent doc-line* attribute-line* Dedent') +@_handles("external-body -> Indent doc-line* attribute-line* Dedent") def _external_body(indent, docs, attributes, dedent): - return ir_data.TypeDefinition( - external=ir_data.External( - # Set source_location here, since it won't be set automatically. - source_location=ir_data.Location(start=indent.source_location.start, - end=dedent.source_location.end)), - documentation=docs.list, - attribute=attributes.list) + return ir_data.TypeDefinition( + external=ir_data.External( + # Set source_location here, since it won't be set automatically. + source_location=ir_data.Location( + start=indent.source_location.start, end=dedent.source_location.end + ) + ), + documentation=docs.list, + attribute=attributes.list, + ) @_handles('field-location -> expression "[" "+" expression "]"') def _field_location(start, open_bracket, plus, size, close_bracket): - del open_bracket, plus, close_bracket # Unused. - return ir_data.FieldLocation(start=start, size=size) + del open_bracket, plus, close_bracket # Unused. + return ir_data.FieldLocation(start=start, size=size) @_handles('delimited-argument-list -> "(" argument-list ")"') def _type_argument_list(open_paren, arguments, close_paren): - del open_paren, close_paren # Unused - return arguments + del open_paren, close_paren # Unused + return arguments # A type is "TypeName" or "TypeName[length]" or "TypeName[length][length]", etc. # An array type may have an empty length ("Type[]"). This is only valid for the # outermost length (the last set of brackets), but that must be checked # elsewhere. -@_handles('type -> type-reference delimited-argument-list? type-size-specifier?' - ' array-length-specifier*') +@_handles( + "type -> type-reference delimited-argument-list? type-size-specifier?" + " array-length-specifier*" +) def _type(reference, parameters, size, array_spec): - """Builds the IR for a type specifier.""" - base_type_source_location_end = reference.source_location.end - atomic_type_source_location_end = reference.source_location.end - if parameters.list: - base_type_source_location_end = parameters.source_location.end - atomic_type_source_location_end = parameters.source_location.end - if size.list: - base_type_source_location_end = size.source_location.end - base_type_location = parser_types.make_location( - reference.source_location.start, - base_type_source_location_end) - atomic_type_location = parser_types.make_location( - reference.source_location.start, - atomic_type_source_location_end) - t = ir_data.Type( - atomic_type=ir_data.AtomicType( - reference=ir_data_utils.copy(reference), - source_location=atomic_type_location, - runtime_parameter=parameters.list[0].list if parameters.list else []), - size_in_bits=size.list[0] if size.list else None, - source_location=base_type_location) - for length in array_spec.list: - location = parser_types.make_location( - t.source_location.start, length.source_location.end) - if isinstance(length, ir_data.Expression): - t = ir_data.Type( - array_type=ir_data.ArrayType(base_type=t, - element_count=length, - source_location=location), - source_location=location) - elif isinstance(length, ir_data.Empty): - t = ir_data.Type( - array_type=ir_data.ArrayType(base_type=t, - automatic=length, - source_location=location), - source_location=location) - else: - assert False, "Shouldn't be here." - return t + """Builds the IR for a type specifier.""" + base_type_source_location_end = reference.source_location.end + atomic_type_source_location_end = reference.source_location.end + if parameters.list: + base_type_source_location_end = parameters.source_location.end + atomic_type_source_location_end = parameters.source_location.end + if size.list: + base_type_source_location_end = size.source_location.end + base_type_location = parser_types.make_location( + reference.source_location.start, base_type_source_location_end + ) + atomic_type_location = parser_types.make_location( + reference.source_location.start, atomic_type_source_location_end + ) + t = ir_data.Type( + atomic_type=ir_data.AtomicType( + reference=ir_data_utils.copy(reference), + source_location=atomic_type_location, + runtime_parameter=parameters.list[0].list if parameters.list else [], + ), + size_in_bits=size.list[0] if size.list else None, + source_location=base_type_location, + ) + for length in array_spec.list: + location = parser_types.make_location( + t.source_location.start, length.source_location.end + ) + if isinstance(length, ir_data.Expression): + t = ir_data.Type( + array_type=ir_data.ArrayType( + base_type=t, element_count=length, source_location=location + ), + source_location=location, + ) + elif isinstance(length, ir_data.Empty): + t = ir_data.Type( + array_type=ir_data.ArrayType( + base_type=t, automatic=length, source_location=location + ), + source_location=location, + ) + else: + assert False, "Shouldn't be here." + return t # TODO(bolms): Should symbolic names or expressions be allowed? E.g., # UInt:FIELD_SIZE or UInt:(16 + 16)? @_handles('type-size-specifier -> ":" numeric-constant') def _type_size_specifier(colon, numeric_constant): - """handles the ":32" part of a type specifier like "UInt:32".""" - del colon - return ir_data.Expression(constant=numeric_constant) + """handles the ":32" part of a type specifier like "UInt:32".""" + del colon + return ir_data.Expression(constant=numeric_constant) # The distinctions between different formats of NameDefinitions, Words, and # References are enforced during parsing, but not propagated to the IR. -@_handles('type-name -> type-word') -@_handles('snake-name -> snake-word') -@_handles('constant-name -> constant-word') +@_handles("type-name -> type-word") +@_handles("snake-name -> snake-word") +@_handles("constant-name -> constant-word") def _name(word): - return ir_data.NameDefinition(name=word) + return ir_data.NameDefinition(name=word) -@_handles('type-word -> CamelWord') -@_handles('snake-word -> SnakeWord') +@_handles("type-word -> CamelWord") +@_handles("snake-word -> SnakeWord") @_handles('builtin-field-word -> "$size_in_bits"') @_handles('builtin-field-word -> "$size_in_bytes"') @_handles('builtin-field-word -> "$max_size_in_bits"') @@ -1280,7 +1430,7 @@ def _name(word): @_handles('builtin-word -> "$is_statically_sized"') @_handles('builtin-word -> "$static_size_in_bits"') @_handles('builtin-word -> "$next"') -@_handles('constant-word -> ShoutyWord') +@_handles("constant-word -> ShoutyWord") @_handles('and-operator -> "&&"') @_handles('or-operator -> "||"') @_handles('less-operator -> "<="') @@ -1297,28 +1447,29 @@ def _name(word): @_handles('function-name -> "$upper_bound"') @_handles('function-name -> "$lower_bound"') def _word(word): - return ir_data.Word(text=word.text) + return ir_data.Word(text=word.text) -@_handles('type-reference -> type-reference-tail') -@_handles('constant-reference -> constant-reference-tail') +@_handles("type-reference -> type-reference-tail") +@_handles("constant-reference -> constant-reference-tail") def _un_module_qualified_type_reference(reference): - return reference + return reference -@_handles('constant-reference-tail -> constant-word') -@_handles('type-reference-tail -> type-word') -@_handles('snake-reference -> snake-word') -@_handles('snake-reference -> builtin-field-word') +@_handles("constant-reference-tail -> constant-word") +@_handles("type-reference-tail -> type-word") +@_handles("snake-reference -> snake-word") +@_handles("snake-reference -> builtin-field-word") def _reference(word): - return ir_data.Reference(source_name=[word]) + return ir_data.Reference(source_name=[word]) -@_handles('builtin-reference -> builtin-word') +@_handles("builtin-reference -> builtin-word") def _builtin_reference(word): - return ir_data.Reference(source_name=[word], - canonical_name=ir_data.CanonicalName( - object_path=[word.text])) + return ir_data.Reference( + source_name=[word], + canonical_name=ir_data.CanonicalName(object_path=[word.text]), + ) # Because constant-references ("Enum.NAME") are used in the same contexts as @@ -1334,11 +1485,11 @@ def _builtin_reference(word): # "snake_name.snake_name". @_handles('constant-reference -> snake-reference "." constant-reference-tail') def _module_qualified_constant_reference(new_head, dot, reference): - del dot # Unused. - new_source_name = list(new_head.source_name) + list(reference.source_name) - del reference.source_name[:] - reference.source_name.extend(new_source_name) - return reference + del dot # Unused. + new_source_name = list(new_head.source_name) + list(reference.source_name) + del reference.source_name[:] + reference.source_name.extend(new_source_name) + return reference @_handles('constant-reference-tail -> type-word "." constant-reference-tail') @@ -1348,69 +1499,70 @@ def _module_qualified_constant_reference(new_head, dot, reference): @_handles('type-reference-tail -> type-word "." type-reference-tail') @_handles('type-reference -> snake-word "." type-reference-tail') def _qualified_reference(word, dot, reference): - """Adds a name. or Type. qualification to the head of a reference.""" - del dot # Unused. - new_source_name = [word] + list(reference.source_name) - del reference.source_name[:] - reference.source_name.extend(new_source_name) - return reference + """Adds a name. or Type. qualification to the head of a reference.""" + del dot # Unused. + new_source_name = [word] + list(reference.source_name) + del reference.source_name[:] + reference.source_name.extend(new_source_name) + return reference # Arrays are properly translated to IR in _type(). @_handles('array-length-specifier -> "[" expression "]"') def _array_length_specifier(open_bracket, length, close_bracket): - del open_bracket, close_bracket # Unused. - return length + del open_bracket, close_bracket # Unused. + return length # An array specifier can end with empty brackets ("arr[3][]"), in which case the # array's size is inferred from the size of its enclosing field. @_handles('array-length-specifier -> "[" "]"') def _auto_array_length_specifier(open_bracket, close_bracket): - # Note that the Void's source_location is the space between the brackets (if - # any). - return ir_data.Empty( - source_location=ir_data.Location(start=open_bracket.source_location.end, - end=close_bracket.source_location.start)) + # Note that the Void's source_location is the space between the brackets (if + # any). + return ir_data.Empty( + source_location=ir_data.Location( + start=open_bracket.source_location.end, + end=close_bracket.source_location.start, + ) + ) @_handles('eol -> "\\n" comment-line*') def _eol(eol, comments): - del comments # Unused - return eol + del comments # Unused + return eol @_handles('comment-line -> Comment? "\\n"') def _comment_line(comment, eol): - del comment # Unused - return eol + del comment # Unused + return eol def _finalize_grammar(): - """_Finalize adds productions for foo*, foo+, and foo? symbols.""" - star_symbols = set() - plus_symbols = set() - option_symbols = set() - for production in _handlers: - for symbol in production.rhs: - if symbol[-1] == '*': - star_symbols.add(symbol[:-1]) - elif symbol[-1] == '+': - # symbol+ relies on the rule for symbol* - star_symbols.add(symbol[:-1]) - plus_symbols.add(symbol[:-1]) - elif symbol[-1] == '?': - option_symbols.add(symbol[:-1]) - for symbol in star_symbols: - _handles('{s}* -> {s} {s}*'.format(s=symbol))( - lambda e, r: _List([e] + r.list)) - _handles('{s}* ->'.format(s=symbol))(lambda: _List([])) - for symbol in plus_symbols: - _handles('{s}+ -> {s} {s}*'.format(s=symbol))( - lambda e, r: _List([e] + r.list)) - for symbol in option_symbols: - _handles('{s}? -> {s}'.format(s=symbol))(lambda e: _List([e])) - _handles('{s}? ->'.format(s=symbol))(lambda: _List([])) + """_Finalize adds productions for foo*, foo+, and foo? symbols.""" + star_symbols = set() + plus_symbols = set() + option_symbols = set() + for production in _handlers: + for symbol in production.rhs: + if symbol[-1] == "*": + star_symbols.add(symbol[:-1]) + elif symbol[-1] == "+": + # symbol+ relies on the rule for symbol* + star_symbols.add(symbol[:-1]) + plus_symbols.add(symbol[:-1]) + elif symbol[-1] == "?": + option_symbols.add(symbol[:-1]) + for symbol in star_symbols: + _handles("{s}* -> {s} {s}*".format(s=symbol))(lambda e, r: _List([e] + r.list)) + _handles("{s}* ->".format(s=symbol))(lambda: _List([])) + for symbol in plus_symbols: + _handles("{s}+ -> {s} {s}*".format(s=symbol))(lambda e, r: _List([e] + r.list)) + for symbol in option_symbols: + _handles("{s}? -> {s}".format(s=symbol))(lambda e: _List([e])) + _handles("{s}? ->".format(s=symbol))(lambda: _List([])) _finalize_grammar() @@ -1420,6 +1572,6 @@ def _finalize_grammar(): # These export the grammar used by module_ir so that parser_generator can build # a parser for the same language. -START_SYMBOL = 'module' -EXPRESSION_START_SYMBOL = 'expression' +START_SYMBOL = "module" +EXPRESSION_START_SYMBOL = "expression" PRODUCTIONS = list(_handlers.keys()) diff --git a/compiler/front_end/module_ir_test.py b/compiler/front_end/module_ir_test.py index 5faa107..ad884a8 100644 --- a/compiler/front_end/module_ir_test.py +++ b/compiler/front_end/module_ir_test.py @@ -30,12 +30,16 @@ _TESTDATA_PATH = "testdata.golden" _MINIMAL_SOURCE = pkgutil.get_data( - _TESTDATA_PATH, "span_se_log_file_status.emb").decode(encoding="UTF-8") + _TESTDATA_PATH, "span_se_log_file_status.emb" +).decode(encoding="UTF-8") _MINIMAL_SAMPLE = parser.parse_module( - tokenizer.tokenize(_MINIMAL_SOURCE, "")[0]).parse_tree -_MINIMAL_SAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.Module, + tokenizer.tokenize(_MINIMAL_SOURCE, "")[0] +).parse_tree +_MINIMAL_SAMPLE_IR = ir_data_utils.IrDataSerializer.from_json( + ir_data.Module, pkgutil.get_data(_TESTDATA_PATH, "span_se_log_file_status.ir.txt").decode( - encoding="UTF-8") + encoding="UTF-8" + ), ) # _TEST_CASES contains test cases, separated by '===', that ensure that specific @@ -3974,230 +3978,253 @@ def _get_test_cases(): - test_case = collections.namedtuple("test_case", ["name", "parse_tree", "ir"]) - result = [] - for case in _TEST_CASES.split("==="): - name, emb, ir_text = case.split("---") - name = name.strip() - try: - ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Module, ir_text) - except Exception: - print(name) - raise - parse_result = parser.parse_module(tokenizer.tokenize(emb, "")[0]) - assert not parse_result.error, "{}:\n{}".format(name, parse_result.error) - result.append(test_case(name, parse_result.parse_tree, ir)) - return result + test_case = collections.namedtuple("test_case", ["name", "parse_tree", "ir"]) + result = [] + for case in _TEST_CASES.split("==="): + name, emb, ir_text = case.split("---") + name = name.strip() + try: + ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Module, ir_text) + except Exception: + print(name) + raise + parse_result = parser.parse_module(tokenizer.tokenize(emb, "")[0]) + assert not parse_result.error, "{}:\n{}".format(name, parse_result.error) + result.append(test_case(name, parse_result.parse_tree, ir)) + return result def _get_negative_test_cases(): - test_case = collections.namedtuple("test_case", - ["name", "text", "error_token"]) - result = [] - for case in _NEGATIVE_TEST_CASES.split("==="): - name, error_token, text = case.split("---") - name = name.strip() - error_token = error_token.strip() - result.append(test_case(name, text, error_token)) - return result + test_case = collections.namedtuple("test_case", ["name", "text", "error_token"]) + result = [] + for case in _NEGATIVE_TEST_CASES.split("==="): + name, error_token, text = case.split("---") + name = name.strip() + error_token = error_token.strip() + result.append(test_case(name, text, error_token)) + return result def _check_source_location(source_location, path, min_start, max_end): - """Performs sanity checks on a source_location field. + """Performs sanity checks on a source_location field. - Arguments: - source_location: The source_location to check. - path: The path, to use in error messages. - min_start: A minimum value for source_location.start, or None. - max_end: A maximum value for source_location.end, or None. + Arguments: + source_location: The source_location to check. + path: The path, to use in error messages. + min_start: A minimum value for source_location.start, or None. + max_end: A maximum value for source_location.end, or None. - Returns: - A list of error messages, or an empty list if no errors. - """ - if source_location.is_disjoint_from_parent: - # If source_location.is_disjoint_from_parent, then this source_location is - # allowed to be outside of the parent's source_location. - return [] + Returns: + A list of error messages, or an empty list if no errors. + """ + if source_location.is_disjoint_from_parent: + # If source_location.is_disjoint_from_parent, then this source_location is + # allowed to be outside of the parent's source_location. + return [] - result = [] - start = None - end = None - if not source_location.HasField("start"): - result.append("{}.start missing".format(path)) - else: - start = source_location.start - if not source_location.HasField("end"): - result.append("{}.end missing".format(path)) - else: - end = source_location.end - - if start and end: - if start.HasField("line") and end.HasField("line"): - if start.line > end.line: - result.append("{}.start.line > {}.end.line ({} vs {})".format( - path, path, start.line, end.line)) - elif start.line == end.line: - if (start.HasField("column") and end.HasField("column") and - start.column > end.column): - result.append("{}.start.column > {}.end.column ({} vs {})".format( - path, path, start.column, end.column)) - - for name, field in (("start", start), ("end", end)): - if not field: - continue - if field.HasField("line"): - if field.line <= 0: - result.append("{}.{}.line <= 0 ({})".format(path, name, field.line)) + result = [] + start = None + end = None + if not source_location.HasField("start"): + result.append("{}.start missing".format(path)) else: - result.append("{}.{}.line missing".format(path, name)) - if field.HasField("column"): - if field.column <= 0: - result.append("{}.{}.column <= 0 ({})".format(path, name, field.column)) + start = source_location.start + if not source_location.HasField("end"): + result.append("{}.end missing".format(path)) else: - result.append("{}.{}.column missing".format(path, name)) + end = source_location.end + + if start and end: + if start.HasField("line") and end.HasField("line"): + if start.line > end.line: + result.append( + "{}.start.line > {}.end.line ({} vs {})".format( + path, path, start.line, end.line + ) + ) + elif start.line == end.line: + if ( + start.HasField("column") + and end.HasField("column") + and start.column > end.column + ): + result.append( + "{}.start.column > {}.end.column ({} vs {})".format( + path, path, start.column, end.column + ) + ) + + for name, field in (("start", start), ("end", end)): + if not field: + continue + if field.HasField("line"): + if field.line <= 0: + result.append("{}.{}.line <= 0 ({})".format(path, name, field.line)) + else: + result.append("{}.{}.line missing".format(path, name)) + if field.HasField("column"): + if field.column <= 0: + result.append("{}.{}.column <= 0 ({})".format(path, name, field.column)) + else: + result.append("{}.{}.column missing".format(path, name)) - if min_start and start: - if min_start.line > start.line or ( - min_start.line == start.line and min_start.column > start.column): - result.append("{}.start before parent start".format(path)) + if min_start and start: + if min_start.line > start.line or ( + min_start.line == start.line and min_start.column > start.column + ): + result.append("{}.start before parent start".format(path)) - if max_end and end: - if max_end.line < end.line or ( - max_end.line == end.line and max_end.column < end.column): - result.append("{}.end after parent end".format(path)) + if max_end and end: + if max_end.line < end.line or ( + max_end.line == end.line and max_end.column < end.column + ): + result.append("{}.end after parent end".format(path)) - return result + return result def _check_all_source_locations(proto, path="", min_start=None, max_end=None): - """Performs sanity checks on all source_locations in proto. + """Performs sanity checks on all source_locations in proto. - Arguments: - proto: The proto to recursively check. - path: The path, to use in error messages. - min_start: A minimum value for source_location.start, or None. - max_end: A maximum value for source_location.end, or None. + Arguments: + proto: The proto to recursively check. + path: The path, to use in error messages. + min_start: A minimum value for source_location.start, or None. + max_end: A maximum value for source_location.end, or None. - Returns: - A list of error messages, or an empty list if no errors. - """ - if path: - path += "." + Returns: + A list of error messages, or an empty list if no errors. + """ + if path: + path += "." - errors = [] + errors = [] - child_start = None - child_end = None - # Only check the source_location value if this proto message actually has a - # source_location field. - if proto.HasField("source_location"): - errors.extend(_check_source_location(proto.source_location, - path + "source_location", - min_start, max_end)) - child_start = proto.source_location.start - child_end = proto.source_location.end + child_start = None + child_end = None + # Only check the source_location value if this proto message actually has a + # source_location field. + if proto.HasField("source_location"): + errors.extend( + _check_source_location( + proto.source_location, path + "source_location", min_start, max_end + ) + ) + child_start = proto.source_location.start + child_end = proto.source_location.end - for name, spec in ir_data_fields.field_specs(proto).items(): - if name == "source_location": - continue - if not proto.HasField(name): - continue - field_path = "{}{}".format(path, name) - if spec.is_dataclass: - if spec.is_sequence: - index = 0 - for i in getattr(proto, name): - item_path = "{}[{}]".format(field_path, index) - index += 1 - errors.extend( - _check_all_source_locations(i, item_path, child_start, child_end)) - else: - errors.extend(_check_all_source_locations(getattr(proto, name), - field_path, child_start, - child_end)) + for name, spec in ir_data_fields.field_specs(proto).items(): + if name == "source_location": + continue + if not proto.HasField(name): + continue + field_path = "{}{}".format(path, name) + if spec.is_dataclass: + if spec.is_sequence: + index = 0 + for i in getattr(proto, name): + item_path = "{}[{}]".format(field_path, index) + index += 1 + errors.extend( + _check_all_source_locations( + i, item_path, child_start, child_end + ) + ) + else: + errors.extend( + _check_all_source_locations( + getattr(proto, name), field_path, child_start, child_end + ) + ) - return errors + return errors class ModuleIrTest(unittest.TestCase): - """Tests the module_ir.build_ir() function.""" + """Tests the module_ir.build_ir() function.""" - def test_build_ir(self): - ir = module_ir.build_ir(_MINIMAL_SAMPLE) - ir.source_text = _MINIMAL_SOURCE - self.assertEqual(ir, _MINIMAL_SAMPLE_IR) + def test_build_ir(self): + ir = module_ir.build_ir(_MINIMAL_SAMPLE) + ir.source_text = _MINIMAL_SOURCE + self.assertEqual(ir, _MINIMAL_SAMPLE_IR) - def test_production_coverage(self): - """Checks that all grammar productions are used somewhere in tests.""" - used_productions = set() - module_ir.build_ir(_MINIMAL_SAMPLE, used_productions) - for test in _get_test_cases(): - module_ir.build_ir(test.parse_tree, used_productions) - self.assertEqual(set(module_ir.PRODUCTIONS) - used_productions, set([])) + def test_production_coverage(self): + """Checks that all grammar productions are used somewhere in tests.""" + used_productions = set() + module_ir.build_ir(_MINIMAL_SAMPLE, used_productions) + for test in _get_test_cases(): + module_ir.build_ir(test.parse_tree, used_productions) + self.assertEqual(set(module_ir.PRODUCTIONS) - used_productions, set([])) - def test_double_negative_non_compilation(self): - """Checks that unparenthesized double unary minus/plus is a parse error.""" - for example in ("[x: - -3]", "[x: + -3]", "[x: - +3]", "[x: + +3]"): - parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0]) - self.assertTrue(parse_result.error) - self.assertEqual(7, parse_result.error.token.source_location.start.column) - for example in ("[x:-(-3)]", "[x:+(-3)]", "[x:-(+3)]", "[x:+(+3)]"): - parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0]) - self.assertFalse(parse_result.error) + def test_double_negative_non_compilation(self): + """Checks that unparenthesized double unary minus/plus is a parse error.""" + for example in ("[x: - -3]", "[x: + -3]", "[x: - +3]", "[x: + +3]"): + parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0]) + self.assertTrue(parse_result.error) + self.assertEqual(7, parse_result.error.token.source_location.start.column) + for example in ("[x:-(-3)]", "[x:+(-3)]", "[x:-(+3)]", "[x:+(+3)]"): + parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0]) + self.assertFalse(parse_result.error) def _make_superset_tests(): - def _make_superset_test(test): + def _make_superset_test(test): - def test_case(self): - ir = module_ir.build_ir(test.parse_tree) - is_superset, error_message = test_util.proto_is_superset(ir, test.ir) + def test_case(self): + ir = module_ir.build_ir(test.parse_tree) + is_superset, error_message = test_util.proto_is_superset(ir, test.ir) - self.assertTrue( - is_superset, - error_message + "\n" + ir_data_utils.IrDataSerializer(ir).to_json(indent=2) + "\n" + - ir_data_utils.IrDataSerializer(test.ir).to_json(indent=2)) + self.assertTrue( + is_superset, + error_message + + "\n" + + ir_data_utils.IrDataSerializer(ir).to_json(indent=2) + + "\n" + + ir_data_utils.IrDataSerializer(test.ir).to_json(indent=2), + ) - return test_case + return test_case - for test in _get_test_cases(): - test_name = "test " + test.name + " proto superset" - assert not hasattr(ModuleIrTest, test_name) - setattr(ModuleIrTest, test_name, _make_superset_test(test)) + for test in _get_test_cases(): + test_name = "test " + test.name + " proto superset" + assert not hasattr(ModuleIrTest, test_name) + setattr(ModuleIrTest, test_name, _make_superset_test(test)) def _make_source_location_tests(): - def _make_source_location_test(test): + def _make_source_location_test(test): - def test_case(self): - error_list = _check_all_source_locations( - module_ir.build_ir(test.parse_tree)) - self.assertFalse(error_list, "\n".join([test.name] + error_list)) + def test_case(self): + error_list = _check_all_source_locations( + module_ir.build_ir(test.parse_tree) + ) + self.assertFalse(error_list, "\n".join([test.name] + error_list)) - return test_case + return test_case - for test in _get_test_cases(): - test_name = "test " + test.name + " source location" - assert not hasattr(ModuleIrTest, test_name) - setattr(ModuleIrTest, test_name, _make_source_location_test(test)) + for test in _get_test_cases(): + test_name = "test " + test.name + " source location" + assert not hasattr(ModuleIrTest, test_name) + setattr(ModuleIrTest, test_name, _make_source_location_test(test)) def _make_negative_tests(): - def _make_negative_test(test): + def _make_negative_test(test): + + def test_case(self): + parse_result = parser.parse_module(tokenizer.tokenize(test.text, "")[0]) + self.assertEqual(test.error_token, parse_result.error.token.text.strip()) - def test_case(self): - parse_result = parser.parse_module(tokenizer.tokenize(test.text, "")[0]) - self.assertEqual(test.error_token, parse_result.error.token.text.strip()) + return test_case - return test_case + for test in _get_negative_test_cases(): + test_name = "test " + test.name + " compilation failure" + assert not hasattr(ModuleIrTest, test_name) + setattr(ModuleIrTest, test_name, _make_negative_test(test)) - for test in _get_negative_test_cases(): - test_name = "test " + test.name + " compilation failure" - assert not hasattr(ModuleIrTest, test_name) - setattr(ModuleIrTest, test_name, _make_negative_test(test)) _make_negative_tests() _make_superset_tests() @@ -4205,4 +4232,4 @@ def test_case(self): if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/parser.py b/compiler/front_end/parser.py index 6ece324..a6d7130 100644 --- a/compiler/front_end/parser.py +++ b/compiler/front_end/parser.py @@ -22,104 +22,110 @@ class ParserGenerationError(Exception): - """An error occurred during parser generation.""" - pass + """An error occurred during parser generation.""" + + pass def parse_error_examples(error_example_text): - """Parses error examples from error_example_text. - - Arguments: - error_example_text: The text of an error example file. - - Returns: - A list of tuples, suitable for passing into generate_parser. - - Raises: - ParserGenerationError: There is a problem parsing the error examples. - """ - error_examples = error_example_text.split("\n" + "=" * 80 + "\n") - result = [] - # Everything before the first "======" line is explanatory text: ignore it. - for error_example in error_examples[1:]: - message_and_examples = error_example.split("\n" + "-" * 80 + "\n") - if len(message_and_examples) != 2: - raise ParserGenerationError( - "Expected one error message and one example section in:\n" + - error_example) - message, example_text = message_and_examples - examples = example_text.split("\n---\n") - for example in examples: - # TODO(bolms): feed a line number into tokenize, so that tokenization - # failures refer to the correct line within error_example_text. - tokens, errors = tokenizer.tokenize(example, "") - if errors: - raise ParserGenerationError(str(errors)) - - for i in range(len(tokens)): - if tokens[i].symbol == "BadWord" and tokens[i].text == "$ANY": - tokens[i] = lr1.ANY_TOKEN - - error_token = None - for i in range(len(tokens)): - if tokens[i].symbol == "BadWord" and tokens[i].text == "$ERR": - error_token = tokens[i + 1] - del tokens[i] - break - else: - raise ParserGenerationError( - "No error token marker '$ERR' in:\n" + error_example) - - result.append((tokens, error_token, message.strip(), example)) - return result + """Parses error examples from error_example_text. + + Arguments: + error_example_text: The text of an error example file. + + Returns: + A list of tuples, suitable for passing into generate_parser. + + Raises: + ParserGenerationError: There is a problem parsing the error examples. + """ + error_examples = error_example_text.split("\n" + "=" * 80 + "\n") + result = [] + # Everything before the first "======" line is explanatory text: ignore it. + for error_example in error_examples[1:]: + message_and_examples = error_example.split("\n" + "-" * 80 + "\n") + if len(message_and_examples) != 2: + raise ParserGenerationError( + "Expected one error message and one example section in:\n" + + error_example + ) + message, example_text = message_and_examples + examples = example_text.split("\n---\n") + for example in examples: + # TODO(bolms): feed a line number into tokenize, so that tokenization + # failures refer to the correct line within error_example_text. + tokens, errors = tokenizer.tokenize(example, "") + if errors: + raise ParserGenerationError(str(errors)) + + for i in range(len(tokens)): + if tokens[i].symbol == "BadWord" and tokens[i].text == "$ANY": + tokens[i] = lr1.ANY_TOKEN + + error_token = None + for i in range(len(tokens)): + if tokens[i].symbol == "BadWord" and tokens[i].text == "$ERR": + error_token = tokens[i + 1] + del tokens[i] + break + else: + raise ParserGenerationError( + "No error token marker '$ERR' in:\n" + error_example + ) + + result.append((tokens, error_token, message.strip(), example)) + return result def generate_parser(start_symbol, productions, error_examples): - """Generates a parser from grammar, and applies error_examples. - - Arguments: - start_symbol: the start symbol of the grammar (a string) - productions: a list of parser_types.Production in the grammar - error_examples: A list of (source tokens, error message, source text) - tuples. - - Returns: - A parser. - - Raises: - ParserGenerationError: There is a problem generating the parser. - """ - parser = lr1.Grammar(start_symbol, productions).parser() - if parser.conflicts: - raise ParserGenerationError("\n".join([str(c) for c in parser.conflicts])) - for example in error_examples: - mark_result = parser.mark_error(example[0], example[1], example[2]) - if mark_result: - raise ParserGenerationError( - "error marking example: {}\nExample:\n{}".format( - mark_result, example[3])) - return parser + """Generates a parser from grammar, and applies error_examples. + + Arguments: + start_symbol: the start symbol of the grammar (a string) + productions: a list of parser_types.Production in the grammar + error_examples: A list of (source tokens, error message, source text) + tuples. + + Returns: + A parser. + + Raises: + ParserGenerationError: There is a problem generating the parser. + """ + parser = lr1.Grammar(start_symbol, productions).parser() + if parser.conflicts: + raise ParserGenerationError("\n".join([str(c) for c in parser.conflicts])) + for example in error_examples: + mark_result = parser.mark_error(example[0], example[1], example[2]) + if mark_result: + raise ParserGenerationError( + "error marking example: {}\nExample:\n{}".format( + mark_result, example[3] + ) + ) + return parser @simple_memoizer.memoize def _load_module_parser(): - error_examples = parse_error_examples( - resources.load("compiler.front_end", "error_examples")) - return generate_parser(module_ir.START_SYMBOL, module_ir.PRODUCTIONS, - error_examples) + error_examples = parse_error_examples( + resources.load("compiler.front_end", "error_examples") + ) + return generate_parser( + module_ir.START_SYMBOL, module_ir.PRODUCTIONS, error_examples + ) @simple_memoizer.memoize def _load_expression_parser(): - return generate_parser(module_ir.EXPRESSION_START_SYMBOL, - module_ir.PRODUCTIONS, []) + return generate_parser(module_ir.EXPRESSION_START_SYMBOL, module_ir.PRODUCTIONS, []) def parse_module(tokens): - """Parses the provided Emboss token list into an Emboss module parse tree.""" - return _load_module_parser().parse(tokens) + """Parses the provided Emboss token list into an Emboss module parse tree.""" + return _load_module_parser().parse(tokens) def parse_expression(tokens): - """Parses the provided Emboss token list into an expression parse tree.""" - return _load_expression_parser().parse(tokens) + """Parses the provided Emboss token list into an expression parse tree.""" + return _load_expression_parser().parse(tokens) diff --git a/compiler/front_end/parser_test.py b/compiler/front_end/parser_test.py index 06cfd03..f5f111b 100644 --- a/compiler/front_end/parser_test.py +++ b/compiler/front_end/parser_test.py @@ -23,8 +23,8 @@ # TODO(bolms): This is repeated in lr1_test.py; separate into test utils? def _parse_productions(*productions): - """Parses text into a grammar by calling Production.parse on each line.""" - return [parser_types.Production.parse(p) for p in productions] + """Parses text into a grammar by calling Production.parse on each line.""" + return [parser_types.Production.parse(p) for p in productions] _EXAMPLE_DIVIDER = "\n" + "=" * 80 + "\n" @@ -33,175 +33,266 @@ def _parse_productions(*productions): class ParserGeneratorTest(unittest.TestCase): - """Tests parser.parse_error_examples and generate_parser.""" - - def test_parse_good_error_examples(self): - errors = parser.parse_error_examples( - _EXAMPLE_DIVIDER + # ======... - "structure names must be Camel" + # Message. - _MESSAGE_ERROR_DIVIDER + # ------... - "struct $ERR FOO" + # First example. - _ERROR_DIVIDER + # --- - "struct $ERR foo" + # Second example. - _EXAMPLE_DIVIDER + # ======... - ' \n struct must be followed by ":" \n\n' + # Second message. - _MESSAGE_ERROR_DIVIDER + # ------... - "struct Foo $ERR") # Example for second message. - self.assertEqual(tokenizer.tokenize("struct FOO", "")[0], errors[0][0]) - self.assertEqual("structure names must be Camel", errors[0][2]) - self.assertEqual(tokenizer.tokenize("struct foo", "")[0], errors[1][0]) - self.assertEqual("structure names must be Camel", errors[1][2]) - self.assertEqual(tokenizer.tokenize("struct Foo ", "")[0], errors[2][0]) - self.assertEqual('struct must be followed by ":"', errors[2][2]) - - def test_parse_good_wildcard_example(self): - errors = parser.parse_error_examples( - _EXAMPLE_DIVIDER + # ======... - ' \n struct must be followed by ":" \n\n' + # Second message. - _MESSAGE_ERROR_DIVIDER + # ------... - "struct Foo $ERR $ANY") - tokens = tokenizer.tokenize("struct Foo ", "")[0] - # The $ANY token should come just before the end-of-line token in the parsed - # result. - tokens.insert(-1, lr1.ANY_TOKEN) - self.assertEqual(tokens, errors[0][0]) - self.assertEqual('struct must be followed by ":"', errors[0][2]) - - def test_parse_with_no_error_marker(self): - self.assertRaises( - parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "-- doc") - - def test_that_no_error_example_fails(self): - self.assertRaises(parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + _EXAMPLE_DIVIDER + "msg" + - _MESSAGE_ERROR_DIVIDER + "example") - - def test_that_message_example_divider_must_be_on_its_own_line(self): - self.assertRaises(parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "example") - self.assertRaises(parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + "example") - self.assertRaises(parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "\nexample") - self.assertRaises(parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + " \nexample") - - def test_that_example_divider_must_be_on_its_own_line(self): - self.assertRaises( - parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example" + "=" * 80 - + "msg" + _MESSAGE_ERROR_DIVIDER + "example") - self.assertRaises( - parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example\n" + "=" * - 80 + "msg" + _MESSAGE_ERROR_DIVIDER + "example") - self.assertRaises( - parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example" + "=" * 80 - + "\nmsg" + _MESSAGE_ERROR_DIVIDER + "example") - self.assertRaises( - parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example\n" + "=" * - 80 + " \nmsg" + _MESSAGE_ERROR_DIVIDER + "example") - - def test_that_tokenization_failure_results_in_failure(self): - self.assertRaises( - parser.ParserGenerationError, - parser.parse_error_examples, - _EXAMPLE_DIVIDER + "message" + _MESSAGE_ERROR_DIVIDER + "|") - - def test_generate_parser(self): - self.assertTrue(parser.generate_parser("C", _parse_productions("C -> s"), - [])) - self.assertTrue(parser.generate_parser( - "C", _parse_productions("C -> s", "C -> d"), [])) - - def test_generated_parser_error(self): - test_parser = parser.generate_parser( - "C", _parse_productions("C -> s", "C -> d"), - [([parser_types.Token("s", "s", None), - parser_types.Token("s", "s", None)], - parser_types.Token("s", "s", None), - "double s", "ss")]) - parse_result = test_parser.parse([parser_types.Token("s", "s", None), - parser_types.Token("s", "s", None)]) - self.assertEqual(None, parse_result.parse_tree) - self.assertEqual("double s", parse_result.error.code) - - def test_conflict_error(self): - self.assertRaises( - parser.ParserGenerationError, - parser.generate_parser, - "C", _parse_productions("C -> S", "C -> D", "S -> a", "D -> a"), []) - - def test_bad_mark_error(self): - self.assertRaises(parser.ParserGenerationError, - parser.generate_parser, - "C", _parse_productions("C -> s", "C -> d"), - [([parser_types.Token("s", "s", None), - parser_types.Token("s", "s", None)], + """Tests parser.parse_error_examples and generate_parser.""" + + def test_parse_good_error_examples(self): + errors = parser.parse_error_examples( + _EXAMPLE_DIVIDER # ======... + + "structure names must be Camel" # Message. + + _MESSAGE_ERROR_DIVIDER # ------... + + "struct $ERR FOO" # First example. + + _ERROR_DIVIDER # --- + + "struct $ERR foo" # Second example. + + _EXAMPLE_DIVIDER # ======... + + ' \n struct must be followed by ":" \n\n' # Second message. + + _MESSAGE_ERROR_DIVIDER # ------... + + "struct Foo $ERR" + ) # Example for second message. + self.assertEqual(tokenizer.tokenize("struct FOO", "")[0], errors[0][0]) + self.assertEqual("structure names must be Camel", errors[0][2]) + self.assertEqual(tokenizer.tokenize("struct foo", "")[0], errors[1][0]) + self.assertEqual("structure names must be Camel", errors[1][2]) + self.assertEqual(tokenizer.tokenize("struct Foo ", "")[0], errors[2][0]) + self.assertEqual('struct must be followed by ":"', errors[2][2]) + + def test_parse_good_wildcard_example(self): + errors = parser.parse_error_examples( + _EXAMPLE_DIVIDER # ======... + + ' \n struct must be followed by ":" \n\n' # Second message. + + _MESSAGE_ERROR_DIVIDER # ------... + + "struct Foo $ERR $ANY" + ) + tokens = tokenizer.tokenize("struct Foo ", "")[0] + # The $ANY token should come just before the end-of-line token in the parsed + # result. + tokens.insert(-1, lr1.ANY_TOKEN) + self.assertEqual(tokens, errors[0][0]) + self.assertEqual('struct must be followed by ":"', errors[0][2]) + + def test_parse_with_no_error_marker(self): + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "-- doc", + ) + + def test_that_no_error_example_fails(self): + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + + "msg" + + _EXAMPLE_DIVIDER + + "msg" + + _MESSAGE_ERROR_DIVIDER + + "example", + ) + + def test_that_message_example_divider_must_be_on_its_own_line(self): + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "example", + ) + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + "example", + ) + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "\nexample", + ) + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + " \nexample", + ) + + def test_that_example_divider_must_be_on_its_own_line(self): + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + + "msg" + + _MESSAGE_ERROR_DIVIDER + + "example" + + "=" * 80 + + "msg" + + _MESSAGE_ERROR_DIVIDER + + "example", + ) + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + + "msg" + + _MESSAGE_ERROR_DIVIDER + + "example\n" + + "=" * 80 + + "msg" + + _MESSAGE_ERROR_DIVIDER + + "example", + ) + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + + "msg" + + _MESSAGE_ERROR_DIVIDER + + "example" + + "=" * 80 + + "\nmsg" + + _MESSAGE_ERROR_DIVIDER + + "example", + ) + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + + "msg" + + _MESSAGE_ERROR_DIVIDER + + "example\n" + + "=" * 80 + + " \nmsg" + + _MESSAGE_ERROR_DIVIDER + + "example", + ) + + def test_that_tokenization_failure_results_in_failure(self): + self.assertRaises( + parser.ParserGenerationError, + parser.parse_error_examples, + _EXAMPLE_DIVIDER + "message" + _MESSAGE_ERROR_DIVIDER + "|", + ) + + def test_generate_parser(self): + self.assertTrue(parser.generate_parser("C", _parse_productions("C -> s"), [])) + self.assertTrue( + parser.generate_parser("C", _parse_productions("C -> s", "C -> d"), []) + ) + + def test_generated_parser_error(self): + test_parser = parser.generate_parser( + "C", + _parse_productions("C -> s", "C -> d"), + [ + ( + [ parser_types.Token("s", "s", None), - "double s", "ss"), - ([parser_types.Token("s", "s", None), - parser_types.Token("s", "s", None)], parser_types.Token("s", "s", None), - "double 's'", "ss")]) - self.assertRaises(parser.ParserGenerationError, - parser.generate_parser, - "C", _parse_productions("C -> s", "C -> d"), - [([parser_types.Token("s", "s", None)], + ], + parser_types.Token("s", "s", None), + "double s", + "ss", + ) + ], + ) + parse_result = test_parser.parse( + [parser_types.Token("s", "s", None), parser_types.Token("s", "s", None)] + ) + self.assertEqual(None, parse_result.parse_tree) + self.assertEqual("double s", parse_result.error.code) + + def test_conflict_error(self): + self.assertRaises( + parser.ParserGenerationError, + parser.generate_parser, + "C", + _parse_productions("C -> S", "C -> D", "S -> a", "D -> a"), + [], + ) + + def test_bad_mark_error(self): + self.assertRaises( + parser.ParserGenerationError, + parser.generate_parser, + "C", + _parse_productions("C -> s", "C -> d"), + [ + ( + [ + parser_types.Token("s", "s", None), + parser_types.Token("s", "s", None), + ], + parser_types.Token("s", "s", None), + "double s", + "ss", + ), + ( + [ + parser_types.Token("s", "s", None), parser_types.Token("s", "s", None), - "single s", "s")]) + ], + parser_types.Token("s", "s", None), + "double 's'", + "ss", + ), + ], + ) + self.assertRaises( + parser.ParserGenerationError, + parser.generate_parser, + "C", + _parse_productions("C -> s", "C -> d"), + [ + ( + [parser_types.Token("s", "s", None)], + parser_types.Token("s", "s", None), + "single s", + "s", + ) + ], + ) class ModuleParserTest(unittest.TestCase): - """Tests for parser.parse_module(). - - Correct parses should mostly be checked in conjunction with - module_ir.build_ir, as the exact data structure returned by - parser.parse_module() is determined by the grammar defined in module_ir. - These tests only need to cover errors and sanity checking. - """ - - def test_error_reporting_by_example(self): - parse_result = parser.parse_module( - tokenizer.tokenize("struct LogFileStatus:\n" - " 0 [+4] UInt\n", "")[0]) - self.assertEqual(None, parse_result.parse_tree) - self.assertEqual("A name is required for a struct field.", - parse_result.error.code) - self.assertEqual('"\\n"', parse_result.error.token.symbol) - self.assertEqual(set(['"["', "SnakeWord", '"."', '":"', '"("']), - parse_result.error.expected_tokens) - - def test_error_reporting_without_example(self): - parse_result = parser.parse_module( - tokenizer.tokenize("struct LogFileStatus:\n" - " 0 [+4] UInt foo +\n", "")[0]) - self.assertEqual(None, parse_result.parse_tree) - self.assertEqual(None, parse_result.error.code) - self.assertEqual('"+"', parse_result.error.token.symbol) - self.assertEqual(set(['"("', '"\\n"', '"["', "Documentation", "Comment"]), - parse_result.error.expected_tokens) - - def test_ok_parse(self): - parse_result = parser.parse_module( - tokenizer.tokenize("struct LogFileStatus:\n" - " 0 [+4] UInt foo\n", "")[0]) - self.assertTrue(parse_result.parse_tree) - self.assertEqual(None, parse_result.error) + """Tests for parser.parse_module(). + + Correct parses should mostly be checked in conjunction with + module_ir.build_ir, as the exact data structure returned by + parser.parse_module() is determined by the grammar defined in module_ir. + These tests only need to cover errors and sanity checking. + """ + + def test_error_reporting_by_example(self): + parse_result = parser.parse_module( + tokenizer.tokenize("struct LogFileStatus:\n" " 0 [+4] UInt\n", "")[0] + ) + self.assertEqual(None, parse_result.parse_tree) + self.assertEqual( + "A name is required for a struct field.", parse_result.error.code + ) + self.assertEqual('"\\n"', parse_result.error.token.symbol) + self.assertEqual( + set(['"["', "SnakeWord", '"."', '":"', '"("']), + parse_result.error.expected_tokens, + ) + + def test_error_reporting_without_example(self): + parse_result = parser.parse_module( + tokenizer.tokenize( + "struct LogFileStatus:\n" " 0 [+4] UInt foo +\n", "" + )[0] + ) + self.assertEqual(None, parse_result.parse_tree) + self.assertEqual(None, parse_result.error.code) + self.assertEqual('"+"', parse_result.error.token.symbol) + self.assertEqual( + set(['"("', '"\\n"', '"["', "Documentation", "Comment"]), + parse_result.error.expected_tokens, + ) + + def test_ok_parse(self): + parse_result = parser.parse_module( + tokenizer.tokenize( + "struct LogFileStatus:\n" " 0 [+4] UInt foo\n", "" + )[0] + ) + self.assertTrue(parse_result.parse_tree) + self.assertEqual(None, parse_result.error) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py index 9328e8f..498b1a9 100644 --- a/compiler/front_end/symbol_resolver.py +++ b/compiler/front_end/symbol_resolver.py @@ -33,497 +33,602 @@ def ambiguous_name_error(file_name, location, name, candidate_locations): - """A name cannot be resolved because there are two or more candidates.""" - result = [error.error(file_name, location, "Ambiguous name '{}'".format(name)) - ] - for location in sorted(candidate_locations): - result.append(error.note(location.file, location.location, - "Possible resolution")) - return result + """A name cannot be resolved because there are two or more candidates.""" + result = [error.error(file_name, location, "Ambiguous name '{}'".format(name))] + for location in sorted(candidate_locations): + result.append( + error.note(location.file, location.location, "Possible resolution") + ) + return result def duplicate_name_error(file_name, location, name, original_location): - """A name is defined two or more times.""" - return [error.error(file_name, location, "Duplicate name '{}'".format(name)), - error.note(original_location.file, original_location.location, - "Original definition")] + """A name is defined two or more times.""" + return [ + error.error(file_name, location, "Duplicate name '{}'".format(name)), + error.note( + original_location.file, original_location.location, "Original definition" + ), + ] def missing_name_error(file_name, location, name): - return [error.error(file_name, location, "No candidate for '{}'".format(name)) - ] + return [error.error(file_name, location, "No candidate for '{}'".format(name))] def array_subfield_error(file_name, location, name): - return [error.error(file_name, location, - "Cannot access member of array '{}'".format(name))] + return [ + error.error( + file_name, location, "Cannot access member of array '{}'".format(name) + ) + ] def noncomposite_subfield_error(file_name, location, name): - return [error.error(file_name, location, - "Cannot access member of noncomposite field '{}'".format( - name))] + return [ + error.error( + file_name, + location, + "Cannot access member of noncomposite field '{}'".format(name), + ) + ] def _nested_name(canonical_name, name): - """Creates a new CanonicalName with name appended to the object_path.""" - return ir_data.CanonicalName( - module_file=canonical_name.module_file, - object_path=list(canonical_name.object_path) + [name]) + """Creates a new CanonicalName with name appended to the object_path.""" + return ir_data.CanonicalName( + module_file=canonical_name.module_file, + object_path=list(canonical_name.object_path) + [name], + ) class _Scope(dict): - """A _Scope holds data for a symbol. - - A _Scope is a dict with some additional attributes. Lexically nested names - are kept in the dict, and bookkeeping is kept in the additional attributes. - - For example, each module should have a child _Scope for each type contained in - the module. `struct` and `bits` types should have nested _Scopes for each - field; `enum` types should have nested scopes for each enumerated name. - - Attributes: - canonical_name: The absolute name of this symbol; e.g. ("file.emb", - "TypeName", "SubTypeName", "field_name") - source_location: The ir_data.SourceLocation where this symbol is defined. - visibility: LOCAL, PRIVATE, or SEARCHABLE; see below. - alias: If set, this name is merely a pointer to another name. - """ - __slots__ = ("canonical_name", "source_location", "visibility", "alias") - - # A LOCAL name is visible outside of its enclosing scope, but should not be - # found when searching for a name. That is, this name should be matched in - # the tail of a qualified reference (the 'bar' in 'foo.bar'), but not when - # searching for names (the 'foo' in 'foo.bar' should not match outside of - # 'foo's scope). This applies to public field names. - LOCAL = object() - - # A PRIVATE name is similar to LOCAL except that it is never visible outside - # its enclosing scope. This applies to abbreviations of field names: if 'a' - # is an abbreviation for field 'apple', then 'foo.a' is not a valid reference; - # instead it should be 'foo.apple'. - PRIVATE = object() - - # A SEARCHABLE name is visible as long as it is in a scope in the search list. - # This applies to type names ('Foo'), which may be found from many scopes. - SEARCHABLE = object() - - def __init__(self, canonical_name, source_location, visibility, alias=None): - super(_Scope, self).__init__() - self.canonical_name = canonical_name - self.source_location = source_location - self.visibility = visibility - self.alias = alias + """A _Scope holds data for a symbol. + + A _Scope is a dict with some additional attributes. Lexically nested names + are kept in the dict, and bookkeeping is kept in the additional attributes. + + For example, each module should have a child _Scope for each type contained in + the module. `struct` and `bits` types should have nested _Scopes for each + field; `enum` types should have nested scopes for each enumerated name. + + Attributes: + canonical_name: The absolute name of this symbol; e.g. ("file.emb", + "TypeName", "SubTypeName", "field_name") + source_location: The ir_data.SourceLocation where this symbol is defined. + visibility: LOCAL, PRIVATE, or SEARCHABLE; see below. + alias: If set, this name is merely a pointer to another name. + """ + + __slots__ = ("canonical_name", "source_location", "visibility", "alias") + + # A LOCAL name is visible outside of its enclosing scope, but should not be + # found when searching for a name. That is, this name should be matched in + # the tail of a qualified reference (the 'bar' in 'foo.bar'), but not when + # searching for names (the 'foo' in 'foo.bar' should not match outside of + # 'foo's scope). This applies to public field names. + LOCAL = object() + + # A PRIVATE name is similar to LOCAL except that it is never visible outside + # its enclosing scope. This applies to abbreviations of field names: if 'a' + # is an abbreviation for field 'apple', then 'foo.a' is not a valid reference; + # instead it should be 'foo.apple'. + PRIVATE = object() + + # A SEARCHABLE name is visible as long as it is in a scope in the search list. + # This applies to type names ('Foo'), which may be found from many scopes. + SEARCHABLE = object() + + def __init__(self, canonical_name, source_location, visibility, alias=None): + super(_Scope, self).__init__() + self.canonical_name = canonical_name + self.source_location = source_location + self.visibility = visibility + self.alias = alias def _add_name_to_scope(name_ir, scope, canonical_name, visibility, errors): - """Adds the given name_ir to the given scope.""" - name = name_ir.text - new_scope = _Scope(canonical_name, name_ir.source_location, visibility) - if name in scope: - errors.append(duplicate_name_error( - scope.canonical_name.module_file, name_ir.source_location, name, - FileLocation(scope[name].canonical_name.module_file, - scope[name].source_location))) - else: - scope[name] = new_scope - return new_scope + """Adds the given name_ir to the given scope.""" + name = name_ir.text + new_scope = _Scope(canonical_name, name_ir.source_location, visibility) + if name in scope: + errors.append( + duplicate_name_error( + scope.canonical_name.module_file, + name_ir.source_location, + name, + FileLocation( + scope[name].canonical_name.module_file, scope[name].source_location + ), + ) + ) + else: + scope[name] = new_scope + return new_scope def _add_name_to_scope_and_normalize(name_ir, scope, visibility, errors): - """Adds the given name_ir to scope and sets its canonical_name.""" - name = name_ir.name.text - canonical_name = _nested_name(scope.canonical_name, name) - ir_data_utils.builder(name_ir).canonical_name.CopyFrom(canonical_name) - return _add_name_to_scope(name_ir.name, scope, canonical_name, visibility, - errors) + """Adds the given name_ir to scope and sets its canonical_name.""" + name = name_ir.name.text + canonical_name = _nested_name(scope.canonical_name, name) + ir_data_utils.builder(name_ir).canonical_name.CopyFrom(canonical_name) + return _add_name_to_scope(name_ir.name, scope, canonical_name, visibility, errors) def _add_struct_field_to_scope(field, scope, errors): - """Adds the name of the given field to the scope.""" - new_scope = _add_name_to_scope_and_normalize(field.name, scope, _Scope.LOCAL, - errors) - if field.HasField("abbreviation"): - _add_name_to_scope(field.abbreviation, scope, new_scope.canonical_name, - _Scope.PRIVATE, errors) - - value_builtin_name = ir_data.Word( - text="this", - source_location=ir_data.Location(is_synthetic=True), - ) - # In "inside field" scope, the name `this` maps back to the field itself. - # This is important for attributes like `[requires]`. - _add_name_to_scope(value_builtin_name, new_scope, - field.name.canonical_name, _Scope.PRIVATE, errors) + """Adds the name of the given field to the scope.""" + new_scope = _add_name_to_scope_and_normalize( + field.name, scope, _Scope.LOCAL, errors + ) + if field.HasField("abbreviation"): + _add_name_to_scope( + field.abbreviation, scope, new_scope.canonical_name, _Scope.PRIVATE, errors + ) + + value_builtin_name = ir_data.Word( + text="this", + source_location=ir_data.Location(is_synthetic=True), + ) + # In "inside field" scope, the name `this` maps back to the field itself. + # This is important for attributes like `[requires]`. + _add_name_to_scope( + value_builtin_name, new_scope, field.name.canonical_name, _Scope.PRIVATE, errors + ) def _add_parameter_name_to_scope(parameter, scope, errors): - """Adds the name of the given parameter to the scope.""" - _add_name_to_scope_and_normalize(parameter.name, scope, _Scope.LOCAL, errors) + """Adds the name of the given parameter to the scope.""" + _add_name_to_scope_and_normalize(parameter.name, scope, _Scope.LOCAL, errors) def _add_enum_value_to_scope(value, scope, errors): - """Adds the name of the enum value to scope.""" - _add_name_to_scope_and_normalize(value.name, scope, _Scope.LOCAL, errors) + """Adds the name of the enum value to scope.""" + _add_name_to_scope_and_normalize(value.name, scope, _Scope.LOCAL, errors) def _add_type_name_to_scope(type_definition, scope, errors): - """Adds the name of type_definition to the given scope.""" - new_scope = _add_name_to_scope_and_normalize(type_definition.name, scope, - _Scope.SEARCHABLE, errors) - return {"scope": new_scope} + """Adds the name of type_definition to the given scope.""" + new_scope = _add_name_to_scope_and_normalize( + type_definition.name, scope, _Scope.SEARCHABLE, errors + ) + return {"scope": new_scope} def _set_scope_for_type_definition(type_definition, scope): - """Sets the current scope for an ir_data.AddressableUnit.""" - return {"scope": scope[type_definition.name.name.text]} + """Sets the current scope for an ir_data.AddressableUnit.""" + return {"scope": scope[type_definition.name.name.text]} def _add_module_to_scope(module, scope): - """Adds the name of the module to the given scope.""" - module_symbol_table = _Scope( - ir_data.CanonicalName(module_file=module.source_file_name, - object_path=[]), - None, - _Scope.SEARCHABLE) - scope[module.source_file_name] = module_symbol_table - return {"scope": scope[module.source_file_name]} + """Adds the name of the module to the given scope.""" + module_symbol_table = _Scope( + ir_data.CanonicalName(module_file=module.source_file_name, object_path=[]), + None, + _Scope.SEARCHABLE, + ) + scope[module.source_file_name] = module_symbol_table + return {"scope": scope[module.source_file_name]} def _set_scope_for_module(module, scope): - """Adds the name of the module to the given scope.""" - return {"scope": scope[module.source_file_name]} + """Adds the name of the module to the given scope.""" + return {"scope": scope[module.source_file_name]} def _add_import_to_scope(foreign_import, table, module, errors): - if not foreign_import.local_name.text: - # This is the prelude import; ignore it. - return - _add_alias_to_scope(foreign_import.local_name, table, module.canonical_name, - [foreign_import.file_name.text], _Scope.SEARCHABLE, - errors) + if not foreign_import.local_name.text: + # This is the prelude import; ignore it. + return + _add_alias_to_scope( + foreign_import.local_name, + table, + module.canonical_name, + [foreign_import.file_name.text], + _Scope.SEARCHABLE, + errors, + ) def _construct_symbol_tables(ir): - """Constructs per-module symbol tables for each module in ir.""" - symbol_tables = {} - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Module], _add_module_to_scope, - parameters={"errors": errors, "scope": symbol_tables}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.TypeDefinition], _add_type_name_to_scope, - incidental_actions={ir_data.Module: _set_scope_for_module}, - parameters={"errors": errors, "scope": symbol_tables}) - if errors: - # Ideally, we would find duplicate field names elsewhere in the module, even - # if there are duplicate type names, but field/enum names in the colliding - # types also end up colliding, leading to spurious errors. E.g., if you - # have two `struct Foo`s, then the field check will also discover a - # collision for `$size_in_bytes`, since there are two `Foo.$size_in_bytes`. + """Constructs per-module symbol tables for each module in ir.""" + symbol_tables = {} + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Module], + _add_module_to_scope, + parameters={"errors": errors, "scope": symbol_tables}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.TypeDefinition], + _add_type_name_to_scope, + incidental_actions={ir_data.Module: _set_scope_for_module}, + parameters={"errors": errors, "scope": symbol_tables}, + ) + if errors: + # Ideally, we would find duplicate field names elsewhere in the module, even + # if there are duplicate type names, but field/enum names in the colliding + # types also end up colliding, leading to spurious errors. E.g., if you + # have two `struct Foo`s, then the field check will also discover a + # collision for `$size_in_bytes`, since there are two `Foo.$size_in_bytes`. + return symbol_tables, errors + + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.EnumValue], + _add_enum_value_to_scope, + incidental_actions={ + ir_data.Module: _set_scope_for_module, + ir_data.TypeDefinition: _set_scope_for_type_definition, + }, + parameters={"errors": errors, "scope": symbol_tables}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Field], + _add_struct_field_to_scope, + incidental_actions={ + ir_data.Module: _set_scope_for_module, + ir_data.TypeDefinition: _set_scope_for_type_definition, + }, + parameters={"errors": errors, "scope": symbol_tables}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.RuntimeParameter], + _add_parameter_name_to_scope, + incidental_actions={ + ir_data.Module: _set_scope_for_module, + ir_data.TypeDefinition: _set_scope_for_type_definition, + }, + parameters={"errors": errors, "scope": symbol_tables}, + ) return symbol_tables, errors - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.EnumValue], _add_enum_value_to_scope, - incidental_actions={ - ir_data.Module: _set_scope_for_module, - ir_data.TypeDefinition: _set_scope_for_type_definition, - }, - parameters={"errors": errors, "scope": symbol_tables}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Field], _add_struct_field_to_scope, - incidental_actions={ - ir_data.Module: _set_scope_for_module, - ir_data.TypeDefinition: _set_scope_for_type_definition, - }, - parameters={"errors": errors, "scope": symbol_tables}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.RuntimeParameter], _add_parameter_name_to_scope, - incidental_actions={ - ir_data.Module: _set_scope_for_module, - ir_data.TypeDefinition: _set_scope_for_type_definition, - }, - parameters={"errors": errors, "scope": symbol_tables}) - return symbol_tables, errors - def _add_alias_to_scope(name_ir, table, scope, alias, visibility, errors): - """Adds the given name to the scope as an alias.""" - name = name_ir.text - new_scope = _Scope(_nested_name(scope, name), name_ir.source_location, - visibility, alias) - scoped_table = table[scope.module_file] - for path_element in scope.object_path: - scoped_table = scoped_table[path_element] - if name in scoped_table: - errors.append(duplicate_name_error( - scoped_table.canonical_name.module_file, name_ir.source_location, name, - FileLocation(scoped_table[name].canonical_name.module_file, - scoped_table[name].source_location))) - else: - scoped_table[name] = new_scope - return new_scope - - -def _resolve_head_of_field_reference(field_reference, table, current_scope, - visible_scopes, source_file_name, errors): - return _resolve_reference( - field_reference.path[0], table, current_scope, - visible_scopes, source_file_name, errors) - - -def _resolve_reference(reference, table, current_scope, visible_scopes, - source_file_name, errors): - """Sets the canonical name of the given reference.""" - if reference.HasField("canonical_name"): - # This reference has already been resolved by the _resolve_field_reference - # pass. - return - target = _find_target_of_reference(reference, table, current_scope, - visible_scopes, source_file_name, errors) - if target is not None: - assert not target.alias - ir_data_utils.builder(reference).canonical_name.CopyFrom(target.canonical_name) - - -def _find_target_of_reference(reference, table, current_scope, visible_scopes, - source_file_name, errors): - """Returns the resolved name of the given reference.""" - found_in_table = None - name = reference.source_name[0].text - for scope in visible_scopes: + """Adds the given name to the scope as an alias.""" + name = name_ir.text + new_scope = _Scope( + _nested_name(scope, name), name_ir.source_location, visibility, alias + ) scoped_table = table[scope.module_file] - for path_element in scope.object_path or []: - scoped_table = scoped_table[path_element] - if (name in scoped_table and - (scope == current_scope or - scoped_table[name].visibility == _Scope.SEARCHABLE)): - # Prelude is "", so explicitly check for None. - if found_in_table is not None: - # TODO(bolms): Currently, this catches the case where a module tries to - # use a name that is defined (at the same scope) in two different - # modules. It may make sense to raise duplicate_name_error whenever two - # modules define the same name (whether it is used or not), and reserve - # ambiguous_name_error for cases where a name is found in multiple - # scopes. - errors.append(ambiguous_name_error( - source_file_name, reference.source_location, name, [FileLocation( - found_in_table[name].canonical_name.module_file, - found_in_table[name].source_location), FileLocation( - scoped_table[name].canonical_name.module_file, scoped_table[ - name].source_location)])) - continue - found_in_table = scoped_table - if reference.is_local_name: - # This is a little hacky. When "is_local_name" is True, the name refers - # to a type that was defined inline. In many cases, the type should be - # found at the same scope as the field; e.g.: - # - # struct Foo: - # 0 [+1] enum bar: - # BAZ = 1 - # - # In this case, `Foo.bar` has type `Foo.Bar`. Unfortunately, things - # break down a little bit when there is an inline type in an anonymous - # `bits`: - # - # struct Foo: - # 0 [+1] bits: - # 0 [+7] enum bar: - # BAZ = 1 - # - # Types inside of anonymous `bits` are hoisted into their parent type, - # so instead of `Foo.EmbossReservedAnonymous1.Bar`, `bar`'s type is just - # `Foo.Bar`. Unfortunately, the field is still - # `Foo.EmbossReservedAnonymous1.bar`, so `bar`'s type won't be found in - # `bar`'s `current_scope`. - # - # (The name `bar` is exposed from `Foo` as an alias virtual field, so - # perhaps the correct answer is to allow type aliases, so that `Bar` can - # be found in both `Foo` and `Foo.EmbossReservedAnonymous1`. That would - # involve an entirely new feature, though.) - # - # The workaround here is to search scopes from the innermost outward, - # and just stop as soon as a match is found. This isn't ideal, because - # it relies on other bits of the front end having correctly added the - # inline type to the correct scope before symbol resolution, but it does - # work. Names with False `is_local_name` will still be checked for - # ambiguity. - break - if found_in_table is None: - errors.append(missing_name_error( - source_file_name, reference.source_name[0].source_location, name)) - if not errors: - for subname in reference.source_name: - if subname.text not in found_in_table: - errors.append(missing_name_error(source_file_name, - subname.source_location, subname.text)) - return None - found_in_table = found_in_table[subname.text] - while found_in_table.alias: - referenced_table = table - for name in found_in_table.alias: - referenced_table = referenced_table[name] - # TODO(bolms): This section should really be a recursive lookup - # function, which would be able to handle arbitrary aliases through - # other aliases. - # - # This should be fine for now, since the only aliases here should be - # imports, which can't refer to other imports. - assert not referenced_table.alias, "Alias found to contain alias." - found_in_table = referenced_table - return found_in_table - return None + for path_element in scope.object_path: + scoped_table = scoped_table[path_element] + if name in scoped_table: + errors.append( + duplicate_name_error( + scoped_table.canonical_name.module_file, + name_ir.source_location, + name, + FileLocation( + scoped_table[name].canonical_name.module_file, + scoped_table[name].source_location, + ), + ) + ) + else: + scoped_table[name] = new_scope + return new_scope + + +def _resolve_head_of_field_reference( + field_reference, table, current_scope, visible_scopes, source_file_name, errors +): + return _resolve_reference( + field_reference.path[0], + table, + current_scope, + visible_scopes, + source_file_name, + errors, + ) + + +def _resolve_reference( + reference, table, current_scope, visible_scopes, source_file_name, errors +): + """Sets the canonical name of the given reference.""" + if reference.HasField("canonical_name"): + # This reference has already been resolved by the _resolve_field_reference + # pass. + return + target = _find_target_of_reference( + reference, table, current_scope, visible_scopes, source_file_name, errors + ) + if target is not None: + assert not target.alias + ir_data_utils.builder(reference).canonical_name.CopyFrom(target.canonical_name) + + +def _find_target_of_reference( + reference, table, current_scope, visible_scopes, source_file_name, errors +): + """Returns the resolved name of the given reference.""" + found_in_table = None + name = reference.source_name[0].text + for scope in visible_scopes: + scoped_table = table[scope.module_file] + for path_element in scope.object_path or []: + scoped_table = scoped_table[path_element] + if name in scoped_table and ( + scope == current_scope or scoped_table[name].visibility == _Scope.SEARCHABLE + ): + # Prelude is "", so explicitly check for None. + if found_in_table is not None: + # TODO(bolms): Currently, this catches the case where a module tries to + # use a name that is defined (at the same scope) in two different + # modules. It may make sense to raise duplicate_name_error whenever two + # modules define the same name (whether it is used or not), and reserve + # ambiguous_name_error for cases where a name is found in multiple + # scopes. + errors.append( + ambiguous_name_error( + source_file_name, + reference.source_location, + name, + [ + FileLocation( + found_in_table[name].canonical_name.module_file, + found_in_table[name].source_location, + ), + FileLocation( + scoped_table[name].canonical_name.module_file, + scoped_table[name].source_location, + ), + ], + ) + ) + continue + found_in_table = scoped_table + if reference.is_local_name: + # This is a little hacky. When "is_local_name" is True, the name refers + # to a type that was defined inline. In many cases, the type should be + # found at the same scope as the field; e.g.: + # + # struct Foo: + # 0 [+1] enum bar: + # BAZ = 1 + # + # In this case, `Foo.bar` has type `Foo.Bar`. Unfortunately, things + # break down a little bit when there is an inline type in an anonymous + # `bits`: + # + # struct Foo: + # 0 [+1] bits: + # 0 [+7] enum bar: + # BAZ = 1 + # + # Types inside of anonymous `bits` are hoisted into their parent type, + # so instead of `Foo.EmbossReservedAnonymous1.Bar`, `bar`'s type is just + # `Foo.Bar`. Unfortunately, the field is still + # `Foo.EmbossReservedAnonymous1.bar`, so `bar`'s type won't be found in + # `bar`'s `current_scope`. + # + # (The name `bar` is exposed from `Foo` as an alias virtual field, so + # perhaps the correct answer is to allow type aliases, so that `Bar` can + # be found in both `Foo` and `Foo.EmbossReservedAnonymous1`. That would + # involve an entirely new feature, though.) + # + # The workaround here is to search scopes from the innermost outward, + # and just stop as soon as a match is found. This isn't ideal, because + # it relies on other bits of the front end having correctly added the + # inline type to the correct scope before symbol resolution, but it does + # work. Names with False `is_local_name` will still be checked for + # ambiguity. + break + if found_in_table is None: + errors.append( + missing_name_error( + source_file_name, reference.source_name[0].source_location, name + ) + ) + if not errors: + for subname in reference.source_name: + if subname.text not in found_in_table: + errors.append( + missing_name_error( + source_file_name, subname.source_location, subname.text + ) + ) + return None + found_in_table = found_in_table[subname.text] + while found_in_table.alias: + referenced_table = table + for name in found_in_table.alias: + referenced_table = referenced_table[name] + # TODO(bolms): This section should really be a recursive lookup + # function, which would be able to handle arbitrary aliases through + # other aliases. + # + # This should be fine for now, since the only aliases here should be + # imports, which can't refer to other imports. + assert not referenced_table.alias, "Alias found to contain alias." + found_in_table = referenced_table + return found_in_table + return None def _resolve_field_reference(field_reference, source_file_name, errors, ir): - """Resolves the References inside of a FieldReference.""" - if field_reference.path[-1].HasField("canonical_name"): - # Already done. - return - previous_field = ir_util.find_object_or_none(field_reference.path[0], ir) - previous_reference = field_reference.path[0] - for ref in field_reference.path[1:]: - while ir_util.field_is_virtual(previous_field): - if (previous_field.read_transform.WhichOneof("expression") == - "field_reference"): - # Pass a separate error list into the recursive _resolve_field_reference - # call so that only one copy of the error for a particular reference - # will actually surface: in particular, the one that results from a - # direct call from traverse_ir_top_down into _resolve_field_reference. - new_errors = [] - _resolve_field_reference( - previous_field.read_transform.field_reference, - previous_field.name.canonical_name.module_file, new_errors, ir) - # If the recursive _resolve_field_reference was unable to resolve the - # field, then bail. Otherwise we get a cascade of errors, where an - # error in `x` leads to errors in anything trying to reach a member of - # `x`. - if not previous_field.read_transform.field_reference.path[-1].HasField( - "canonical_name"): - return - previous_field = ir_util.find_object( - previous_field.read_transform.field_reference.path[-1], ir) - else: - errors.append( - noncomposite_subfield_error(source_file_name, - previous_reference.source_location, - previous_reference.source_name[0].text)) + """Resolves the References inside of a FieldReference.""" + if field_reference.path[-1].HasField("canonical_name"): + # Already done. return - if previous_field.type.WhichOneof("type") == "array_type": - errors.append( - array_subfield_error(source_file_name, - previous_reference.source_location, - previous_reference.source_name[0].text)) - return - assert previous_field.type.WhichOneof("type") == "atomic_type" - member_name = ir_data_utils.copy( - previous_field.type.atomic_type.reference.canonical_name) - ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text]) - previous_field = ir_util.find_object_or_none(member_name, ir) - if previous_field is None: - errors.append( - missing_name_error(source_file_name, - ref.source_name[0].source_location, - ref.source_name[0].text)) - return - ir_data_utils.builder(ref).canonical_name.CopyFrom(member_name) - previous_reference = ref + previous_field = ir_util.find_object_or_none(field_reference.path[0], ir) + previous_reference = field_reference.path[0] + for ref in field_reference.path[1:]: + while ir_util.field_is_virtual(previous_field): + if ( + previous_field.read_transform.WhichOneof("expression") + == "field_reference" + ): + # Pass a separate error list into the recursive _resolve_field_reference + # call so that only one copy of the error for a particular reference + # will actually surface: in particular, the one that results from a + # direct call from traverse_ir_top_down into _resolve_field_reference. + new_errors = [] + _resolve_field_reference( + previous_field.read_transform.field_reference, + previous_field.name.canonical_name.module_file, + new_errors, + ir, + ) + # If the recursive _resolve_field_reference was unable to resolve the + # field, then bail. Otherwise we get a cascade of errors, where an + # error in `x` leads to errors in anything trying to reach a member of + # `x`. + if not previous_field.read_transform.field_reference.path[-1].HasField( + "canonical_name" + ): + return + previous_field = ir_util.find_object( + previous_field.read_transform.field_reference.path[-1], ir + ) + else: + errors.append( + noncomposite_subfield_error( + source_file_name, + previous_reference.source_location, + previous_reference.source_name[0].text, + ) + ) + return + if previous_field.type.WhichOneof("type") == "array_type": + errors.append( + array_subfield_error( + source_file_name, + previous_reference.source_location, + previous_reference.source_name[0].text, + ) + ) + return + assert previous_field.type.WhichOneof("type") == "atomic_type" + member_name = ir_data_utils.copy( + previous_field.type.atomic_type.reference.canonical_name + ) + ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text]) + previous_field = ir_util.find_object_or_none(member_name, ir) + if previous_field is None: + errors.append( + missing_name_error( + source_file_name, + ref.source_name[0].source_location, + ref.source_name[0].text, + ) + ) + return + ir_data_utils.builder(ref).canonical_name.CopyFrom(member_name) + previous_reference = ref def _set_visible_scopes_for_type_definition(type_definition, visible_scopes): - """Sets current_scope and visible_scopes for the given type_definition.""" - return { - "current_scope": type_definition.name.canonical_name, - - # In order to ensure that the iteration through scopes in - # _find_target_of_reference will go from innermost to outermost, it is - # important that the current scope (type_definition.name.canonical_name) - # precedes the previous visible_scopes here. - "visible_scopes": (type_definition.name.canonical_name,) + visible_scopes, - } + """Sets current_scope and visible_scopes for the given type_definition.""" + return { + "current_scope": type_definition.name.canonical_name, + # In order to ensure that the iteration through scopes in + # _find_target_of_reference will go from innermost to outermost, it is + # important that the current scope (type_definition.name.canonical_name) + # precedes the previous visible_scopes here. + "visible_scopes": (type_definition.name.canonical_name,) + visible_scopes, + } def _set_visible_scopes_for_module(module): - """Sets visible_scopes for the given module.""" - self_scope = ir_data.CanonicalName(module_file=module.source_file_name) - extra_visible_scopes = [] - for foreign_import in module.foreign_import: - # Anonymous imports are searched for top-level names; named imports are not. - # As of right now, only the prelude should be imported anonymously; other - # modules must be imported with names. - if not foreign_import.local_name.text: - extra_visible_scopes.append( - ir_data.CanonicalName(module_file=foreign_import.file_name.text)) - return {"visible_scopes": (self_scope,) + tuple(extra_visible_scopes)} + """Sets visible_scopes for the given module.""" + self_scope = ir_data.CanonicalName(module_file=module.source_file_name) + extra_visible_scopes = [] + for foreign_import in module.foreign_import: + # Anonymous imports are searched for top-level names; named imports are not. + # As of right now, only the prelude should be imported anonymously; other + # modules must be imported with names. + if not foreign_import.local_name.text: + extra_visible_scopes.append( + ir_data.CanonicalName(module_file=foreign_import.file_name.text) + ) + return {"visible_scopes": (self_scope,) + tuple(extra_visible_scopes)} def _set_visible_scopes_for_attribute(attribute, field, visible_scopes): - """Sets current_scope and visible_scopes for the attribute.""" - del attribute # Unused - if field is None: - return - return { - "current_scope": field.name.canonical_name, - "visible_scopes": (field.name.canonical_name,) + visible_scopes, - } + """Sets current_scope and visible_scopes for the attribute.""" + del attribute # Unused + if field is None: + return + return { + "current_scope": field.name.canonical_name, + "visible_scopes": (field.name.canonical_name,) + visible_scopes, + } + def _module_source_from_table_action(m, table): - return {"module": table[m.source_file_name]} + return {"module": table[m.source_file_name]} + def _resolve_symbols_from_table(ir, table): - """Resolves all references in the given IR, given the constructed table.""" - errors = [] - # Symbol resolution is broken into five passes. First, this code resolves any - # imports, and adds import aliases to modules. - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Import], _add_import_to_scope, - incidental_actions={ - ir_data.Module: _module_source_from_table_action, - }, - parameters={"errors": errors, "table": table}) - if errors: + """Resolves all references in the given IR, given the constructed table.""" + errors = [] + # Symbol resolution is broken into five passes. First, this code resolves any + # imports, and adds import aliases to modules. + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Import], + _add_import_to_scope, + incidental_actions={ + ir_data.Module: _module_source_from_table_action, + }, + parameters={"errors": errors, "table": table}, + ) + if errors: + return errors + # Next, this resolves all absolute references (e.g., it resolves "UInt" in + # "0:1 UInt field" to [prelude]::UInt). + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Reference], + _resolve_reference, + skip_descendants_of=(ir_data.FieldReference,), + incidental_actions={ + ir_data.TypeDefinition: _set_visible_scopes_for_type_definition, + ir_data.Module: _set_visible_scopes_for_module, + ir_data.Attribute: _set_visible_scopes_for_attribute, + }, + parameters={"table": table, "errors": errors, "field": None}, + ) + # Lastly, head References to fields (e.g., the `a` of `a.b.c`) are resolved. + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.FieldReference], + _resolve_head_of_field_reference, + incidental_actions={ + ir_data.TypeDefinition: _set_visible_scopes_for_type_definition, + ir_data.Module: _set_visible_scopes_for_module, + ir_data.Attribute: _set_visible_scopes_for_attribute, + }, + parameters={"table": table, "errors": errors, "field": None}, + ) return errors - # Next, this resolves all absolute references (e.g., it resolves "UInt" in - # "0:1 UInt field" to [prelude]::UInt). - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Reference], _resolve_reference, - skip_descendants_of=(ir_data.FieldReference,), - incidental_actions={ - ir_data.TypeDefinition: _set_visible_scopes_for_type_definition, - ir_data.Module: _set_visible_scopes_for_module, - ir_data.Attribute: _set_visible_scopes_for_attribute, - }, - parameters={"table": table, "errors": errors, "field": None}) - # Lastly, head References to fields (e.g., the `a` of `a.b.c`) are resolved. - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.FieldReference], _resolve_head_of_field_reference, - incidental_actions={ - ir_data.TypeDefinition: _set_visible_scopes_for_type_definition, - ir_data.Module: _set_visible_scopes_for_module, - ir_data.Attribute: _set_visible_scopes_for_attribute, - }, - parameters={"table": table, "errors": errors, "field": None}) - return errors def resolve_field_references(ir): - """Resolves structure member accesses ("field.subfield") in ir.""" - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.FieldReference], _resolve_field_reference, - incidental_actions={ - ir_data.TypeDefinition: _set_visible_scopes_for_type_definition, - ir_data.Module: _set_visible_scopes_for_module, - ir_data.Attribute: _set_visible_scopes_for_attribute, - }, - parameters={"errors": errors, "field": None}) - return errors + """Resolves structure member accesses ("field.subfield") in ir.""" + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.FieldReference], + _resolve_field_reference, + incidental_actions={ + ir_data.TypeDefinition: _set_visible_scopes_for_type_definition, + ir_data.Module: _set_visible_scopes_for_module, + ir_data.Attribute: _set_visible_scopes_for_attribute, + }, + parameters={"errors": errors, "field": None}, + ) + return errors def resolve_symbols(ir): - """Resolves the symbols in all modules in ir.""" - symbol_tables, errors = _construct_symbol_tables(ir) - if errors: - return errors - return _resolve_symbols_from_table(ir, symbol_tables) + """Resolves the symbols in all modules in ir.""" + symbol_tables, errors = _construct_symbol_tables(ir) + if errors: + return errors + return _resolve_symbols_from_table(ir, symbol_tables) diff --git a/compiler/front_end/symbol_resolver_test.py b/compiler/front_end/symbol_resolver_test.py index deaf1a0..693d157 100644 --- a/compiler/front_end/symbol_resolver_test.py +++ b/compiler/front_end/symbol_resolver_test.py @@ -51,698 +51,969 @@ class ResolveSymbolsTest(unittest.TestCase): - """Tests for symbol_resolver.resolve_symbols().""" - - def _construct_ir_multiple(self, file_dict, primary_emb_name): - ir, unused_debug_info, errors = glue.parse_emboss_file( - primary_emb_name, - test_util.dict_file_reader(file_dict), - stop_before_step="resolve_symbols") - assert not errors - return ir - - def _construct_ir(self, emb_text, name="happy.emb"): - return self._construct_ir_multiple({name: emb_text}, name) - - def test_struct_field_atomic_type_resolution(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - struct_ir = ir.module[0].type[0].structure - atomic_field1_reference = struct_ir.field[0].type.atomic_type.reference - self.assertEqual(atomic_field1_reference.canonical_name.object_path, ["UInt" - ]) - self.assertEqual(atomic_field1_reference.canonical_name.module_file, "") - atomic_field2_reference = struct_ir.field[1].type.atomic_type.reference - self.assertEqual(atomic_field2_reference.canonical_name.object_path, ["Bar" - ]) - self.assertEqual(atomic_field2_reference.canonical_name.module_file, - "happy.emb") - - def test_struct_field_enum_type_resolution(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - struct_ir = ir.module[0].type[1].structure - atomic_field_reference = struct_ir.field[0].type.atomic_type.reference - self.assertEqual(atomic_field_reference.canonical_name.object_path, ["Qux"]) - self.assertEqual(atomic_field_reference.canonical_name.module_file, - "happy.emb") - - def test_struct_field_array_type_resolution(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - array_field_type = ir.module[0].type[0].structure.field[2].type.array_type - array_field_reference = array_field_type.base_type.atomic_type.reference - self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"]) - self.assertEqual(array_field_reference.canonical_name.module_file, "") - - def test_inner_type_resolution(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - array_field_type = ir.module[0].type[0].structure.field[2].type.array_type - array_field_reference = array_field_type.base_type.atomic_type.reference - self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"]) - self.assertEqual(array_field_reference.canonical_name.module_file, "") - - def test_struct_field_resolution_in_expression_in_location(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - struct_ir = ir.module[0].type[3].structure - field0_loc = struct_ir.field[0].location - abbreviation_reference = field0_loc.size.field_reference.path[0] - self.assertEqual(abbreviation_reference.canonical_name.object_path, - ["FieldRef", "offset"]) - self.assertEqual(abbreviation_reference.canonical_name.module_file, - "happy.emb") - field0_start_left = field0_loc.start.function.args[0] - nested_abbreviation_reference = field0_start_left.field_reference.path[0] - self.assertEqual(nested_abbreviation_reference.canonical_name.object_path, - ["FieldRef", "offset"]) - self.assertEqual(nested_abbreviation_reference.canonical_name.module_file, - "happy.emb") - field1_loc = struct_ir.field[1].location - direct_reference = field1_loc.size.field_reference.path[0] - self.assertEqual(direct_reference.canonical_name.object_path, ["FieldRef", - "offset"]) - self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb") - field1_start_left = field1_loc.start.function.args[0] - nested_direct_reference = field1_start_left.field_reference.path[0] - self.assertEqual(nested_direct_reference.canonical_name.object_path, - ["FieldRef", "offset"]) - self.assertEqual(nested_direct_reference.canonical_name.module_file, - "happy.emb") - - def test_struct_field_resolution_in_expression_in_array_length(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - struct_ir = ir.module[0].type[3].structure - field0_array_type = struct_ir.field[0].type.array_type - field0_array_element_count = field0_array_type.element_count - abbreviation_reference = field0_array_element_count.field_reference.path[0] - self.assertEqual(abbreviation_reference.canonical_name.object_path, - ["FieldRef", "offset"]) - self.assertEqual(abbreviation_reference.canonical_name.module_file, - "happy.emb") - field1_array_type = struct_ir.field[1].type.array_type - direct_reference = field1_array_type.element_count.field_reference.path[0] - self.assertEqual(direct_reference.canonical_name.object_path, ["FieldRef", - "offset"]) - self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb") - - def test_struct_parameter_resolution(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - struct_ir = ir.module[0].type[6].structure - size_ir = struct_ir.field[0].location.size - self.assertTrue(size_ir.HasField("field_reference")) - self.assertEqual(size_ir.field_reference.path[0].canonical_name.object_path, - ["UsesParameter", "x"]) - - def test_enum_value_resolution_in_expression_in_enum_field(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - enum_ir = ir.module[0].type[5].enumeration - value_reference = enum_ir.value[1].value.constant_reference - self.assertEqual(value_reference.canonical_name.object_path, - ["Quux", "ABC"]) - self.assertEqual(value_reference.canonical_name.module_file, "happy.emb") - - def test_symbol_resolution_in_expression_in_void_array_length(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - struct_ir = ir.module[0].type[4].structure - array_type = struct_ir.field[0].type.array_type - # The symbol resolver should ignore void fields. - self.assertEqual("automatic", array_type.WhichOneof("size")) - - def test_name_definitions_have_correct_canonical_names(self): - ir = self._construct_ir(_HAPPY_EMB) - self.assertEqual([], symbol_resolver.resolve_symbols(ir)) - foo_name = ir.module[0].type[0].name - self.assertEqual(foo_name.canonical_name.object_path, ["Foo"]) - self.assertEqual(foo_name.canonical_name.module_file, "happy.emb") - uint_field_name = ir.module[0].type[0].structure.field[0].name - self.assertEqual(uint_field_name.canonical_name.object_path, ["Foo", - "uint_field"]) - self.assertEqual(uint_field_name.canonical_name.module_file, "happy.emb") - foo_name = ir.module[0].type[2].name - self.assertEqual(foo_name.canonical_name.object_path, ["Qux"]) - self.assertEqual(foo_name.canonical_name.module_file, "happy.emb") - - def test_duplicate_type_name(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+4] UInt field\n" - "struct Foo:\n" - " 0 [+4] UInt bar\n", "duplicate_type.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - self.assertEqual([ - [error.error("duplicate_type.emb", - ir.module[0].type[1].name.source_location, - "Duplicate name 'Foo'"), - error.note("duplicate_type.emb", - ir.module[0].type[0].name.source_location, - "Original definition")] - ], errors) - - def test_duplicate_field_name_in_struct(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+4] UInt field\n" - " 4 [+4] UInt field\n", "duplicate_field.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("duplicate_field.emb", - struct.field[1].name.source_location, - "Duplicate name 'field'"), - error.note("duplicate_field.emb", - struct.field[0].name.source_location, - "Original definition") - ]], errors) - - def test_duplicate_abbreviation_in_struct(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+4] UInt field1 (f)\n" - " 4 [+4] UInt field2 (f)\n", - "duplicate_field.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("duplicate_field.emb", - struct.field[1].abbreviation.source_location, - "Duplicate name 'f'"), - error.note("duplicate_field.emb", - struct.field[0].abbreviation.source_location, - "Original definition") - ]], errors) - - def test_abbreviation_duplicates_field_name_in_struct(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+4] UInt field\n" - " 4 [+4] UInt field2 (field)\n", - "duplicate_field.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("duplicate_field.emb", - struct.field[1].abbreviation.source_location, - "Duplicate name 'field'"), - error.note("duplicate_field.emb", - struct.field[0].name.source_location, - "Original definition") - ]], errors) - - def test_field_name_duplicates_abbreviation_in_struct(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+4] UInt field (field2)\n" - " 4 [+4] UInt field2\n", "duplicate_field.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("duplicate_field.emb", - struct.field[1].name.source_location, - "Duplicate name 'field2'"), - error.note("duplicate_field.emb", - struct.field[0].abbreviation.source_location, - "Original definition") - ]], errors) - - def test_duplicate_value_name_in_enum(self): - ir = self._construct_ir("enum Foo:\n" - " BAR = 1\n" - " BAR = 1\n", "duplicate_enum.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - self.assertEqual([[ - error.error( - "duplicate_enum.emb", - ir.module[0].type[0].enumeration.value[1].name.source_location, - "Duplicate name 'BAR'"), - error.note( - "duplicate_enum.emb", - ir.module[0].type[0].enumeration.value[0].name.source_location, - "Original definition") - ]], errors) - - def test_ambiguous_name(self): - # struct UInt will be ambiguous with the external UInt in the prelude. - ir = self._construct_ir("struct UInt:\n" - " 0 [+4] Int:8[4] field\n" - "struct Foo:\n" - " 0 [+4] UInt bar\n", "ambiguous.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - # Find the UInt definition in the prelude. - for type_ir in ir.module[1].type: - if type_ir.name.name.text == "UInt": - prelude_uint = type_ir - break - ambiguous_type_ir = ir.module[0].type[1].structure.field[0].type.atomic_type - self.assertEqual([[ - error.error("ambiguous.emb", - ambiguous_type_ir.reference.source_name[0].source_location, - "Ambiguous name 'UInt'"), error.note( - "", prelude_uint.name.source_location, - "Possible resolution"), - error.note("ambiguous.emb", ir.module[0].type[0].name.source_location, - "Possible resolution") - ]], errors) - - def test_missing_name(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+4] Bar field\n", - "missing.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - missing_type_ir = ir.module[0].type[0].structure.field[0].type.atomic_type - self.assertEqual([ - [error.error("missing.emb", - missing_type_ir.reference.source_name[0].source_location, - "No candidate for 'Bar'")] - ], errors) - - def test_missing_leading_name(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+Num.FOUR] UInt field\n", "missing.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size - self.assertEqual([ - [error.error( + """Tests for symbol_resolver.resolve_symbols().""" + + def _construct_ir_multiple(self, file_dict, primary_emb_name): + ir, unused_debug_info, errors = glue.parse_emboss_file( + primary_emb_name, + test_util.dict_file_reader(file_dict), + stop_before_step="resolve_symbols", + ) + assert not errors + return ir + + def _construct_ir(self, emb_text, name="happy.emb"): + return self._construct_ir_multiple({name: emb_text}, name) + + def test_struct_field_atomic_type_resolution(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + struct_ir = ir.module[0].type[0].structure + atomic_field1_reference = struct_ir.field[0].type.atomic_type.reference + self.assertEqual(atomic_field1_reference.canonical_name.object_path, ["UInt"]) + self.assertEqual(atomic_field1_reference.canonical_name.module_file, "") + atomic_field2_reference = struct_ir.field[1].type.atomic_type.reference + self.assertEqual(atomic_field2_reference.canonical_name.object_path, ["Bar"]) + self.assertEqual( + atomic_field2_reference.canonical_name.module_file, "happy.emb" + ) + + def test_struct_field_enum_type_resolution(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + struct_ir = ir.module[0].type[1].structure + atomic_field_reference = struct_ir.field[0].type.atomic_type.reference + self.assertEqual(atomic_field_reference.canonical_name.object_path, ["Qux"]) + self.assertEqual(atomic_field_reference.canonical_name.module_file, "happy.emb") + + def test_struct_field_array_type_resolution(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + array_field_type = ir.module[0].type[0].structure.field[2].type.array_type + array_field_reference = array_field_type.base_type.atomic_type.reference + self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"]) + self.assertEqual(array_field_reference.canonical_name.module_file, "") + + def test_inner_type_resolution(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + array_field_type = ir.module[0].type[0].structure.field[2].type.array_type + array_field_reference = array_field_type.base_type.atomic_type.reference + self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"]) + self.assertEqual(array_field_reference.canonical_name.module_file, "") + + def test_struct_field_resolution_in_expression_in_location(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + struct_ir = ir.module[0].type[3].structure + field0_loc = struct_ir.field[0].location + abbreviation_reference = field0_loc.size.field_reference.path[0] + self.assertEqual( + abbreviation_reference.canonical_name.object_path, ["FieldRef", "offset"] + ) + self.assertEqual(abbreviation_reference.canonical_name.module_file, "happy.emb") + field0_start_left = field0_loc.start.function.args[0] + nested_abbreviation_reference = field0_start_left.field_reference.path[0] + self.assertEqual( + nested_abbreviation_reference.canonical_name.object_path, + ["FieldRef", "offset"], + ) + self.assertEqual( + nested_abbreviation_reference.canonical_name.module_file, "happy.emb" + ) + field1_loc = struct_ir.field[1].location + direct_reference = field1_loc.size.field_reference.path[0] + self.assertEqual( + direct_reference.canonical_name.object_path, ["FieldRef", "offset"] + ) + self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb") + field1_start_left = field1_loc.start.function.args[0] + nested_direct_reference = field1_start_left.field_reference.path[0] + self.assertEqual( + nested_direct_reference.canonical_name.object_path, ["FieldRef", "offset"] + ) + self.assertEqual( + nested_direct_reference.canonical_name.module_file, "happy.emb" + ) + + def test_struct_field_resolution_in_expression_in_array_length(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + struct_ir = ir.module[0].type[3].structure + field0_array_type = struct_ir.field[0].type.array_type + field0_array_element_count = field0_array_type.element_count + abbreviation_reference = field0_array_element_count.field_reference.path[0] + self.assertEqual( + abbreviation_reference.canonical_name.object_path, ["FieldRef", "offset"] + ) + self.assertEqual(abbreviation_reference.canonical_name.module_file, "happy.emb") + field1_array_type = struct_ir.field[1].type.array_type + direct_reference = field1_array_type.element_count.field_reference.path[0] + self.assertEqual( + direct_reference.canonical_name.object_path, ["FieldRef", "offset"] + ) + self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb") + + def test_struct_parameter_resolution(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + struct_ir = ir.module[0].type[6].structure + size_ir = struct_ir.field[0].location.size + self.assertTrue(size_ir.HasField("field_reference")) + self.assertEqual( + size_ir.field_reference.path[0].canonical_name.object_path, + ["UsesParameter", "x"], + ) + + def test_enum_value_resolution_in_expression_in_enum_field(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + enum_ir = ir.module[0].type[5].enumeration + value_reference = enum_ir.value[1].value.constant_reference + self.assertEqual(value_reference.canonical_name.object_path, ["Quux", "ABC"]) + self.assertEqual(value_reference.canonical_name.module_file, "happy.emb") + + def test_symbol_resolution_in_expression_in_void_array_length(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + struct_ir = ir.module[0].type[4].structure + array_type = struct_ir.field[0].type.array_type + # The symbol resolver should ignore void fields. + self.assertEqual("automatic", array_type.WhichOneof("size")) + + def test_name_definitions_have_correct_canonical_names(self): + ir = self._construct_ir(_HAPPY_EMB) + self.assertEqual([], symbol_resolver.resolve_symbols(ir)) + foo_name = ir.module[0].type[0].name + self.assertEqual(foo_name.canonical_name.object_path, ["Foo"]) + self.assertEqual(foo_name.canonical_name.module_file, "happy.emb") + uint_field_name = ir.module[0].type[0].structure.field[0].name + self.assertEqual( + uint_field_name.canonical_name.object_path, ["Foo", "uint_field"] + ) + self.assertEqual(uint_field_name.canonical_name.module_file, "happy.emb") + foo_name = ir.module[0].type[2].name + self.assertEqual(foo_name.canonical_name.object_path, ["Qux"]) + self.assertEqual(foo_name.canonical_name.module_file, "happy.emb") + + def test_duplicate_type_name(self): + ir = self._construct_ir( + "struct Foo:\n" + " 0 [+4] UInt field\n" + "struct Foo:\n" + " 0 [+4] UInt bar\n", + "duplicate_type.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + self.assertEqual( + [ + [ + error.error( + "duplicate_type.emb", + ir.module[0].type[1].name.source_location, + "Duplicate name 'Foo'", + ), + error.note( + "duplicate_type.emb", + ir.module[0].type[0].name.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_duplicate_field_name_in_struct(self): + ir = self._construct_ir( + "struct Foo:\n" " 0 [+4] UInt field\n" " 4 [+4] UInt field\n", + "duplicate_field.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "duplicate_field.emb", + struct.field[1].name.source_location, + "Duplicate name 'field'", + ), + error.note( + "duplicate_field.emb", + struct.field[0].name.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_duplicate_abbreviation_in_struct(self): + ir = self._construct_ir( + "struct Foo:\n" + " 0 [+4] UInt field1 (f)\n" + " 4 [+4] UInt field2 (f)\n", + "duplicate_field.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "duplicate_field.emb", + struct.field[1].abbreviation.source_location, + "Duplicate name 'f'", + ), + error.note( + "duplicate_field.emb", + struct.field[0].abbreviation.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_abbreviation_duplicates_field_name_in_struct(self): + ir = self._construct_ir( + "struct Foo:\n" + " 0 [+4] UInt field\n" + " 4 [+4] UInt field2 (field)\n", + "duplicate_field.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "duplicate_field.emb", + struct.field[1].abbreviation.source_location, + "Duplicate name 'field'", + ), + error.note( + "duplicate_field.emb", + struct.field[0].name.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_field_name_duplicates_abbreviation_in_struct(self): + ir = self._construct_ir( + "struct Foo:\n" + " 0 [+4] UInt field (field2)\n" + " 4 [+4] UInt field2\n", + "duplicate_field.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "duplicate_field.emb", + struct.field[1].name.source_location, + "Duplicate name 'field2'", + ), + error.note( + "duplicate_field.emb", + struct.field[0].abbreviation.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_duplicate_value_name_in_enum(self): + ir = self._construct_ir( + "enum Foo:\n" " BAR = 1\n" " BAR = 1\n", "duplicate_enum.emb" + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + self.assertEqual( + [ + [ + error.error( + "duplicate_enum.emb", + ir.module[0].type[0].enumeration.value[1].name.source_location, + "Duplicate name 'BAR'", + ), + error.note( + "duplicate_enum.emb", + ir.module[0].type[0].enumeration.value[0].name.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_ambiguous_name(self): + # struct UInt will be ambiguous with the external UInt in the prelude. + ir = self._construct_ir( + "struct UInt:\n" + " 0 [+4] Int:8[4] field\n" + "struct Foo:\n" + " 0 [+4] UInt bar\n", + "ambiguous.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + # Find the UInt definition in the prelude. + for type_ir in ir.module[1].type: + if type_ir.name.name.text == "UInt": + prelude_uint = type_ir + break + ambiguous_type_ir = ir.module[0].type[1].structure.field[0].type.atomic_type + self.assertEqual( + [ + [ + error.error( + "ambiguous.emb", + ambiguous_type_ir.reference.source_name[0].source_location, + "Ambiguous name 'UInt'", + ), + error.note( + "", prelude_uint.name.source_location, "Possible resolution" + ), + error.note( + "ambiguous.emb", + ir.module[0].type[0].name.source_location, + "Possible resolution", + ), + ] + ], + errors, + ) + + def test_missing_name(self): + ir = self._construct_ir("struct Foo:\n" " 0 [+4] Bar field\n", "missing.emb") + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + missing_type_ir = ir.module[0].type[0].structure.field[0].type.atomic_type + self.assertEqual( + [ + [ + error.error( + "missing.emb", + missing_type_ir.reference.source_name[0].source_location, + "No candidate for 'Bar'", + ) + ] + ], + errors, + ) + + def test_missing_leading_name(self): + ir = self._construct_ir( + "struct Foo:\n" " 0 [+Num.FOUR] UInt field\n", "missing.emb" + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size + self.assertEqual( + [ + [ + error.error( + "missing.emb", + missing_expr_ir.constant_reference.source_name[ + 0 + ].source_location, + "No candidate for 'Num'", + ) + ] + ], + errors, + ) + + def test_missing_trailing_name(self): + ir = self._construct_ir( + "struct Foo:\n" + " 0 [+Num.FOUR] UInt field\n" + "enum Num:\n" + " THREE = 3\n", "missing.emb", - missing_expr_ir.constant_reference.source_name[0].source_location, - "No candidate for 'Num'")] - ], errors) - - def test_missing_trailing_name(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+Num.FOUR] UInt field\n" - "enum Num:\n" - " THREE = 3\n", "missing.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size - self.assertEqual([ - [error.error( + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size + self.assertEqual( + [ + [ + error.error( + "missing.emb", + missing_expr_ir.constant_reference.source_name[ + 1 + ].source_location, + "No candidate for 'FOUR'", + ) + ] + ], + errors, + ) + + def test_missing_middle_name(self): + ir = self._construct_ir( + "struct Foo:\n" + " 0 [+Num.NaN.FOUR] UInt field\n" + "enum Num:\n" + " FOUR = 4\n", "missing.emb", - missing_expr_ir.constant_reference.source_name[1].source_location, - "No candidate for 'FOUR'")] - ], errors) - - def test_missing_middle_name(self): - ir = self._construct_ir("struct Foo:\n" - " 0 [+Num.NaN.FOUR] UInt field\n" - "enum Num:\n" - " FOUR = 4\n", "missing.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size - self.assertEqual([ - [error.error( - "missing.emb", - missing_expr_ir.constant_reference.source_name[1].source_location, - "No candidate for 'NaN'")] - ], errors) - - def test_inner_resolution(self): - ir = self._construct_ir( - "struct OuterStruct:\n" - "\n" - " struct InnerStruct2:\n" - " 0 [+1] InnerStruct.InnerEnum inner_enum\n" - "\n" - " struct InnerStruct:\n" - " enum InnerEnum:\n" - " ONE = 1\n" - "\n" - " 0 [+1] InnerEnum inner_enum\n" - "\n" - " 0 [+InnerStruct.InnerEnum.ONE] InnerStruct.InnerEnum inner_enum\n", - "nested.emb") - errors = symbol_resolver.resolve_symbols(ir) - self.assertFalse(errors) - outer_struct = ir.module[0].type[0] - inner_struct = outer_struct.subtype[1] - inner_struct_2 = outer_struct.subtype[0] - inner_enum = inner_struct.subtype[0] - self.assertEqual(["OuterStruct", "InnerStruct"], - list(inner_struct.name.canonical_name.object_path)) - self.assertEqual(["OuterStruct", "InnerStruct", "InnerEnum"], - list(inner_enum.name.canonical_name.object_path)) - self.assertEqual(["OuterStruct", "InnerStruct2"], - list(inner_struct_2.name.canonical_name.object_path)) - outer_field = outer_struct.structure.field[0] - outer_field_end_ref = outer_field.location.size.constant_reference - self.assertEqual( - ["OuterStruct", "InnerStruct", "InnerEnum", "ONE"], list( - outer_field_end_ref.canonical_name.object_path)) - self.assertEqual( - ["OuterStruct", "InnerStruct", "InnerEnum"], - list(outer_field.type.atomic_type.reference.canonical_name.object_path)) - inner_field_2_type = inner_struct_2.structure.field[0].type.atomic_type - self.assertEqual( - ["OuterStruct", "InnerStruct", "InnerEnum" - ], list(inner_field_2_type.reference.canonical_name.object_path)) - - def test_resolution_against_anonymous_bits(self): - ir = self._construct_ir("struct Struct:\n" - " 0 [+1] bits:\n" - " 7 [+1] Flag last_packet\n" - " 5 [+2] enum inline_inner_enum:\n" - " AA = 0\n" - " BB = 1\n" - " CC = 2\n" - " DD = 3\n" - " 0 [+5] UInt header_size (h)\n" - " 0 [+h] UInt:8[] header_bytes\n" - "\n" - "struct Struct2:\n" - " 0 [+1] Struct.InlineInnerEnum value\n", - "anonymity.emb") - errors = symbol_resolver.resolve_symbols(ir) - self.assertFalse(errors) - struct1 = ir.module[0].type[0] - struct1_bits_field = struct1.structure.field[0] - struct1_bits_field_type = struct1_bits_field.type.atomic_type.reference - struct1_byte_field = struct1.structure.field[4] - inner_bits = struct1.subtype[0] - inner_enum = struct1.subtype[1] - self.assertTrue(inner_bits.HasField("structure")) - self.assertTrue(inner_enum.HasField("enumeration")) - self.assertTrue(inner_bits.name.is_anonymous) - self.assertFalse(inner_enum.name.is_anonymous) - self.assertEqual(["Struct", "InlineInnerEnum"], - list(inner_enum.name.canonical_name.object_path)) - self.assertEqual( - ["Struct", "InlineInnerEnum", "AA"], - list(inner_enum.enumeration.value[0].name.canonical_name.object_path)) - self.assertEqual( - list(inner_bits.name.canonical_name.object_path), - list(struct1_bits_field_type.canonical_name.object_path)) - self.assertEqual(2, len(inner_bits.name.canonical_name.object_path)) - self.assertEqual( - ["Struct", "header_size"], - list(struct1_byte_field.location.size.field_reference.path[0]. - canonical_name.object_path)) - - def test_duplicate_name_in_different_inline_bits(self): - ir = self._construct_ir( - "struct Struct:\n" - " 0 [+1] bits:\n" - " 7 [+1] Flag a\n" - " 1 [+1] bits:\n" - " 0 [+1] Flag a\n", "duplicate_in_anon.emb") - errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) - supertype = ir.module[0].type[0] - self.assertEqual([[ - error.error( - "duplicate_in_anon.emb", - supertype.structure.field[3].name.source_location, - "Duplicate name 'a'"), - error.note( - "duplicate_in_anon.emb", - supertype.structure.field[1].name.source_location, - "Original definition") - ]], errors) - - def test_duplicate_name_in_same_inline_bits(self): - ir = self._construct_ir( - "struct Struct:\n" - " 0 [+1] bits:\n" - " 7 [+1] Flag a\n" - " 0 [+1] Flag a\n", "duplicate_in_anon.emb") - errors = symbol_resolver.resolve_symbols(ir) - supertype = ir.module[0].type[0] - self.assertEqual([[ - error.error( + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size + self.assertEqual( + [ + [ + error.error( + "missing.emb", + missing_expr_ir.constant_reference.source_name[ + 1 + ].source_location, + "No candidate for 'NaN'", + ) + ] + ], + errors, + ) + + def test_inner_resolution(self): + ir = self._construct_ir( + "struct OuterStruct:\n" + "\n" + " struct InnerStruct2:\n" + " 0 [+1] InnerStruct.InnerEnum inner_enum\n" + "\n" + " struct InnerStruct:\n" + " enum InnerEnum:\n" + " ONE = 1\n" + "\n" + " 0 [+1] InnerEnum inner_enum\n" + "\n" + " 0 [+InnerStruct.InnerEnum.ONE] InnerStruct.InnerEnum inner_enum\n", + "nested.emb", + ) + errors = symbol_resolver.resolve_symbols(ir) + self.assertFalse(errors) + outer_struct = ir.module[0].type[0] + inner_struct = outer_struct.subtype[1] + inner_struct_2 = outer_struct.subtype[0] + inner_enum = inner_struct.subtype[0] + self.assertEqual( + ["OuterStruct", "InnerStruct"], + list(inner_struct.name.canonical_name.object_path), + ) + self.assertEqual( + ["OuterStruct", "InnerStruct", "InnerEnum"], + list(inner_enum.name.canonical_name.object_path), + ) + self.assertEqual( + ["OuterStruct", "InnerStruct2"], + list(inner_struct_2.name.canonical_name.object_path), + ) + outer_field = outer_struct.structure.field[0] + outer_field_end_ref = outer_field.location.size.constant_reference + self.assertEqual( + ["OuterStruct", "InnerStruct", "InnerEnum", "ONE"], + list(outer_field_end_ref.canonical_name.object_path), + ) + self.assertEqual( + ["OuterStruct", "InnerStruct", "InnerEnum"], + list(outer_field.type.atomic_type.reference.canonical_name.object_path), + ) + inner_field_2_type = inner_struct_2.structure.field[0].type.atomic_type + self.assertEqual( + ["OuterStruct", "InnerStruct", "InnerEnum"], + list(inner_field_2_type.reference.canonical_name.object_path), + ) + + def test_resolution_against_anonymous_bits(self): + ir = self._construct_ir( + "struct Struct:\n" + " 0 [+1] bits:\n" + " 7 [+1] Flag last_packet\n" + " 5 [+2] enum inline_inner_enum:\n" + " AA = 0\n" + " BB = 1\n" + " CC = 2\n" + " DD = 3\n" + " 0 [+5] UInt header_size (h)\n" + " 0 [+h] UInt:8[] header_bytes\n" + "\n" + "struct Struct2:\n" + " 0 [+1] Struct.InlineInnerEnum value\n", + "anonymity.emb", + ) + errors = symbol_resolver.resolve_symbols(ir) + self.assertFalse(errors) + struct1 = ir.module[0].type[0] + struct1_bits_field = struct1.structure.field[0] + struct1_bits_field_type = struct1_bits_field.type.atomic_type.reference + struct1_byte_field = struct1.structure.field[4] + inner_bits = struct1.subtype[0] + inner_enum = struct1.subtype[1] + self.assertTrue(inner_bits.HasField("structure")) + self.assertTrue(inner_enum.HasField("enumeration")) + self.assertTrue(inner_bits.name.is_anonymous) + self.assertFalse(inner_enum.name.is_anonymous) + self.assertEqual( + ["Struct", "InlineInnerEnum"], + list(inner_enum.name.canonical_name.object_path), + ) + self.assertEqual( + ["Struct", "InlineInnerEnum", "AA"], + list(inner_enum.enumeration.value[0].name.canonical_name.object_path), + ) + self.assertEqual( + list(inner_bits.name.canonical_name.object_path), + list(struct1_bits_field_type.canonical_name.object_path), + ) + self.assertEqual(2, len(inner_bits.name.canonical_name.object_path)) + self.assertEqual( + ["Struct", "header_size"], + list( + struct1_byte_field.location.size.field_reference.path[ + 0 + ].canonical_name.object_path + ), + ) + + def test_duplicate_name_in_different_inline_bits(self): + ir = self._construct_ir( + "struct Struct:\n" + " 0 [+1] bits:\n" + " 7 [+1] Flag a\n" + " 1 [+1] bits:\n" + " 0 [+1] Flag a\n", "duplicate_in_anon.emb", - supertype.structure.field[2].name.source_location, - "Duplicate name 'a'"), - error.note( + ) + errors = error.filter_errors(symbol_resolver.resolve_symbols(ir)) + supertype = ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "duplicate_in_anon.emb", + supertype.structure.field[3].name.source_location, + "Duplicate name 'a'", + ), + error.note( + "duplicate_in_anon.emb", + supertype.structure.field[1].name.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_duplicate_name_in_same_inline_bits(self): + ir = self._construct_ir( + "struct Struct:\n" + " 0 [+1] bits:\n" + " 7 [+1] Flag a\n" + " 0 [+1] Flag a\n", "duplicate_in_anon.emb", - supertype.structure.field[1].name.source_location, - "Original definition") - ]], error.filter_errors(errors)) - - def test_import_type_resolution(self): - importer = ('import "ed.emb" as ed\n' - "struct Ff:\n" - " 0 [+1] ed.Gg gg\n") - imported = ("struct Gg:\n" - " 0 [+1] UInt qq\n") - ir = self._construct_ir_multiple({"ed.emb": imported, "er.emb": importer}, - "er.emb") - errors = symbol_resolver.resolve_symbols(ir) - self.assertEqual([], errors) - - def test_duplicate_import_name(self): - importer = ('import "ed.emb" as ed\n' - 'import "ed.emb" as ed\n' - "struct Ff:\n" - " 0 [+1] ed.Gg gg\n") - imported = ("struct Gg:\n" - " 0 [+1] UInt qq\n") - ir = self._construct_ir_multiple({"ed.emb": imported, "er.emb": importer}, - "er.emb") - errors = symbol_resolver.resolve_symbols(ir) - # Note: the error is on import[2] duplicating import[1] because the implicit - # prelude import is import[0]. - self.assertEqual([ - [error.error("er.emb", - ir.module[0].foreign_import[2].local_name.source_location, - "Duplicate name 'ed'"), - error.note("er.emb", - ir.module[0].foreign_import[1].local_name.source_location, - "Original definition")] - ], errors) - - def test_import_enum_resolution(self): - importer = ('import "ed.emb" as ed\n' - "struct Ff:\n" - " if ed.Gg.GG == ed.Gg.GG:\n" - " 0 [+1] UInt gg\n") - imported = ("enum Gg:\n" - " GG = 0\n") - ir = self._construct_ir_multiple({"ed.emb": imported, "er.emb": importer}, - "er.emb") - errors = symbol_resolver.resolve_symbols(ir) - self.assertEqual([], errors) - - def test_that_double_import_names_are_syntactically_invalid(self): - # There are currently no checks in resolve_symbols that it is not possible - # to get to symbols imported by another module, because it is syntactically - # invalid. This may change in the future, in which case this test should be - # fixed by adding an explicit check to resolve_symbols and checking the - # error message here. - importer = ('import "ed.emb" as ed\n' - "struct Ff:\n" - " 0 [+1] ed.ed2.Gg gg\n") - imported = 'import "ed2.emb" as ed2\n' - imported2 = ("struct Gg:\n" - " 0 [+1] UInt qq\n") - unused_ir, unused_debug_info, errors = glue.parse_emboss_file( - "er.emb", - test_util.dict_file_reader({"ed.emb": imported, - "ed2.emb": imported2, - "er.emb": importer}), - stop_before_step="resolve_symbols") - assert errors - - def test_no_error_when_inline_name_aliases_outer_name(self): - # The inline enum's complete type should be Foo.Foo. During parsing, the - # name is set to just "Foo", but symbol resolution should a) select the - # correct Foo, and b) not complain that multiple Foos could match. - ir = self._construct_ir( - "struct Foo:\n" - " 0 [+1] enum foo:\n" - " BAR = 0\n") - errors = symbol_resolver.resolve_symbols(ir) - self.assertEqual([], errors) - field = ir.module[0].type[0].structure.field[0] - self.assertEqual( - ["Foo", "Foo"], - list(field.type.atomic_type.reference.canonical_name.object_path)) - - def test_no_error_when_inline_name_in_anonymous_bits_aliases_outer_name(self): - # There is an extra layer of complexity when an inline type appears inside - # of an inline bits. - ir = self._construct_ir( - "struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] enum foo:\n" - " BAR = 0\n") - errors = symbol_resolver.resolve_symbols(ir) - self.assertEqual([], error.filter_errors(errors)) - field = ir.module[0].type[0].subtype[0].structure.field[0] - self.assertEqual( - ["Foo", "Foo"], - list(field.type.atomic_type.reference.canonical_name.object_path)) + ) + errors = symbol_resolver.resolve_symbols(ir) + supertype = ir.module[0].type[0] + self.assertEqual( + [ + [ + error.error( + "duplicate_in_anon.emb", + supertype.structure.field[2].name.source_location, + "Duplicate name 'a'", + ), + error.note( + "duplicate_in_anon.emb", + supertype.structure.field[1].name.source_location, + "Original definition", + ), + ] + ], + error.filter_errors(errors), + ) + + def test_import_type_resolution(self): + importer = 'import "ed.emb" as ed\n' "struct Ff:\n" " 0 [+1] ed.Gg gg\n" + imported = "struct Gg:\n" " 0 [+1] UInt qq\n" + ir = self._construct_ir_multiple( + {"ed.emb": imported, "er.emb": importer}, "er.emb" + ) + errors = symbol_resolver.resolve_symbols(ir) + self.assertEqual([], errors) + + def test_duplicate_import_name(self): + importer = ( + 'import "ed.emb" as ed\n' + 'import "ed.emb" as ed\n' + "struct Ff:\n" + " 0 [+1] ed.Gg gg\n" + ) + imported = "struct Gg:\n" " 0 [+1] UInt qq\n" + ir = self._construct_ir_multiple( + {"ed.emb": imported, "er.emb": importer}, "er.emb" + ) + errors = symbol_resolver.resolve_symbols(ir) + # Note: the error is on import[2] duplicating import[1] because the implicit + # prelude import is import[0]. + self.assertEqual( + [ + [ + error.error( + "er.emb", + ir.module[0].foreign_import[2].local_name.source_location, + "Duplicate name 'ed'", + ), + error.note( + "er.emb", + ir.module[0].foreign_import[1].local_name.source_location, + "Original definition", + ), + ] + ], + errors, + ) + + def test_import_enum_resolution(self): + importer = ( + 'import "ed.emb" as ed\n' + "struct Ff:\n" + " if ed.Gg.GG == ed.Gg.GG:\n" + " 0 [+1] UInt gg\n" + ) + imported = "enum Gg:\n" " GG = 0\n" + ir = self._construct_ir_multiple( + {"ed.emb": imported, "er.emb": importer}, "er.emb" + ) + errors = symbol_resolver.resolve_symbols(ir) + self.assertEqual([], errors) + + def test_that_double_import_names_are_syntactically_invalid(self): + # There are currently no checks in resolve_symbols that it is not possible + # to get to symbols imported by another module, because it is syntactically + # invalid. This may change in the future, in which case this test should be + # fixed by adding an explicit check to resolve_symbols and checking the + # error message here. + importer = 'import "ed.emb" as ed\n' "struct Ff:\n" " 0 [+1] ed.ed2.Gg gg\n" + imported = 'import "ed2.emb" as ed2\n' + imported2 = "struct Gg:\n" " 0 [+1] UInt qq\n" + unused_ir, unused_debug_info, errors = glue.parse_emboss_file( + "er.emb", + test_util.dict_file_reader( + {"ed.emb": imported, "ed2.emb": imported2, "er.emb": importer} + ), + stop_before_step="resolve_symbols", + ) + assert errors + + def test_no_error_when_inline_name_aliases_outer_name(self): + # The inline enum's complete type should be Foo.Foo. During parsing, the + # name is set to just "Foo", but symbol resolution should a) select the + # correct Foo, and b) not complain that multiple Foos could match. + ir = self._construct_ir( + "struct Foo:\n" " 0 [+1] enum foo:\n" " BAR = 0\n" + ) + errors = symbol_resolver.resolve_symbols(ir) + self.assertEqual([], errors) + field = ir.module[0].type[0].structure.field[0] + self.assertEqual( + ["Foo", "Foo"], + list(field.type.atomic_type.reference.canonical_name.object_path), + ) + + def test_no_error_when_inline_name_in_anonymous_bits_aliases_outer_name(self): + # There is an extra layer of complexity when an inline type appears inside + # of an inline bits. + ir = self._construct_ir( + "struct Foo:\n" + " 0 [+1] bits:\n" + " 0 [+4] enum foo:\n" + " BAR = 0\n" + ) + errors = symbol_resolver.resolve_symbols(ir) + self.assertEqual([], error.filter_errors(errors)) + field = ir.module[0].type[0].subtype[0].structure.field[0] + self.assertEqual( + ["Foo", "Foo"], + list(field.type.atomic_type.reference.canonical_name.object_path), + ) class ResolveFieldReferencesTest(unittest.TestCase): - """Tests for symbol_resolver.resolve_field_references().""" - - def _construct_ir_multiple(self, file_dict, primary_emb_name): - ir, unused_debug_info, errors = glue.parse_emboss_file( - primary_emb_name, - test_util.dict_file_reader(file_dict), - stop_before_step="resolve_field_references") - assert not errors - return ir - - def _construct_ir(self, emb_text, name="happy.emb"): - return self._construct_ir_multiple({name: emb_text}, name) - - def test_subfield_resolution(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg gg\n" - " 1 [+gg.qq] UInt:8[] data\n" - "struct Gg:\n" - " 0 [+1] UInt qq\n", "subfield.emb") - errors = symbol_resolver.resolve_field_references(ir) - self.assertFalse(errors) - ff = ir.module[0].type[0] - location_end_path = ff.structure.field[1].location.size.field_reference.path - self.assertEqual(["Ff", "gg"], - list(location_end_path[0].canonical_name.object_path)) - self.assertEqual(["Gg", "qq"], - list(location_end_path[1].canonical_name.object_path)) - - def test_aliased_subfield_resolution(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg real_gg\n" - " 1 [+gg.qq] UInt:8[] data\n" - " let gg = real_gg\n" - "struct Gg:\n" - " 0 [+1] UInt real_qq\n" - " let qq = real_qq", "subfield.emb") - errors = symbol_resolver.resolve_field_references(ir) - self.assertFalse(errors) - ff = ir.module[0].type[0] - location_end_path = ff.structure.field[1].location.size.field_reference.path - self.assertEqual(["Ff", "gg"], - list(location_end_path[0].canonical_name.object_path)) - self.assertEqual(["Gg", "qq"], - list(location_end_path[1].canonical_name.object_path)) - - def test_aliased_aliased_subfield_resolution(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg really_real_gg\n" - " 1 [+gg.qq] UInt:8[] data\n" - " let gg = real_gg\n" - " let real_gg = really_real_gg\n" - "struct Gg:\n" - " 0 [+1] UInt qq\n", "subfield.emb") - errors = symbol_resolver.resolve_field_references(ir) - self.assertFalse(errors) - ff = ir.module[0].type[0] - location_end_path = ff.structure.field[1].location.size.field_reference.path - self.assertEqual(["Ff", "gg"], - list(location_end_path[0].canonical_name.object_path)) - self.assertEqual(["Gg", "qq"], - list(location_end_path[1].canonical_name.object_path)) - - def test_subfield_resolution_fails(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg gg\n" - " 1 [+gg.rr] UInt:8[] data\n" - "struct Gg:\n" - " 0 [+1] UInt qq\n", "subfield.emb") - errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) - self.assertEqual([ - [error.error("subfield.emb", ir.module[0].type[0].structure.field[ - 1].location.size.field_reference.path[1].source_name[ - 0].source_location, "No candidate for 'rr'")] - ], errors) - - def test_subfield_resolution_failure_shortcuts_further_resolution(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg gg\n" - " 1 [+gg.rr.qq] UInt:8[] data\n" - "struct Gg:\n" - " 0 [+1] UInt qq\n", "subfield.emb") - errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) - self.assertEqual([ - [error.error("subfield.emb", ir.module[0].type[0].structure.field[ - 1].location.size.field_reference.path[1].source_name[ - 0].source_location, "No candidate for 'rr'")] - ], errors) - - def test_subfield_resolution_failure_with_aliased_name(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg gg\n" - " 1 [+gg.gg] UInt:8[] data\n" - "struct Gg:\n" - " 0 [+1] UInt qq\n", "subfield.emb") - errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) - self.assertEqual([ - [error.error("subfield.emb", ir.module[0].type[0].structure.field[ - 1].location.size.field_reference.path[1].source_name[ - 0].source_location, "No candidate for 'gg'")] - ], errors) - - def test_subfield_resolution_failure_with_array(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg[1] gg\n" - " 1 [+gg.qq] UInt:8[] data\n" - "struct Gg:\n" - " 0 [+1] UInt qq\n", "subfield.emb") - errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) - self.assertEqual([ - [error.error("subfield.emb", ir.module[0].type[0].structure.field[ - 1].location.size.field_reference.path[0].source_name[ - 0].source_location, "Cannot access member of array 'gg'")] - ], errors) - - def test_subfield_resolution_failure_with_int(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] UInt gg_source\n" - " 1 [+gg.qq] UInt:8[] data\n" - " let gg = gg_source + 1\n", - "subfield.emb") - errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) - error_field = ir.module[0].type[0].structure.field[1] - error_reference = error_field.location.size.field_reference - error_location = error_reference.path[0].source_name[0].source_location - self.assertEqual([ - [error.error("subfield.emb", error_location, - "Cannot access member of noncomposite field 'gg'")] - ], errors) - - def test_subfield_resolution_failure_with_int_no_cascade(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] UInt gg_source\n" - " 1 [+qqx] UInt:8[] data\n" - " let gg = gg_source + 1\n" - " let yy = gg.no_field\n" - " let qqx = yy.x\n" - " let qqy = yy.y\n", - "subfield.emb") - errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) - error_field = ir.module[0].type[0].structure.field[3] - error_reference = error_field.read_transform.field_reference - error_location = error_reference.path[0].source_name[0].source_location - self.assertEqual([ - [error.error("subfield.emb", error_location, - "Cannot access member of noncomposite field 'gg'")] - ], errors) - - def test_subfield_resolution_failure_with_abbreviation(self): - ir = self._construct_ir( - "struct Ff:\n" - " 0 [+1] Gg gg\n" - " 1 [+gg.q] UInt:8[] data\n" - "struct Gg:\n" - " 0 [+1] UInt qq (q)\n", "subfield.emb") - errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) - self.assertEqual([ - # TODO(bolms): Make the error message clearer, in this case. - [error.error("subfield.emb", ir.module[0].type[0].structure.field[ - 1].location.size.field_reference.path[1].source_name[ - 0].source_location, "No candidate for 'q'")] - ], errors) + """Tests for symbol_resolver.resolve_field_references().""" + + def _construct_ir_multiple(self, file_dict, primary_emb_name): + ir, unused_debug_info, errors = glue.parse_emboss_file( + primary_emb_name, + test_util.dict_file_reader(file_dict), + stop_before_step="resolve_field_references", + ) + assert not errors + return ir + + def _construct_ir(self, emb_text, name="happy.emb"): + return self._construct_ir_multiple({name: emb_text}, name) + + def test_subfield_resolution(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg gg\n" + " 1 [+gg.qq] UInt:8[] data\n" + "struct Gg:\n" + " 0 [+1] UInt qq\n", + "subfield.emb", + ) + errors = symbol_resolver.resolve_field_references(ir) + self.assertFalse(errors) + ff = ir.module[0].type[0] + location_end_path = ff.structure.field[1].location.size.field_reference.path + self.assertEqual( + ["Ff", "gg"], list(location_end_path[0].canonical_name.object_path) + ) + self.assertEqual( + ["Gg", "qq"], list(location_end_path[1].canonical_name.object_path) + ) + + def test_aliased_subfield_resolution(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg real_gg\n" + " 1 [+gg.qq] UInt:8[] data\n" + " let gg = real_gg\n" + "struct Gg:\n" + " 0 [+1] UInt real_qq\n" + " let qq = real_qq", + "subfield.emb", + ) + errors = symbol_resolver.resolve_field_references(ir) + self.assertFalse(errors) + ff = ir.module[0].type[0] + location_end_path = ff.structure.field[1].location.size.field_reference.path + self.assertEqual( + ["Ff", "gg"], list(location_end_path[0].canonical_name.object_path) + ) + self.assertEqual( + ["Gg", "qq"], list(location_end_path[1].canonical_name.object_path) + ) + + def test_aliased_aliased_subfield_resolution(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg really_real_gg\n" + " 1 [+gg.qq] UInt:8[] data\n" + " let gg = real_gg\n" + " let real_gg = really_real_gg\n" + "struct Gg:\n" + " 0 [+1] UInt qq\n", + "subfield.emb", + ) + errors = symbol_resolver.resolve_field_references(ir) + self.assertFalse(errors) + ff = ir.module[0].type[0] + location_end_path = ff.structure.field[1].location.size.field_reference.path + self.assertEqual( + ["Ff", "gg"], list(location_end_path[0].canonical_name.object_path) + ) + self.assertEqual( + ["Gg", "qq"], list(location_end_path[1].canonical_name.object_path) + ) + + def test_subfield_resolution_fails(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg gg\n" + " 1 [+gg.rr] UInt:8[] data\n" + "struct Gg:\n" + " 0 [+1] UInt qq\n", + "subfield.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) + self.assertEqual( + [ + [ + error.error( + "subfield.emb", + ir.module[0] + .type[0] + .structure.field[1] + .location.size.field_reference.path[1] + .source_name[0] + .source_location, + "No candidate for 'rr'", + ) + ] + ], + errors, + ) + + def test_subfield_resolution_failure_shortcuts_further_resolution(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg gg\n" + " 1 [+gg.rr.qq] UInt:8[] data\n" + "struct Gg:\n" + " 0 [+1] UInt qq\n", + "subfield.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) + self.assertEqual( + [ + [ + error.error( + "subfield.emb", + ir.module[0] + .type[0] + .structure.field[1] + .location.size.field_reference.path[1] + .source_name[0] + .source_location, + "No candidate for 'rr'", + ) + ] + ], + errors, + ) + + def test_subfield_resolution_failure_with_aliased_name(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg gg\n" + " 1 [+gg.gg] UInt:8[] data\n" + "struct Gg:\n" + " 0 [+1] UInt qq\n", + "subfield.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) + self.assertEqual( + [ + [ + error.error( + "subfield.emb", + ir.module[0] + .type[0] + .structure.field[1] + .location.size.field_reference.path[1] + .source_name[0] + .source_location, + "No candidate for 'gg'", + ) + ] + ], + errors, + ) + + def test_subfield_resolution_failure_with_array(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg[1] gg\n" + " 1 [+gg.qq] UInt:8[] data\n" + "struct Gg:\n" + " 0 [+1] UInt qq\n", + "subfield.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) + self.assertEqual( + [ + [ + error.error( + "subfield.emb", + ir.module[0] + .type[0] + .structure.field[1] + .location.size.field_reference.path[0] + .source_name[0] + .source_location, + "Cannot access member of array 'gg'", + ) + ] + ], + errors, + ) + + def test_subfield_resolution_failure_with_int(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] UInt gg_source\n" + " 1 [+gg.qq] UInt:8[] data\n" + " let gg = gg_source + 1\n", + "subfield.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) + error_field = ir.module[0].type[0].structure.field[1] + error_reference = error_field.location.size.field_reference + error_location = error_reference.path[0].source_name[0].source_location + self.assertEqual( + [ + [ + error.error( + "subfield.emb", + error_location, + "Cannot access member of noncomposite field 'gg'", + ) + ] + ], + errors, + ) + + def test_subfield_resolution_failure_with_int_no_cascade(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] UInt gg_source\n" + " 1 [+qqx] UInt:8[] data\n" + " let gg = gg_source + 1\n" + " let yy = gg.no_field\n" + " let qqx = yy.x\n" + " let qqy = yy.y\n", + "subfield.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) + error_field = ir.module[0].type[0].structure.field[3] + error_reference = error_field.read_transform.field_reference + error_location = error_reference.path[0].source_name[0].source_location + self.assertEqual( + [ + [ + error.error( + "subfield.emb", + error_location, + "Cannot access member of noncomposite field 'gg'", + ) + ] + ], + errors, + ) + + def test_subfield_resolution_failure_with_abbreviation(self): + ir = self._construct_ir( + "struct Ff:\n" + " 0 [+1] Gg gg\n" + " 1 [+gg.q] UInt:8[] data\n" + "struct Gg:\n" + " 0 [+1] UInt qq (q)\n", + "subfield.emb", + ) + errors = error.filter_errors(symbol_resolver.resolve_field_references(ir)) + self.assertEqual( + [ + # TODO(bolms): Make the error message clearer, in this case. + [ + error.error( + "subfield.emb", + ir.module[0] + .type[0] + .structure.field[1] + .location.size.field_reference.path[1] + .source_name[0] + .source_location, + "No candidate for 'q'", + ) + ] + ], + errors, + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py index 73be294..1b331a3 100644 --- a/compiler/front_end/synthetics.py +++ b/compiler/front_end/synthetics.py @@ -24,27 +24,28 @@ def _mark_as_synthetic(proto): - """Marks all source_locations in proto with is_synthetic=True.""" - if not isinstance(proto, ir_data.Message): - return - if hasattr(proto, "source_location"): - ir_data_utils.builder(proto).source_location.is_synthetic = True - for spec, value in ir_data_utils.get_set_fields(proto): - if spec.name != "source_location" and spec.is_dataclass: - if spec.is_sequence: - for i in value: - _mark_as_synthetic(i) - else: - _mark_as_synthetic(value) + """Marks all source_locations in proto with is_synthetic=True.""" + if not isinstance(proto, ir_data.Message): + return + if hasattr(proto, "source_location"): + ir_data_utils.builder(proto).source_location.is_synthetic = True + for spec, value in ir_data_utils.get_set_fields(proto): + if spec.name != "source_location" and spec.is_dataclass: + if spec.is_sequence: + for i in value: + _mark_as_synthetic(i) + else: + _mark_as_synthetic(value) def _skip_text_output_attribute(): - """Returns the IR for a [text_output: "Skip"] attribute.""" - result = ir_data.Attribute( - name=ir_data.Word(text=attributes.TEXT_OUTPUT), - value=ir_data.AttributeValue(string_constant=ir_data.String(text="Skip"))) - _mark_as_synthetic(result) - return result + """Returns the IR for a [text_output: "Skip"] attribute.""" + result = ir_data.Attribute( + name=ir_data.Word(text=attributes.TEXT_OUTPUT), + value=ir_data.AttributeValue(string_constant=ir_data.String(text="Skip")), + ) + _mark_as_synthetic(result) + return result # The existence condition for an alias for an anonymous bits' field is the union @@ -52,122 +53,133 @@ def _skip_text_output_attribute(): # for the field within. The 'x' and 'x.y' are placeholders here; they'll be # overwritten in _add_anonymous_aliases. _ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON = expression_parser.parse( - "$present(x) && $present(x.y)") + "$present(x) && $present(x.y)" +) def _add_anonymous_aliases(structure, type_definition): - """Adds synthetic alias fields for all fields in anonymous fields. - - This essentially completes the rewrite of this: - - struct Foo: - 0 [+4] bits: - 0 [+1] Flag low - 31 [+1] Flag high - - Into this: - - struct Foo: - bits EmbossReservedAnonymous0: - [text_output: "Skip"] - 0 [+1] Flag low - 31 [+1] Flag high - 0 [+4] EmbossReservedAnonymous0 emboss_reserved_anonymous_1 - let low = emboss_reserved_anonymous_1.low - let high = emboss_reserved_anonymous_1.high - - Note that this pass runs very, very early -- even before symbols have been - resolved -- so very little in ir_util will work at this point. - - Arguments: - structure: The ir_data.Structure on which to synthesize fields. - type_definition: The ir_data.TypeDefinition containing structure. - - Returns: - None - """ - new_fields = [] - for field in structure.field: - new_fields.append(field) - if not field.name.is_anonymous: - continue - field.attribute.extend([_skip_text_output_attribute()]) - for subtype in type_definition.subtype: - if (subtype.name.name.text == - field.type.atomic_type.reference.source_name[-1].text): - field_type = subtype - break - else: - assert False, ("Unable to find corresponding type {} for anonymous field " - "in {}.".format( - field.type.atomic_type.reference, type_definition)) - anonymous_reference = ir_data.Reference(source_name=[field.name.name]) - anonymous_field_reference = ir_data.FieldReference( - path=[anonymous_reference]) - for subfield in field_type.structure.field: - alias_field_reference = ir_data.FieldReference( - path=[ - anonymous_reference, - ir_data.Reference(source_name=[subfield.name.name]), - ] - ) - new_existence_condition = ir_data_utils.copy(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON) - existence_clauses = ir_data_utils.builder(new_existence_condition).function.args - existence_clauses[0].function.args[0].field_reference.CopyFrom( - anonymous_field_reference) - existence_clauses[1].function.args[0].field_reference.CopyFrom( - alias_field_reference) - new_read_transform = ir_data.Expression( - field_reference=ir_data_utils.copy(alias_field_reference)) - # This treats *most* of the alias field as synthetic, but not its name(s): - # leaving the name(s) as "real" means that symbol collisions with the - # surrounding structure will be properly reported to the user. - _mark_as_synthetic(new_existence_condition) - _mark_as_synthetic(new_read_transform) - new_alias = ir_data.Field( - read_transform=new_read_transform, - existence_condition=new_existence_condition, - name=ir_data_utils.copy(subfield.name)) - if subfield.HasField("abbreviation"): - ir_data_utils.builder(new_alias).abbreviation.CopyFrom(subfield.abbreviation) - _mark_as_synthetic(new_alias.existence_condition) - _mark_as_synthetic(new_alias.read_transform) - new_fields.append(new_alias) - # Since the alias field's name(s) are "real," it is important to mark the - # original field's name(s) as synthetic, to avoid duplicate error - # messages. - _mark_as_synthetic(subfield.name) - if subfield.HasField("abbreviation"): - _mark_as_synthetic(subfield.abbreviation) - del structure.field[:] - structure.field.extend(new_fields) + """Adds synthetic alias fields for all fields in anonymous fields. + + This essentially completes the rewrite of this: + + struct Foo: + 0 [+4] bits: + 0 [+1] Flag low + 31 [+1] Flag high + + Into this: + + struct Foo: + bits EmbossReservedAnonymous0: + [text_output: "Skip"] + 0 [+1] Flag low + 31 [+1] Flag high + 0 [+4] EmbossReservedAnonymous0 emboss_reserved_anonymous_1 + let low = emboss_reserved_anonymous_1.low + let high = emboss_reserved_anonymous_1.high + + Note that this pass runs very, very early -- even before symbols have been + resolved -- so very little in ir_util will work at this point. + + Arguments: + structure: The ir_data.Structure on which to synthesize fields. + type_definition: The ir_data.TypeDefinition containing structure. + + Returns: + None + """ + new_fields = [] + for field in structure.field: + new_fields.append(field) + if not field.name.is_anonymous: + continue + field.attribute.extend([_skip_text_output_attribute()]) + for subtype in type_definition.subtype: + if ( + subtype.name.name.text + == field.type.atomic_type.reference.source_name[-1].text + ): + field_type = subtype + break + else: + assert False, ( + "Unable to find corresponding type {} for anonymous field " + "in {}.".format(field.type.atomic_type.reference, type_definition) + ) + anonymous_reference = ir_data.Reference(source_name=[field.name.name]) + anonymous_field_reference = ir_data.FieldReference(path=[anonymous_reference]) + for subfield in field_type.structure.field: + alias_field_reference = ir_data.FieldReference( + path=[ + anonymous_reference, + ir_data.Reference(source_name=[subfield.name.name]), + ] + ) + new_existence_condition = ir_data_utils.copy( + _ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON + ) + existence_clauses = ir_data_utils.builder( + new_existence_condition + ).function.args + existence_clauses[0].function.args[0].field_reference.CopyFrom( + anonymous_field_reference + ) + existence_clauses[1].function.args[0].field_reference.CopyFrom( + alias_field_reference + ) + new_read_transform = ir_data.Expression( + field_reference=ir_data_utils.copy(alias_field_reference) + ) + # This treats *most* of the alias field as synthetic, but not its name(s): + # leaving the name(s) as "real" means that symbol collisions with the + # surrounding structure will be properly reported to the user. + _mark_as_synthetic(new_existence_condition) + _mark_as_synthetic(new_read_transform) + new_alias = ir_data.Field( + read_transform=new_read_transform, + existence_condition=new_existence_condition, + name=ir_data_utils.copy(subfield.name), + ) + if subfield.HasField("abbreviation"): + ir_data_utils.builder(new_alias).abbreviation.CopyFrom( + subfield.abbreviation + ) + _mark_as_synthetic(new_alias.existence_condition) + _mark_as_synthetic(new_alias.read_transform) + new_fields.append(new_alias) + # Since the alias field's name(s) are "real," it is important to mark the + # original field's name(s) as synthetic, to avoid duplicate error + # messages. + _mark_as_synthetic(subfield.name) + if subfield.HasField("abbreviation"): + _mark_as_synthetic(subfield.abbreviation) + del structure.field[:] + structure.field.extend(new_fields) _SIZE_BOUNDS = { "$max_size_in_bits": expression_parser.parse("$upper_bound($size_in_bits)"), "$min_size_in_bits": expression_parser.parse("$lower_bound($size_in_bits)"), - "$max_size_in_bytes": expression_parser.parse( - "$upper_bound($size_in_bytes)"), - "$min_size_in_bytes": expression_parser.parse( - "$lower_bound($size_in_bytes)"), + "$max_size_in_bytes": expression_parser.parse("$upper_bound($size_in_bytes)"), + "$min_size_in_bytes": expression_parser.parse("$lower_bound($size_in_bytes)"), } def _add_size_bound_virtuals(structure, type_definition): - """Adds ${min,max}_size_in_{bits,bytes} virtual fields to structure.""" - names = { - ir_data.AddressableUnit.BIT: ("$max_size_in_bits", "$min_size_in_bits"), - ir_data.AddressableUnit.BYTE: ("$max_size_in_bytes", "$min_size_in_bytes"), - } - for name in names[type_definition.addressable_unit]: - bound_field = ir_data.Field( - read_transform=_SIZE_BOUNDS[name], - name=ir_data.NameDefinition(name=ir_data.Word(text=name)), - existence_condition=expression_parser.parse("true"), - attribute=[_skip_text_output_attribute()] - ) - _mark_as_synthetic(bound_field.read_transform) - structure.field.extend([bound_field]) + """Adds ${min,max}_size_in_{bits,bytes} virtual fields to structure.""" + names = { + ir_data.AddressableUnit.BIT: ("$max_size_in_bits", "$min_size_in_bits"), + ir_data.AddressableUnit.BYTE: ("$max_size_in_bytes", "$min_size_in_bytes"), + } + for name in names[type_definition.addressable_unit]: + bound_field = ir_data.Field( + read_transform=_SIZE_BOUNDS[name], + name=ir_data.NameDefinition(name=ir_data.Word(text=name)), + existence_condition=expression_parser.parse("true"), + attribute=[_skip_text_output_attribute()], + ) + _mark_as_synthetic(bound_field.read_transform) + structure.field.extend([bound_field]) # Each non-virtual field in a structure generates a clause that is passed to @@ -177,42 +189,43 @@ def _add_size_bound_virtuals(structure, type_definition): # physical fields don't end up with a zero-argument `$max()` call, which would # fail type checking. _SIZE_CLAUSE_SKELETON = expression_parser.parse( - "existence_condition ? start + size : 0") + "existence_condition ? start + size : 0" +) _SIZE_SKELETON = expression_parser.parse("$max(0)") def _add_size_virtuals(structure, type_definition): - """Adds a $size_in_bits or $size_in_bytes virtual field to structure.""" - names = { - ir_data.AddressableUnit.BIT: "$size_in_bits", - ir_data.AddressableUnit.BYTE: "$size_in_bytes", - } - size_field_name = names[type_definition.addressable_unit] - size_clauses = [] - for field in structure.field: - # Virtual fields do not have a physical location, and thus do not contribute - # to the size of the structure. - if ir_util.field_is_virtual(field): - continue - size_clause_ir = ir_data_utils.copy(_SIZE_CLAUSE_SKELETON) - size_clause = ir_data_utils.builder(size_clause_ir) - # Copy the appropriate clauses into `existence_condition ? start + size : 0` - size_clause.function.args[0].CopyFrom(field.existence_condition) - size_clause.function.args[1].function.args[0].CopyFrom(field.location.start) - size_clause.function.args[1].function.args[1].CopyFrom(field.location.size) - size_clauses.append(size_clause_ir) - size_expression = ir_data_utils.copy(_SIZE_SKELETON) - size_expression.function.args.extend(size_clauses) - _mark_as_synthetic(size_expression) - size_field = ir_data.Field( - read_transform=size_expression, - name=ir_data.NameDefinition(name=ir_data.Word(text=size_field_name)), - existence_condition=ir_data.Expression( - boolean_constant=ir_data.BooleanConstant(value=True) - ), - attribute=[_skip_text_output_attribute()] - ) - structure.field.extend([size_field]) + """Adds a $size_in_bits or $size_in_bytes virtual field to structure.""" + names = { + ir_data.AddressableUnit.BIT: "$size_in_bits", + ir_data.AddressableUnit.BYTE: "$size_in_bytes", + } + size_field_name = names[type_definition.addressable_unit] + size_clauses = [] + for field in structure.field: + # Virtual fields do not have a physical location, and thus do not contribute + # to the size of the structure. + if ir_util.field_is_virtual(field): + continue + size_clause_ir = ir_data_utils.copy(_SIZE_CLAUSE_SKELETON) + size_clause = ir_data_utils.builder(size_clause_ir) + # Copy the appropriate clauses into `existence_condition ? start + size : 0` + size_clause.function.args[0].CopyFrom(field.existence_condition) + size_clause.function.args[1].function.args[0].CopyFrom(field.location.start) + size_clause.function.args[1].function.args[1].CopyFrom(field.location.size) + size_clauses.append(size_clause_ir) + size_expression = ir_data_utils.copy(_SIZE_SKELETON) + size_expression.function.args.extend(size_clauses) + _mark_as_synthetic(size_expression) + size_field = ir_data.Field( + read_transform=size_expression, + name=ir_data.NameDefinition(name=ir_data.Word(text=size_field_name)), + existence_condition=ir_data.Expression( + boolean_constant=ir_data.BooleanConstant(value=True) + ), + attribute=[_skip_text_output_attribute()], + ) + structure.field.extend([size_field]) # The replacement for the "$next" keyword is a simple "start + size" expression. @@ -220,113 +233,134 @@ def _add_size_virtuals(structure, type_definition): _NEXT_KEYWORD_REPLACEMENT_EXPRESSION = expression_parser.parse("x + y") -def _maybe_replace_next_keyword_in_expression(expression_ir, last_location, - source_file_name, errors): - if not expression_ir.HasField("builtin_reference"): - return - if ir_data_utils.reader(expression_ir).builtin_reference.canonical_name.object_path[0] != "$next": - return - if not last_location: - errors.append([ - error.error(source_file_name, expression_ir.source_location, - "`$next` may not be used in the first physical field of a " + - "structure; perhaps you meant `0`?") - ]) - return - original_location = expression_ir.source_location - expression = ir_data_utils.builder(expression_ir) - expression.CopyFrom(_NEXT_KEYWORD_REPLACEMENT_EXPRESSION) - expression.function.args[0].CopyFrom(last_location.start) - expression.function.args[1].CopyFrom(last_location.size) - expression.source_location.CopyFrom(original_location) - _mark_as_synthetic(expression.function) +def _maybe_replace_next_keyword_in_expression( + expression_ir, last_location, source_file_name, errors +): + if not expression_ir.HasField("builtin_reference"): + return + if ( + ir_data_utils.reader( + expression_ir + ).builtin_reference.canonical_name.object_path[0] + != "$next" + ): + return + if not last_location: + errors.append( + [ + error.error( + source_file_name, + expression_ir.source_location, + "`$next` may not be used in the first physical field of a " + + "structure; perhaps you meant `0`?", + ) + ] + ) + return + original_location = expression_ir.source_location + expression = ir_data_utils.builder(expression_ir) + expression.CopyFrom(_NEXT_KEYWORD_REPLACEMENT_EXPRESSION) + expression.function.args[0].CopyFrom(last_location.start) + expression.function.args[1].CopyFrom(last_location.size) + expression.source_location.CopyFrom(original_location) + _mark_as_synthetic(expression.function) def _check_for_bad_next_keyword_in_size(expression, source_file_name, errors): - if not expression.HasField("builtin_reference"): - return - if expression.builtin_reference.canonical_name.object_path[0] != "$next": - return - errors.append([ - error.error(source_file_name, expression.source_location, - "`$next` may only be used in the start expression of a " + - "physical field.") - ]) + if not expression.HasField("builtin_reference"): + return + if expression.builtin_reference.canonical_name.object_path[0] != "$next": + return + errors.append( + [ + error.error( + source_file_name, + expression.source_location, + "`$next` may only be used in the start expression of a " + + "physical field.", + ) + ] + ) def _replace_next_keyword(structure, source_file_name, errors): - last_physical_field_location = None - new_errors = [] - for field in structure.field: - if ir_util.field_is_virtual(field): - # TODO(bolms): It could be useful to allow `$next` in a virtual field, in - # order to reuse the value (say, to allow overlapping fields in a - # mostly-packed structure), but it seems better to add `$end_of(field)`, - # `$offset_of(field)`, and `$size_of(field)` constructs of some sort, - # instead. - continue - traverse_ir.fast_traverse_node_top_down( - field.location.size, [ir_data.Expression], - _check_for_bad_next_keyword_in_size, - parameters={ - "errors": new_errors, - "source_file_name": source_file_name, - }) - # If `$next` is misused in a field size, it can end up causing a - # `RecursionError` in fast_traverse_node_top_down. (When the `$next` node - # in the next field is replaced, its replacement gets traversed, but the - # replacement also contains a `$next` node, leading to infinite recursion.) - # - # Technically, we could scan all of the sizes instead of bailing early, but - # it seems relatively unlikely that someone will have `$next` in multiple - # sizes and not figure out what is going on relatively quickly. - if new_errors: - errors.extend(new_errors) - return - traverse_ir.fast_traverse_node_top_down( - field.location.start, [ir_data.Expression], - _maybe_replace_next_keyword_in_expression, - parameters={ - "last_location": last_physical_field_location, - "errors": new_errors, - "source_file_name": source_file_name, - }) - # The only possible error from _maybe_replace_next_keyword_in_expression is - # `$next` occurring in the start expression of the first physical field, - # which leads to similar recursion issue if `$next` is used in the start - # expression of the next physical field. - if new_errors: - errors.extend(new_errors) - return - last_physical_field_location = field.location + last_physical_field_location = None + new_errors = [] + for field in structure.field: + if ir_util.field_is_virtual(field): + # TODO(bolms): It could be useful to allow `$next` in a virtual field, in + # order to reuse the value (say, to allow overlapping fields in a + # mostly-packed structure), but it seems better to add `$end_of(field)`, + # `$offset_of(field)`, and `$size_of(field)` constructs of some sort, + # instead. + continue + traverse_ir.fast_traverse_node_top_down( + field.location.size, + [ir_data.Expression], + _check_for_bad_next_keyword_in_size, + parameters={ + "errors": new_errors, + "source_file_name": source_file_name, + }, + ) + # If `$next` is misused in a field size, it can end up causing a + # `RecursionError` in fast_traverse_node_top_down. (When the `$next` node + # in the next field is replaced, its replacement gets traversed, but the + # replacement also contains a `$next` node, leading to infinite recursion.) + # + # Technically, we could scan all of the sizes instead of bailing early, but + # it seems relatively unlikely that someone will have `$next` in multiple + # sizes and not figure out what is going on relatively quickly. + if new_errors: + errors.extend(new_errors) + return + traverse_ir.fast_traverse_node_top_down( + field.location.start, + [ir_data.Expression], + _maybe_replace_next_keyword_in_expression, + parameters={ + "last_location": last_physical_field_location, + "errors": new_errors, + "source_file_name": source_file_name, + }, + ) + # The only possible error from _maybe_replace_next_keyword_in_expression is + # `$next` occurring in the start expression of the first physical field, + # which leads to similar recursion issue if `$next` is used in the start + # expression of the next physical field. + if new_errors: + errors.extend(new_errors) + return + last_physical_field_location = field.location def _add_virtuals_to_structure(structure, type_definition): - _add_anonymous_aliases(structure, type_definition) - _add_size_virtuals(structure, type_definition) - _add_size_bound_virtuals(structure, type_definition) + _add_anonymous_aliases(structure, type_definition) + _add_size_virtuals(structure, type_definition) + _add_size_bound_virtuals(structure, type_definition) def desugar(ir): - """Translates pure syntactic sugar to its desugared form. - - Replaces `$next` symbols with the start+length of the previous physical - field. - - Adds aliases for all fields in anonymous `bits` to the enclosing structure. - - Arguments: - ir: The IR to desugar. - - Returns: - A list of errors, or an empty list. - """ - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure], _replace_next_keyword, - parameters={"errors": errors}) - if errors: - return errors - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Structure], _add_virtuals_to_structure) - return [] + """Translates pure syntactic sugar to its desugared form. + + Replaces `$next` symbols with the start+length of the previous physical + field. + + Adds aliases for all fields in anonymous `bits` to the enclosing structure. + + Arguments: + ir: The IR to desugar. + + Returns: + A list of errors, or an empty list. + """ + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Structure], _replace_next_keyword, parameters={"errors": errors} + ) + if errors: + return errors + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Structure], _add_virtuals_to_structure + ) + return [] diff --git a/compiler/front_end/synthetics_test.py b/compiler/front_end/synthetics_test.py index 85a3dfb..ac8e671 100644 --- a/compiler/front_end/synthetics_test.py +++ b/compiler/front_end/synthetics_test.py @@ -24,246 +24,300 @@ class SyntheticsTest(unittest.TestCase): - def _find_attribute(self, field, name): - result = None - for attribute in field.attribute: - if attribute.name.text == name: - self.assertIsNone(result) - result = attribute - self.assertIsNotNone(result) - return result + def _find_attribute(self, field, name): + result = None + for attribute in field.attribute: + if attribute.name.text == name: + self.assertIsNone(result) + result = attribute + self.assertIsNotNone(result) + return result - def _make_ir(self, emb_text): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": emb_text}), - stop_before_step="desugar") - assert not errors, errors - return ir + def _make_ir(self, emb_text): + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader({"m.emb": emb_text}), + stop_before_step="desugar", + ) + assert not errors, errors + return ir - def test_nothing_to_do(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt:8[] y\n") - self.assertEqual([], synthetics.desugar(ir)) + def test_nothing_to_do(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+1] UInt:8[] y\n" + ) + self.assertEqual([], synthetics.desugar(ir)) - def test_adds_anonymous_bits_fields(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] Bar bar\n" - " 4 [+4] UInt uint\n" - " 1 [+1] bits:\n" - " 0 [+4] Bits nested_bits\n" - "enum Bar:\n" - " BAR = 0\n" - "bits Bits:\n" - " 0 [+4] UInt uint\n") - self.assertEqual([], synthetics.desugar(ir)) - structure = ir.module[0].type[0].structure - # The first field should be the anonymous bits structure. - self.assertTrue(structure.field[0].HasField("location")) - # Then the aliases generated for those structures. - self.assertEqual("bar", structure.field[1].name.name.text) - self.assertEqual("uint", structure.field[2].name.name.text) - # Then the second anonymous bits. - self.assertTrue(structure.field[3].HasField("location")) - # Then the alias from the second anonymous bits. - self.assertEqual("nested_bits", structure.field[4].name.name.text) + def test_adds_anonymous_bits_fields(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] bits:\n" + " 0 [+4] Bar bar\n" + " 4 [+4] UInt uint\n" + " 1 [+1] bits:\n" + " 0 [+4] Bits nested_bits\n" + "enum Bar:\n" + " BAR = 0\n" + "bits Bits:\n" + " 0 [+4] UInt uint\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + structure = ir.module[0].type[0].structure + # The first field should be the anonymous bits structure. + self.assertTrue(structure.field[0].HasField("location")) + # Then the aliases generated for those structures. + self.assertEqual("bar", structure.field[1].name.name.text) + self.assertEqual("uint", structure.field[2].name.name.text) + # Then the second anonymous bits. + self.assertTrue(structure.field[3].HasField("location")) + # Then the alias from the second anonymous bits. + self.assertEqual("nested_bits", structure.field[4].name.name.text) - def test_adds_correct_existence_condition(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] UInt bar\n") - self.assertEqual([], synthetics.desugar(ir)) - bits_field = ir.module[0].type[0].structure.field[0] - alias_field = ir.module[0].type[0].structure.field[1] - self.assertEqual("bar", alias_field.name.name.text) - self.assertEqual(bits_field.name.name.text, - alias_field.existence_condition.function.args[0].function. - args[0].field_reference.path[0].source_name[-1].text) - self.assertEqual(bits_field.name.name.text, - alias_field.existence_condition.function.args[1].function. - args[0].field_reference.path[0].source_name[-1].text) - self.assertEqual("bar", - alias_field.existence_condition.function.args[1].function. - args[0].field_reference.path[1].source_name[-1].text) - self.assertEqual( - ir_data.FunctionMapping.PRESENCE, - alias_field.existence_condition.function.args[0].function.function) - self.assertEqual( - ir_data.FunctionMapping.PRESENCE, - alias_field.existence_condition.function.args[1].function.function) - self.assertEqual(ir_data.FunctionMapping.AND, - alias_field.existence_condition.function.function) + def test_adds_correct_existence_condition(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + bits_field = ir.module[0].type[0].structure.field[0] + alias_field = ir.module[0].type[0].structure.field[1] + self.assertEqual("bar", alias_field.name.name.text) + self.assertEqual( + bits_field.name.name.text, + alias_field.existence_condition.function.args[0] + .function.args[0] + .field_reference.path[0] + .source_name[-1] + .text, + ) + self.assertEqual( + bits_field.name.name.text, + alias_field.existence_condition.function.args[1] + .function.args[0] + .field_reference.path[0] + .source_name[-1] + .text, + ) + self.assertEqual( + "bar", + alias_field.existence_condition.function.args[1] + .function.args[0] + .field_reference.path[1] + .source_name[-1] + .text, + ) + self.assertEqual( + ir_data.FunctionMapping.PRESENCE, + alias_field.existence_condition.function.args[0].function.function, + ) + self.assertEqual( + ir_data.FunctionMapping.PRESENCE, + alias_field.existence_condition.function.args[1].function.function, + ) + self.assertEqual( + ir_data.FunctionMapping.AND, + alias_field.existence_condition.function.function, + ) - def test_adds_correct_read_transform(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] UInt bar\n") - self.assertEqual([], synthetics.desugar(ir)) - bits_field = ir.module[0].type[0].structure.field[0] - alias_field = ir.module[0].type[0].structure.field[1] - self.assertEqual("bar", alias_field.name.name.text) - self.assertEqual( - bits_field.name.name.text, - alias_field.read_transform.field_reference.path[0].source_name[-1].text) - self.assertEqual( - "bar", - alias_field.read_transform.field_reference.path[1].source_name[-1].text) + def test_adds_correct_read_transform(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + bits_field = ir.module[0].type[0].structure.field[0] + alias_field = ir.module[0].type[0].structure.field[1] + self.assertEqual("bar", alias_field.name.name.text) + self.assertEqual( + bits_field.name.name.text, + alias_field.read_transform.field_reference.path[0].source_name[-1].text, + ) + self.assertEqual( + "bar", + alias_field.read_transform.field_reference.path[1].source_name[-1].text, + ) - def test_adds_correct_abbreviation(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] UInt bar\n" - " 4 [+4] UInt baz (qux)\n") - self.assertEqual([], synthetics.desugar(ir)) - bar_alias = ir.module[0].type[0].structure.field[1] - baz_alias = ir.module[0].type[0].structure.field[2] - self.assertFalse(bar_alias.HasField("abbreviation")) - self.assertEqual("qux", baz_alias.abbreviation.text) + def test_adds_correct_abbreviation(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] bits:\n" + " 0 [+4] UInt bar\n" + " 4 [+4] UInt baz (qux)\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + bar_alias = ir.module[0].type[0].structure.field[1] + baz_alias = ir.module[0].type[0].structure.field[2] + self.assertFalse(bar_alias.HasField("abbreviation")) + self.assertEqual("qux", baz_alias.abbreviation.text) - def test_anonymous_bits_sets_correct_is_synthetic(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] UInt bar (b)\n") - self.assertEqual([], synthetics.desugar(ir)) - bits_field = ir.module[0].type[0].subtype[0].structure.field[0] - alias_field = ir.module[0].type[0].structure.field[1] - self.assertFalse(alias_field.name.source_location.is_synthetic) - self.assertTrue(alias_field.HasField("abbreviation")) - self.assertFalse(alias_field.abbreviation.source_location.is_synthetic) - self.assertTrue(alias_field.HasField("read_transform")) - read_alias = alias_field.read_transform - self.assertTrue(read_alias.source_location.is_synthetic) - self.assertTrue( - read_alias.field_reference.path[0].source_location.is_synthetic) - alias_condition = alias_field.existence_condition - self.assertTrue(alias_condition.source_location.is_synthetic) - self.assertTrue( - alias_condition.function.args[0].source_location.is_synthetic) - self.assertTrue(bits_field.name.source_location.is_synthetic) - self.assertTrue(bits_field.name.name.source_location.is_synthetic) - self.assertTrue(bits_field.abbreviation.source_location.is_synthetic) + def test_anonymous_bits_sets_correct_is_synthetic(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar (b)\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + bits_field = ir.module[0].type[0].subtype[0].structure.field[0] + alias_field = ir.module[0].type[0].structure.field[1] + self.assertFalse(alias_field.name.source_location.is_synthetic) + self.assertTrue(alias_field.HasField("abbreviation")) + self.assertFalse(alias_field.abbreviation.source_location.is_synthetic) + self.assertTrue(alias_field.HasField("read_transform")) + read_alias = alias_field.read_transform + self.assertTrue(read_alias.source_location.is_synthetic) + self.assertTrue(read_alias.field_reference.path[0].source_location.is_synthetic) + alias_condition = alias_field.existence_condition + self.assertTrue(alias_condition.source_location.is_synthetic) + self.assertTrue(alias_condition.function.args[0].source_location.is_synthetic) + self.assertTrue(bits_field.name.source_location.is_synthetic) + self.assertTrue(bits_field.name.name.source_location.is_synthetic) + self.assertTrue(bits_field.abbreviation.source_location.is_synthetic) - def test_adds_text_output_skip_attribute_to_anonymous_bits(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] UInt bar (b)\n") - self.assertEqual([], synthetics.desugar(ir)) - bits_field = ir.module[0].type[0].structure.field[0] - text_output_attribute = self._find_attribute(bits_field, "text_output") - self.assertEqual("Skip", text_output_attribute.value.string_constant.text) + def test_adds_text_output_skip_attribute_to_anonymous_bits(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar (b)\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + bits_field = ir.module[0].type[0].structure.field[0] + text_output_attribute = self._find_attribute(bits_field, "text_output") + self.assertEqual("Skip", text_output_attribute.value.string_constant.text) - def test_skip_attribute_is_marked_as_synthetic(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] bits:\n" - " 0 [+4] UInt bar\n") - self.assertEqual([], synthetics.desugar(ir)) - bits_field = ir.module[0].type[0].structure.field[0] - attribute = self._find_attribute(bits_field, "text_output") - self.assertTrue(attribute.source_location.is_synthetic) - self.assertTrue(attribute.name.source_location.is_synthetic) - self.assertTrue(attribute.value.source_location.is_synthetic) - self.assertTrue( - attribute.value.string_constant.source_location.is_synthetic) + def test_skip_attribute_is_marked_as_synthetic(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + bits_field = ir.module[0].type[0].structure.field[0] + attribute = self._find_attribute(bits_field, "text_output") + self.assertTrue(attribute.source_location.is_synthetic) + self.assertTrue(attribute.name.source_location.is_synthetic) + self.assertTrue(attribute.value.source_location.is_synthetic) + self.assertTrue(attribute.value.string_constant.source_location.is_synthetic) - def test_adds_size_in_bytes(self): - ir = self._make_ir("struct Foo:\n" - " 1 [+l] UInt:8[] bytes\n" - " 0 [+1] UInt length (l)\n") - self.assertEqual([], synthetics.desugar(ir)) - structure = ir.module[0].type[0].structure - size_in_bytes_field = structure.field[2] - max_size_in_bytes_field = structure.field[3] - min_size_in_bytes_field = structure.field[4] - self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text) - self.assertEqual(ir_data.FunctionMapping.MAXIMUM, - size_in_bytes_field.read_transform.function.function) - self.assertEqual("$max_size_in_bytes", - max_size_in_bytes_field.name.name.text) - self.assertEqual(ir_data.FunctionMapping.UPPER_BOUND, - max_size_in_bytes_field.read_transform.function.function) - self.assertEqual("$min_size_in_bytes", - min_size_in_bytes_field.name.name.text) - self.assertEqual(ir_data.FunctionMapping.LOWER_BOUND, - min_size_in_bytes_field.read_transform.function.function) - # The correctness of $size_in_bytes et al are tested much further down - # stream, in tests of the generated C++ code. + def test_adds_size_in_bytes(self): + ir = self._make_ir( + "struct Foo:\n" + " 1 [+l] UInt:8[] bytes\n" + " 0 [+1] UInt length (l)\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + structure = ir.module[0].type[0].structure + size_in_bytes_field = structure.field[2] + max_size_in_bytes_field = structure.field[3] + min_size_in_bytes_field = structure.field[4] + self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text) + self.assertEqual( + ir_data.FunctionMapping.MAXIMUM, + size_in_bytes_field.read_transform.function.function, + ) + self.assertEqual("$max_size_in_bytes", max_size_in_bytes_field.name.name.text) + self.assertEqual( + ir_data.FunctionMapping.UPPER_BOUND, + max_size_in_bytes_field.read_transform.function.function, + ) + self.assertEqual("$min_size_in_bytes", min_size_in_bytes_field.name.name.text) + self.assertEqual( + ir_data.FunctionMapping.LOWER_BOUND, + min_size_in_bytes_field.read_transform.function.function, + ) + # The correctness of $size_in_bytes et al are tested much further down + # stream, in tests of the generated C++ code. - def test_adds_size_in_bits(self): - ir = self._make_ir("bits Foo:\n" - " 1 [+9] UInt hi\n" - " 0 [+1] Flag lo\n") - self.assertEqual([], synthetics.desugar(ir)) - structure = ir.module[0].type[0].structure - size_in_bits_field = structure.field[2] - max_size_in_bits_field = structure.field[3] - min_size_in_bits_field = structure.field[4] - self.assertEqual("$size_in_bits", size_in_bits_field.name.name.text) - self.assertEqual(ir_data.FunctionMapping.MAXIMUM, - size_in_bits_field.read_transform.function.function) - self.assertEqual("$max_size_in_bits", - max_size_in_bits_field.name.name.text) - self.assertEqual(ir_data.FunctionMapping.UPPER_BOUND, - max_size_in_bits_field.read_transform.function.function) - self.assertEqual("$min_size_in_bits", - min_size_in_bits_field.name.name.text) - self.assertEqual(ir_data.FunctionMapping.LOWER_BOUND, - min_size_in_bits_field.read_transform.function.function) - # The correctness of $size_in_bits et al are tested much further down - # stream, in tests of the generated C++ code. + def test_adds_size_in_bits(self): + ir = self._make_ir("bits Foo:\n" " 1 [+9] UInt hi\n" " 0 [+1] Flag lo\n") + self.assertEqual([], synthetics.desugar(ir)) + structure = ir.module[0].type[0].structure + size_in_bits_field = structure.field[2] + max_size_in_bits_field = structure.field[3] + min_size_in_bits_field = structure.field[4] + self.assertEqual("$size_in_bits", size_in_bits_field.name.name.text) + self.assertEqual( + ir_data.FunctionMapping.MAXIMUM, + size_in_bits_field.read_transform.function.function, + ) + self.assertEqual("$max_size_in_bits", max_size_in_bits_field.name.name.text) + self.assertEqual( + ir_data.FunctionMapping.UPPER_BOUND, + max_size_in_bits_field.read_transform.function.function, + ) + self.assertEqual("$min_size_in_bits", min_size_in_bits_field.name.name.text) + self.assertEqual( + ir_data.FunctionMapping.LOWER_BOUND, + min_size_in_bits_field.read_transform.function.function, + ) + # The correctness of $size_in_bits et al are tested much further down + # stream, in tests of the generated C++ code. - def test_adds_text_output_skip_attribute_to_size_in_bytes(self): - ir = self._make_ir("struct Foo:\n" - " 1 [+l] UInt:8[] bytes\n" - " 0 [+1] UInt length (l)\n") - self.assertEqual([], synthetics.desugar(ir)) - size_in_bytes_field = ir.module[0].type[0].structure.field[2] - self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text) - text_output_attribute = self._find_attribute(size_in_bytes_field, - "text_output") - self.assertEqual("Skip", text_output_attribute.value.string_constant.text) + def test_adds_text_output_skip_attribute_to_size_in_bytes(self): + ir = self._make_ir( + "struct Foo:\n" + " 1 [+l] UInt:8[] bytes\n" + " 0 [+1] UInt length (l)\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + size_in_bytes_field = ir.module[0].type[0].structure.field[2] + self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text) + text_output_attribute = self._find_attribute(size_in_bytes_field, "text_output") + self.assertEqual("Skip", text_output_attribute.value.string_constant.text) - def test_replaces_next(self): - ir = self._make_ir("struct Foo:\n" - " 1 [+2] UInt:8[] a\n" - " $next [+4] UInt b\n" - " $next [+1] UInt c\n") - self.assertEqual([], synthetics.desugar(ir)) - offset_of_b = ir.module[0].type[0].structure.field[1].location.start - self.assertTrue(offset_of_b.HasField("function")) - self.assertEqual(offset_of_b.function.function, ir_data.FunctionMapping.ADDITION) - self.assertEqual(offset_of_b.function.args[0].constant.value, "1") - self.assertEqual(offset_of_b.function.args[1].constant.value, "2") - offset_of_c = ir.module[0].type[0].structure.field[2].location.start - self.assertEqual( - offset_of_c.function.args[0].function.args[0].constant.value, "1") - self.assertEqual( - offset_of_c.function.args[0].function.args[1].constant.value, "2") - self.assertEqual(offset_of_c.function.args[1].constant.value, "4") + def test_replaces_next(self): + ir = self._make_ir( + "struct Foo:\n" + " 1 [+2] UInt:8[] a\n" + " $next [+4] UInt b\n" + " $next [+1] UInt c\n" + ) + self.assertEqual([], synthetics.desugar(ir)) + offset_of_b = ir.module[0].type[0].structure.field[1].location.start + self.assertTrue(offset_of_b.HasField("function")) + self.assertEqual( + offset_of_b.function.function, ir_data.FunctionMapping.ADDITION + ) + self.assertEqual(offset_of_b.function.args[0].constant.value, "1") + self.assertEqual(offset_of_b.function.args[1].constant.value, "2") + offset_of_c = ir.module[0].type[0].structure.field[2].location.start + self.assertEqual( + offset_of_c.function.args[0].function.args[0].constant.value, "1" + ) + self.assertEqual( + offset_of_c.function.args[0].function.args[1].constant.value, "2" + ) + self.assertEqual(offset_of_c.function.args[1].constant.value, "4") - def test_next_in_first_field(self): - ir = self._make_ir("struct Foo:\n" - " $next [+2] UInt:8[] a\n" - " $next [+4] UInt b\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[0].location.start.source_location, - "`$next` may not be used in the first physical field of " + - "a structure; perhaps you meant `0`?"), - ]], synthetics.desugar(ir)) + def test_next_in_first_field(self): + ir = self._make_ir( + "struct Foo:\n" " $next [+2] UInt:8[] a\n" " $next [+4] UInt b\n" + ) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[0].location.start.source_location, + "`$next` may not be used in the first physical field of " + + "a structure; perhaps you meant `0`?", + ), + ] + ], + synthetics.desugar(ir), + ) - def test_next_in_size(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+2] UInt:8[] a\n" - " 1 [+$next] UInt b\n") - struct = ir.module[0].type[0].structure - self.assertEqual([[ - error.error("m.emb", struct.field[1].location.size.source_location, - "`$next` may only be used in the start expression of a " + - "physical field."), - ]], synthetics.desugar(ir)) + def test_next_in_size(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+2] UInt:8[] a\n" " 1 [+$next] UInt b\n" + ) + struct = ir.module[0].type[0].structure + self.assertEqual( + [ + [ + error.error( + "m.emb", + struct.field[1].location.size.source_location, + "`$next` may only be used in the start expression of a " + + "physical field.", + ), + ] + ], + synthetics.desugar(ir), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/tokenizer.py b/compiler/front_end/tokenizer.py index 752371f..defee34 100644 --- a/compiler/front_end/tokenizer.py +++ b/compiler/front_end/tokenizer.py @@ -36,87 +36,123 @@ def tokenize(text, file_name): - # TODO(bolms): suppress end-of-line, indent, and dedent tokens between matched - # delimiters ([], (), and {}). - """Tokenizes its argument. - - Arguments: - text: The raw text of a .emb file. - file_name: The name of the file to use in errors. - - Returns: - A tuple of: - a list of parser_types.Tokens or None - a possibly-empty list of errors. - """ - tokens = [] - indent_stack = [""] - line_number = 0 - for line in text.splitlines(): - line_number += 1 - - # _tokenize_line splits the actual text into tokens. - line_tokens, errors = _tokenize_line(line, line_number, file_name) - if errors: - return None, errors - - # Lines with only whitespace and comments are not used for Indent/Dedent - # calculation, and do not produce end-of-line tokens. - for token in line_tokens: - if token.symbol != "Comment": - break - else: - tokens.extend(line_tokens) - tokens.append(parser_types.Token( - '"\\n"', "\n", parser_types.make_location( - (line_number, len(line) + 1), (line_number, len(line) + 1)))) - continue - - # Leading whitespace is whatever .lstrip() removes. - leading_whitespace = line[0:len(line) - len(line.lstrip())] - if leading_whitespace == indent_stack[-1]: - # If the current leading whitespace is equal to the last leading - # whitespace, do not emit an Indent or Dedent token. - pass - elif leading_whitespace.startswith(indent_stack[-1]): - # If the current leading whitespace is longer than the last leading - # whitespace, emit an Indent token. For the token text, take the new - # part of the whitespace. - tokens.append( - parser_types.Token( - "Indent", leading_whitespace[len(indent_stack[-1]):], - parser_types.make_location( - (line_number, len(indent_stack[-1]) + 1), - (line_number, len(leading_whitespace) + 1)))) - indent_stack.append(leading_whitespace) - else: - # Otherwise, search for the unclosed indentation level that matches - # the current indentation level. Emit a Dedent token for each - # newly-closed indentation level. - for i in range(len(indent_stack) - 1, -1, -1): - if leading_whitespace == indent_stack[i]: - break + # TODO(bolms): suppress end-of-line, indent, and dedent tokens between matched + # delimiters ([], (), and {}). + """Tokenizes its argument. + + Arguments: + text: The raw text of a .emb file. + file_name: The name of the file to use in errors. + + Returns: + A tuple of: + a list of parser_types.Tokens or None + a possibly-empty list of errors. + """ + tokens = [] + indent_stack = [""] + line_number = 0 + for line in text.splitlines(): + line_number += 1 + + # _tokenize_line splits the actual text into tokens. + line_tokens, errors = _tokenize_line(line, line_number, file_name) + if errors: + return None, errors + + # Lines with only whitespace and comments are not used for Indent/Dedent + # calculation, and do not produce end-of-line tokens. + for token in line_tokens: + if token.symbol != "Comment": + break + else: + tokens.extend(line_tokens) + tokens.append( + parser_types.Token( + '"\\n"', + "\n", + parser_types.make_location( + (line_number, len(line) + 1), (line_number, len(line) + 1) + ), + ) + ) + continue + + # Leading whitespace is whatever .lstrip() removes. + leading_whitespace = line[0 : len(line) - len(line.lstrip())] + if leading_whitespace == indent_stack[-1]: + # If the current leading whitespace is equal to the last leading + # whitespace, do not emit an Indent or Dedent token. + pass + elif leading_whitespace.startswith(indent_stack[-1]): + # If the current leading whitespace is longer than the last leading + # whitespace, emit an Indent token. For the token text, take the new + # part of the whitespace. + tokens.append( + parser_types.Token( + "Indent", + leading_whitespace[len(indent_stack[-1]) :], + parser_types.make_location( + (line_number, len(indent_stack[-1]) + 1), + (line_number, len(leading_whitespace) + 1), + ), + ) + ) + indent_stack.append(leading_whitespace) + else: + # Otherwise, search for the unclosed indentation level that matches + # the current indentation level. Emit a Dedent token for each + # newly-closed indentation level. + for i in range(len(indent_stack) - 1, -1, -1): + if leading_whitespace == indent_stack[i]: + break + tokens.append( + parser_types.Token( + "Dedent", + "", + parser_types.make_location( + (line_number, len(leading_whitespace) + 1), + (line_number, len(leading_whitespace) + 1), + ), + ) + ) + del indent_stack[i] + else: + return None, [ + [ + error.error( + file_name, + parser_types.make_location( + (line_number, 1), + (line_number, len(leading_whitespace) + 1), + ), + "Bad indentation", + ) + ] + ] + + tokens.extend(line_tokens) + + # Append an end-of-line token (for non-whitespace lines). tokens.append( - parser_types.Token("Dedent", "", parser_types.make_location( - (line_number, len(leading_whitespace) + 1), - (line_number, len(leading_whitespace) + 1)))) - del indent_stack[i] - else: - return None, [[error.error( - file_name, parser_types.make_location( - (line_number, 1), (line_number, len(leading_whitespace) + 1)), - "Bad indentation")]] - - tokens.extend(line_tokens) - - # Append an end-of-line token (for non-whitespace lines). - tokens.append(parser_types.Token( - '"\\n"', "\n", parser_types.make_location( - (line_number, len(line) + 1), (line_number, len(line) + 1)))) - for i in range(len(indent_stack) - 1): - tokens.append(parser_types.Token("Dedent", "", parser_types.make_location( - (line_number + 1, 1), (line_number + 1, 1)))) - return tokens, [] + parser_types.Token( + '"\\n"', + "\n", + parser_types.make_location( + (line_number, len(line) + 1), (line_number, len(line) + 1) + ), + ) + ) + for i in range(len(indent_stack) - 1): + tokens.append( + parser_types.Token( + "Dedent", + "", + parser_types.make_location((line_number + 1, 1), (line_number + 1, 1)), + ) + ) + return tokens, [] + # Token patterns used by _tokenize_line. LITERAL_TOKEN_PATTERNS = ( @@ -125,7 +161,8 @@ def tokenize(text, file_name): "$max $present $upper_bound $lower_bound $next " "$size_in_bits $size_in_bytes " "$max_size_in_bits $max_size_in_bytes $min_size_in_bits $min_size_in_bytes " - "$default struct bits enum external import as if let").split() + "$default struct bits enum external import as if let" +).split() _T = collections.namedtuple("T", ["regex", "symbol"]) REGEX_TOKEN_PATTERNS = [ # Words starting with variations of "emboss reserved" are reserved for @@ -168,56 +205,68 @@ def tokenize(text, file_name): def _tokenize_line(line, line_number, file_name): - """Tokenizes a single line of input. - - Arguments: - line: The line of text to tokenize. - line_number: The line number (used when constructing token objects). - file_name: The name of a file to use in errors. - - Returns: - A tuple of: - A list of token objects or None. - A possibly-empty list of errors. - """ - tokens = [] - offset = 0 - while offset < len(line): - best_candidate = "" - best_candidate_symbol = None - # Find the longest match. Ties go to the first match. This way, keywords - # ("struct") are matched as themselves, but words that only happen to start - # with keywords ("structure") are matched as words. - # - # There is never a reason to try to match a literal after a regex that - # could also match that literal, so check literals first. - for literal in LITERAL_TOKEN_PATTERNS: - if line[offset:].startswith(literal) and len(literal) > len( - best_candidate): - best_candidate = literal - # For Emboss, the name of a literal token is just the literal in quotes, - # so that the grammar can read a little more naturally, e.g.: - # - # expression -> expression "+" expression - # - # instead of + """Tokenizes a single line of input. + + Arguments: + line: The line of text to tokenize. + line_number: The line number (used when constructing token objects). + file_name: The name of a file to use in errors. + + Returns: + A tuple of: + A list of token objects or None. + A possibly-empty list of errors. + """ + tokens = [] + offset = 0 + while offset < len(line): + best_candidate = "" + best_candidate_symbol = None + # Find the longest match. Ties go to the first match. This way, keywords + # ("struct") are matched as themselves, but words that only happen to start + # with keywords ("structure") are matched as words. # - # expression -> expression Plus expression - best_candidate_symbol = '"' + literal + '"' - for pattern in REGEX_TOKEN_PATTERNS: - match_result = pattern.regex.match(line[offset:]) - if match_result and len(match_result.group(0)) > len(best_candidate): - best_candidate = match_result.group(0) - best_candidate_symbol = pattern.symbol - if not best_candidate: - return None, [[error.error( - file_name, parser_types.make_location( - (line_number, offset + 1), (line_number, offset + 2)), - "Unrecognized token")]] - if best_candidate_symbol: - tokens.append(parser_types.Token( - best_candidate_symbol, best_candidate, parser_types.make_location( - (line_number, offset + 1), - (line_number, offset + len(best_candidate) + 1)))) - offset += len(best_candidate) - return tokens, None + # There is never a reason to try to match a literal after a regex that + # could also match that literal, so check literals first. + for literal in LITERAL_TOKEN_PATTERNS: + if line[offset:].startswith(literal) and len(literal) > len(best_candidate): + best_candidate = literal + # For Emboss, the name of a literal token is just the literal in quotes, + # so that the grammar can read a little more naturally, e.g.: + # + # expression -> expression "+" expression + # + # instead of + # + # expression -> expression Plus expression + best_candidate_symbol = '"' + literal + '"' + for pattern in REGEX_TOKEN_PATTERNS: + match_result = pattern.regex.match(line[offset:]) + if match_result and len(match_result.group(0)) > len(best_candidate): + best_candidate = match_result.group(0) + best_candidate_symbol = pattern.symbol + if not best_candidate: + return None, [ + [ + error.error( + file_name, + parser_types.make_location( + (line_number, offset + 1), (line_number, offset + 2) + ), + "Unrecognized token", + ) + ] + ] + if best_candidate_symbol: + tokens.append( + parser_types.Token( + best_candidate_symbol, + best_candidate, + parser_types.make_location( + (line_number, offset + 1), + (line_number, offset + len(best_candidate) + 1), + ), + ) + ) + offset += len(best_candidate) + return tokens, None diff --git a/compiler/front_end/tokenizer_test.py b/compiler/front_end/tokenizer_test.py index 91e1b3b..6a301e6 100644 --- a/compiler/front_end/tokenizer_test.py +++ b/compiler/front_end/tokenizer_test.py @@ -21,362 +21,479 @@ def _token_symbols(token_list): - """Given a list of tokens, returns a list of their symbol names.""" - return [token.symbol for token in token_list] + """Given a list of tokens, returns a list of their symbol names.""" + return [token.symbol for token in token_list] class TokenizerTest(unittest.TestCase): - """Tests for the tokenizer.tokenize function.""" - - def test_bad_indent_tab_versus_space(self): - # A bad indent is one that doesn't match a previous unmatched indent. - tokens, errors = tokenizer.tokenize(" a\n\tb", "file") - self.assertFalse(tokens) - self.assertEqual([[error.error("file", parser_types.make_location( - (2, 1), (2, 2)), "Bad indentation")]], errors) - - def test_bad_indent_tab_versus_eight_spaces(self): - tokens, errors = tokenizer.tokenize(" a\n\tb", "file") - self.assertFalse(tokens) - self.assertEqual([[error.error("file", parser_types.make_location( - (2, 1), (2, 2)), "Bad indentation")]], errors) - - def test_bad_indent_tab_versus_four_spaces(self): - tokens, errors = tokenizer.tokenize(" a\n\tb", "file") - self.assertFalse(tokens) - self.assertEqual([[error.error("file", parser_types.make_location( - (2, 1), (2, 2)), "Bad indentation")]], errors) - - def test_bad_indent_two_spaces_versus_one_space(self): - tokens, errors = tokenizer.tokenize(" a\n b", "file") - self.assertFalse(tokens) - self.assertEqual([[error.error("file", parser_types.make_location( - (2, 1), (2, 2)), "Bad indentation")]], errors) - - def test_bad_indent_matches_closed_indent(self): - tokens, errors = tokenizer.tokenize(" a\nb\n c\n d", "file") - self.assertFalse(tokens) - self.assertEqual([[error.error("file", parser_types.make_location( - (4, 1), (4, 2)), "Bad indentation")]], errors) - - def test_bad_string_after_string_with_escaped_backslash_at_end(self): - tokens, errors = tokenizer.tokenize(r'"\\""', "name") - self.assertFalse(tokens) - self.assertEqual([[error.error("name", parser_types.make_location( - (1, 5), (1, 6)), "Unrecognized token")]], errors) + """Tests for the tokenizer.tokenize function.""" + + def test_bad_indent_tab_versus_space(self): + # A bad indent is one that doesn't match a previous unmatched indent. + tokens, errors = tokenizer.tokenize(" a\n\tb", "file") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "file", + parser_types.make_location((2, 1), (2, 2)), + "Bad indentation", + ) + ] + ], + errors, + ) + + def test_bad_indent_tab_versus_eight_spaces(self): + tokens, errors = tokenizer.tokenize(" a\n\tb", "file") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "file", + parser_types.make_location((2, 1), (2, 2)), + "Bad indentation", + ) + ] + ], + errors, + ) + + def test_bad_indent_tab_versus_four_spaces(self): + tokens, errors = tokenizer.tokenize(" a\n\tb", "file") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "file", + parser_types.make_location((2, 1), (2, 2)), + "Bad indentation", + ) + ] + ], + errors, + ) + + def test_bad_indent_two_spaces_versus_one_space(self): + tokens, errors = tokenizer.tokenize(" a\n b", "file") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "file", + parser_types.make_location((2, 1), (2, 2)), + "Bad indentation", + ) + ] + ], + errors, + ) + + def test_bad_indent_matches_closed_indent(self): + tokens, errors = tokenizer.tokenize(" a\nb\n c\n d", "file") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "file", + parser_types.make_location((4, 1), (4, 2)), + "Bad indentation", + ) + ] + ], + errors, + ) + + def test_bad_string_after_string_with_escaped_backslash_at_end(self): + tokens, errors = tokenizer.tokenize(r'"\\""', "name") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "name", + parser_types.make_location((1, 5), (1, 6)), + "Unrecognized token", + ) + ] + ], + errors, + ) def _make_short_token_match_tests(): - """Makes tests for short, simple tokenization cases.""" - eol = '"\\n"' - cases = { - "Cam": ["CamelWord", eol], - "Ca9": ["CamelWord", eol], - "CanB": ["CamelWord", eol], - "CanBee": ["CamelWord", eol], - "CBa": ["CamelWord", eol], - "cam": ["SnakeWord", eol], - "ca9": ["SnakeWord", eol], - "can_b": ["SnakeWord", eol], - "can_bee": ["SnakeWord", eol], - "c_ba": ["SnakeWord", eol], - "cba_": ["SnakeWord", eol], - "c_b_a_": ["SnakeWord", eol], - "CAM": ["ShoutyWord", eol], - "CA9": ["ShoutyWord", eol], - "CAN_B": ["ShoutyWord", eol], - "CAN_BEE": ["ShoutyWord", eol], - "C_BA": ["ShoutyWord", eol], - "C": ["BadWord", eol], - "C1": ["BadWord", eol], - "c": ["SnakeWord", eol], - "$": ["BadWord", eol], - "_": ["BadWord", eol], - "_a": ["BadWord", eol], - "_A": ["BadWord", eol], - "Cb_A": ["BadWord", eol], - "aCb": ["BadWord", eol], - "a b": ["SnakeWord", "SnakeWord", eol], - "a\tb": ["SnakeWord", "SnakeWord", eol], - "a \t b ": ["SnakeWord", "SnakeWord", eol], - " \t ": [eol], - "a #b": ["SnakeWord", "Comment", eol], - "a#": ["SnakeWord", "Comment", eol], - "# b": ["Comment", eol], - " # b": ["Comment", eol], - " #": ["Comment", eol], - "": [], - "\n": [eol], - "\na": [eol, "SnakeWord", eol], - "a--example": ["SnakeWord", "BadDocumentation", eol], - "a ---- example": ["SnakeWord", "BadDocumentation", eol], - "a --- example": ["SnakeWord", "BadDocumentation", eol], - "a-- example": ["SnakeWord", "Documentation", eol], - "a -- -- example": ["SnakeWord", "Documentation", eol], - "a -- - example": ["SnakeWord", "Documentation", eol], - "--": ["Documentation", eol], - "-- ": ["Documentation", eol], - "-- ": ["Documentation", eol], - "$default": ['"$default"', eol], - "$defaultx": ["BadWord", eol], - "$def": ["BadWord", eol], - "x$default": ["BadWord", eol], - "9$default": ["BadWord", eol], - "struct": ['"struct"', eol], - "external": ['"external"', eol], - "bits": ['"bits"', eol], - "enum": ['"enum"', eol], - "as": ['"as"', eol], - "import": ['"import"', eol], - "true": ["BooleanConstant", eol], - "false": ["BooleanConstant", eol], - "truex": ["SnakeWord", eol], - "falsex": ["SnakeWord", eol], - "structx": ["SnakeWord", eol], - "bitsx": ["SnakeWord", eol], - "enumx": ["SnakeWord", eol], - "0b": ["BadNumber", eol], - "0x": ["BadNumber", eol], - "0b011101": ["Number", eol], - "0b0": ["Number", eol], - "0b0111_1111_0000": ["Number", eol], - "0b00_000_00": ["BadNumber", eol], - "0b0_0_0": ["BadNumber", eol], - "0b0111012": ["BadNumber", eol], - "0b011101x": ["BadWord", eol], - "0b011101b": ["BadNumber", eol], - "0B0": ["BadNumber", eol], - "0X0": ["BadNumber", eol], - "0b_": ["BadNumber", eol], - "0x_": ["BadNumber", eol], - "0b__": ["BadNumber", eol], - "0x__": ["BadNumber", eol], - "0b_0000": ["Number", eol], - "0b0000_": ["BadNumber", eol], - "0b00_____00": ["BadNumber", eol], - "0x00_000_00": ["BadNumber", eol], - "0x0_0_0": ["BadNumber", eol], - "0b____0____": ["BadNumber", eol], - "0b00000000000000000000": ["Number", eol], - "0b_00000000": ["Number", eol], - "0b0000_0000_0000": ["Number", eol], - "0b000_0000_0000": ["Number", eol], - "0b00_0000_0000": ["Number", eol], - "0b0_0000_0000": ["Number", eol], - "0b_0000_0000_0000": ["Number", eol], - "0b_000_0000_0000": ["Number", eol], - "0b_00_0000_0000": ["Number", eol], - "0b_0_0000_0000": ["Number", eol], - "0b00000000_00000000_00000000": ["Number", eol], - "0b0000000_00000000_00000000": ["Number", eol], - "0b000000_00000000_00000000": ["Number", eol], - "0b00000_00000000_00000000": ["Number", eol], - "0b0000_00000000_00000000": ["Number", eol], - "0b000_00000000_00000000": ["Number", eol], - "0b00_00000000_00000000": ["Number", eol], - "0b0_00000000_00000000": ["Number", eol], - "0b_00000000_00000000_00000000": ["Number", eol], - "0b_0000000_00000000_00000000": ["Number", eol], - "0b_000000_00000000_00000000": ["Number", eol], - "0b_00000_00000000_00000000": ["Number", eol], - "0b_0000_00000000_00000000": ["Number", eol], - "0b_000_00000000_00000000": ["Number", eol], - "0b_00_00000000_00000000": ["Number", eol], - "0b_0_00000000_00000000": ["Number", eol], - "0x0": ["Number", eol], - "0x00000000000000000000": ["Number", eol], - "0x_0000": ["Number", eol], - "0x_00000000": ["Number", eol], - "0x0000_0000_0000": ["Number", eol], - "0x000_0000_0000": ["Number", eol], - "0x00_0000_0000": ["Number", eol], - "0x0_0000_0000": ["Number", eol], - "0x_0000_0000_0000": ["Number", eol], - "0x_000_0000_0000": ["Number", eol], - "0x_00_0000_0000": ["Number", eol], - "0x_0_0000_0000": ["Number", eol], - "0x00000000_00000000_00000000": ["Number", eol], - "0x0000000_00000000_00000000": ["Number", eol], - "0x000000_00000000_00000000": ["Number", eol], - "0x00000_00000000_00000000": ["Number", eol], - "0x0000_00000000_00000000": ["Number", eol], - "0x000_00000000_00000000": ["Number", eol], - "0x00_00000000_00000000": ["Number", eol], - "0x0_00000000_00000000": ["Number", eol], - "0x_00000000_00000000_00000000": ["Number", eol], - "0x_0000000_00000000_00000000": ["Number", eol], - "0x_000000_00000000_00000000": ["Number", eol], - "0x_00000_00000000_00000000": ["Number", eol], - "0x_0000_00000000_00000000": ["Number", eol], - "0x_000_00000000_00000000": ["Number", eol], - "0x_00_00000000_00000000": ["Number", eol], - "0x_0_00000000_00000000": ["Number", eol], - "0x__00000000_00000000": ["BadNumber", eol], - "0x00000000_00000000_0000": ["BadNumber", eol], - "0x00000000_0000_0000": ["BadNumber", eol], - "0x_00000000000000000000": ["BadNumber", eol], - "0b_00000000000000000000": ["BadNumber", eol], - "0b00000000_00000000_0000": ["BadNumber", eol], - "0b00000000_0000_0000": ["BadNumber", eol], - "0x0000_": ["BadNumber", eol], - "0x00_____00": ["BadNumber", eol], - "0x____0____": ["BadNumber", eol], - "EmbossReserved": ["BadWord", eol], - "EmbossReservedA": ["BadWord", eol], - "EmbossReserved_": ["BadWord", eol], - "EMBOSS_RESERVED": ["BadWord", eol], - "EMBOSS_RESERVED_": ["BadWord", eol], - "EMBOSS_RESERVEDA": ["BadWord", eol], - "emboss_reserved": ["BadWord", eol], - "emboss_reserved_": ["BadWord", eol], - "emboss_reserveda": ["BadWord", eol], - "0x0123456789abcdefABCDEF": ["Number", eol], - "0": ["Number", eol], - "1": ["Number", eol], - "1a": ["BadNumber", eol], - "1g": ["BadWord", eol], - "1234567890": ["Number", eol], - "1_234_567_890": ["Number", eol], - "234_567_890": ["Number", eol], - "34_567_890": ["Number", eol], - "4_567_890": ["Number", eol], - "1_2_3_4_5_6_7_8_9_0": ["BadNumber", eol], - "1234567890_": ["BadNumber", eol], - "1__234567890": ["BadNumber", eol], - "_1234567890": ["BadWord", eol], - "[]": ['"["', '"]"', eol], - "()": ['"("', '")"', eol], - "..": ['"."', '"."', eol], - "...": ['"."', '"."', '"."', eol], - "....": ['"."', '"."', '"."', '"."', eol], - '"abc"': ["String", eol], - '""': ["String", eol], - r'"\\"': ["String", eol], - r'"\""': ["String", eol], - r'"\n"': ["String", eol], - r'"\\n"': ["String", eol], - r'"\\xyz"': ["String", eol], - r'"\\\\"': ["String", eol], - } - for c in ("[ ] ( ) ? : = + - * . == != < <= > >= && || , $max $present " - "$upper_bound $lower_bound $size_in_bits $size_in_bytes " - "$max_size_in_bits $max_size_in_bytes $min_size_in_bits " - "$min_size_in_bytes " - "$default struct bits enum external import as if let").split(): - cases[c] = ['"' + c + '"', eol] - - def make_test_case(case): - - def test_case(self): - tokens, errors = tokenizer.tokenize(case, "name") - symbols = _token_symbols(tokens) - self.assertFalse(errors) - self.assertEqual(symbols, cases[case]) - - return test_case - - for c in cases: - setattr(TokenizerTest, "testShortTokenMatch{!r}".format(c), - make_test_case(c)) + """Makes tests for short, simple tokenization cases.""" + eol = '"\\n"' + cases = { + "Cam": ["CamelWord", eol], + "Ca9": ["CamelWord", eol], + "CanB": ["CamelWord", eol], + "CanBee": ["CamelWord", eol], + "CBa": ["CamelWord", eol], + "cam": ["SnakeWord", eol], + "ca9": ["SnakeWord", eol], + "can_b": ["SnakeWord", eol], + "can_bee": ["SnakeWord", eol], + "c_ba": ["SnakeWord", eol], + "cba_": ["SnakeWord", eol], + "c_b_a_": ["SnakeWord", eol], + "CAM": ["ShoutyWord", eol], + "CA9": ["ShoutyWord", eol], + "CAN_B": ["ShoutyWord", eol], + "CAN_BEE": ["ShoutyWord", eol], + "C_BA": ["ShoutyWord", eol], + "C": ["BadWord", eol], + "C1": ["BadWord", eol], + "c": ["SnakeWord", eol], + "$": ["BadWord", eol], + "_": ["BadWord", eol], + "_a": ["BadWord", eol], + "_A": ["BadWord", eol], + "Cb_A": ["BadWord", eol], + "aCb": ["BadWord", eol], + "a b": ["SnakeWord", "SnakeWord", eol], + "a\tb": ["SnakeWord", "SnakeWord", eol], + "a \t b ": ["SnakeWord", "SnakeWord", eol], + " \t ": [eol], + "a #b": ["SnakeWord", "Comment", eol], + "a#": ["SnakeWord", "Comment", eol], + "# b": ["Comment", eol], + " # b": ["Comment", eol], + " #": ["Comment", eol], + "": [], + "\n": [eol], + "\na": [eol, "SnakeWord", eol], + "a--example": ["SnakeWord", "BadDocumentation", eol], + "a ---- example": ["SnakeWord", "BadDocumentation", eol], + "a --- example": ["SnakeWord", "BadDocumentation", eol], + "a-- example": ["SnakeWord", "Documentation", eol], + "a -- -- example": ["SnakeWord", "Documentation", eol], + "a -- - example": ["SnakeWord", "Documentation", eol], + "--": ["Documentation", eol], + "-- ": ["Documentation", eol], + "-- ": ["Documentation", eol], + "$default": ['"$default"', eol], + "$defaultx": ["BadWord", eol], + "$def": ["BadWord", eol], + "x$default": ["BadWord", eol], + "9$default": ["BadWord", eol], + "struct": ['"struct"', eol], + "external": ['"external"', eol], + "bits": ['"bits"', eol], + "enum": ['"enum"', eol], + "as": ['"as"', eol], + "import": ['"import"', eol], + "true": ["BooleanConstant", eol], + "false": ["BooleanConstant", eol], + "truex": ["SnakeWord", eol], + "falsex": ["SnakeWord", eol], + "structx": ["SnakeWord", eol], + "bitsx": ["SnakeWord", eol], + "enumx": ["SnakeWord", eol], + "0b": ["BadNumber", eol], + "0x": ["BadNumber", eol], + "0b011101": ["Number", eol], + "0b0": ["Number", eol], + "0b0111_1111_0000": ["Number", eol], + "0b00_000_00": ["BadNumber", eol], + "0b0_0_0": ["BadNumber", eol], + "0b0111012": ["BadNumber", eol], + "0b011101x": ["BadWord", eol], + "0b011101b": ["BadNumber", eol], + "0B0": ["BadNumber", eol], + "0X0": ["BadNumber", eol], + "0b_": ["BadNumber", eol], + "0x_": ["BadNumber", eol], + "0b__": ["BadNumber", eol], + "0x__": ["BadNumber", eol], + "0b_0000": ["Number", eol], + "0b0000_": ["BadNumber", eol], + "0b00_____00": ["BadNumber", eol], + "0x00_000_00": ["BadNumber", eol], + "0x0_0_0": ["BadNumber", eol], + "0b____0____": ["BadNumber", eol], + "0b00000000000000000000": ["Number", eol], + "0b_00000000": ["Number", eol], + "0b0000_0000_0000": ["Number", eol], + "0b000_0000_0000": ["Number", eol], + "0b00_0000_0000": ["Number", eol], + "0b0_0000_0000": ["Number", eol], + "0b_0000_0000_0000": ["Number", eol], + "0b_000_0000_0000": ["Number", eol], + "0b_00_0000_0000": ["Number", eol], + "0b_0_0000_0000": ["Number", eol], + "0b00000000_00000000_00000000": ["Number", eol], + "0b0000000_00000000_00000000": ["Number", eol], + "0b000000_00000000_00000000": ["Number", eol], + "0b00000_00000000_00000000": ["Number", eol], + "0b0000_00000000_00000000": ["Number", eol], + "0b000_00000000_00000000": ["Number", eol], + "0b00_00000000_00000000": ["Number", eol], + "0b0_00000000_00000000": ["Number", eol], + "0b_00000000_00000000_00000000": ["Number", eol], + "0b_0000000_00000000_00000000": ["Number", eol], + "0b_000000_00000000_00000000": ["Number", eol], + "0b_00000_00000000_00000000": ["Number", eol], + "0b_0000_00000000_00000000": ["Number", eol], + "0b_000_00000000_00000000": ["Number", eol], + "0b_00_00000000_00000000": ["Number", eol], + "0b_0_00000000_00000000": ["Number", eol], + "0x0": ["Number", eol], + "0x00000000000000000000": ["Number", eol], + "0x_0000": ["Number", eol], + "0x_00000000": ["Number", eol], + "0x0000_0000_0000": ["Number", eol], + "0x000_0000_0000": ["Number", eol], + "0x00_0000_0000": ["Number", eol], + "0x0_0000_0000": ["Number", eol], + "0x_0000_0000_0000": ["Number", eol], + "0x_000_0000_0000": ["Number", eol], + "0x_00_0000_0000": ["Number", eol], + "0x_0_0000_0000": ["Number", eol], + "0x00000000_00000000_00000000": ["Number", eol], + "0x0000000_00000000_00000000": ["Number", eol], + "0x000000_00000000_00000000": ["Number", eol], + "0x00000_00000000_00000000": ["Number", eol], + "0x0000_00000000_00000000": ["Number", eol], + "0x000_00000000_00000000": ["Number", eol], + "0x00_00000000_00000000": ["Number", eol], + "0x0_00000000_00000000": ["Number", eol], + "0x_00000000_00000000_00000000": ["Number", eol], + "0x_0000000_00000000_00000000": ["Number", eol], + "0x_000000_00000000_00000000": ["Number", eol], + "0x_00000_00000000_00000000": ["Number", eol], + "0x_0000_00000000_00000000": ["Number", eol], + "0x_000_00000000_00000000": ["Number", eol], + "0x_00_00000000_00000000": ["Number", eol], + "0x_0_00000000_00000000": ["Number", eol], + "0x__00000000_00000000": ["BadNumber", eol], + "0x00000000_00000000_0000": ["BadNumber", eol], + "0x00000000_0000_0000": ["BadNumber", eol], + "0x_00000000000000000000": ["BadNumber", eol], + "0b_00000000000000000000": ["BadNumber", eol], + "0b00000000_00000000_0000": ["BadNumber", eol], + "0b00000000_0000_0000": ["BadNumber", eol], + "0x0000_": ["BadNumber", eol], + "0x00_____00": ["BadNumber", eol], + "0x____0____": ["BadNumber", eol], + "EmbossReserved": ["BadWord", eol], + "EmbossReservedA": ["BadWord", eol], + "EmbossReserved_": ["BadWord", eol], + "EMBOSS_RESERVED": ["BadWord", eol], + "EMBOSS_RESERVED_": ["BadWord", eol], + "EMBOSS_RESERVEDA": ["BadWord", eol], + "emboss_reserved": ["BadWord", eol], + "emboss_reserved_": ["BadWord", eol], + "emboss_reserveda": ["BadWord", eol], + "0x0123456789abcdefABCDEF": ["Number", eol], + "0": ["Number", eol], + "1": ["Number", eol], + "1a": ["BadNumber", eol], + "1g": ["BadWord", eol], + "1234567890": ["Number", eol], + "1_234_567_890": ["Number", eol], + "234_567_890": ["Number", eol], + "34_567_890": ["Number", eol], + "4_567_890": ["Number", eol], + "1_2_3_4_5_6_7_8_9_0": ["BadNumber", eol], + "1234567890_": ["BadNumber", eol], + "1__234567890": ["BadNumber", eol], + "_1234567890": ["BadWord", eol], + "[]": ['"["', '"]"', eol], + "()": ['"("', '")"', eol], + "..": ['"."', '"."', eol], + "...": ['"."', '"."', '"."', eol], + "....": ['"."', '"."', '"."', '"."', eol], + '"abc"': ["String", eol], + '""': ["String", eol], + r'"\\"': ["String", eol], + r'"\""': ["String", eol], + r'"\n"': ["String", eol], + r'"\\n"': ["String", eol], + r'"\\xyz"': ["String", eol], + r'"\\\\"': ["String", eol], + } + for c in ( + "[ ] ( ) ? : = + - * . == != < <= > >= && || , $max $present " + "$upper_bound $lower_bound $size_in_bits $size_in_bytes " + "$max_size_in_bits $max_size_in_bytes $min_size_in_bits " + "$min_size_in_bytes " + "$default struct bits enum external import as if let" + ).split(): + cases[c] = ['"' + c + '"', eol] + + def make_test_case(case): + + def test_case(self): + tokens, errors = tokenizer.tokenize(case, "name") + symbols = _token_symbols(tokens) + self.assertFalse(errors) + self.assertEqual(symbols, cases[case]) + + return test_case + + for c in cases: + setattr(TokenizerTest, "testShortTokenMatch{!r}".format(c), make_test_case(c)) def _make_bad_char_tests(): - """Makes tests that an error is returned for bad characters.""" + """Makes tests that an error is returned for bad characters.""" - def make_test_case(case): + def make_test_case(case): - def test_case(self): - tokens, errors = tokenizer.tokenize(case, "name") - self.assertFalse(tokens) - self.assertEqual([[error.error("name", parser_types.make_location( - (1, 1), (1, 2)), "Unrecognized token")]], errors) + def test_case(self): + tokens, errors = tokenizer.tokenize(case, "name") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "name", + parser_types.make_location((1, 1), (1, 2)), + "Unrecognized token", + ) + ] + ], + errors, + ) - return test_case + return test_case - for c in "~`!@%^&\\|;'\"/{}": - setattr(TokenizerTest, "testBadChar{!r}".format(c), make_test_case(c)) + for c in "~`!@%^&\\|;'\"/{}": + setattr(TokenizerTest, "testBadChar{!r}".format(c), make_test_case(c)) def _make_bad_string_tests(): - """Makes tests that an error is returned for bad strings.""" - bad_strings = (r'"\"', '"\\\n"', r'"\\\"', r'"', r'"\q"', r'"\\\q"') + """Makes tests that an error is returned for bad strings.""" + bad_strings = (r'"\"', '"\\\n"', r'"\\\"', r'"', r'"\q"', r'"\\\q"') - def make_test_case(string): + def make_test_case(string): - def test_case(self): - tokens, errors = tokenizer.tokenize(string, "name") - self.assertFalse(tokens) - self.assertEqual([[error.error("name", parser_types.make_location( - (1, 1), (1, 2)), "Unrecognized token")]], errors) + def test_case(self): + tokens, errors = tokenizer.tokenize(string, "name") + self.assertFalse(tokens) + self.assertEqual( + [ + [ + error.error( + "name", + parser_types.make_location((1, 1), (1, 2)), + "Unrecognized token", + ) + ] + ], + errors, + ) - return test_case + return test_case - for s in bad_strings: - setattr(TokenizerTest, "testBadString{!r}".format(s), make_test_case(s)) + for s in bad_strings: + setattr(TokenizerTest, "testBadString{!r}".format(s), make_test_case(s)) def _make_multiline_tests(): - """Makes tests for indent/dedent insertion and eol insertion.""" - - c = "Comment" - eol = '"\\n"' - sw = "SnakeWord" - ind = "Indent" - ded = "Dedent" - cases = { - "a\nb\n": [sw, eol, sw, eol], - "a\n\nb\n": [sw, eol, eol, sw, eol], - "a\n#foo\nb\n": [sw, eol, c, eol, sw, eol], - "a\n #foo\nb\n": [sw, eol, c, eol, sw, eol], - "a\n b\n": [sw, eol, ind, sw, eol, ded], - "a\n b\n\n": [sw, eol, ind, sw, eol, eol, ded], - "a\n b\n c\n": [sw, eol, ind, sw, eol, ind, sw, eol, ded, ded], - "a\n b\n c\n": [sw, eol, ind, sw, eol, sw, eol, ded], - "a\n b\n\n c\n": [sw, eol, ind, sw, eol, eol, sw, eol, ded], - "a\n b\n #\n c\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded], - "a\n\tb\n #\n\tc\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded], - " a\n b\n c\n d\n": [ind, sw, eol, ind, sw, eol, ind, sw, eol, ded, - ded, sw, eol, ded], - } - - def make_test_case(case): - - def test_case(self): - tokens, errors = tokenizer.tokenize(case, "file") - self.assertFalse(errors) - self.assertEqual(_token_symbols(tokens), cases[case]) - - return test_case - - for c in cases: - setattr(TokenizerTest, "testMultiline{!r}".format(c), make_test_case(c)) + """Makes tests for indent/dedent insertion and eol insertion.""" + + c = "Comment" + eol = '"\\n"' + sw = "SnakeWord" + ind = "Indent" + ded = "Dedent" + cases = { + "a\nb\n": [sw, eol, sw, eol], + "a\n\nb\n": [sw, eol, eol, sw, eol], + "a\n#foo\nb\n": [sw, eol, c, eol, sw, eol], + "a\n #foo\nb\n": [sw, eol, c, eol, sw, eol], + "a\n b\n": [sw, eol, ind, sw, eol, ded], + "a\n b\n\n": [sw, eol, ind, sw, eol, eol, ded], + "a\n b\n c\n": [sw, eol, ind, sw, eol, ind, sw, eol, ded, ded], + "a\n b\n c\n": [sw, eol, ind, sw, eol, sw, eol, ded], + "a\n b\n\n c\n": [sw, eol, ind, sw, eol, eol, sw, eol, ded], + "a\n b\n #\n c\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded], + "a\n\tb\n #\n\tc\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded], + " a\n b\n c\n d\n": [ + ind, + sw, + eol, + ind, + sw, + eol, + ind, + sw, + eol, + ded, + ded, + sw, + eol, + ded, + ], + } + + def make_test_case(case): + + def test_case(self): + tokens, errors = tokenizer.tokenize(case, "file") + self.assertFalse(errors) + self.assertEqual(_token_symbols(tokens), cases[case]) + + return test_case + + for c in cases: + setattr(TokenizerTest, "testMultiline{!r}".format(c), make_test_case(c)) def _make_offset_tests(): - """Makes tests that the tokenizer fills in correct source locations.""" - cases = { - "a+": ["1:1-1:2", "1:2-1:3", "1:3-1:3"], - "a + ": ["1:1-1:2", "1:5-1:6", "1:9-1:9"], - "a\n\nb": ["1:1-1:2", "1:2-1:2", "2:1-2:1", "3:1-3:2", "3:2-3:2"], - "a\n b": ["1:1-1:2", "1:2-1:2", "2:1-2:3", "2:3-2:4", "2:4-2:4", - "3:1-3:1"], - "a\n b\nc": ["1:1-1:2", "1:2-1:2", "2:1-2:3", "2:3-2:4", "2:4-2:4", - "3:1-3:1", "3:1-3:2", "3:2-3:2"], - "a\n b\n c": ["1:1-1:2", "1:2-1:2", "2:1-2:2", "2:2-2:3", "2:3-2:3", - "3:2-3:3", "3:3-3:4", "3:4-3:4", "4:1-4:1", "4:1-4:1"], - } - - def make_test_case(case): - - def test_case(self): - self.assertEqual([parser_types.format_location(l.source_location) - for l in tokenizer.tokenize(case, "file")[0]], - cases[case]) - - return test_case - - for c in cases: - setattr(TokenizerTest, "testOffset{!r}".format(c), make_test_case(c)) + """Makes tests that the tokenizer fills in correct source locations.""" + cases = { + "a+": ["1:1-1:2", "1:2-1:3", "1:3-1:3"], + "a + ": ["1:1-1:2", "1:5-1:6", "1:9-1:9"], + "a\n\nb": ["1:1-1:2", "1:2-1:2", "2:1-2:1", "3:1-3:2", "3:2-3:2"], + "a\n b": ["1:1-1:2", "1:2-1:2", "2:1-2:3", "2:3-2:4", "2:4-2:4", "3:1-3:1"], + "a\n b\nc": [ + "1:1-1:2", + "1:2-1:2", + "2:1-2:3", + "2:3-2:4", + "2:4-2:4", + "3:1-3:1", + "3:1-3:2", + "3:2-3:2", + ], + "a\n b\n c": [ + "1:1-1:2", + "1:2-1:2", + "2:1-2:2", + "2:2-2:3", + "2:3-2:3", + "3:2-3:3", + "3:3-3:4", + "3:4-3:4", + "4:1-4:1", + "4:1-4:1", + ], + } + + def make_test_case(case): + + def test_case(self): + self.assertEqual( + [ + parser_types.format_location(l.source_location) + for l in tokenizer.tokenize(case, "file")[0] + ], + cases[case], + ) + + return test_case + + for c in cases: + setattr(TokenizerTest, "testOffset{!r}".format(c), make_test_case(c)) + _make_short_token_match_tests() _make_bad_char_tests() @@ -385,4 +502,4 @@ def test_case(self): _make_offset_tests() if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py index 9ad4aee..f562cc8 100644 --- a/compiler/front_end/type_check.py +++ b/compiler/front_end/type_check.py @@ -23,461 +23,633 @@ def _type_check_expression(expression, source_file_name, ir, errors): - """Checks and annotates the type of an expression and all subexpressions.""" - if ir_data_utils.reader(expression).type.WhichOneof("type"): - # This expression has already been type checked. - return - expression_variety = expression.WhichOneof("expression") - if expression_variety == "constant": - _type_check_integer_constant(expression) - elif expression_variety == "constant_reference": - _type_check_constant_reference(expression, source_file_name, ir, errors) - elif expression_variety == "function": - _type_check_operation(expression, source_file_name, ir, errors) - elif expression_variety == "field_reference": - _type_check_local_reference(expression, ir, errors) - elif expression_variety == "boolean_constant": - _type_check_boolean_constant(expression) - elif expression_variety == "builtin_reference": - _type_check_builtin_reference(expression) - else: - assert False, "Unknown expression variety {!r}".format(expression_variety) + """Checks and annotates the type of an expression and all subexpressions.""" + if ir_data_utils.reader(expression).type.WhichOneof("type"): + # This expression has already been type checked. + return + expression_variety = expression.WhichOneof("expression") + if expression_variety == "constant": + _type_check_integer_constant(expression) + elif expression_variety == "constant_reference": + _type_check_constant_reference(expression, source_file_name, ir, errors) + elif expression_variety == "function": + _type_check_operation(expression, source_file_name, ir, errors) + elif expression_variety == "field_reference": + _type_check_local_reference(expression, ir, errors) + elif expression_variety == "boolean_constant": + _type_check_boolean_constant(expression) + elif expression_variety == "builtin_reference": + _type_check_builtin_reference(expression) + else: + assert False, "Unknown expression variety {!r}".format(expression_variety) def _annotate_as_integer(expression): - ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType()) + ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType()) def _annotate_as_boolean(expression): - ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType()) + ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType()) -def _type_check(expression, source_file_name, errors, type_oneof, type_name, - expression_name): - if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof: - errors.append([ - error.error(source_file_name, expression.source_location, - "{} must be {}.".format(expression_name, type_name)) - ]) +def _type_check( + expression, source_file_name, errors, type_oneof, type_name, expression_name +): + if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof: + errors.append( + [ + error.error( + source_file_name, + expression.source_location, + "{} must be {}.".format(expression_name, type_name), + ) + ] + ) def _type_check_integer(expression, source_file_name, errors, expression_name): - _type_check(expression, source_file_name, errors, "integer", - "an integer", expression_name) + _type_check( + expression, source_file_name, errors, "integer", "an integer", expression_name + ) def _type_check_boolean(expression, source_file_name, errors, expression_name): - _type_check(expression, source_file_name, errors, "boolean", "a boolean", - expression_name) + _type_check( + expression, source_file_name, errors, "boolean", "a boolean", expression_name + ) -def _kind_check_field_reference(expression, source_file_name, errors, - expression_name): - if expression.WhichOneof("expression") != "field_reference": - errors.append([ - error.error(source_file_name, expression.source_location, - "{} must be a field.".format(expression_name)) - ]) +def _kind_check_field_reference(expression, source_file_name, errors, expression_name): + if expression.WhichOneof("expression") != "field_reference": + errors.append( + [ + error.error( + source_file_name, + expression.source_location, + "{} must be a field.".format(expression_name), + ) + ] + ) def _type_check_integer_constant(expression): - _annotate_as_integer(expression) + _annotate_as_integer(expression) def _type_check_constant_reference(expression, source_file_name, ir, errors): - """Annotates the type of a constant reference.""" - referred_name = expression.constant_reference.canonical_name - referred_object = ir_util.find_object(referred_name, ir) - if isinstance(referred_object, ir_data.EnumValue): - ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(expression.constant_reference) - del expression.type.enumeration.name.canonical_name.object_path[-1] - elif isinstance(referred_object, ir_data.Field): - if not ir_util.field_is_virtual(referred_object): - errors.append([ - error.error(source_file_name, expression.source_location, - "Static references to physical fields are not allowed."), - error.note(referred_name.module_file, referred_object.source_location, - "{} is a physical field.".format( - referred_name.object_path[-1])), - ]) - return - _type_check_expression(referred_object.read_transform, - referred_name.module_file, ir, errors) - ir_data_utils.builder(expression).type.CopyFrom(referred_object.read_transform.type) - else: - assert False, "Unexpected constant reference type." + """Annotates the type of a constant reference.""" + referred_name = expression.constant_reference.canonical_name + referred_object = ir_util.find_object(referred_name, ir) + if isinstance(referred_object, ir_data.EnumValue): + ir_data_utils.builder(expression).type.enumeration.name.CopyFrom( + expression.constant_reference + ) + del expression.type.enumeration.name.canonical_name.object_path[-1] + elif isinstance(referred_object, ir_data.Field): + if not ir_util.field_is_virtual(referred_object): + errors.append( + [ + error.error( + source_file_name, + expression.source_location, + "Static references to physical fields are not allowed.", + ), + error.note( + referred_name.module_file, + referred_object.source_location, + "{} is a physical field.".format(referred_name.object_path[-1]), + ), + ] + ) + return + _type_check_expression( + referred_object.read_transform, referred_name.module_file, ir, errors + ) + ir_data_utils.builder(expression).type.CopyFrom( + referred_object.read_transform.type + ) + else: + assert False, "Unexpected constant reference type." def _type_check_operation(expression, source_file_name, ir, errors): - for arg in expression.function.args: - _type_check_expression(arg, source_file_name, ir, errors) - function = expression.function.function - if function in (ir_data.FunctionMapping.EQUALITY, ir_data.FunctionMapping.INEQUALITY, - ir_data.FunctionMapping.LESS, ir_data.FunctionMapping.LESS_OR_EQUAL, - ir_data.FunctionMapping.GREATER, ir_data.FunctionMapping.GREATER_OR_EQUAL): - _type_check_comparison_operator(expression, source_file_name, errors) - elif function == ir_data.FunctionMapping.CHOICE: - _type_check_choice_operator(expression, source_file_name, errors) - else: - _type_check_monomorphic_operator(expression, source_file_name, errors) + for arg in expression.function.args: + _type_check_expression(arg, source_file_name, ir, errors) + function = expression.function.function + if function in ( + ir_data.FunctionMapping.EQUALITY, + ir_data.FunctionMapping.INEQUALITY, + ir_data.FunctionMapping.LESS, + ir_data.FunctionMapping.LESS_OR_EQUAL, + ir_data.FunctionMapping.GREATER, + ir_data.FunctionMapping.GREATER_OR_EQUAL, + ): + _type_check_comparison_operator(expression, source_file_name, errors) + elif function == ir_data.FunctionMapping.CHOICE: + _type_check_choice_operator(expression, source_file_name, errors) + else: + _type_check_monomorphic_operator(expression, source_file_name, errors) def _type_check_monomorphic_operator(expression, source_file_name, errors): - """Type checks an operator that accepts only one set of argument types.""" - args = expression.function.args - int_args = _type_check_integer - bool_args = _type_check_boolean - field_args = _kind_check_field_reference - int_result = _annotate_as_integer - bool_result = _annotate_as_boolean - binary = ("Left argument", "Right argument") - n_ary = ("Argument {}".format(n) for n in range(len(args))) - functions = { - ir_data.FunctionMapping.ADDITION: (int_result, int_args, binary, 2, 2, - "operator"), - ir_data.FunctionMapping.SUBTRACTION: (int_result, int_args, binary, 2, 2, - "operator"), - ir_data.FunctionMapping.MULTIPLICATION: (int_result, int_args, binary, 2, 2, - "operator"), - ir_data.FunctionMapping.AND: (bool_result, bool_args, binary, 2, 2, "operator"), - ir_data.FunctionMapping.OR: (bool_result, bool_args, binary, 2, 2, "operator"), - ir_data.FunctionMapping.MAXIMUM: (int_result, int_args, n_ary, 1, None, - "function"), - ir_data.FunctionMapping.PRESENCE: (bool_result, field_args, n_ary, 1, 1, - "function"), - ir_data.FunctionMapping.UPPER_BOUND: (int_result, int_args, n_ary, 1, 1, - "function"), - ir_data.FunctionMapping.LOWER_BOUND: (int_result, int_args, n_ary, 1, 1, - "function"), - } - function = expression.function.function - (set_result_type, check_arg, arg_names, min_args, max_args, - kind) = functions[function] - for argument, name in zip(args, arg_names): - assert name is not None, "Too many arguments to function!" - check_arg(argument, source_file_name, errors, - "{} of {} '{}'".format(name, kind, - expression.function.function_name.text)) - if len(args) < min_args: - errors.append([ - error.error(source_file_name, expression.source_location, + """Type checks an operator that accepts only one set of argument types.""" + args = expression.function.args + int_args = _type_check_integer + bool_args = _type_check_boolean + field_args = _kind_check_field_reference + int_result = _annotate_as_integer + bool_result = _annotate_as_boolean + binary = ("Left argument", "Right argument") + n_ary = ("Argument {}".format(n) for n in range(len(args))) + functions = { + ir_data.FunctionMapping.ADDITION: ( + int_result, + int_args, + binary, + 2, + 2, + "operator", + ), + ir_data.FunctionMapping.SUBTRACTION: ( + int_result, + int_args, + binary, + 2, + 2, + "operator", + ), + ir_data.FunctionMapping.MULTIPLICATION: ( + int_result, + int_args, + binary, + 2, + 2, + "operator", + ), + ir_data.FunctionMapping.AND: (bool_result, bool_args, binary, 2, 2, "operator"), + ir_data.FunctionMapping.OR: (bool_result, bool_args, binary, 2, 2, "operator"), + ir_data.FunctionMapping.MAXIMUM: ( + int_result, + int_args, + n_ary, + 1, + None, + "function", + ), + ir_data.FunctionMapping.PRESENCE: ( + bool_result, + field_args, + n_ary, + 1, + 1, + "function", + ), + ir_data.FunctionMapping.UPPER_BOUND: ( + int_result, + int_args, + n_ary, + 1, + 1, + "function", + ), + ir_data.FunctionMapping.LOWER_BOUND: ( + int_result, + int_args, + n_ary, + 1, + 1, + "function", + ), + } + function = expression.function.function + (set_result_type, check_arg, arg_names, min_args, max_args, kind) = functions[ + function + ] + for argument, name in zip(args, arg_names): + assert name is not None, "Too many arguments to function!" + check_arg( + argument, + source_file_name, + errors, + "{} of {} '{}'".format(name, kind, expression.function.function_name.text), + ) + if len(args) < min_args: + errors.append( + [ + error.error( + source_file_name, + expression.source_location, "{} '{}' requires {} {} argument{}.".format( - kind.title(), expression.function.function_name.text, + kind.title(), + expression.function.function_name.text, "exactly" if min_args == max_args else "at least", - min_args, "s" if min_args > 1 else "")) - ]) - if max_args is not None and len(args) > max_args: - errors.append([ - error.error(source_file_name, expression.source_location, + min_args, + "s" if min_args > 1 else "", + ), + ) + ] + ) + if max_args is not None and len(args) > max_args: + errors.append( + [ + error.error( + source_file_name, + expression.source_location, "{} '{}' requires {} {} argument{}.".format( - kind.title(), expression.function.function_name.text, + kind.title(), + expression.function.function_name.text, "exactly" if min_args == max_args else "at most", - max_args, "s" if max_args > 1 else "")) - ]) - set_result_type(expression) + max_args, + "s" if max_args > 1 else "", + ), + ) + ] + ) + set_result_type(expression) def _type_check_local_reference(expression, ir, errors): - """Annotates the type of a local reference.""" - referrent = ir_util.find_object(expression.field_reference.path[-1], ir) - assert referrent, "Local reference should be non-None after name resolution." - if isinstance(referrent, ir_data.RuntimeParameter): - parameter = referrent - _set_expression_type_from_physical_type_reference( - expression, parameter.physical_type_alias.atomic_type.reference, ir) - return - field = referrent - if ir_util.field_is_virtual(field): - _type_check_expression(field.read_transform, - expression.field_reference.path[0], ir, errors) - ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type) - return - if not field.type.HasField("atomic_type"): - ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType()) - else: - _set_expression_type_from_physical_type_reference( - expression, field.type.atomic_type.reference, ir) + """Annotates the type of a local reference.""" + referrent = ir_util.find_object(expression.field_reference.path[-1], ir) + assert referrent, "Local reference should be non-None after name resolution." + if isinstance(referrent, ir_data.RuntimeParameter): + parameter = referrent + _set_expression_type_from_physical_type_reference( + expression, parameter.physical_type_alias.atomic_type.reference, ir + ) + return + field = referrent + if ir_util.field_is_virtual(field): + _type_check_expression( + field.read_transform, expression.field_reference.path[0], ir, errors + ) + ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type) + return + if not field.type.HasField("atomic_type"): + ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType()) + else: + _set_expression_type_from_physical_type_reference( + expression, field.type.atomic_type.reference, ir + ) def unbounded_expression_type_for_physical_type(type_definition): - """Gets the ExpressionType for a field of the given TypeDefinition. - - Arguments: - type_definition: an ir_data.AddressableUnit. - - Returns: - An ir_data.ExpressionType with the corresponding expression type filled in: - for example, [prelude].UInt will result in an ExpressionType with the - `integer` field filled in. - - The returned ExpressionType will not have any bounds set. - """ - # TODO(bolms): Add a `[value_type]` attribute for `external`s. - if ir_util.get_boolean_attribute(type_definition.attribute, - attributes.IS_INTEGER): - return ir_data.ExpressionType(integer=ir_data.IntegerType()) - elif tuple(type_definition.name.canonical_name.object_path) == ("Flag",): - # This is a hack: the Flag type should say that it is a boolean. - return ir_data.ExpressionType(boolean=ir_data.BooleanType()) - elif type_definition.HasField("enumeration"): - return ir_data.ExpressionType( - enumeration=ir_data.EnumType( - name=ir_data.Reference( - canonical_name=type_definition.name.canonical_name))) - else: - return ir_data.ExpressionType(opaque=ir_data.OpaqueType()) - - -def _set_expression_type_from_physical_type_reference(expression, - type_reference, ir): - """Sets the type of an expression to match a physical type.""" - field_type = ir_util.find_object(type_reference, ir) - assert field_type, "Field type should be non-None after name resolution." - ir_data_utils.builder(expression).type.CopyFrom( - unbounded_expression_type_for_physical_type(field_type)) + """Gets the ExpressionType for a field of the given TypeDefinition. + + Arguments: + type_definition: an ir_data.AddressableUnit. + + Returns: + An ir_data.ExpressionType with the corresponding expression type filled in: + for example, [prelude].UInt will result in an ExpressionType with the + `integer` field filled in. + + The returned ExpressionType will not have any bounds set. + """ + # TODO(bolms): Add a `[value_type]` attribute for `external`s. + if ir_util.get_boolean_attribute(type_definition.attribute, attributes.IS_INTEGER): + return ir_data.ExpressionType(integer=ir_data.IntegerType()) + elif tuple(type_definition.name.canonical_name.object_path) == ("Flag",): + # This is a hack: the Flag type should say that it is a boolean. + return ir_data.ExpressionType(boolean=ir_data.BooleanType()) + elif type_definition.HasField("enumeration"): + return ir_data.ExpressionType( + enumeration=ir_data.EnumType( + name=ir_data.Reference( + canonical_name=type_definition.name.canonical_name + ) + ) + ) + else: + return ir_data.ExpressionType(opaque=ir_data.OpaqueType()) + + +def _set_expression_type_from_physical_type_reference(expression, type_reference, ir): + """Sets the type of an expression to match a physical type.""" + field_type = ir_util.find_object(type_reference, ir) + assert field_type, "Field type should be non-None after name resolution." + ir_data_utils.builder(expression).type.CopyFrom( + unbounded_expression_type_for_physical_type(field_type) + ) def _annotate_parameter_type(parameter, ir, source_file_name, errors): - if parameter.physical_type_alias.WhichOneof("type") != "atomic_type": - errors.append([ - error.error( - source_file_name, parameter.physical_type_alias.source_location, - "Parameters cannot be arrays.") - ]) - return - _set_expression_type_from_physical_type_reference( - parameter, parameter.physical_type_alias.atomic_type.reference, ir) + if parameter.physical_type_alias.WhichOneof("type") != "atomic_type": + errors.append( + [ + error.error( + source_file_name, + parameter.physical_type_alias.source_location, + "Parameters cannot be arrays.", + ) + ] + ) + return + _set_expression_type_from_physical_type_reference( + parameter, parameter.physical_type_alias.atomic_type.reference, ir + ) def _types_are_compatible(a, b): - """Returns true if a and b have compatible types.""" - if a.type.WhichOneof("type") != b.type.WhichOneof("type"): - return False - elif a.type.WhichOneof("type") == "enumeration": - return (ir_util.hashable_form_of_reference(a.type.enumeration.name) == - ir_util.hashable_form_of_reference(b.type.enumeration.name)) - elif a.type.WhichOneof("type") in ("integer", "boolean"): - # All integers are compatible with integers; booleans are compatible with - # booleans - return True - else: - assert False, "_types_are_compatible works with enums, integers, booleans." + """Returns true if a and b have compatible types.""" + if a.type.WhichOneof("type") != b.type.WhichOneof("type"): + return False + elif a.type.WhichOneof("type") == "enumeration": + return ir_util.hashable_form_of_reference( + a.type.enumeration.name + ) == ir_util.hashable_form_of_reference(b.type.enumeration.name) + elif a.type.WhichOneof("type") in ("integer", "boolean"): + # All integers are compatible with integers; booleans are compatible with + # booleans + return True + else: + assert False, "_types_are_compatible works with enums, integers, booleans." def _type_check_comparison_operator(expression, source_file_name, errors): - """Checks the type of a comparison operator (==, !=, <, >, >=, <=).""" - # Applying less than or greater than to a boolean is likely a mistake, so - # only equality and inequality are allowed for booleans. - if expression.function.function in (ir_data.FunctionMapping.EQUALITY, - ir_data.FunctionMapping.INEQUALITY): - acceptable_types = ("integer", "boolean", "enumeration") - acceptable_types_for_humans = "an integer, boolean, or enum" - else: - acceptable_types = ("integer", "enumeration") - acceptable_types_for_humans = "an integer or enum" - left = expression.function.args[0] - right = expression.function.args[1] - for (argument, name) in ((left, "Left"), (right, "Right")): - if argument.type.WhichOneof("type") not in acceptable_types: - errors.append([ - error.error(source_file_name, argument.source_location, - "{} argument of operator '{}' must be {}.".format( - name, expression.function.function_name.text, - acceptable_types_for_humans)) - ]) - return - if not _types_are_compatible(left, right): - errors.append([ - error.error(source_file_name, expression.source_location, + """Checks the type of a comparison operator (==, !=, <, >, >=, <=).""" + # Applying less than or greater than to a boolean is likely a mistake, so + # only equality and inequality are allowed for booleans. + if expression.function.function in ( + ir_data.FunctionMapping.EQUALITY, + ir_data.FunctionMapping.INEQUALITY, + ): + acceptable_types = ("integer", "boolean", "enumeration") + acceptable_types_for_humans = "an integer, boolean, or enum" + else: + acceptable_types = ("integer", "enumeration") + acceptable_types_for_humans = "an integer or enum" + left = expression.function.args[0] + right = expression.function.args[1] + for argument, name in ((left, "Left"), (right, "Right")): + if argument.type.WhichOneof("type") not in acceptable_types: + errors.append( + [ + error.error( + source_file_name, + argument.source_location, + "{} argument of operator '{}' must be {}.".format( + name, + expression.function.function_name.text, + acceptable_types_for_humans, + ), + ) + ] + ) + return + if not _types_are_compatible(left, right): + errors.append( + [ + error.error( + source_file_name, + expression.source_location, "Both arguments of operator '{}' must have the same " - "type.".format(expression.function.function_name.text)) - ]) - _annotate_as_boolean(expression) + "type.".format(expression.function.function_name.text), + ) + ] + ) + _annotate_as_boolean(expression) def _type_check_choice_operator(expression, source_file_name, errors): - """Checks the type of the choice operator cond ? if_true : if_false.""" - condition = expression.function.args[0] - if condition.type.WhichOneof("type") != "boolean": - errors.append([ - error.error(source_file_name, condition.source_location, - "Condition of operator '?:' must be a boolean.") - ]) - if_true = expression.function.args[1] - if if_true.type.WhichOneof("type") not in ("integer", "boolean", - "enumeration"): - errors.append([ - error.error(source_file_name, if_true.source_location, + """Checks the type of the choice operator cond ? if_true : if_false.""" + condition = expression.function.args[0] + if condition.type.WhichOneof("type") != "boolean": + errors.append( + [ + error.error( + source_file_name, + condition.source_location, + "Condition of operator '?:' must be a boolean.", + ) + ] + ) + if_true = expression.function.args[1] + if if_true.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"): + errors.append( + [ + error.error( + source_file_name, + if_true.source_location, "If-true clause of operator '?:' must be an integer, " - "boolean, or enum.") - ]) - return - if_false = expression.function.args[2] - if not _types_are_compatible(if_true, if_false): - errors.append([ - error.error(source_file_name, expression.source_location, + "boolean, or enum.", + ) + ] + ) + return + if_false = expression.function.args[2] + if not _types_are_compatible(if_true, if_false): + errors.append( + [ + error.error( + source_file_name, + expression.source_location, "The if-true and if-false clauses of operator '?:' must " - "have the same type.") - ]) - if if_true.type.WhichOneof("type") == "integer": - _annotate_as_integer(expression) - elif if_true.type.WhichOneof("type") == "boolean": - _annotate_as_boolean(expression) - elif if_true.type.WhichOneof("type") == "enumeration": - ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(if_true.type.enumeration.name) - else: - assert False, "Unexpected type for if_true." + "have the same type.", + ) + ] + ) + if if_true.type.WhichOneof("type") == "integer": + _annotate_as_integer(expression) + elif if_true.type.WhichOneof("type") == "boolean": + _annotate_as_boolean(expression) + elif if_true.type.WhichOneof("type") == "enumeration": + ir_data_utils.builder(expression).type.enumeration.name.CopyFrom( + if_true.type.enumeration.name + ) + else: + assert False, "Unexpected type for if_true." def _type_check_boolean_constant(expression): - _annotate_as_boolean(expression) + _annotate_as_boolean(expression) def _type_check_builtin_reference(expression): - name = expression.builtin_reference.canonical_name.object_path[0] - if name == "$is_statically_sized": - _annotate_as_boolean(expression) - elif name == "$static_size_in_bits": - _annotate_as_integer(expression) - else: - assert False, "Unknown builtin '{}'.".format(name) + name = expression.builtin_reference.canonical_name.object_path[0] + if name == "$is_statically_sized": + _annotate_as_boolean(expression) + elif name == "$static_size_in_bits": + _annotate_as_integer(expression) + else: + assert False, "Unknown builtin '{}'.".format(name) def _type_check_array_size(expression, source_file_name, errors): - _type_check_integer(expression, source_file_name, errors, "Array size") + _type_check_integer(expression, source_file_name, errors, "Array size") def _type_check_field_location(location, source_file_name, errors): - _type_check_integer(location.start, source_file_name, errors, - "Start of field") - _type_check_integer(location.size, source_file_name, errors, "Size of field") + _type_check_integer(location.start, source_file_name, errors, "Start of field") + _type_check_integer(location.size, source_file_name, errors, "Size of field") def _type_check_field_existence_condition(field, source_file_name, errors): - _type_check_boolean(field.existence_condition, source_file_name, errors, - "Existence condition") + _type_check_boolean( + field.existence_condition, source_file_name, errors, "Existence condition" + ) def _type_name_for_error_messages(expression_type): - if expression_type.WhichOneof("type") == "integer": - return "integer" - elif expression_type.WhichOneof("type") == "enumeration": - # TODO(bolms): Should this be the fully-qualified name? - return expression_type.enumeration.name.canonical_name.object_path[-1] - assert False, "Shouldn't be here." + if expression_type.WhichOneof("type") == "integer": + return "integer" + elif expression_type.WhichOneof("type") == "enumeration": + # TODO(bolms): Should this be the fully-qualified name? + return expression_type.enumeration.name.canonical_name.object_path[-1] + assert False, "Shouldn't be here." def _type_check_passed_parameters(atomic_type, ir, source_file_name, errors): - """Checks the types of parameters to a parameterized physical type.""" - referenced_type = ir_util.find_object(atomic_type.reference.canonical_name, - ir) - if (len(referenced_type.runtime_parameter) != - len(atomic_type.runtime_parameter)): - errors.append([ - error.error( - source_file_name, atomic_type.source_location, - "Type {} requires {} parameter{}; {} parameter{} given.".format( - referenced_type.name.name.text, - len(referenced_type.runtime_parameter), - "" if len(referenced_type.runtime_parameter) == 1 else "s", - len(atomic_type.runtime_parameter), - "" if len(atomic_type.runtime_parameter) == 1 else "s")), - error.note( - atomic_type.reference.canonical_name.module_file, - referenced_type.source_location, - "Definition of type {}.".format(referenced_type.name.name.text)) - ]) - return - for i in range(len(referenced_type.runtime_parameter)): - if referenced_type.runtime_parameter[i].type.WhichOneof("type") not in ( - "integer", "boolean", "enumeration"): - # _type_check_parameter will catch invalid parameter types at the - # definition site; no need for another, probably-confusing error at any - # usage sites. - continue - if (atomic_type.runtime_parameter[i].type.WhichOneof("type") != - referenced_type.runtime_parameter[i].type.WhichOneof("type")): - errors.append([ - error.error( - source_file_name, - atomic_type.runtime_parameter[i].source_location, - "Parameter {} of type {} must be {}, not {}.".format( - i, referenced_type.name.name.text, - _type_name_for_error_messages( - referenced_type.runtime_parameter[i].type), - _type_name_for_error_messages( - atomic_type.runtime_parameter[i].type))), - error.note( - atomic_type.reference.canonical_name.module_file, - referenced_type.runtime_parameter[i].source_location, - "Parameter {} of {}.".format(i, referenced_type.name.name.text)) - ]) + """Checks the types of parameters to a parameterized physical type.""" + referenced_type = ir_util.find_object(atomic_type.reference.canonical_name, ir) + if len(referenced_type.runtime_parameter) != len(atomic_type.runtime_parameter): + errors.append( + [ + error.error( + source_file_name, + atomic_type.source_location, + "Type {} requires {} parameter{}; {} parameter{} given.".format( + referenced_type.name.name.text, + len(referenced_type.runtime_parameter), + "" if len(referenced_type.runtime_parameter) == 1 else "s", + len(atomic_type.runtime_parameter), + "" if len(atomic_type.runtime_parameter) == 1 else "s", + ), + ), + error.note( + atomic_type.reference.canonical_name.module_file, + referenced_type.source_location, + "Definition of type {}.".format(referenced_type.name.name.text), + ), + ] + ) + return + for i in range(len(referenced_type.runtime_parameter)): + if referenced_type.runtime_parameter[i].type.WhichOneof("type") not in ( + "integer", + "boolean", + "enumeration", + ): + # _type_check_parameter will catch invalid parameter types at the + # definition site; no need for another, probably-confusing error at any + # usage sites. + continue + if atomic_type.runtime_parameter[i].type.WhichOneof( + "type" + ) != referenced_type.runtime_parameter[i].type.WhichOneof("type"): + errors.append( + [ + error.error( + source_file_name, + atomic_type.runtime_parameter[i].source_location, + "Parameter {} of type {} must be {}, not {}.".format( + i, + referenced_type.name.name.text, + _type_name_for_error_messages( + referenced_type.runtime_parameter[i].type + ), + _type_name_for_error_messages( + atomic_type.runtime_parameter[i].type + ), + ), + ), + error.note( + atomic_type.reference.canonical_name.module_file, + referenced_type.runtime_parameter[i].source_location, + "Parameter {} of {}.".format(i, referenced_type.name.name.text), + ), + ] + ) def _type_check_parameter(runtime_parameter, source_file_name, errors): - """Checks the type of a parameter to a physical type.""" - if runtime_parameter.type.WhichOneof("type") not in ("integer", - "enumeration"): - errors.append([ - error.error(source_file_name, + """Checks the type of a parameter to a physical type.""" + if runtime_parameter.type.WhichOneof("type") not in ("integer", "enumeration"): + errors.append( + [ + error.error( + source_file_name, runtime_parameter.physical_type_alias.source_location, - "Runtime parameters must be integer or enum.") - ]) + "Runtime parameters must be integer or enum.", + ) + ] + ) def annotate_types(ir): - """Adds type annotations to all expressions in ir. - - annotate_types adds type information to all expressions (and subexpressions) - in the IR. Additionally, it checks expressions for internal type consistency: - it will generate an error for constructs like "1 + true", where the types of - the operands are not accepted by the operator. - - Arguments: - ir: an IR to which to add type annotations - - Returns: - A (possibly empty) list of errors. - """ - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Expression], _type_check_expression, - skip_descendants_of={ir_data.Expression}, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.RuntimeParameter], _annotate_parameter_type, - parameters={"errors": errors}) - return errors + """Adds type annotations to all expressions in ir. + + annotate_types adds type information to all expressions (and subexpressions) + in the IR. Additionally, it checks expressions for internal type consistency: + it will generate an error for constructs like "1 + true", where the types of + the operands are not accepted by the operator. + + Arguments: + ir: an IR to which to add type annotations + + Returns: + A (possibly empty) list of errors. + """ + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Expression], + _type_check_expression, + skip_descendants_of={ir_data.Expression}, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.RuntimeParameter], + _annotate_parameter_type, + parameters={"errors": errors}, + ) + return errors def check_types(ir): - """Checks that expressions within the IR have the correct top-level types. - - check_types ensures that expressions at the top level have correct types; in - particular, it ensures that array sizes are integers ("UInt[true]" is not a - valid array type) and that the starts and ends of ranges are integers. - - Arguments: - ir: an IR to type check. - - Returns: - A (possibly empty) list of errors. - """ - errors = [] - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.FieldLocation], _type_check_field_location, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.ArrayType, ir_data.Expression], _type_check_array_size, - skip_descendants_of={ir_data.AtomicType}, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Field], _type_check_field_existence_condition, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.RuntimeParameter], _type_check_parameter, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.AtomicType], _type_check_passed_parameters, - parameters={"errors": errors}) - return errors + """Checks that expressions within the IR have the correct top-level types. + + check_types ensures that expressions at the top level have correct types; in + particular, it ensures that array sizes are integers ("UInt[true]" is not a + valid array type) and that the starts and ends of ranges are integers. + + Arguments: + ir: an IR to type check. + + Returns: + A (possibly empty) list of errors. + """ + errors = [] + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.FieldLocation], + _type_check_field_location, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.ArrayType, ir_data.Expression], + _type_check_array_size, + skip_descendants_of={ir_data.AtomicType}, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.Field], + _type_check_field_existence_condition, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.RuntimeParameter], + _type_check_parameter, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.AtomicType], + _type_check_passed_parameters, + parameters={"errors": errors}, + ) + return errors diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py index d308fed..995f20d 100644 --- a/compiler/front_end/type_check_test.py +++ b/compiler/front_end/type_check_test.py @@ -24,657 +24,1053 @@ class TypeAnnotationTest(unittest.TestCase): - def _make_ir(self, emb_text): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": emb_text}), - stop_before_step="annotate_types") - assert not errors, errors - return ir - - def test_adds_integer_constant_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt:8[] y\n") - self.assertEqual([], type_check.annotate_types(ir)) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "integer") - - def test_adds_boolean_constant_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+true] UInt:8[] y\n") - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)), - ir_data_utils.IrDataSerializer(ir).to_json(indent=2)) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "boolean") - - def test_adds_enum_constant_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+Enum.VALUE] UInt x\n" - "enum Enum:\n" - " VALUE = 1\n") - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - expression = ir.module[0].type[0].structure.field[0].location.size - self.assertEqual(expression.type.WhichOneof("type"), "enumeration") - enum_type_name = expression.type.enumeration.name.canonical_name - self.assertEqual(enum_type_name.module_file, "m.emb") - self.assertEqual(enum_type_name.object_path[0], "Enum") - - def test_adds_enum_field_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Enum x\n" - " 1 [+x] UInt y\n" - "enum Enum:\n" - " VALUE = 1\n") - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "enumeration") - enum_type_name = expression.type.enumeration.name.canonical_name - self.assertEqual(enum_type_name.module_file, "m.emb") - self.assertEqual(enum_type_name.object_path[0], "Enum") - - def test_adds_integer_operation_types(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1+1] UInt:8[] y\n") - self.assertEqual([], type_check.annotate_types(ir)) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "integer") - self.assertEqual(expression.function.args[0].type.WhichOneof("type"), - "integer") - self.assertEqual(expression.function.args[1].type.WhichOneof("type"), - "integer") - - def test_adds_enum_operation_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+Enum.VAL==Enum.VAL] UInt:8[] y\n" - "enum Enum:\n" - " VAL = 1\n") - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "boolean") - self.assertEqual(expression.function.args[0].type.WhichOneof("type"), - "enumeration") - self.assertEqual(expression.function.args[1].type.WhichOneof("type"), - "enumeration") - - def test_adds_enum_comparison_operation_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+Enum.VAL>=Enum.VAL] UInt:8[] y\n" - "enum Enum:\n" - " VAL = 1\n") - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "boolean") - self.assertEqual(expression.function.args[0].type.WhichOneof("type"), - "enumeration") - self.assertEqual(expression.function.args[1].type.WhichOneof("type"), - "enumeration") - - def test_adds_integer_field_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+x] UInt:8[] y\n") - self.assertEqual([], type_check.annotate_types(ir)) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "integer") - - def test_adds_opaque_field_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Bar x\n" - " 1 [+x] UInt:8[] y\n" - "struct Bar:\n" - " 0 [+1] UInt z\n") - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "opaque") - - def test_adds_opaque_field_type_for_array(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+x] UInt:8[] y\n") - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "opaque") - - def test_error_on_bad_plus_operand_types(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1+true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[1].source_location, - "Right argument of operator '+' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_minus_operand_types(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1-true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[1].source_location, - "Right argument of operator '-' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_times_operand_types(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1*true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[1].source_location, - "Right argument of operator '*' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_equality_left_operand(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+x==x] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[0].source_location, - "Left argument of operator '==' must be an integer, " - "boolean, or enum.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_equality_right_operand(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+1==x] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[1].source_location, - "Right argument of operator '==' must be an integer, " - "boolean, or enum.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_equality_mismatched_operands_int_bool(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1==true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Both arguments of operator '==' must have the same " - "type.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_mismatched_comparison_operands(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8 x\n" - " 1 [+x>=Bar.BAR] UInt:8[] y\n" - "enum Bar:\n" - " BAR = 1\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Both arguments of operator '>=' must have the same " - "type.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_equality_mismatched_operands_bool_int(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+true!=1] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Both arguments of operator '!=' must have the same " - "type.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_equality_mismatched_operands_enum_enum(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+Bar.BAR==Baz.BAZ] UInt:8[] y\n" - "enum Bar:\n" - " BAR = 1\n" - "enum Baz:\n" - " BAZ = 1\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Both arguments of operator '==' must have the same " - "type.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_choice_condition_operand(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+5 ? 0 : 1] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - condition_arg = expression.function.args[0] - self.assertEqual([ - [error.error("m.emb", condition_arg.source_location, - "Condition of operator '?:' must be a boolean.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_choice_if_true_operand(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+true ? x : x] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - if_true_arg = expression.function.args[1] - self.assertEqual([ - [error.error("m.emb", if_true_arg.source_location, - "If-true clause of operator '?:' must be an integer, " - "boolean, or enum.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_choice_of_bools(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+true ? true : false] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - self.assertEqual("boolean", expression.type.WhichOneof("type")) - - def test_choice_of_integers(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+true ? 0 : 100] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) - - def test_choice_of_enums(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] enum xx:\n" - " XX = 1\n" - " YY = 1\n" - " 1 [+true ? Xx.XX : Xx.YY] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - self.assertEqual("enumeration", expression.type.WhichOneof("type")) - self.assertFalse(expression.type.enumeration.HasField("value")) - self.assertEqual( - "m.emb", expression.type.enumeration.name.canonical_name.module_file) - self.assertEqual( - "Foo", expression.type.enumeration.name.canonical_name.object_path[0]) - self.assertEqual( - "Xx", expression.type.enumeration.name.canonical_name.object_path[1]) - - def test_error_on_bad_choice_mismatched_operands(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+true ? 0 : true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "The if-true and if-false clauses of operator '?:' must " - "have the same type.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_choice_mismatched_enum_operands(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+true ? Baz.BAZ : Bar.BAR] UInt:8[] y\n" - "enum Bar:\n" - " BAR = 1\n" - "enum Baz:\n" - " BAZ = 1\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "The if-true and if-false clauses of operator '?:' must " - "have the same type.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_left_operand_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+true+1] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[0].source_location, - "Left argument of operator '+' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_opaque_operand_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+x+1] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[0].source_location, - "Left argument of operator '+' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_left_comparison_operand_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+true<1] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error( - "m.emb", expression.function.args[0].source_location, - "Left argument of operator '<' must be an integer or enum.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_right_comparison_operand_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1>=true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error( - "m.emb", expression.function.args[1].source_location, - "Right argument of operator '>=' must be an integer or enum.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_bad_boolean_operand_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+1&&true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.function.args[0].source_location, - "Left argument of operator '&&' must be a boolean.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_max_return_type(self): - ir = self._make_ir("struct Foo:\n" - " $max(1, 2, 3) [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) - - def test_error_on_bad_max_argument(self): - ir = self._make_ir("struct Foo:\n" - " $max(Bar.XX) [+1] UInt:8[] x\n" - "enum Bar:\n" - " XX = 0\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error("m.emb", expression.function.args[0].source_location, - "Argument 0 of function '$max' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_no_max_argument(self): - ir = self._make_ir("struct Foo:\n" - " $max() [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Function '$max' requires at least 1 argument.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_upper_bound_return_type(self): - ir = self._make_ir("struct Foo:\n" - " $upper_bound(3) [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) - - def test_upper_bound_too_few_arguments(self): - ir = self._make_ir("struct Foo:\n" - " $upper_bound() [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Function '$upper_bound' requires exactly 1 argument.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_upper_bound_too_many_arguments(self): - ir = self._make_ir("struct Foo:\n" - " $upper_bound(1, 2) [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Function '$upper_bound' requires exactly 1 argument.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_upper_bound_wrong_argument_type(self): - ir = self._make_ir("struct Foo:\n" - " $upper_bound(Bar.XX) [+1] UInt:8[] x\n" - "enum Bar:\n" - " XX = 0\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error( - "m.emb", expression.function.args[0].source_location, - "Argument 0 of function '$upper_bound' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_lower_bound_return_type(self): - ir = self._make_ir("struct Foo:\n" - " $lower_bound(3) [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) - - def test_lower_bound_too_few_arguments(self): - ir = self._make_ir("struct Foo:\n" - " $lower_bound() [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Function '$lower_bound' requires exactly 1 argument.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_lower_bound_too_many_arguments(self): - ir = self._make_ir("struct Foo:\n" - " $lower_bound(1, 2) [+1] UInt:8[] x\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Function '$lower_bound' requires exactly 1 argument.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_lower_bound_wrong_argument_type(self): - ir = self._make_ir("struct Foo:\n" - " $lower_bound(Bar.XX) [+1] UInt:8[] x\n" - "enum Bar:\n" - " XX = 0\n") - expression = ir.module[0].type[0].structure.field[0].location.start - self.assertEqual([ - [error.error( - "m.emb", expression.function.args[0].source_location, - "Argument 0 of function '$lower_bound' must be an integer.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_static_reference_to_physical_field(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = Foo.x\n") - static_ref = ir.module[0].type[0].structure.field[1].read_transform - physical_field = ir.module[0].type[0].structure.field[0] - self.assertEqual([ - [error.error("m.emb", static_ref.source_location, - "Static references to physical fields are not allowed."), - error.note("m.emb", physical_field.source_location, - "x is a physical field.")] - ], type_check.annotate_types(ir)) - - def test_error_on_non_field_argument_to_has(self): - ir = self._make_ir("struct Foo:\n" - " if $present(0):\n" - " 0 [+1] UInt x\n") - expression = ir.module[0].type[0].structure.field[0].existence_condition - self.assertEqual([ - [error.error("m.emb", expression.function.args[0].source_location, - "Argument 0 of function '$present' must be a field.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_no_argument_has(self): - ir = self._make_ir("struct Foo:\n" - " if $present():\n" - " 0 [+1] UInt x\n") - expression = ir.module[0].type[0].structure.field[0].existence_condition - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Function '$present' requires exactly 1 argument.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_error_on_too_many_argument_has(self): - ir = self._make_ir("struct Foo:\n" - " if $present(y, y):\n" - " 0 [+1] UInt x\n" - " 1 [+1] UInt y\n") - expression = ir.module[0].type[0].structure.field[0].existence_condition - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Function '$present' requires exactly 1 argument.")] - ], error.filter_errors(type_check.annotate_types(ir))) - - def test_checks_that_parameters_are_atomic_types(self): - ir = self._make_ir("struct Foo(y: UInt:8[1]):\n" - " 0 [+1] UInt x\n") - error_parameter = ir.module[0].type[0].runtime_parameter[0] - error_location = error_parameter.physical_type_alias.source_location - self.assertEqual( - [[error.error("m.emb", error_location, - "Parameters cannot be arrays.")]], - error.filter_errors(type_check.annotate_types(ir))) + def _make_ir(self, emb_text): + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader({"m.emb": emb_text}), + stop_before_step="annotate_types", + ) + assert not errors, errors + return ir + + def test_adds_integer_constant_type(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+1] UInt:8[] y\n" + ) + self.assertEqual([], type_check.annotate_types(ir)) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "integer") + + def test_adds_boolean_constant_type(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+true] UInt:8[] y\n" + ) + self.assertEqual( + [], + error.filter_errors(type_check.annotate_types(ir)), + ir_data_utils.IrDataSerializer(ir).to_json(indent=2), + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "boolean") + + def test_adds_enum_constant_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+Enum.VALUE] UInt x\n" + "enum Enum:\n" + " VALUE = 1\n" + ) + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + expression = ir.module[0].type[0].structure.field[0].location.size + self.assertEqual(expression.type.WhichOneof("type"), "enumeration") + enum_type_name = expression.type.enumeration.name.canonical_name + self.assertEqual(enum_type_name.module_file, "m.emb") + self.assertEqual(enum_type_name.object_path[0], "Enum") + + def test_adds_enum_field_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Enum x\n" + " 1 [+x] UInt y\n" + "enum Enum:\n" + " VALUE = 1\n" + ) + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "enumeration") + enum_type_name = expression.type.enumeration.name.canonical_name + self.assertEqual(enum_type_name.module_file, "m.emb") + self.assertEqual(enum_type_name.object_path[0], "Enum") + + def test_adds_integer_operation_types(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+1+1] UInt:8[] y\n" + ) + self.assertEqual([], type_check.annotate_types(ir)) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "integer") + self.assertEqual(expression.function.args[0].type.WhichOneof("type"), "integer") + self.assertEqual(expression.function.args[1].type.WhichOneof("type"), "integer") + + def test_adds_enum_operation_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+Enum.VAL==Enum.VAL] UInt:8[] y\n" + "enum Enum:\n" + " VAL = 1\n" + ) + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "boolean") + self.assertEqual( + expression.function.args[0].type.WhichOneof("type"), "enumeration" + ) + self.assertEqual( + expression.function.args[1].type.WhichOneof("type"), "enumeration" + ) + + def test_adds_enum_comparison_operation_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+Enum.VAL>=Enum.VAL] UInt:8[] y\n" + "enum Enum:\n" + " VAL = 1\n" + ) + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "boolean") + self.assertEqual( + expression.function.args[0].type.WhichOneof("type"), "enumeration" + ) + self.assertEqual( + expression.function.args[1].type.WhichOneof("type"), "enumeration" + ) + + def test_adds_integer_field_type(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+x] UInt:8[] y\n" + ) + self.assertEqual([], type_check.annotate_types(ir)) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "integer") + + def test_adds_opaque_field_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Bar x\n" + " 1 [+x] UInt:8[] y\n" + "struct Bar:\n" + " 0 [+1] UInt z\n" + ) + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "opaque") + + def test_adds_opaque_field_type_for_array(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt:8[] x\n" " 1 [+x] UInt:8[] y\n" + ) + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "opaque") + + def test_error_on_bad_plus_operand_types(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1+true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[1].source_location, + "Right argument of operator '+' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_minus_operand_types(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1-true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[1].source_location, + "Right argument of operator '-' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_times_operand_types(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1*true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[1].source_location, + "Right argument of operator '*' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_equality_left_operand(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+x==x] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Left argument of operator '==' must be an integer, " + "boolean, or enum.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_equality_right_operand(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+1==x] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[1].source_location, + "Right argument of operator '==' must be an integer, " + "boolean, or enum.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_equality_mismatched_operands_int_bool(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1==true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Both arguments of operator '==' must have the same " "type.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_mismatched_comparison_operands(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8 x\n" + " 1 [+x>=Bar.BAR] UInt:8[] y\n" + "enum Bar:\n" + " BAR = 1\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Both arguments of operator '>=' must have the same " "type.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_equality_mismatched_operands_bool_int(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+true!=1] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Both arguments of operator '!=' must have the same " "type.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_equality_mismatched_operands_enum_enum(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+Bar.BAR==Baz.BAZ] UInt:8[] y\n" + "enum Bar:\n" + " BAR = 1\n" + "enum Baz:\n" + " BAZ = 1\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Both arguments of operator '==' must have the same " "type.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_choice_condition_operand(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+5 ? 0 : 1] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + condition_arg = expression.function.args[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + condition_arg.source_location, + "Condition of operator '?:' must be a boolean.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_choice_if_true_operand(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+true ? x : x] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + if_true_arg = expression.function.args[1] + self.assertEqual( + [ + [ + error.error( + "m.emb", + if_true_arg.source_location, + "If-true clause of operator '?:' must be an integer, " + "boolean, or enum.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_choice_of_bools(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+true ? true : false] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + self.assertEqual("boolean", expression.type.WhichOneof("type")) + + def test_choice_of_integers(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+true ? 0 : 100] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual([], type_check.annotate_types(ir)) + self.assertEqual("integer", expression.type.WhichOneof("type")) + + def test_choice_of_enums(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] enum xx:\n" + " XX = 1\n" + " YY = 1\n" + " 1 [+true ? Xx.XX : Xx.YY] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + self.assertEqual("enumeration", expression.type.WhichOneof("type")) + self.assertFalse(expression.type.enumeration.HasField("value")) + self.assertEqual( + "m.emb", expression.type.enumeration.name.canonical_name.module_file + ) + self.assertEqual( + "Foo", expression.type.enumeration.name.canonical_name.object_path[0] + ) + self.assertEqual( + "Xx", expression.type.enumeration.name.canonical_name.object_path[1] + ) + + def test_error_on_bad_choice_mismatched_operands(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+true ? 0 : true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "The if-true and if-false clauses of operator '?:' must " + "have the same type.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_choice_mismatched_enum_operands(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+true ? Baz.BAZ : Bar.BAR] UInt:8[] y\n" + "enum Bar:\n" + " BAR = 1\n" + "enum Baz:\n" + " BAZ = 1\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "The if-true and if-false clauses of operator '?:' must " + "have the same type.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_left_operand_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+true+1] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Left argument of operator '+' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_opaque_operand_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+x+1] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Left argument of operator '+' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_left_comparison_operand_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+true<1] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Left argument of operator '<' must be an integer or enum.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_right_comparison_operand_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1>=true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[1].source_location, + "Right argument of operator '>=' must be an integer or enum.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_bad_boolean_operand_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+1&&true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Left argument of operator '&&' must be a boolean.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_max_return_type(self): + ir = self._make_ir("struct Foo:\n" " $max(1, 2, 3) [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual([], type_check.annotate_types(ir)) + self.assertEqual("integer", expression.type.WhichOneof("type")) + + def test_error_on_bad_max_argument(self): + ir = self._make_ir( + "struct Foo:\n" + " $max(Bar.XX) [+1] UInt:8[] x\n" + "enum Bar:\n" + " XX = 0\n" + ) + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Argument 0 of function '$max' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_no_max_argument(self): + ir = self._make_ir("struct Foo:\n" " $max() [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Function '$max' requires at least 1 argument.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_upper_bound_return_type(self): + ir = self._make_ir("struct Foo:\n" " $upper_bound(3) [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual([], type_check.annotate_types(ir)) + self.assertEqual("integer", expression.type.WhichOneof("type")) + + def test_upper_bound_too_few_arguments(self): + ir = self._make_ir("struct Foo:\n" " $upper_bound() [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Function '$upper_bound' requires exactly 1 argument.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_upper_bound_too_many_arguments(self): + ir = self._make_ir("struct Foo:\n" " $upper_bound(1, 2) [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Function '$upper_bound' requires exactly 1 argument.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_upper_bound_wrong_argument_type(self): + ir = self._make_ir( + "struct Foo:\n" + " $upper_bound(Bar.XX) [+1] UInt:8[] x\n" + "enum Bar:\n" + " XX = 0\n" + ) + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Argument 0 of function '$upper_bound' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_lower_bound_return_type(self): + ir = self._make_ir("struct Foo:\n" " $lower_bound(3) [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual([], type_check.annotate_types(ir)) + self.assertEqual("integer", expression.type.WhichOneof("type")) + + def test_lower_bound_too_few_arguments(self): + ir = self._make_ir("struct Foo:\n" " $lower_bound() [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Function '$lower_bound' requires exactly 1 argument.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_lower_bound_too_many_arguments(self): + ir = self._make_ir("struct Foo:\n" " $lower_bound(1, 2) [+1] UInt:8[] x\n") + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Function '$lower_bound' requires exactly 1 argument.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_lower_bound_wrong_argument_type(self): + ir = self._make_ir( + "struct Foo:\n" + " $lower_bound(Bar.XX) [+1] UInt:8[] x\n" + "enum Bar:\n" + " XX = 0\n" + ) + expression = ir.module[0].type[0].structure.field[0].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Argument 0 of function '$lower_bound' must be an integer.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_static_reference_to_physical_field(self): + ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = Foo.x\n") + static_ref = ir.module[0].type[0].structure.field[1].read_transform + physical_field = ir.module[0].type[0].structure.field[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + static_ref.source_location, + "Static references to physical fields are not allowed.", + ), + error.note( + "m.emb", + physical_field.source_location, + "x is a physical field.", + ), + ] + ], + type_check.annotate_types(ir), + ) + + def test_error_on_non_field_argument_to_has(self): + ir = self._make_ir( + "struct Foo:\n" " if $present(0):\n" " 0 [+1] UInt x\n" + ) + expression = ir.module[0].type[0].structure.field[0].existence_condition + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.function.args[0].source_location, + "Argument 0 of function '$present' must be a field.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_no_argument_has(self): + ir = self._make_ir("struct Foo:\n" " if $present():\n" " 0 [+1] UInt x\n") + expression = ir.module[0].type[0].structure.field[0].existence_condition + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Function '$present' requires exactly 1 argument.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_error_on_too_many_argument_has(self): + ir = self._make_ir( + "struct Foo:\n" + " if $present(y, y):\n" + " 0 [+1] UInt x\n" + " 1 [+1] UInt y\n" + ) + expression = ir.module[0].type[0].structure.field[0].existence_condition + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Function '$present' requires exactly 1 argument.", + ) + ] + ], + error.filter_errors(type_check.annotate_types(ir)), + ) + + def test_checks_that_parameters_are_atomic_types(self): + ir = self._make_ir("struct Foo(y: UInt:8[1]):\n" " 0 [+1] UInt x\n") + error_parameter = ir.module[0].type[0].runtime_parameter[0] + error_location = error_parameter.physical_type_alias.source_location + self.assertEqual( + [[error.error("m.emb", error_location, "Parameters cannot be arrays.")]], + error.filter_errors(type_check.annotate_types(ir)), + ) class TypeCheckTest(unittest.TestCase): - def _make_ir(self, emb_text): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": emb_text}), - stop_before_step="check_types") - assert not errors, errors - return ir - - def test_error_on_opaque_type_in_field_start(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " x [+10] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Start of field must be an integer.")] - ], type_check.check_types(ir)) - - def test_error_on_boolean_type_in_field_start(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " true [+10] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.start - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Start of field must be an integer.")] - ], type_check.check_types(ir)) - - def test_error_on_opaque_type_in_field_size(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+x] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Size of field must be an integer.")] - ], type_check.check_types(ir)) - - def test_error_on_boolean_type_in_field_size(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+true] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Size of field must be an integer.")] - ], type_check.check_types(ir)) - - def test_error_on_opaque_type_in_array_size(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+9] UInt:8[x] y\n") - expression = (ir.module[0].type[0].structure.field[1].type.array_type. - element_count) - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Array size must be an integer.")] - ], type_check.check_types(ir)) - - def test_error_on_boolean_type_in_array_size(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " 1 [+9] UInt:8[true] y\n") - expression = (ir.module[0].type[0].structure.field[1].type.array_type. - element_count) - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Array size must be an integer.")] - ], type_check.check_types(ir)) - - def test_error_on_integer_type_in_existence_condition(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt:8[] x\n" - " if 1:\n" - " 1 [+9] UInt:8[] y\n") - expression = ir.module[0].type[0].structure.field[1].existence_condition - self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Existence condition must be a boolean.")] - ], type_check.check_types(ir)) - - def test_error_on_non_integer_non_enum_parameter(self): - ir = self._make_ir("struct Foo(f: Flag):\n" - " 0 [+1] UInt:8[] x\n") - parameter = ir.module[0].type[0].runtime_parameter[0] - self.assertEqual( - [[error.error("m.emb", parameter.physical_type_alias.source_location, - "Runtime parameters must be integer or enum.")]], - type_check.check_types(ir)) - - def test_error_on_failure_to_pass_parameter(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Bar b\n" - "struct Bar(f: UInt:6):\n" - " 0 [+1] UInt:8[] x\n") - type_ir = ir.module[0].type[0].structure.field[0].type - bar = ir.module[0].type[1] - self.assertEqual( - [[ - error.error("m.emb", type_ir.source_location, - "Type Bar requires 1 parameter; 0 parameters given."), - error.note("m.emb", bar.source_location, - "Definition of type Bar.") - ]], - type_check.check_types(ir)) - - def test_error_on_passing_unneeded_parameter(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Bar(1) b\n" - "struct Bar:\n" - " 0 [+1] UInt:8[] x\n") - type_ir = ir.module[0].type[0].structure.field[0].type - bar = ir.module[0].type[1] - self.assertEqual( - [[ - error.error("m.emb", type_ir.source_location, - "Type Bar requires 0 parameters; 1 parameter given."), - error.note("m.emb", bar.source_location, - "Definition of type Bar.") - ]], - type_check.check_types(ir)) - - def test_error_on_passing_wrong_parameter_type(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] Bar(1) b\n" - "enum Baz:\n" - " QUX = 1\n" - "struct Bar(n: Baz):\n" - " 0 [+1] UInt:8[] x\n") - type_ir = ir.module[0].type[0].structure.field[0].type - usage_parameter_ir = type_ir.atomic_type.runtime_parameter[0] - source_parameter_ir = ir.module[0].type[2].runtime_parameter[0] - self.assertEqual( - [[ - error.error("m.emb", usage_parameter_ir.source_location, - "Parameter 0 of type Bar must be Baz, not integer."), - error.note("m.emb", source_parameter_ir.source_location, - "Parameter 0 of Bar.") - ]], - type_check.check_types(ir)) + def _make_ir(self, emb_text): + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader({"m.emb": emb_text}), + stop_before_step="check_types", + ) + assert not errors, errors + return ir + + def test_error_on_opaque_type_in_field_start(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " x [+10] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Start of field must be an integer.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_boolean_type_in_field_start(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " true [+10] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.start + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Start of field must be an integer.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_opaque_type_in_field_size(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+x] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Size of field must be an integer.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_boolean_type_in_field_size(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+true] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Size of field must be an integer.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_opaque_type_in_array_size(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+9] UInt:8[x] y\n" + ) + expression = ( + ir.module[0].type[0].structure.field[1].type.array_type.element_count + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Array size must be an integer.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_boolean_type_in_array_size(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " 1 [+9] UInt:8[true] y\n" + ) + expression = ( + ir.module[0].type[0].structure.field[1].type.array_type.element_count + ) + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Array size must be an integer.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_integer_type_in_existence_condition(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt:8[] x\n" + " if 1:\n" + " 1 [+9] UInt:8[] y\n" + ) + expression = ir.module[0].type[0].structure.field[1].existence_condition + self.assertEqual( + [ + [ + error.error( + "m.emb", + expression.source_location, + "Existence condition must be a boolean.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_non_integer_non_enum_parameter(self): + ir = self._make_ir("struct Foo(f: Flag):\n" " 0 [+1] UInt:8[] x\n") + parameter = ir.module[0].type[0].runtime_parameter[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + parameter.physical_type_alias.source_location, + "Runtime parameters must be integer or enum.", + ) + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_failure_to_pass_parameter(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Bar b\n" + "struct Bar(f: UInt:6):\n" + " 0 [+1] UInt:8[] x\n" + ) + type_ir = ir.module[0].type[0].structure.field[0].type + bar = ir.module[0].type[1] + self.assertEqual( + [ + [ + error.error( + "m.emb", + type_ir.source_location, + "Type Bar requires 1 parameter; 0 parameters given.", + ), + error.note("m.emb", bar.source_location, "Definition of type Bar."), + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_passing_unneeded_parameter(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Bar(1) b\n" + "struct Bar:\n" + " 0 [+1] UInt:8[] x\n" + ) + type_ir = ir.module[0].type[0].structure.field[0].type + bar = ir.module[0].type[1] + self.assertEqual( + [ + [ + error.error( + "m.emb", + type_ir.source_location, + "Type Bar requires 0 parameters; 1 parameter given.", + ), + error.note("m.emb", bar.source_location, "Definition of type Bar."), + ] + ], + type_check.check_types(ir), + ) + + def test_error_on_passing_wrong_parameter_type(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] Bar(1) b\n" + "enum Baz:\n" + " QUX = 1\n" + "struct Bar(n: Baz):\n" + " 0 [+1] UInt:8[] x\n" + ) + type_ir = ir.module[0].type[0].structure.field[0].type + usage_parameter_ir = type_ir.atomic_type.runtime_parameter[0] + source_parameter_ir = ir.module[0].type[2].runtime_parameter[0] + self.assertEqual( + [ + [ + error.error( + "m.emb", + usage_parameter_ir.source_location, + "Parameter 0 of type Bar must be Baz, not integer.", + ), + error.note( + "m.emb", + source_parameter_ir.source_location, + "Parameter 0 of Bar.", + ), + ] + ], + type_check.check_types(ir), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/front_end/write_inference.py b/compiler/front_end/write_inference.py index 7cfe7df..8353306 100644 --- a/compiler/front_end/write_inference.py +++ b/compiler/front_end/write_inference.py @@ -23,262 +23,269 @@ def _find_field_reference_path(expression): - """Returns a path to a field reference, or None. + """Returns a path to a field reference, or None. - If the provided expression contains exactly one field_reference, - _find_field_reference_path will return a list of indexes, such that - recursively reading the index'th element of expression.function.args will find - the field_reference. For example, for: + If the provided expression contains exactly one field_reference, + _find_field_reference_path will return a list of indexes, such that + recursively reading the index'th element of expression.function.args will find + the field_reference. For example, for: - 5 + (x * 2) + 5 + (x * 2) - _find_field_reference_path will return [1, 0]: from the top-level `+` - expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*` - expression. + _find_field_reference_path will return [1, 0]: from the top-level `+` + expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*` + expression. - Arguments: - expression: an ir_data.Expression to walk + Arguments: + expression: an ir_data.Expression to walk - Returns: - A list of indexes to find a field_reference, or None. - """ - found, indexes = _recursively_find_field_reference_path(expression) - if found == 1: - return indexes - else: - return None + Returns: + A list of indexes to find a field_reference, or None. + """ + found, indexes = _recursively_find_field_reference_path(expression) + if found == 1: + return indexes + else: + return None def _recursively_find_field_reference_path(expression): - """Recursive implementation of _find_field_reference_path.""" - if expression.WhichOneof("expression") == "field_reference": - return 1, [] - elif expression.WhichOneof("expression") == "function": - field_count = 0 - path = [] - for index in range(len(expression.function.args)): - arg = expression.function.args[index] - arg_result = _recursively_find_field_reference_path(arg) - arg_field_count, arg_path = arg_result - if arg_field_count == 1 and field_count == 0: - path = [index] + arg_path - field_count += arg_field_count - if field_count == 1: - return field_count, path + """Recursive implementation of _find_field_reference_path.""" + if expression.WhichOneof("expression") == "field_reference": + return 1, [] + elif expression.WhichOneof("expression") == "function": + field_count = 0 + path = [] + for index in range(len(expression.function.args)): + arg = expression.function.args[index] + arg_result = _recursively_find_field_reference_path(arg) + arg_field_count, arg_path = arg_result + if arg_field_count == 1 and field_count == 0: + path = [index] + arg_path + field_count += arg_field_count + if field_count == 1: + return field_count, path + else: + return field_count, [] else: - return field_count, [] - else: - return 0, [] + return 0, [] def _invert_expression(expression, ir): - """For the given expression, searches for an algebraic inverse expression. - - That is, it takes the notional equation: - - $logical_value = expression - - and, if there is exactly one `field_reference` in `expression`, it will - attempt to solve the equation for that field. For example, if the expression - is `x + 1`, it will iteratively transform: - - $logical_value = x + 1 - $logical_value - 1 = x + 1 - 1 - $logical_value - 1 = x - - and finally return `x` and `$logical_value - 1`. - - The purpose of this transformation is to find an assignment statement that can - be used to write back through certain virtual fields. E.g., given: - - struct Foo: - 0 [+1] UInt raw_value - let actual_value = raw_value + 100 - - it should be possible to write a value to the `actual_value` field, and have - it set `raw_value` to the appropriate value. - - Arguments: - expression: an ir_data.Expression to be inverted. - ir: the full IR, for looking up symbols. - - Returns: - (field_reference, inverse_expression) if expression can be inverted, - otherwise None. - """ - reference_path = _find_field_reference_path(expression) - if reference_path is None: - return None - subexpression = expression - result = ir_data.Expression( - builtin_reference=ir_data.Reference( - canonical_name=ir_data.CanonicalName( - module_file="", - object_path=["$logical_value"] - ), - source_name=[ir_data.Word( - text="$logical_value", - source_location=ir_data.Location(is_synthetic=True) - )], - source_location=ir_data.Location(is_synthetic=True) - ), - type=expression.type, - source_location=ir_data.Location(is_synthetic=True) - ) - - # This loop essentially starts with: - # - # f(g(x)) == $logical_value - # - # and ends with - # - # x == g_inv(f_inv($logical_value)) - # - # At each step, `subexpression` has one layer removed, and `result` has a - # corresponding inverse function applied. So, for example, it might start - # with: - # - # 2 + ((3 - x) - 10) == $logical_value - # - # On each iteration, `subexpression` and `result` will become: - # - # (3 - x) - 10 == $logical_value - 2 [subtract 2 from both sides] - # (3 - x) == ($logical_value - 2) + 10 [add 10 to both sides] - # x == 3 - (($logical_value - 2) + 10) [subtract both sides from 3] - # - # This is an extremely limited algebraic solver, but it covers common-enough - # cases. - # - # Note that any equation that can be solved here becomes part of Emboss's - # contract, forever, so be conservative in expanding its solving capabilities! - for index in reference_path: - if subexpression.function.function == ir_data.FunctionMapping.ADDITION: - result = ir_data.Expression( - function=ir_data.Function( - function=ir_data.FunctionMapping.SUBTRACTION, - args=[ - result, - subexpression.function.args[1 - index], - ] - ), - type=ir_data.ExpressionType(integer=ir_data.IntegerType()) - ) - elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION: - if index == 0: - result = ir_data.Expression( - function=ir_data.Function( - function=ir_data.FunctionMapping.ADDITION, - args=[ - result, - subexpression.function.args[1], - ] - ), - type=ir_data.ExpressionType(integer=ir_data.IntegerType()) - ) - else: - result = ir_data.Expression( - function=ir_data.Function( - function=ir_data.FunctionMapping.SUBTRACTION, - args=[ - subexpression.function.args[0], - result, - ] + """For the given expression, searches for an algebraic inverse expression. + + That is, it takes the notional equation: + + $logical_value = expression + + and, if there is exactly one `field_reference` in `expression`, it will + attempt to solve the equation for that field. For example, if the expression + is `x + 1`, it will iteratively transform: + + $logical_value = x + 1 + $logical_value - 1 = x + 1 - 1 + $logical_value - 1 = x + + and finally return `x` and `$logical_value - 1`. + + The purpose of this transformation is to find an assignment statement that can + be used to write back through certain virtual fields. E.g., given: + + struct Foo: + 0 [+1] UInt raw_value + let actual_value = raw_value + 100 + + it should be possible to write a value to the `actual_value` field, and have + it set `raw_value` to the appropriate value. + + Arguments: + expression: an ir_data.Expression to be inverted. + ir: the full IR, for looking up symbols. + + Returns: + (field_reference, inverse_expression) if expression can be inverted, + otherwise None. + """ + reference_path = _find_field_reference_path(expression) + if reference_path is None: + return None + subexpression = expression + result = ir_data.Expression( + builtin_reference=ir_data.Reference( + canonical_name=ir_data.CanonicalName( + module_file="", object_path=["$logical_value"] ), - type=ir_data.ExpressionType(integer=ir_data.IntegerType()) - ) - else: - return None - subexpression = subexpression.function.args[index] - expression_bounds.compute_constraints_of_expression(result, ir) - return subexpression, result + source_name=[ + ir_data.Word( + text="$logical_value", + source_location=ir_data.Location(is_synthetic=True), + ) + ], + source_location=ir_data.Location(is_synthetic=True), + ), + type=expression.type, + source_location=ir_data.Location(is_synthetic=True), + ) + + # This loop essentially starts with: + # + # f(g(x)) == $logical_value + # + # and ends with + # + # x == g_inv(f_inv($logical_value)) + # + # At each step, `subexpression` has one layer removed, and `result` has a + # corresponding inverse function applied. So, for example, it might start + # with: + # + # 2 + ((3 - x) - 10) == $logical_value + # + # On each iteration, `subexpression` and `result` will become: + # + # (3 - x) - 10 == $logical_value - 2 [subtract 2 from both sides] + # (3 - x) == ($logical_value - 2) + 10 [add 10 to both sides] + # x == 3 - (($logical_value - 2) + 10) [subtract both sides from 3] + # + # This is an extremely limited algebraic solver, but it covers common-enough + # cases. + # + # Note that any equation that can be solved here becomes part of Emboss's + # contract, forever, so be conservative in expanding its solving capabilities! + for index in reference_path: + if subexpression.function.function == ir_data.FunctionMapping.ADDITION: + result = ir_data.Expression( + function=ir_data.Function( + function=ir_data.FunctionMapping.SUBTRACTION, + args=[ + result, + subexpression.function.args[1 - index], + ], + ), + type=ir_data.ExpressionType(integer=ir_data.IntegerType()), + ) + elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION: + if index == 0: + result = ir_data.Expression( + function=ir_data.Function( + function=ir_data.FunctionMapping.ADDITION, + args=[ + result, + subexpression.function.args[1], + ], + ), + type=ir_data.ExpressionType(integer=ir_data.IntegerType()), + ) + else: + result = ir_data.Expression( + function=ir_data.Function( + function=ir_data.FunctionMapping.SUBTRACTION, + args=[ + subexpression.function.args[0], + result, + ], + ), + type=ir_data.ExpressionType(integer=ir_data.IntegerType()), + ) + else: + return None + subexpression = subexpression.function.args[index] + expression_bounds.compute_constraints_of_expression(result, ir) + return subexpression, result def _add_write_method(field, ir): - """Adds an appropriate write_method to field, if applicable. - - Currently, the "alias" write_method will be added for virtual fields of the - form `let v = some_field_reference` when `some_field_reference` is a physical - field or a writeable alias. The "physical" write_method will be added for - physical fields. The "transform" write_method will be added when the virtual - field's value is an easily-invertible function of a single writeable field. - All other fields will have the "read_only" write_method; i.e., they will not - be writeable. - - Arguments: - field: an ir_data.Field to which to add a write_method. - ir: The IR in which to look up field_references. - - Returns: - None - """ - if field.HasField("write_method"): - # Do not recompute anything. - return - - if not ir_util.field_is_virtual(field): - # If the field is not virtual, writes are physical. - ir_data_utils.builder(field).write_method.physical = True - return - - field_checker = ir_data_utils.reader(field) - field_builder = ir_data_utils.builder(field) - - # A virtual field cannot be a direct alias if it has an additional - # requirement. - requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES) - if (field_checker.read_transform.WhichOneof("expression") != "field_reference" or - requires_attr is not None): - inverse = _invert_expression(field.read_transform, ir) - if inverse: - field_reference, function_body = inverse - referenced_field = ir_util.find_object( - field_reference.field_reference.path[-1], ir) - if not isinstance(referenced_field, ir_data.Field): - reference_is_read_only = True - else: - _add_write_method(referenced_field, ir) - reference_is_read_only = referenced_field.write_method.read_only - if not reference_is_read_only: - field_builder.write_method.transform.destination.CopyFrom( - field_reference.field_reference) - field_builder.write_method.transform.function_body.CopyFrom(function_body) - else: - # If the virtual field's expression is invertible, but its target field - # is read-only, it is also read-only. + """Adds an appropriate write_method to field, if applicable. + + Currently, the "alias" write_method will be added for virtual fields of the + form `let v = some_field_reference` when `some_field_reference` is a physical + field or a writeable alias. The "physical" write_method will be added for + physical fields. The "transform" write_method will be added when the virtual + field's value is an easily-invertible function of a single writeable field. + All other fields will have the "read_only" write_method; i.e., they will not + be writeable. + + Arguments: + field: an ir_data.Field to which to add a write_method. + ir: The IR in which to look up field_references. + + Returns: + None + """ + if field.HasField("write_method"): + # Do not recompute anything. + return + + if not ir_util.field_is_virtual(field): + # If the field is not virtual, writes are physical. + ir_data_utils.builder(field).write_method.physical = True + return + + field_checker = ir_data_utils.reader(field) + field_builder = ir_data_utils.builder(field) + + # A virtual field cannot be a direct alias if it has an additional + # requirement. + requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES) + if ( + field_checker.read_transform.WhichOneof("expression") != "field_reference" + or requires_attr is not None + ): + inverse = _invert_expression(field.read_transform, ir) + if inverse: + field_reference, function_body = inverse + referenced_field = ir_util.find_object( + field_reference.field_reference.path[-1], ir + ) + if not isinstance(referenced_field, ir_data.Field): + reference_is_read_only = True + else: + _add_write_method(referenced_field, ir) + reference_is_read_only = referenced_field.write_method.read_only + if not reference_is_read_only: + field_builder.write_method.transform.destination.CopyFrom( + field_reference.field_reference + ) + field_builder.write_method.transform.function_body.CopyFrom( + function_body + ) + else: + # If the virtual field's expression is invertible, but its target field + # is read-only, it is also read-only. + field_builder.write_method.read_only = True + else: + # If the virtual field's expression is not invertible, it is + # read-only. + field_builder.write_method.read_only = True + return + + referenced_field = ir_util.find_object( + field.read_transform.field_reference.path[-1], ir + ) + if not isinstance(referenced_field, ir_data.Field): + # If the virtual field aliases a non-field (i.e., a parameter), it is + # read-only. field_builder.write_method.read_only = True - else: - # If the virtual field's expression is not invertible, it is - # read-only. - field_builder.write_method.read_only = True - return - - referenced_field = ir_util.find_object( - field.read_transform.field_reference.path[-1], ir) - if not isinstance(referenced_field, ir_data.Field): - # If the virtual field aliases a non-field (i.e., a parameter), it is - # read-only. - field_builder.write_method.read_only = True - return - - _add_write_method(referenced_field, ir) - if referenced_field.write_method.read_only: - # If the virtual field directly aliases a read-only field, it is read-only. - field_builder.write_method.read_only = True - return - - # Otherwise, it can be written as a direct alias. - field_builder.write_method.alias.CopyFrom( - field.read_transform.field_reference) + return + + _add_write_method(referenced_field, ir) + if referenced_field.write_method.read_only: + # If the virtual field directly aliases a read-only field, it is read-only. + field_builder.write_method.read_only = True + return + + # Otherwise, it can be written as a direct alias. + field_builder.write_method.alias.CopyFrom(field.read_transform.field_reference) def set_write_methods(ir): - """Sets the write_method member of all ir_data.Fields in ir. + """Sets the write_method member of all ir_data.Fields in ir. - Arguments: - ir: The IR to which to add write_methods. + Arguments: + ir: The IR to which to add write_methods. - Returns: - A list of errors, or an empty list. - """ - traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method) - return [] + Returns: + A list of errors, or an empty list. + """ + traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method) + return [] diff --git a/compiler/front_end/write_inference_test.py b/compiler/front_end/write_inference_test.py index d1de5f2..c6afa2f 100644 --- a/compiler/front_end/write_inference_test.py +++ b/compiler/front_end/write_inference_test.py @@ -23,194 +23,193 @@ class WriteInferenceTest(unittest.TestCase): - def _make_ir(self, emb_text): - ir, unused_debug_info, errors = glue.parse_emboss_file( - "m.emb", - test_util.dict_file_reader({"m.emb": emb_text}), - stop_before_step="set_write_methods") - assert not errors, errors - return ir - - def test_adds_physical_write_method(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - self.assertTrue( - ir.module[0].type[0].structure.field[0].write_method.physical) - - def test_adds_read_only_write_method_to_non_alias_virtual(self): - ir = self._make_ir("struct Foo:\n" - " let x = 5\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - self.assertTrue( - ir.module[0].type[0].structure.field[0].write_method.read_only) - - def test_adds_alias_write_method_to_alias_of_physical_field(self): - ir = self._make_ir("struct Foo:\n" - " let x = y\n" - " 0 [+1] UInt y\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[0] - self.assertTrue(field.write_method.HasField("alias")) - self.assertEqual( - "y", field.write_method.alias.path[0].canonical_name.object_path[-1]) - - def test_adds_alias_write_method_to_alias_of_alias_of_physical_field(self): - ir = self._make_ir("struct Foo:\n" - " let x = z\n" - " let z = y\n" - " 0 [+1] UInt y\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[0] - self.assertTrue(field.write_method.HasField("alias")) - self.assertEqual( - "z", field.write_method.alias.path[0].canonical_name.object_path[-1]) - - def test_adds_read_only_write_method_to_alias_of_read_only(self): - ir = self._make_ir("struct Foo:\n" - " let x = y\n" - " let y = 5\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[0] - self.assertTrue(field.write_method.read_only) - - def test_adds_read_only_write_method_to_alias_of_alias_of_read_only(self): - ir = self._make_ir("struct Foo:\n" - " let x = z\n" - " let z = y\n" - " let y = 5\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[0] - self.assertTrue(field.write_method.read_only) - - def test_adds_read_only_write_method_to_alias_of_parameter(self): - ir = self._make_ir("struct Foo(x: UInt:8):\n" - " let y = x\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[0] - self.assertTrue(field.write_method.read_only) - - def test_adds_transform_write_method_to_base_value_field(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = x + 50\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[1] - transform = field.write_method.transform - self.assertTrue(transform) - self.assertEqual( - "x", - transform.destination.path[0].canonical_name.object_path[-1]) - self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, - transform.function_body.function.function) - arg0, arg1 = transform.function_body.function.args - self.assertEqual("$logical_value", - arg0.builtin_reference.canonical_name.object_path[0]) - self.assertEqual("50", arg1.constant.value) - - def test_adds_transform_write_method_to_negative_base_value_field(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = x - 50\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[1] - transform = field.write_method.transform - self.assertTrue(transform) - self.assertEqual( - "x", - transform.destination.path[0].canonical_name.object_path[-1]) - self.assertEqual(ir_data.FunctionMapping.ADDITION, - transform.function_body.function.function) - arg0, arg1 = transform.function_body.function.args - self.assertEqual("$logical_value", - arg0.builtin_reference.canonical_name.object_path[0]) - self.assertEqual("50", arg1.constant.value) - - def test_adds_transform_write_method_to_reversed_base_value_field(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = 50 + x\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[1] - transform = field.write_method.transform - self.assertTrue(transform) - self.assertEqual( - "x", - transform.destination.path[0].canonical_name.object_path[-1]) - self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, - transform.function_body.function.function) - arg0, arg1 = transform.function_body.function.args - self.assertEqual("$logical_value", - arg0.builtin_reference.canonical_name.object_path[0]) - self.assertEqual("50", arg1.constant.value) - - def test_adds_transform_write_method_to_reversed_negative_base_value_field( - self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = 50 - x\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[1] - transform = field.write_method.transform - self.assertTrue(transform) - self.assertEqual( - "x", - transform.destination.path[0].canonical_name.object_path[-1]) - self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, - transform.function_body.function.function) - arg0, arg1 = transform.function_body.function.args - self.assertEqual("50", arg0.constant.value) - self.assertEqual("$logical_value", - arg1.builtin_reference.canonical_name.object_path[0]) - - def test_adds_transform_write_method_to_nested_invertible_field(self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = 30 + (50 - x)\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[1] - transform = field.write_method.transform - self.assertTrue(transform) - self.assertEqual( - "x", - transform.destination.path[0].canonical_name.object_path[-1]) - self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, - transform.function_body.function.function) - arg0, arg1 = transform.function_body.function.args - self.assertEqual("50", arg0.constant.value) - self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, arg1.function.function) - arg10, arg11 = arg1.function.args - self.assertEqual("$logical_value", - arg10.builtin_reference.canonical_name.object_path[0]) - self.assertEqual("30", arg11.constant.value) - - def test_does_not_add_transform_write_method_for_parameter_target(self): - ir = self._make_ir("struct Foo(x: UInt:8):\n" - " let y = 50 + x\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[0] - self.assertEqual("read_only", field.write_method.WhichOneof("method")) - - def test_adds_transform_write_method_with_complex_auxiliary_subexpression( - self): - ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " let y = x - $max(Foo.$size_in_bytes, Foo.z)\n" - " let z = 500\n") - self.assertEqual([], write_inference.set_write_methods(ir)) - field = ir.module[0].type[0].structure.field[1] - transform = field.write_method.transform - self.assertTrue(transform) - self.assertEqual( - "x", - transform.destination.path[0].canonical_name.object_path[-1]) - self.assertEqual(ir_data.FunctionMapping.ADDITION, - transform.function_body.function.function) - args = transform.function_body.function.args - self.assertEqual("$logical_value", - args[0].builtin_reference.canonical_name.object_path[0]) - self.assertEqual(field.read_transform.function.args[1], args[1]) + def _make_ir(self, emb_text): + ir, unused_debug_info, errors = glue.parse_emboss_file( + "m.emb", + test_util.dict_file_reader({"m.emb": emb_text}), + stop_before_step="set_write_methods", + ) + assert not errors, errors + return ir + + def test_adds_physical_write_method(self): + ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + self.assertTrue(ir.module[0].type[0].structure.field[0].write_method.physical) + + def test_adds_read_only_write_method_to_non_alias_virtual(self): + ir = self._make_ir("struct Foo:\n" " let x = 5\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + self.assertTrue(ir.module[0].type[0].structure.field[0].write_method.read_only) + + def test_adds_alias_write_method_to_alias_of_physical_field(self): + ir = self._make_ir("struct Foo:\n" " let x = y\n" " 0 [+1] UInt y\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[0] + self.assertTrue(field.write_method.HasField("alias")) + self.assertEqual( + "y", field.write_method.alias.path[0].canonical_name.object_path[-1] + ) + + def test_adds_alias_write_method_to_alias_of_alias_of_physical_field(self): + ir = self._make_ir( + "struct Foo:\n" " let x = z\n" " let z = y\n" " 0 [+1] UInt y\n" + ) + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[0] + self.assertTrue(field.write_method.HasField("alias")) + self.assertEqual( + "z", field.write_method.alias.path[0].canonical_name.object_path[-1] + ) + + def test_adds_read_only_write_method_to_alias_of_read_only(self): + ir = self._make_ir("struct Foo:\n" " let x = y\n" " let y = 5\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[0] + self.assertTrue(field.write_method.read_only) + + def test_adds_read_only_write_method_to_alias_of_alias_of_read_only(self): + ir = self._make_ir( + "struct Foo:\n" " let x = z\n" " let z = y\n" " let y = 5\n" + ) + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[0] + self.assertTrue(field.write_method.read_only) + + def test_adds_read_only_write_method_to_alias_of_parameter(self): + ir = self._make_ir("struct Foo(x: UInt:8):\n" " let y = x\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[0] + self.assertTrue(field.write_method.read_only) + + def test_adds_transform_write_method_to_base_value_field(self): + ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = x + 50\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[1] + transform = field.write_method.transform + self.assertTrue(transform) + self.assertEqual( + "x", transform.destination.path[0].canonical_name.object_path[-1] + ) + self.assertEqual( + ir_data.FunctionMapping.SUBTRACTION, + transform.function_body.function.function, + ) + arg0, arg1 = transform.function_body.function.args + self.assertEqual( + "$logical_value", arg0.builtin_reference.canonical_name.object_path[0] + ) + self.assertEqual("50", arg1.constant.value) + + def test_adds_transform_write_method_to_negative_base_value_field(self): + ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = x - 50\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[1] + transform = field.write_method.transform + self.assertTrue(transform) + self.assertEqual( + "x", transform.destination.path[0].canonical_name.object_path[-1] + ) + self.assertEqual( + ir_data.FunctionMapping.ADDITION, transform.function_body.function.function + ) + arg0, arg1 = transform.function_body.function.args + self.assertEqual( + "$logical_value", arg0.builtin_reference.canonical_name.object_path[0] + ) + self.assertEqual("50", arg1.constant.value) + + def test_adds_transform_write_method_to_reversed_base_value_field(self): + ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = 50 + x\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[1] + transform = field.write_method.transform + self.assertTrue(transform) + self.assertEqual( + "x", transform.destination.path[0].canonical_name.object_path[-1] + ) + self.assertEqual( + ir_data.FunctionMapping.SUBTRACTION, + transform.function_body.function.function, + ) + arg0, arg1 = transform.function_body.function.args + self.assertEqual( + "$logical_value", arg0.builtin_reference.canonical_name.object_path[0] + ) + self.assertEqual("50", arg1.constant.value) + + def test_adds_transform_write_method_to_reversed_negative_base_value_field(self): + ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = 50 - x\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[1] + transform = field.write_method.transform + self.assertTrue(transform) + self.assertEqual( + "x", transform.destination.path[0].canonical_name.object_path[-1] + ) + self.assertEqual( + ir_data.FunctionMapping.SUBTRACTION, + transform.function_body.function.function, + ) + arg0, arg1 = transform.function_body.function.args + self.assertEqual("50", arg0.constant.value) + self.assertEqual( + "$logical_value", arg1.builtin_reference.canonical_name.object_path[0] + ) + + def test_adds_transform_write_method_to_nested_invertible_field(self): + ir = self._make_ir( + "struct Foo:\n" " 0 [+1] UInt x\n" " let y = 30 + (50 - x)\n" + ) + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[1] + transform = field.write_method.transform + self.assertTrue(transform) + self.assertEqual( + "x", transform.destination.path[0].canonical_name.object_path[-1] + ) + self.assertEqual( + ir_data.FunctionMapping.SUBTRACTION, + transform.function_body.function.function, + ) + arg0, arg1 = transform.function_body.function.args + self.assertEqual("50", arg0.constant.value) + self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, arg1.function.function) + arg10, arg11 = arg1.function.args + self.assertEqual( + "$logical_value", arg10.builtin_reference.canonical_name.object_path[0] + ) + self.assertEqual("30", arg11.constant.value) + + def test_does_not_add_transform_write_method_for_parameter_target(self): + ir = self._make_ir("struct Foo(x: UInt:8):\n" " let y = 50 + x\n") + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[0] + self.assertEqual("read_only", field.write_method.WhichOneof("method")) + + def test_adds_transform_write_method_with_complex_auxiliary_subexpression(self): + ir = self._make_ir( + "struct Foo:\n" + " 0 [+1] UInt x\n" + " let y = x - $max(Foo.$size_in_bytes, Foo.z)\n" + " let z = 500\n" + ) + self.assertEqual([], write_inference.set_write_methods(ir)) + field = ir.module[0].type[0].structure.field[1] + transform = field.write_method.transform + self.assertTrue(transform) + self.assertEqual( + "x", transform.destination.path[0].canonical_name.object_path[-1] + ) + self.assertEqual( + ir_data.FunctionMapping.ADDITION, transform.function_body.function.function + ) + args = transform.function_body.function.args + self.assertEqual( + "$logical_value", args[0].builtin_reference.canonical_name.object_path[0] + ) + self.assertEqual(field.read_transform.function.args[1], args[1]) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py index 6e04280..a83cf51 100644 --- a/compiler/util/attribute_util.py +++ b/compiler/util/attribute_util.py @@ -31,61 +31,94 @@ def _attribute_name_for_errors(attr): - if ir_data_utils.reader(attr).back_end.text: - return f"({attr.back_end.text}) {attr.name.text}" - else: - return attr.name.text + if ir_data_utils.reader(attr).back_end.text: + return f"({attr.back_end.text}) {attr.name.text}" + else: + return attr.name.text # Attribute type checkers def _is_constant_boolean(attr, module_source_file): - """Checks if the given attr is a constant boolean.""" - if not attr.value.expression.type.boolean.HasField("value"): - return [[error.error(module_source_file, - attr.value.source_location, - _BAD_TYPE_MESSAGE.format( - name=_attribute_name_for_errors(attr), - type="a constant boolean"))]] - return [] + """Checks if the given attr is a constant boolean.""" + if not attr.value.expression.type.boolean.HasField("value"): + return [ + [ + error.error( + module_source_file, + attr.value.source_location, + _BAD_TYPE_MESSAGE.format( + name=_attribute_name_for_errors(attr), type="a constant boolean" + ), + ) + ] + ] + return [] def _is_boolean(attr, module_source_file): - """Checks if the given attr is a boolean.""" - if attr.value.expression.type.WhichOneof("type") != "boolean": - return [[error.error(module_source_file, - attr.value.source_location, - _BAD_TYPE_MESSAGE.format( - name=_attribute_name_for_errors(attr), - type="a boolean"))]] - return [] + """Checks if the given attr is a boolean.""" + if attr.value.expression.type.WhichOneof("type") != "boolean": + return [ + [ + error.error( + module_source_file, + attr.value.source_location, + _BAD_TYPE_MESSAGE.format( + name=_attribute_name_for_errors(attr), type="a boolean" + ), + ) + ] + ] + return [] def _is_constant_integer(attr, module_source_file): - """Checks if the given attr is an integer constant expression.""" - if (not attr.value.HasField("expression") or - attr.value.expression.type.WhichOneof("type") != "integer"): - return [[error.error(module_source_file, - attr.value.source_location, - _BAD_TYPE_MESSAGE.format( - name=_attribute_name_for_errors(attr), - type="an integer"))]] - if not ir_util.is_constant(attr.value.expression): - return [[error.error(module_source_file, - attr.value.source_location, - _MUST_BE_CONSTANT_MESSAGE.format( - name=_attribute_name_for_errors(attr)))]] - return [] + """Checks if the given attr is an integer constant expression.""" + if ( + not attr.value.HasField("expression") + or attr.value.expression.type.WhichOneof("type") != "integer" + ): + return [ + [ + error.error( + module_source_file, + attr.value.source_location, + _BAD_TYPE_MESSAGE.format( + name=_attribute_name_for_errors(attr), type="an integer" + ), + ) + ] + ] + if not ir_util.is_constant(attr.value.expression): + return [ + [ + error.error( + module_source_file, + attr.value.source_location, + _MUST_BE_CONSTANT_MESSAGE.format( + name=_attribute_name_for_errors(attr) + ), + ) + ] + ] + return [] def _is_string(attr, module_source_file): - """Checks if the given attr is a string.""" - if not attr.value.HasField("string_constant"): - return [[error.error(module_source_file, - attr.value.source_location, - _BAD_TYPE_MESSAGE.format( - name=_attribute_name_for_errors(attr), - type="a string"))]] - return [] + """Checks if the given attr is a string.""" + if not attr.value.HasField("string_constant"): + return [ + [ + error.error( + module_source_file, + attr.value.source_location, + _BAD_TYPE_MESSAGE.format( + name=_attribute_name_for_errors(attr), type="a string" + ), + ) + ] + ] + return [] # Provide more readable names for these functions when used in attribute type @@ -97,215 +130,287 @@ def _is_string(attr, module_source_file): def string_from_list(valid_values): - """Checks if the given attr has one of the valid_values.""" - def _string_from_list(attr, module_source_file): - if ir_data_utils.reader(attr).value.string_constant.text not in valid_values: - return [[error.error(module_source_file, - attr.value.source_location, - "Attribute '{name}' must be '{options}'.".format( - name=_attribute_name_for_errors(attr), - options="' or '".join(sorted(valid_values))))]] - return [] - return _string_from_list - - -def check_attributes_in_ir(ir, - *, - back_end=None, - types=None, - module_attributes=None, - struct_attributes=None, - bits_attributes=None, - enum_attributes=None, - enum_value_attributes=None, - external_attributes=None, - structure_virtual_field_attributes=None, - structure_physical_field_attributes=None): - """Performs basic checks on all attributes in the given ir. - - This function calls _check_attributes on each attribute list in ir. - - Arguments: - ir: An ir_data.EmbossIr to check. - back_end: A string specifying the attribute qualifier to check (such as - `cpp` for `[(cpp) namespace = "foo"]`), or None to check unqualified - attributes. - - Attributes with a different qualifier will not be checked. - types: A map from attribute names to validators, such as: - { - "maximum_bits": attribute_util.INTEGER_CONSTANT, - "requires": attribute_util.BOOLEAN, - } - module_attributes: A set of (attribute_name, is_default) tuples specifying - the attributes that are allowed at module scope. - struct_attributes: A set of (attribute_name, is_default) tuples specifying - the attributes that are allowed at `struct` scope. - bits_attributes: A set of (attribute_name, is_default) tuples specifying - the attributes that are allowed at `bits` scope. - enum_attributes: A set of (attribute_name, is_default) tuples specifying - the attributes that are allowed at `enum` scope. - enum_value_attributes: A set of (attribute_name, is_default) tuples - specifying the attributes that are allowed at the scope of enum values. - external_attributes: A set of (attribute_name, is_default) tuples - specifying the attributes that are allowed at `external` scope. - structure_virtual_field_attributes: A set of (attribute_name, is_default) - tuples specifying the attributes that are allowed at the scope of - virtual fields (`let` fields) in structures (both `struct` and `bits`). - structure_physical_field_attributes: A set of (attribute_name, is_default) - tuples specifying the attributes that are allowed at the scope of - physical fields in structures (both `struct` and `bits`). - - Returns: - A list of lists of error.error, or an empty list if there were no errors. - """ - - def check_module(module, errors): - errors.extend(_check_attributes( - module.attribute, types, back_end, module_attributes, - "module '{}'".format( - module.source_file_name), module.source_file_name)) - - def check_type_definition(type_definition, source_file_name, errors): - if type_definition.HasField("structure"): - if type_definition.addressable_unit == ir_data.AddressableUnit.BYTE: - errors.extend(_check_attributes( - type_definition.attribute, types, back_end, struct_attributes, - "struct '{}'".format( - type_definition.name.name.text), source_file_name)) - elif type_definition.addressable_unit == ir_data.AddressableUnit.BIT: - errors.extend(_check_attributes( - type_definition.attribute, types, back_end, bits_attributes, - "bits '{}'".format( - type_definition.name.name.text), source_file_name)) - else: - assert False, "Unexpected addressable_unit '{}'".format( - type_definition.addressable_unit) - elif type_definition.HasField("enumeration"): - errors.extend(_check_attributes( - type_definition.attribute, types, back_end, enum_attributes, - "enum '{}'".format( - type_definition.name.name.text), source_file_name)) - elif type_definition.HasField("external"): - errors.extend(_check_attributes( - type_definition.attribute, types, back_end, external_attributes, - "external '{}'".format( - type_definition.name.name.text), source_file_name)) - - def check_struct_field(field, source_file_name, errors): - if ir_util.field_is_virtual(field): - field_attributes = structure_virtual_field_attributes - field_adjective = "virtual " - else: - field_attributes = structure_physical_field_attributes - field_adjective = "" - errors.extend(_check_attributes( - field.attribute, types, back_end, field_attributes, - "{}struct field '{}'".format(field_adjective, field.name.name.text), - source_file_name)) - - def check_enum_value(value, source_file_name, errors): - errors.extend(_check_attributes( - value.attribute, types, back_end, enum_value_attributes, - "enum value '{}'".format(value.name.name.text), source_file_name)) - - errors = [] - # TODO(bolms): Add a check that only known $default'ed attributes are - # used. - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Module], check_module, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.TypeDefinition], check_type_definition, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.Field], check_struct_field, - parameters={"errors": errors}) - traverse_ir.fast_traverse_ir_top_down( - ir, [ir_data.EnumValue], check_enum_value, - parameters={"errors": errors}) - return errors - - -def _check_attributes(attribute_list, types, back_end, attribute_specs, - context_name, module_source_file): - """Performs basic checks on the given list of attributes. - - Checks the given attribute_list for duplicates, unknown attributes, attributes - with incorrect type, and attributes whose values are not constant. - - Arguments: - attribute_list: An iterable of ir_data.Attribute. - back_end: The qualifier for attributes to check, or None. - attribute_specs: A dict of attribute names to _Attribute structures - specifying the allowed attributes. - context_name: A name for the context of these attributes, such as "struct - 'Foo'" or "module 'm.emb'". Used in error messages. - module_source_file: The value of module.source_file_name from the module - containing 'attribute_list'. Used in error messages. - - Returns: - A list of lists of error.Errors. An empty list indicates no errors were - found. - """ - if attribute_specs is None: - attribute_specs = [] - errors = [] - already_seen_attributes = {} - for attr in attribute_list: - field_checker = ir_data_utils.reader(attr) - if field_checker.back_end.text: - if attr.back_end.text != back_end: - continue - else: - if back_end is not None: - continue - attribute_name = _attribute_name_for_errors(attr) - attr_key = (field_checker.name.text, field_checker.is_default) - if attr_key in already_seen_attributes: - original_attr = already_seen_attributes[attr_key] - errors.append([ - error.error(module_source_file, - attr.source_location, - "Duplicate attribute '{}'.".format(attribute_name)), - error.note(module_source_file, - original_attr.source_location, - "Original attribute")]) - continue - already_seen_attributes[attr_key] = attr - - if attr_key not in attribute_specs: - if attr.is_default: - error_message = "Attribute '{}' may not be defaulted on {}.".format( - attribute_name, context_name) - else: - error_message = "Unknown attribute '{}' on {}.".format(attribute_name, - context_name) - errors.append([error.error(module_source_file, - attr.name.source_location, - error_message)]) - else: - errors.extend(types[attr.name.text](attr, module_source_file)) - return errors + """Checks if the given attr has one of the valid_values.""" + + def _string_from_list(attr, module_source_file): + if ir_data_utils.reader(attr).value.string_constant.text not in valid_values: + return [ + [ + error.error( + module_source_file, + attr.value.source_location, + "Attribute '{name}' must be '{options}'.".format( + name=_attribute_name_for_errors(attr), + options="' or '".join(sorted(valid_values)), + ), + ) + ] + ] + return [] + + return _string_from_list + + +def check_attributes_in_ir( + ir, + *, + back_end=None, + types=None, + module_attributes=None, + struct_attributes=None, + bits_attributes=None, + enum_attributes=None, + enum_value_attributes=None, + external_attributes=None, + structure_virtual_field_attributes=None, + structure_physical_field_attributes=None, +): + """Performs basic checks on all attributes in the given ir. + + This function calls _check_attributes on each attribute list in ir. + + Arguments: + ir: An ir_data.EmbossIr to check. + back_end: A string specifying the attribute qualifier to check (such as + `cpp` for `[(cpp) namespace = "foo"]`), or None to check unqualified + attributes. + + Attributes with a different qualifier will not be checked. + types: A map from attribute names to validators, such as: + { + "maximum_bits": attribute_util.INTEGER_CONSTANT, + "requires": attribute_util.BOOLEAN, + } + module_attributes: A set of (attribute_name, is_default) tuples specifying + the attributes that are allowed at module scope. + struct_attributes: A set of (attribute_name, is_default) tuples specifying + the attributes that are allowed at `struct` scope. + bits_attributes: A set of (attribute_name, is_default) tuples specifying + the attributes that are allowed at `bits` scope. + enum_attributes: A set of (attribute_name, is_default) tuples specifying + the attributes that are allowed at `enum` scope. + enum_value_attributes: A set of (attribute_name, is_default) tuples + specifying the attributes that are allowed at the scope of enum values. + external_attributes: A set of (attribute_name, is_default) tuples + specifying the attributes that are allowed at `external` scope. + structure_virtual_field_attributes: A set of (attribute_name, is_default) + tuples specifying the attributes that are allowed at the scope of + virtual fields (`let` fields) in structures (both `struct` and `bits`). + structure_physical_field_attributes: A set of (attribute_name, is_default) + tuples specifying the attributes that are allowed at the scope of + physical fields in structures (both `struct` and `bits`). + + Returns: + A list of lists of error.error, or an empty list if there were no errors. + """ + + def check_module(module, errors): + errors.extend( + _check_attributes( + module.attribute, + types, + back_end, + module_attributes, + "module '{}'".format(module.source_file_name), + module.source_file_name, + ) + ) + + def check_type_definition(type_definition, source_file_name, errors): + if type_definition.HasField("structure"): + if type_definition.addressable_unit == ir_data.AddressableUnit.BYTE: + errors.extend( + _check_attributes( + type_definition.attribute, + types, + back_end, + struct_attributes, + "struct '{}'".format(type_definition.name.name.text), + source_file_name, + ) + ) + elif type_definition.addressable_unit == ir_data.AddressableUnit.BIT: + errors.extend( + _check_attributes( + type_definition.attribute, + types, + back_end, + bits_attributes, + "bits '{}'".format(type_definition.name.name.text), + source_file_name, + ) + ) + else: + assert False, "Unexpected addressable_unit '{}'".format( + type_definition.addressable_unit + ) + elif type_definition.HasField("enumeration"): + errors.extend( + _check_attributes( + type_definition.attribute, + types, + back_end, + enum_attributes, + "enum '{}'".format(type_definition.name.name.text), + source_file_name, + ) + ) + elif type_definition.HasField("external"): + errors.extend( + _check_attributes( + type_definition.attribute, + types, + back_end, + external_attributes, + "external '{}'".format(type_definition.name.name.text), + source_file_name, + ) + ) + + def check_struct_field(field, source_file_name, errors): + if ir_util.field_is_virtual(field): + field_attributes = structure_virtual_field_attributes + field_adjective = "virtual " + else: + field_attributes = structure_physical_field_attributes + field_adjective = "" + errors.extend( + _check_attributes( + field.attribute, + types, + back_end, + field_attributes, + "{}struct field '{}'".format(field_adjective, field.name.name.text), + source_file_name, + ) + ) + + def check_enum_value(value, source_file_name, errors): + errors.extend( + _check_attributes( + value.attribute, + types, + back_end, + enum_value_attributes, + "enum value '{}'".format(value.name.name.text), + source_file_name, + ) + ) + + errors = [] + # TODO(bolms): Add a check that only known $default'ed attributes are + # used. + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Module], check_module, parameters={"errors": errors} + ) + traverse_ir.fast_traverse_ir_top_down( + ir, + [ir_data.TypeDefinition], + check_type_definition, + parameters={"errors": errors}, + ) + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.Field], check_struct_field, parameters={"errors": errors} + ) + traverse_ir.fast_traverse_ir_top_down( + ir, [ir_data.EnumValue], check_enum_value, parameters={"errors": errors} + ) + return errors + + +def _check_attributes( + attribute_list, types, back_end, attribute_specs, context_name, module_source_file +): + """Performs basic checks on the given list of attributes. + + Checks the given attribute_list for duplicates, unknown attributes, attributes + with incorrect type, and attributes whose values are not constant. + + Arguments: + attribute_list: An iterable of ir_data.Attribute. + back_end: The qualifier for attributes to check, or None. + attribute_specs: A dict of attribute names to _Attribute structures + specifying the allowed attributes. + context_name: A name for the context of these attributes, such as "struct + 'Foo'" or "module 'm.emb'". Used in error messages. + module_source_file: The value of module.source_file_name from the module + containing 'attribute_list'. Used in error messages. + + Returns: + A list of lists of error.Errors. An empty list indicates no errors were + found. + """ + if attribute_specs is None: + attribute_specs = [] + errors = [] + already_seen_attributes = {} + for attr in attribute_list: + field_checker = ir_data_utils.reader(attr) + if field_checker.back_end.text: + if attr.back_end.text != back_end: + continue + else: + if back_end is not None: + continue + attribute_name = _attribute_name_for_errors(attr) + attr_key = (field_checker.name.text, field_checker.is_default) + if attr_key in already_seen_attributes: + original_attr = already_seen_attributes[attr_key] + errors.append( + [ + error.error( + module_source_file, + attr.source_location, + "Duplicate attribute '{}'.".format(attribute_name), + ), + error.note( + module_source_file, + original_attr.source_location, + "Original attribute", + ), + ] + ) + continue + already_seen_attributes[attr_key] = attr + + if attr_key not in attribute_specs: + if attr.is_default: + error_message = "Attribute '{}' may not be defaulted on {}.".format( + attribute_name, context_name + ) + else: + error_message = "Unknown attribute '{}' on {}.".format( + attribute_name, context_name + ) + errors.append( + [ + error.error( + module_source_file, attr.name.source_location, error_message + ) + ] + ) + else: + errors.extend(types[attr.name.text](attr, module_source_file)) + return errors def gather_default_attributes(obj, defaults): - """Gathers default attributes for an IR object - - This is designed to be able to be used as-is as an incidental action in an IR - traversal to accumulate defaults for child nodes. - - Arguments: - defaults: A dict of `{ "defaults": { attr.name.text: attr } }` - - Returns: - A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults - provided by `obj` added/overridden. - """ - defaults = defaults.copy() - for attr in obj.attribute: - if attr.is_default: - defaulted_attr = ir_data_utils.copy(attr) - defaulted_attr.is_default = False - defaults[attr.name.text] = defaulted_attr - return {"defaults": defaults} + """Gathers default attributes for an IR object + + This is designed to be able to be used as-is as an incidental action in an IR + traversal to accumulate defaults for child nodes. + + Arguments: + defaults: A dict of `{ "defaults": { attr.name.text: attr } }` + + Returns: + A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults + provided by `obj` added/overridden. + """ + defaults = defaults.copy() + for attr in obj.attribute: + if attr.is_default: + defaulted_attr = ir_data_utils.copy(attr) + defaulted_attr.is_default = False + defaults[attr.name.text] = defaulted_attr + return {"defaults": defaults} diff --git a/compiler/util/error.py b/compiler/util/error.py index e408c71..a22fa4a 100644 --- a/compiler/util/error.py +++ b/compiler/util/error.py @@ -66,186 +66,199 @@ BOLD = "\033[0;1m" RESET = "\033[0m" + def _copy(location): - location = ir_data_utils.copy(location) - if not location: - location = parser_types.make_location((0,0), (0,0)) - return location + location = ir_data_utils.copy(location) + if not location: + location = parser_types.make_location((0, 0), (0, 0)) + return location def error(source_file, location, message): - """Returns an object representing an error message.""" - return _Message(source_file, _copy(location), ERROR, message) + """Returns an object representing an error message.""" + return _Message(source_file, _copy(location), ERROR, message) def warn(source_file, location, message): - """Returns an object representing a warning.""" - return _Message(source_file, _copy(location), WARNING, message) + """Returns an object representing a warning.""" + return _Message(source_file, _copy(location), WARNING, message) def note(source_file, location, message): - """Returns and object representing an informational note.""" - return _Message(source_file, _copy(location), NOTE, message) + """Returns and object representing an informational note.""" + return _Message(source_file, _copy(location), NOTE, message) class _Message(object): - """_Message holds a human-readable message.""" - __slots__ = ("location", "source_file", "severity", "message") + """_Message holds a human-readable message.""" + + __slots__ = ("location", "source_file", "severity", "message") + + def __init__(self, source_file, location, severity, message): + self.location = location + self.source_file = source_file + self.severity = severity + self.message = message + + def format(self, source_code): + """Formats the _Message for display. + + Arguments: + source_code: A dict of file names to source texts. This is used to + render source snippets. + + Returns: + A list of tuples. + + The first element of each tuple is an escape sequence used to put a Unix + terminal into a particular color mode. For use in non-Unix-terminal + output, the string will match one of the color names exported by this + module. + + The second element is a string containing text to show to the user. + + The text will not end with a newline character, nor will it include a + RESET color element. + + To show non-colorized output, simply write the second element of each + tuple, then a newline at the end. + + To show colorized output, write both the first and second element of each + tuple, then a newline at the end. Before exiting to the operating system, + a RESET sequence should be emitted. + """ + # TODO(bolms): Figure out how to get Vim, Emacs, etc. to parse Emboss error + # messages. + severity_colors = { + ERROR: (BRIGHT_RED, BOLD), + WARNING: (BRIGHT_MAGENTA, BOLD), + NOTE: (BRIGHT_BLACK, WHITE), + } + + result = [] + if self.location.is_synthetic: + pos = "[compiler bug]" + else: + pos = parser_types.format_position(self.location.start) + source_name = self.source_file or "[prelude]" + if not self.location.is_synthetic and self.source_file in source_code: + source_lines = source_code[self.source_file].splitlines() + source_line = source_lines[self.location.start.line - 1] + else: + source_line = "" + lines = self.message.splitlines() + for i in range(len(lines)): + line = lines[i] + # This is a little awkward, but we want to suppress the final newline in + # the message. This newline is final if and only if it is the last line + # of the message and there is no source snippet. + if i != len(lines) - 1 or source_line: + line += "\n" + result.append((BOLD, "{}:{}: ".format(source_name, pos))) + if i == 0: + severity = self.severity + else: + severity = NOTE + result.append((severity_colors[severity][0], "{}: ".format(severity))) + result.append((severity_colors[severity][1], line)) + if source_line: + result.append((WHITE, source_line + "\n")) + indicator_indent = " " * (self.location.start.column - 1) + if self.location.start.line == self.location.end.line: + indicator_caret = "^" * max( + 1, self.location.end.column - self.location.start.column + ) + else: + indicator_caret = "^" + result.append((BRIGHT_GREEN, indicator_indent + indicator_caret)) + return result + + def __repr__(self): + return ( + "Message({source_file!r}, make_location(({start_line!r}, " + "{start_column!r}), ({end_line!r}, {end_column!r}), " + "{is_synthetic!r}), {severity!r}, {message!r})" + ).format( + source_file=self.source_file, + start_line=self.location.start.line, + start_column=self.location.start.column, + end_line=self.location.end.line, + end_column=self.location.end.column, + is_synthetic=self.location.is_synthetic, + severity=self.severity, + message=self.message, + ) + + def __eq__(self, other): + return ( + self.__class__ == other.__class__ + and self.location == other.location + and self.source_file == other.source_file + and self.severity == other.severity + and self.message == other.message + ) + + def __ne__(self, other): + return not self == other - def __init__(self, source_file, location, severity, message): - self.location = location - self.source_file = source_file - self.severity = severity - self.message = message - def format(self, source_code): - """Formats the _Message for display. +def split_errors(errors): + """Splits errors into (user_errors, synthetic_errors). Arguments: - source_code: A dict of file names to source texts. This is used to - render source snippets. + errors: A list of lists of _Message, which is a list of bundles of + associated messages. Returns: - A list of tuples. - - The first element of each tuple is an escape sequence used to put a Unix - terminal into a particular color mode. For use in non-Unix-terminal - output, the string will match one of the color names exported by this - module. - - The second element is a string containing text to show to the user. + (user_errors, synthetic_errors), where both user_errors and + synthetic_errors are lists of lists of _Message. synthetic_errors will + contain all bundles that reference any synthetic source_location, and + user_errors will contain the rest. - The text will not end with a newline character, nor will it include a - RESET color element. - - To show non-colorized output, simply write the second element of each - tuple, then a newline at the end. - - To show colorized output, write both the first and second element of each - tuple, then a newline at the end. Before exiting to the operating system, - a RESET sequence should be emitted. + The intent is that user_errors can be shown to end users, while + synthetic_errors should generally be suppressed. """ - # TODO(bolms): Figure out how to get Vim, Emacs, etc. to parse Emboss error - # messages. - severity_colors = { - ERROR: (BRIGHT_RED, BOLD), - WARNING: (BRIGHT_MAGENTA, BOLD), - NOTE: (BRIGHT_BLACK, WHITE) - } - - result = [] - if self.location.is_synthetic: - pos = "[compiler bug]" - else: - pos = parser_types.format_position(self.location.start) - source_name = self.source_file or "[prelude]" - if not self.location.is_synthetic and self.source_file in source_code: - source_lines = source_code[self.source_file].splitlines() - source_line = source_lines[self.location.start.line - 1] - else: - source_line = "" - lines = self.message.splitlines() - for i in range(len(lines)): - line = lines[i] - # This is a little awkward, but we want to suppress the final newline in - # the message. This newline is final if and only if it is the last line - # of the message and there is no source snippet. - if i != len(lines) - 1 or source_line: - line += "\n" - result.append((BOLD, "{}:{}: ".format(source_name, pos))) - if i == 0: - severity = self.severity - else: - severity = NOTE - result.append((severity_colors[severity][0], "{}: ".format(severity))) - result.append((severity_colors[severity][1], line)) - if source_line: - result.append((WHITE, source_line + "\n")) - indicator_indent = " " * (self.location.start.column - 1) - if self.location.start.line == self.location.end.line: - indicator_caret = "^" * max( - 1, self.location.end.column - self.location.start.column) - else: - indicator_caret = "^" - result.append((BRIGHT_GREEN, indicator_indent + indicator_caret)) - return result - - def __repr__(self): - return ("Message({source_file!r}, make_location(({start_line!r}, " - "{start_column!r}), ({end_line!r}, {end_column!r}), " - "{is_synthetic!r}), {severity!r}, {message!r})").format( - source_file=self.source_file, - start_line=self.location.start.line, - start_column=self.location.start.column, - end_line=self.location.end.line, - end_column=self.location.end.column, - is_synthetic=self.location.is_synthetic, - severity=self.severity, - message=self.message) - - def __eq__(self, other): - return ( - self.__class__ == other.__class__ and self.location == other.location - and self.source_file == other.source_file and - self.severity == other.severity and self.message == other.message) - - def __ne__(self, other): - return not self == other - - -def split_errors(errors): - """Splits errors into (user_errors, synthetic_errors). - - Arguments: - errors: A list of lists of _Message, which is a list of bundles of - associated messages. - - Returns: - (user_errors, synthetic_errors), where both user_errors and - synthetic_errors are lists of lists of _Message. synthetic_errors will - contain all bundles that reference any synthetic source_location, and - user_errors will contain the rest. - - The intent is that user_errors can be shown to end users, while - synthetic_errors should generally be suppressed. - """ - synthetic_errors = [] - user_errors = [] - for error_block in errors: - if any(message.location.is_synthetic for message in error_block): - synthetic_errors.append(error_block) - else: - user_errors.append(error_block) - return user_errors, synthetic_errors + synthetic_errors = [] + user_errors = [] + for error_block in errors: + if any(message.location.is_synthetic for message in error_block): + synthetic_errors.append(error_block) + else: + user_errors.append(error_block) + return user_errors, synthetic_errors def filter_errors(errors): - """Returns the non-synthetic errors from `errors`.""" - return split_errors(errors)[0] + """Returns the non-synthetic errors from `errors`.""" + return split_errors(errors)[0] def format_errors(errors, source_codes, use_color=False): - """Formats error messages with source code snippets.""" - result = [] - for error_group in errors: - assert error_group, "Found empty error_group!" - for message in error_group: - if use_color: - result.append("".join(e[0] + e[1] + RESET - for e in message.format(source_codes))) - else: - result.append("".join(e[1] for e in message.format(source_codes))) - return "\n".join(result) + """Formats error messages with source code snippets.""" + result = [] + for error_group in errors: + assert error_group, "Found empty error_group!" + for message in error_group: + if use_color: + result.append( + "".join(e[0] + e[1] + RESET for e in message.format(source_codes)) + ) + else: + result.append("".join(e[1] for e in message.format(source_codes))) + return "\n".join(result) def make_error_from_parse_error(file_name, parse_error): - return [error(file_name, - parse_error.token.source_location, - "{code}\n" - "Found {text!r} ({symbol}), expected {expected}.".format( - code=parse_error.code or "Syntax error", - text=parse_error.token.text, - symbol=parse_error.token.symbol, - expected=", ".join(parse_error.expected_tokens)))] - - + return [ + error( + file_name, + parse_error.token.source_location, + "{code}\n" + "Found {text!r} ({symbol}), expected {expected}.".format( + code=parse_error.code or "Syntax error", + text=parse_error.token.text, + symbol=parse_error.token.symbol, + expected=", ".join(parse_error.expected_tokens), + ), + ) + ] diff --git a/compiler/util/error_test.py b/compiler/util/error_test.py index 7d2577f..23beddd 100644 --- a/compiler/util/error_test.py +++ b/compiler/util/error_test.py @@ -21,325 +21,463 @@ class MessageTest(unittest.TestCase): - """Tests for _Message, as returned by error, warn, and note.""" + """Tests for _Message, as returned by error, warn, and note.""" - def test_error(self): - error_message = error.error("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "Bad thing") - self.assertEqual("foo.emb", error_message.source_file) - self.assertEqual(error.ERROR, error_message.severity) - self.assertEqual(parser_types.make_location((3, 4), (3, 6)), - error_message.location) - self.assertEqual("Bad thing", error_message.message) - sourceless_format = error_message.format({}) - sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"}) - self.assertEqual("foo.emb:3:4: error: Bad thing", - "".join([x[1] for x in sourceless_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing"), # Message - ], sourceless_format) - self.assertEqual("foo.emb:3:4: error: Bad thing\n" - "abcdefghijklm\n" - " ^^", "".join([x[1] for x in sourced_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing\n"), # Message - (error.WHITE, "abcdefghijklm\n"), # Source snippet - (error.BRIGHT_GREEN, " ^^"), # Error column indicator - ], sourced_format) + def test_error(self): + error_message = error.error( + "foo.emb", parser_types.make_location((3, 4), (3, 6)), "Bad thing" + ) + self.assertEqual("foo.emb", error_message.source_file) + self.assertEqual(error.ERROR, error_message.severity) + self.assertEqual( + parser_types.make_location((3, 4), (3, 6)), error_message.location + ) + self.assertEqual("Bad thing", error_message.message) + sourceless_format = error_message.format({}) + sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"}) + self.assertEqual( + "foo.emb:3:4: error: Bad thing", "".join([x[1] for x in sourceless_format]) + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing"), # Message + ], + sourceless_format, + ) + self.assertEqual( + "foo.emb:3:4: error: Bad thing\n" "abcdefghijklm\n" " ^^", + "".join([x[1] for x in sourced_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing\n"), # Message + (error.WHITE, "abcdefghijklm\n"), # Source snippet + (error.BRIGHT_GREEN, " ^^"), # Error column indicator + ], + sourced_format, + ) - def test_synthetic_error(self): - error_message = error.error("foo.emb", parser_types.make_location( - (3, 4), (3, 6), True), "Bad thing") - sourceless_format = error_message.format({}) - sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"}) - self.assertEqual("foo.emb:[compiler bug]: error: Bad thing", - "".join([x[1] for x in sourceless_format])) - self.assertEqual([ - (error.BOLD, "foo.emb:[compiler bug]: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing"), # Message - ], sourceless_format) - self.assertEqual("foo.emb:[compiler bug]: error: Bad thing", - "".join([x[1] for x in sourced_format])) - self.assertEqual([ - (error.BOLD, "foo.emb:[compiler bug]: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing"), # Message - ], sourced_format) + def test_synthetic_error(self): + error_message = error.error( + "foo.emb", parser_types.make_location((3, 4), (3, 6), True), "Bad thing" + ) + sourceless_format = error_message.format({}) + sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"}) + self.assertEqual( + "foo.emb:[compiler bug]: error: Bad thing", + "".join([x[1] for x in sourceless_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:[compiler bug]: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing"), # Message + ], + sourceless_format, + ) + self.assertEqual( + "foo.emb:[compiler bug]: error: Bad thing", + "".join([x[1] for x in sourced_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:[compiler bug]: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing"), # Message + ], + sourced_format, + ) - def test_prelude_as_file_name(self): - error_message = error.error("", parser_types.make_location( - (3, 4), (3, 6)), "Bad thing") - self.assertEqual("", error_message.source_file) - self.assertEqual(error.ERROR, error_message.severity) - self.assertEqual(parser_types.make_location((3, 4), (3, 6)), - error_message.location) - self.assertEqual("Bad thing", error_message.message) - sourceless_format = error_message.format({}) - sourced_format = error_message.format({"": "\n\nabcdefghijklm"}) - self.assertEqual("[prelude]:3:4: error: Bad thing", - "".join([x[1] for x in sourceless_format])) - self.assertEqual([(error.BOLD, "[prelude]:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing"), # Message - ], sourceless_format) - self.assertEqual("[prelude]:3:4: error: Bad thing\n" - "abcdefghijklm\n" - " ^^", "".join([x[1] for x in sourced_format])) - self.assertEqual([(error.BOLD, "[prelude]:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing\n"), # Message - (error.WHITE, "abcdefghijklm\n"), # Source snippet - (error.BRIGHT_GREEN, " ^^"), # Error column indicator - ], sourced_format) + def test_prelude_as_file_name(self): + error_message = error.error( + "", parser_types.make_location((3, 4), (3, 6)), "Bad thing" + ) + self.assertEqual("", error_message.source_file) + self.assertEqual(error.ERROR, error_message.severity) + self.assertEqual( + parser_types.make_location((3, 4), (3, 6)), error_message.location + ) + self.assertEqual("Bad thing", error_message.message) + sourceless_format = error_message.format({}) + sourced_format = error_message.format({"": "\n\nabcdefghijklm"}) + self.assertEqual( + "[prelude]:3:4: error: Bad thing", + "".join([x[1] for x in sourceless_format]), + ) + self.assertEqual( + [ + (error.BOLD, "[prelude]:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing"), # Message + ], + sourceless_format, + ) + self.assertEqual( + "[prelude]:3:4: error: Bad thing\n" "abcdefghijklm\n" " ^^", + "".join([x[1] for x in sourced_format]), + ) + self.assertEqual( + [ + (error.BOLD, "[prelude]:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing\n"), # Message + (error.WHITE, "abcdefghijklm\n"), # Source snippet + (error.BRIGHT_GREEN, " ^^"), # Error column indicator + ], + sourced_format, + ) - def test_multiline_error_source(self): - error_message = error.error("foo.emb", parser_types.make_location( - (3, 4), (4, 6)), "Bad thing") - self.assertEqual("foo.emb", error_message.source_file) - self.assertEqual(error.ERROR, error_message.severity) - self.assertEqual(parser_types.make_location((3, 4), (4, 6)), - error_message.location) - self.assertEqual("Bad thing", error_message.message) - sourceless_format = error_message.format({}) - sourced_format = error_message.format( - {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"}) - self.assertEqual("foo.emb:3:4: error: Bad thing", - "".join([x[1] for x in sourceless_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing"), # Message - ], sourceless_format) - self.assertEqual("foo.emb:3:4: error: Bad thing\n" - "abcdefghijklm\n" - " ^", "".join([x[1] for x in sourced_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing\n"), # Message - (error.WHITE, "abcdefghijklm\n"), # Source snippet - (error.BRIGHT_GREEN, " ^"), # Error column indicator - ], sourced_format) + def test_multiline_error_source(self): + error_message = error.error( + "foo.emb", parser_types.make_location((3, 4), (4, 6)), "Bad thing" + ) + self.assertEqual("foo.emb", error_message.source_file) + self.assertEqual(error.ERROR, error_message.severity) + self.assertEqual( + parser_types.make_location((3, 4), (4, 6)), error_message.location + ) + self.assertEqual("Bad thing", error_message.message) + sourceless_format = error_message.format({}) + sourced_format = error_message.format( + {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"} + ) + self.assertEqual( + "foo.emb:3:4: error: Bad thing", "".join([x[1] for x in sourceless_format]) + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing"), # Message + ], + sourceless_format, + ) + self.assertEqual( + "foo.emb:3:4: error: Bad thing\n" "abcdefghijklm\n" " ^", + "".join([x[1] for x in sourced_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing\n"), # Message + (error.WHITE, "abcdefghijklm\n"), # Source snippet + (error.BRIGHT_GREEN, " ^"), # Error column indicator + ], + sourced_format, + ) - def test_multiline_error(self): - error_message = error.error("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "Bad thing\nSome explanation\nMore explanation") - self.assertEqual("foo.emb", error_message.source_file) - self.assertEqual(error.ERROR, error_message.severity) - self.assertEqual(parser_types.make_location((3, 4), (3, 6)), - error_message.location) - self.assertEqual("Bad thing\nSome explanation\nMore explanation", - error_message.message) - sourceless_format = error_message.format({}) - sourced_format = error_message.format( - {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"}) - self.assertEqual("foo.emb:3:4: error: Bad thing\n" - "foo.emb:3:4: note: Some explanation\n" - "foo.emb:3:4: note: More explanation", - "".join([x[1] for x in sourceless_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing\n"), # Message - (error.BOLD, "foo.emb:3:4: "), # Location, line 2 - (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2 - (error.WHITE, "Some explanation\n"), # Message, line 2 - (error.BOLD, "foo.emb:3:4: "), # Location, line 3 - (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3 - (error.WHITE, "More explanation"), # Message, line 3 - ], sourceless_format) - self.assertEqual("foo.emb:3:4: error: Bad thing\n" - "foo.emb:3:4: note: Some explanation\n" - "foo.emb:3:4: note: More explanation\n" - "abcdefghijklm\n" - " ^^", "".join([x[1] for x in sourced_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_RED, "error: "), # Severity - (error.BOLD, "Bad thing\n"), # Message - (error.BOLD, "foo.emb:3:4: "), # Location, line 2 - (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2 - (error.WHITE, "Some explanation\n"), # Message, line 2 - (error.BOLD, "foo.emb:3:4: "), # Location, line 3 - (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3 - (error.WHITE, "More explanation\n"), # Message, line 3 - (error.WHITE, "abcdefghijklm\n"), # Source snippet - (error.BRIGHT_GREEN, " ^^"), # Column indicator - ], sourced_format) + def test_multiline_error(self): + error_message = error.error( + "foo.emb", + parser_types.make_location((3, 4), (3, 6)), + "Bad thing\nSome explanation\nMore explanation", + ) + self.assertEqual("foo.emb", error_message.source_file) + self.assertEqual(error.ERROR, error_message.severity) + self.assertEqual( + parser_types.make_location((3, 4), (3, 6)), error_message.location + ) + self.assertEqual( + "Bad thing\nSome explanation\nMore explanation", error_message.message + ) + sourceless_format = error_message.format({}) + sourced_format = error_message.format( + {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"} + ) + self.assertEqual( + "foo.emb:3:4: error: Bad thing\n" + "foo.emb:3:4: note: Some explanation\n" + "foo.emb:3:4: note: More explanation", + "".join([x[1] for x in sourceless_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing\n"), # Message + (error.BOLD, "foo.emb:3:4: "), # Location, line 2 + (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2 + (error.WHITE, "Some explanation\n"), # Message, line 2 + (error.BOLD, "foo.emb:3:4: "), # Location, line 3 + (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3 + (error.WHITE, "More explanation"), # Message, line 3 + ], + sourceless_format, + ) + self.assertEqual( + "foo.emb:3:4: error: Bad thing\n" + "foo.emb:3:4: note: Some explanation\n" + "foo.emb:3:4: note: More explanation\n" + "abcdefghijklm\n" + " ^^", + "".join([x[1] for x in sourced_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_RED, "error: "), # Severity + (error.BOLD, "Bad thing\n"), # Message + (error.BOLD, "foo.emb:3:4: "), # Location, line 2 + (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2 + (error.WHITE, "Some explanation\n"), # Message, line 2 + (error.BOLD, "foo.emb:3:4: "), # Location, line 3 + (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3 + (error.WHITE, "More explanation\n"), # Message, line 3 + (error.WHITE, "abcdefghijklm\n"), # Source snippet + (error.BRIGHT_GREEN, " ^^"), # Column indicator + ], + sourced_format, + ) - def test_warn(self): - warning_message = error.warn("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "Not good thing") - self.assertEqual("foo.emb", warning_message.source_file) - self.assertEqual(error.WARNING, warning_message.severity) - self.assertEqual(parser_types.make_location((3, 4), (3, 6)), - warning_message.location) - self.assertEqual("Not good thing", warning_message.message) - sourced_format = warning_message.format({"foo.emb": "\n\nabcdefghijklm"}) - self.assertEqual("foo.emb:3:4: warning: Not good thing\n" - "abcdefghijklm\n" - " ^^", "".join([x[1] for x in sourced_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_MAGENTA, "warning: "), # Severity - (error.BOLD, "Not good thing\n"), # Message - (error.WHITE, "abcdefghijklm\n"), # Source snippet - (error.BRIGHT_GREEN, " ^^"), # Column indicator - ], sourced_format) + def test_warn(self): + warning_message = error.warn( + "foo.emb", parser_types.make_location((3, 4), (3, 6)), "Not good thing" + ) + self.assertEqual("foo.emb", warning_message.source_file) + self.assertEqual(error.WARNING, warning_message.severity) + self.assertEqual( + parser_types.make_location((3, 4), (3, 6)), warning_message.location + ) + self.assertEqual("Not good thing", warning_message.message) + sourced_format = warning_message.format({"foo.emb": "\n\nabcdefghijklm"}) + self.assertEqual( + "foo.emb:3:4: warning: Not good thing\n" "abcdefghijklm\n" " ^^", + "".join([x[1] for x in sourced_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_MAGENTA, "warning: "), # Severity + (error.BOLD, "Not good thing\n"), # Message + (error.WHITE, "abcdefghijklm\n"), # Source snippet + (error.BRIGHT_GREEN, " ^^"), # Column indicator + ], + sourced_format, + ) - def test_note(self): - note_message = error.note("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "OK thing") - self.assertEqual("foo.emb", note_message.source_file) - self.assertEqual(error.NOTE, note_message.severity) - self.assertEqual(parser_types.make_location((3, 4), (3, 6)), - note_message.location) - self.assertEqual("OK thing", note_message.message) - sourced_format = note_message.format({"foo.emb": "\n\nabcdefghijklm"}) - self.assertEqual("foo.emb:3:4: note: OK thing\n" - "abcdefghijklm\n" - " ^^", "".join([x[1] for x in sourced_format])) - self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location - (error.BRIGHT_BLACK, "note: "), # Severity - (error.WHITE, "OK thing\n"), # Message - (error.WHITE, "abcdefghijklm\n"), # Source snippet - (error.BRIGHT_GREEN, " ^^"), # Column indicator - ], sourced_format) + def test_note(self): + note_message = error.note( + "foo.emb", parser_types.make_location((3, 4), (3, 6)), "OK thing" + ) + self.assertEqual("foo.emb", note_message.source_file) + self.assertEqual(error.NOTE, note_message.severity) + self.assertEqual( + parser_types.make_location((3, 4), (3, 6)), note_message.location + ) + self.assertEqual("OK thing", note_message.message) + sourced_format = note_message.format({"foo.emb": "\n\nabcdefghijklm"}) + self.assertEqual( + "foo.emb:3:4: note: OK thing\n" "abcdefghijklm\n" " ^^", + "".join([x[1] for x in sourced_format]), + ) + self.assertEqual( + [ + (error.BOLD, "foo.emb:3:4: "), # Location + (error.BRIGHT_BLACK, "note: "), # Severity + (error.WHITE, "OK thing\n"), # Message + (error.WHITE, "abcdefghijklm\n"), # Source snippet + (error.BRIGHT_GREEN, " ^^"), # Column indicator + ], + sourced_format, + ) - def test_equality(self): - note_message = error.note("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "thing") - self.assertEqual(note_message, - error.note("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "thing")) - self.assertNotEqual(note_message, - error.warn("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "thing")) - self.assertNotEqual(note_message, - error.note("foo2.emb", parser_types.make_location( - (3, 4), (3, 6)), "thing")) - self.assertNotEqual(note_message, - error.note("foo.emb", parser_types.make_location( - (2, 4), (3, 6)), "thing")) - self.assertNotEqual(note_message, - error.note("foo.emb", parser_types.make_location( - (3, 4), (3, 6)), "thing2")) + def test_equality(self): + note_message = error.note( + "foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing" + ) + self.assertEqual( + note_message, + error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing"), + ) + self.assertNotEqual( + note_message, + error.warn("foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing"), + ) + self.assertNotEqual( + note_message, + error.note("foo2.emb", parser_types.make_location((3, 4), (3, 6)), "thing"), + ) + self.assertNotEqual( + note_message, + error.note("foo.emb", parser_types.make_location((2, 4), (3, 6)), "thing"), + ) + self.assertNotEqual( + note_message, + error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing2"), + ) class StringTest(unittest.TestCase): - """Tests for strings.""" + """Tests for strings.""" - # These strings are a fixed part of the API. + # These strings are a fixed part of the API. - def test_color_strings(self): - self.assertEqual("\033[0;30m", error.BLACK) - self.assertEqual("\033[0;31m", error.RED) - self.assertEqual("\033[0;32m", error.GREEN) - self.assertEqual("\033[0;33m", error.YELLOW) - self.assertEqual("\033[0;34m", error.BLUE) - self.assertEqual("\033[0;35m", error.MAGENTA) - self.assertEqual("\033[0;36m", error.CYAN) - self.assertEqual("\033[0;37m", error.WHITE) - self.assertEqual("\033[0;1;30m", error.BRIGHT_BLACK) - self.assertEqual("\033[0;1;31m", error.BRIGHT_RED) - self.assertEqual("\033[0;1;32m", error.BRIGHT_GREEN) - self.assertEqual("\033[0;1;33m", error.BRIGHT_YELLOW) - self.assertEqual("\033[0;1;34m", error.BRIGHT_BLUE) - self.assertEqual("\033[0;1;35m", error.BRIGHT_MAGENTA) - self.assertEqual("\033[0;1;36m", error.BRIGHT_CYAN) - self.assertEqual("\033[0;1;37m", error.BRIGHT_WHITE) - self.assertEqual("\033[0;1m", error.BOLD) - self.assertEqual("\033[0m", error.RESET) + def test_color_strings(self): + self.assertEqual("\033[0;30m", error.BLACK) + self.assertEqual("\033[0;31m", error.RED) + self.assertEqual("\033[0;32m", error.GREEN) + self.assertEqual("\033[0;33m", error.YELLOW) + self.assertEqual("\033[0;34m", error.BLUE) + self.assertEqual("\033[0;35m", error.MAGENTA) + self.assertEqual("\033[0;36m", error.CYAN) + self.assertEqual("\033[0;37m", error.WHITE) + self.assertEqual("\033[0;1;30m", error.BRIGHT_BLACK) + self.assertEqual("\033[0;1;31m", error.BRIGHT_RED) + self.assertEqual("\033[0;1;32m", error.BRIGHT_GREEN) + self.assertEqual("\033[0;1;33m", error.BRIGHT_YELLOW) + self.assertEqual("\033[0;1;34m", error.BRIGHT_BLUE) + self.assertEqual("\033[0;1;35m", error.BRIGHT_MAGENTA) + self.assertEqual("\033[0;1;36m", error.BRIGHT_CYAN) + self.assertEqual("\033[0;1;37m", error.BRIGHT_WHITE) + self.assertEqual("\033[0;1m", error.BOLD) + self.assertEqual("\033[0m", error.RESET) - def test_error_strings(self): - self.assertEqual("error", error.ERROR) - self.assertEqual("warning", error.WARNING) - self.assertEqual("note", error.NOTE) + def test_error_strings(self): + self.assertEqual("error", error.ERROR) + self.assertEqual("warning", error.WARNING) + self.assertEqual("note", error.NOTE) class SplitErrorsTest(unittest.TestCase): - def test_split_errors(self): - user_error = [ - error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)), - "Bad thing"), - error.note("foo.emb", parser_types.make_location((3, 4), (5, 6)), - "Note: bad thing referrent") - ] - user_error_2 = [ - error.error("foo.emb", parser_types.make_location((8, 9), (10, 11)), - "Bad thing"), - error.note("foo.emb", parser_types.make_location((10, 11), (12, 13)), - "Note: bad thing referrent") - ] - synthetic_error = [ - error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)), - "Bad thing"), - error.note("foo.emb", parser_types.make_location((3, 4), (5, 6), True), - "Note: bad thing referrent") - ] - synthetic_error_2 = [ - error.error("foo.emb", - parser_types.make_location((8, 9), (10, 11), True), - "Bad thing"), - error.note("foo.emb", parser_types.make_location((10, 11), (12, 13)), - "Note: bad thing referrent") - ] - user_errors, synthetic_errors = error.split_errors( - [user_error, synthetic_error]) - self.assertEqual([user_error], user_errors) - self.assertEqual([synthetic_error], synthetic_errors) - user_errors, synthetic_errors = error.split_errors( - [synthetic_error, user_error]) - self.assertEqual([user_error], user_errors) - self.assertEqual([synthetic_error], synthetic_errors) - user_errors, synthetic_errors = error.split_errors( - [synthetic_error, user_error, synthetic_error_2, user_error_2]) - self.assertEqual([user_error, user_error_2], user_errors) - self.assertEqual([synthetic_error, synthetic_error_2], synthetic_errors) + def test_split_errors(self): + user_error = [ + error.error( + "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing" + ), + error.note( + "foo.emb", + parser_types.make_location((3, 4), (5, 6)), + "Note: bad thing referrent", + ), + ] + user_error_2 = [ + error.error( + "foo.emb", parser_types.make_location((8, 9), (10, 11)), "Bad thing" + ), + error.note( + "foo.emb", + parser_types.make_location((10, 11), (12, 13)), + "Note: bad thing referrent", + ), + ] + synthetic_error = [ + error.error( + "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing" + ), + error.note( + "foo.emb", + parser_types.make_location((3, 4), (5, 6), True), + "Note: bad thing referrent", + ), + ] + synthetic_error_2 = [ + error.error( + "foo.emb", + parser_types.make_location((8, 9), (10, 11), True), + "Bad thing", + ), + error.note( + "foo.emb", + parser_types.make_location((10, 11), (12, 13)), + "Note: bad thing referrent", + ), + ] + user_errors, synthetic_errors = error.split_errors( + [user_error, synthetic_error] + ) + self.assertEqual([user_error], user_errors) + self.assertEqual([synthetic_error], synthetic_errors) + user_errors, synthetic_errors = error.split_errors( + [synthetic_error, user_error] + ) + self.assertEqual([user_error], user_errors) + self.assertEqual([synthetic_error], synthetic_errors) + user_errors, synthetic_errors = error.split_errors( + [synthetic_error, user_error, synthetic_error_2, user_error_2] + ) + self.assertEqual([user_error, user_error_2], user_errors) + self.assertEqual([synthetic_error, synthetic_error_2], synthetic_errors) - def test_filter_errors(self): - user_error = [ - error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)), - "Bad thing"), - error.note("foo.emb", parser_types.make_location((3, 4), (5, 6)), - "Note: bad thing referrent") - ] - synthetic_error = [ - error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)), - "Bad thing"), - error.note("foo.emb", parser_types.make_location((3, 4), (5, 6), True), - "Note: bad thing referrent") - ] - synthetic_error_2 = [ - error.error("foo.emb", - parser_types.make_location((8, 9), (10, 11), True), - "Bad thing"), - error.note("foo.emb", parser_types.make_location((10, 11), (12, 13)), - "Note: bad thing referrent") - ] - self.assertEqual( - [user_error], - error.filter_errors([synthetic_error, user_error, synthetic_error_2])) + def test_filter_errors(self): + user_error = [ + error.error( + "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing" + ), + error.note( + "foo.emb", + parser_types.make_location((3, 4), (5, 6)), + "Note: bad thing referrent", + ), + ] + synthetic_error = [ + error.error( + "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing" + ), + error.note( + "foo.emb", + parser_types.make_location((3, 4), (5, 6), True), + "Note: bad thing referrent", + ), + ] + synthetic_error_2 = [ + error.error( + "foo.emb", + parser_types.make_location((8, 9), (10, 11), True), + "Bad thing", + ), + error.note( + "foo.emb", + parser_types.make_location((10, 11), (12, 13)), + "Note: bad thing referrent", + ), + ] + self.assertEqual( + [user_error], + error.filter_errors([synthetic_error, user_error, synthetic_error_2]), + ) class FormatErrorsTest(unittest.TestCase): - def test_format_errors(self): - errors = [[error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)), - "note")]] - sources = {"foo.emb": "x\ny\nz bcd\nq\n"} - self.assertEqual("foo.emb:3:4: note: note\n" - "z bcd\n" - " ^^", error.format_errors(errors, sources)) - bold = error.BOLD - reset = error.RESET - white = error.WHITE - bright_black = error.BRIGHT_BLACK - bright_green = error.BRIGHT_GREEN - self.assertEqual(bold + "foo.emb:3:4: " + reset + bright_black + "note: " + - reset + white + "note\n" + - reset + white + "z bcd\n" + - reset + bright_green + " ^^" + reset, - error.format_errors(errors, sources, use_color=True)) + def test_format_errors(self): + errors = [ + [error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)), "note")] + ] + sources = {"foo.emb": "x\ny\nz bcd\nq\n"} + self.assertEqual( + "foo.emb:3:4: note: note\n" "z bcd\n" " ^^", + error.format_errors(errors, sources), + ) + bold = error.BOLD + reset = error.RESET + white = error.WHITE + bright_black = error.BRIGHT_BLACK + bright_green = error.BRIGHT_GREEN + self.assertEqual( + bold + + "foo.emb:3:4: " + + reset + + bright_black + + "note: " + + reset + + white + + "note\n" + + reset + + white + + "z bcd\n" + + reset + + bright_green + + " ^^" + + reset, + error.format_errors(errors, sources, use_color=True), + ) + if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/expression_parser.py b/compiler/util/expression_parser.py index 708f23b..3981548 100644 --- a/compiler/util/expression_parser.py +++ b/compiler/util/expression_parser.py @@ -20,28 +20,28 @@ def parse(text): - """Parses text as an Expression. - - This parses text using the expression subset of the Emboss grammar, and - returns an ir_data.Expression. The expression only needs to be syntactically - valid; it will not go through symbol resolution or type checking. This - function is not intended to be called on arbitrary input; it asserts that the - text successfully parses, but does not return errors. - - Arguments: - text: The text of an Emboss expression, like "4 + 5" or "$max(1, a, b)". - - Returns: - An ir_data.Expression corresponding to the textual form. - - Raises: - AssertionError if text is not a well-formed Emboss expression, and - assertions are enabled. - """ - tokens, errors = tokenizer.tokenize(text, "") - assert not errors, "{!r}".format(errors) - # tokenizer.tokenize always inserts a newline token at the end, which breaks - # expression parsing. - parse_result = parser.parse_expression(tokens[:-1]) - assert not parse_result.error, "{!r}".format(parse_result.error) - return module_ir.build_ir(parse_result.parse_tree) + """Parses text as an Expression. + + This parses text using the expression subset of the Emboss grammar, and + returns an ir_data.Expression. The expression only needs to be syntactically + valid; it will not go through symbol resolution or type checking. This + function is not intended to be called on arbitrary input; it asserts that the + text successfully parses, but does not return errors. + + Arguments: + text: The text of an Emboss expression, like "4 + 5" or "$max(1, a, b)". + + Returns: + An ir_data.Expression corresponding to the textual form. + + Raises: + AssertionError if text is not a well-formed Emboss expression, and + assertions are enabled. + """ + tokens, errors = tokenizer.tokenize(text, "") + assert not errors, "{!r}".format(errors) + # tokenizer.tokenize always inserts a newline token at the end, which breaks + # expression parsing. + parse_result = parser.parse_expression(tokens[:-1]) + assert not parse_result.error, "{!r}".format(parse_result.error) + return module_ir.build_ir(parse_result.parse_tree) diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py index 624b09a..af8c2f7 100644 --- a/compiler/util/ir_data.py +++ b/compiler/util/ir_data.py @@ -27,79 +27,95 @@ @dataclasses.dataclass class Message: - """Base class for IR data objects. + """Base class for IR data objects. - Historically protocol buffers were used for serializing this data which has - led to some legacy naming conventions and references. In particular this - class is named `Message` in the sense of a protocol buffer message, - indicating that it is intended to just be data that is used by other higher - level services. + Historically protocol buffers were used for serializing this data which has + led to some legacy naming conventions and references. In particular this + class is named `Message` in the sense of a protocol buffer message, + indicating that it is intended to just be data that is used by other higher + level services. - There are some other legacy idioms leftover from the protocol buffer-based - definition such as support for "oneof" and optional fields. - """ - - IR_DATACLASS: ClassVar[object] = object() - field_specs: ClassVar[ir_data_fields.FilteredIrFieldSpecs] - - def __post_init__(self): - """Called by dataclass subclasses after init. - - Post-processes any lists passed in to use our custom list type. + There are some other legacy idioms leftover from the protocol buffer-based + definition such as support for "oneof" and optional fields. """ - # Convert any lists passed in to CopyValuesList - for spec in self.field_specs.sequence_field_specs: - cur_val = getattr(self, spec.name) - if isinstance(cur_val, ir_data_fields.TemporaryCopyValuesList): - copy_val = cur_val.temp_list - else: - copy_val = ir_data_fields.CopyValuesList(spec.data_type) - if cur_val: - copy_val.shallow_copy(cur_val) - setattr(self, spec.name, copy_val) - - # This hook adds a 15% overhead to end-to-end code generation in some cases - # so we guard it in a `__debug__` block. Users can opt-out of this check by - # running python with the `-O` flag, ie: `python3 -O ./embossc`. - if __debug__: - def __setattr__(self, name: str, value) -> None: - """Debug-only hook that adds basic type checking for ir_data fields.""" - if spec := self.field_specs.all_field_specs.get(name): - if not ( - # Check if it's the expected type - isinstance(value, spec.data_type) or - # Oneof fields are a special case - spec.is_oneof or - # Optional fields can be set to None - (spec.container is ir_data_fields.FieldContainer.OPTIONAL and - value is None) or - # Sequences can be a few variants of lists - (spec.is_sequence and - isinstance(value, ( - list, ir_data_fields.TemporaryCopyValuesList, - ir_data_fields.CopyValuesList))) or - # An enum value can be an int - (spec.is_enum and isinstance(value, int))): - raise AttributeError( - f"Cannot set {value} (type {value.__class__}) for type" - "{spec.data_type}") - object.__setattr__(self, name, value) - - # Non-PEP8 name to mimic the Google Protobuf interface. - def HasField(self, name): # pylint:disable=invalid-name - """Indicates if this class has the given field defined and it is set.""" - return getattr(self, name, None) is not None - - # Non-PEP8 name to mimic the Google Protobuf interface. - def WhichOneof(self, oneof_name): # pylint:disable=invalid-name - """Indicates which field has been set for the oneof value. - - Returns None if no field has been set. - """ - for field_name, oneof in self.field_specs.oneof_mappings: - if oneof == oneof_name and self.HasField(field_name): - return field_name - return None + + IR_DATACLASS: ClassVar[object] = object() + field_specs: ClassVar[ir_data_fields.FilteredIrFieldSpecs] + + def __post_init__(self): + """Called by dataclass subclasses after init. + + Post-processes any lists passed in to use our custom list type. + """ + # Convert any lists passed in to CopyValuesList + for spec in self.field_specs.sequence_field_specs: + cur_val = getattr(self, spec.name) + if isinstance(cur_val, ir_data_fields.TemporaryCopyValuesList): + copy_val = cur_val.temp_list + else: + copy_val = ir_data_fields.CopyValuesList(spec.data_type) + if cur_val: + copy_val.shallow_copy(cur_val) + setattr(self, spec.name, copy_val) + + # This hook adds a 15% overhead to end-to-end code generation in some cases + # so we guard it in a `__debug__` block. Users can opt-out of this check by + # running python with the `-O` flag, ie: `python3 -O ./embossc`. + if __debug__: + + def __setattr__(self, name: str, value) -> None: + """Debug-only hook that adds basic type checking for ir_data fields.""" + if spec := self.field_specs.all_field_specs.get(name): + if not ( + # Check if it's the expected type + isinstance(value, spec.data_type) + or + # Oneof fields are a special case + spec.is_oneof + or + # Optional fields can be set to None + ( + spec.container is ir_data_fields.FieldContainer.OPTIONAL + and value is None + ) + or + # Sequences can be a few variants of lists + ( + spec.is_sequence + and isinstance( + value, + ( + list, + ir_data_fields.TemporaryCopyValuesList, + ir_data_fields.CopyValuesList, + ), + ) + ) + or + # An enum value can be an int + (spec.is_enum and isinstance(value, int)) + ): + raise AttributeError( + f"Cannot set {value} (type {value.__class__}) for type" + "{spec.data_type}" + ) + object.__setattr__(self, name, value) + + # Non-PEP8 name to mimic the Google Protobuf interface. + def HasField(self, name): # pylint:disable=invalid-name + """Indicates if this class has the given field defined and it is set.""" + return getattr(self, name, None) is not None + + # Non-PEP8 name to mimic the Google Protobuf interface. + def WhichOneof(self, oneof_name): # pylint:disable=invalid-name + """Indicates which field has been set for the oneof value. + + Returns None if no field has been set. + """ + for field_name, oneof in self.field_specs.oneof_mappings: + if oneof == oneof_name and self.HasField(field_name): + return field_name + return None ################################################################################ @@ -108,28 +124,28 @@ def WhichOneof(self, oneof_name): # pylint:disable=invalid-name @dataclasses.dataclass class Position(Message): - """A zero-width position within a source file.""" + """A zero-width position within a source file.""" - line: int = 0 - """Line (starts from 1).""" - column: int = 0 - """Column (starts from 1).""" + line: int = 0 + """Line (starts from 1).""" + column: int = 0 + """Column (starts from 1).""" @dataclasses.dataclass class Location(Message): - """A half-open start:end range within a source file.""" + """A half-open start:end range within a source file.""" - start: Optional[Position] = None - """Beginning of the range""" - end: Optional[Position] = None - """One column past the end of the range.""" + start: Optional[Position] = None + """Beginning of the range""" + end: Optional[Position] = None + """One column past the end of the range.""" - is_disjoint_from_parent: Optional[bool] = None - """True if this Location is outside of the parent object's Location.""" + is_disjoint_from_parent: Optional[bool] = None + """True if this Location is outside of the parent object's Location.""" - is_synthetic: Optional[bool] = None - """True if this Location's parent was synthesized, and does not directly + is_synthetic: Optional[bool] = None + """True if this Location's parent was synthesized, and does not directly appear in the source file. The Emboss front end uses this field to cull @@ -139,124 +155,124 @@ class Location(Message): @dataclasses.dataclass class Word(Message): - """IR for a bare word in the source file. + """IR for a bare word in the source file. - This is used in NameDefinitions and References. - """ + This is used in NameDefinitions and References. + """ - text: Optional[str] = None - source_location: Optional[Location] = None + text: Optional[str] = None + source_location: Optional[Location] = None @dataclasses.dataclass class String(Message): - """IR for a string in the source file.""" + """IR for a string in the source file.""" - text: Optional[str] = None - source_location: Optional[Location] = None + text: Optional[str] = None + source_location: Optional[Location] = None @dataclasses.dataclass class Documentation(Message): - text: Optional[str] = None - source_location: Optional[Location] = None + text: Optional[str] = None + source_location: Optional[Location] = None @dataclasses.dataclass class BooleanConstant(Message): - """IR for a boolean constant.""" + """IR for a boolean constant.""" - value: Optional[bool] = None - source_location: Optional[Location] = None + value: Optional[bool] = None + source_location: Optional[Location] = None @dataclasses.dataclass class Empty(Message): - """Placeholder message for automatic element counts for arrays.""" + """Placeholder message for automatic element counts for arrays.""" - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class NumericConstant(Message): - """IR for any numeric constant.""" + """IR for any numeric constant.""" - # Numeric constants are stored as decimal strings; this is the simplest way - # to store the full -2**63..+2**64 range. - # - # TODO(bolms): switch back to int, and just use strings during - # serialization, now that we're free of proto. - value: Optional[str] = None - source_location: Optional[Location] = None + # Numeric constants are stored as decimal strings; this is the simplest way + # to store the full -2**63..+2**64 range. + # + # TODO(bolms): switch back to int, and just use strings during + # serialization, now that we're free of proto. + value: Optional[str] = None + source_location: Optional[Location] = None class FunctionMapping(int, enum.Enum): - """Enum of supported function types""" - - UNKNOWN = 0 - ADDITION = 1 - """`+`""" - SUBTRACTION = 2 - """`-`""" - MULTIPLICATION = 3 - """`*`""" - EQUALITY = 4 - """`==`""" - INEQUALITY = 5 - """`!=`""" - AND = 6 - """`&&`""" - OR = 7 - """`||`""" - LESS = 8 - """`<`""" - LESS_OR_EQUAL = 9 - """`<=`""" - GREATER = 10 - """`>`""" - GREATER_OR_EQUAL = 11 - """`>=`""" - CHOICE = 12 - """`?:`""" - MAXIMUM = 13 - """`$max()`""" - PRESENCE = 14 - """`$present()`""" - UPPER_BOUND = 15 - """`$upper_bound()`""" - LOWER_BOUND = 16 - """`$lower_bound()`""" + """Enum of supported function types""" + + UNKNOWN = 0 + ADDITION = 1 + """`+`""" + SUBTRACTION = 2 + """`-`""" + MULTIPLICATION = 3 + """`*`""" + EQUALITY = 4 + """`==`""" + INEQUALITY = 5 + """`!=`""" + AND = 6 + """`&&`""" + OR = 7 + """`||`""" + LESS = 8 + """`<`""" + LESS_OR_EQUAL = 9 + """`<=`""" + GREATER = 10 + """`>`""" + GREATER_OR_EQUAL = 11 + """`>=`""" + CHOICE = 12 + """`?:`""" + MAXIMUM = 13 + """`$max()`""" + PRESENCE = 14 + """`$present()`""" + UPPER_BOUND = 15 + """`$upper_bound()`""" + LOWER_BOUND = 16 + """`$lower_bound()`""" @dataclasses.dataclass class Function(Message): - """IR for a single function (+, -, *, ==, $max, etc.) in an expression.""" + """IR for a single function (+, -, *, ==, $max, etc.) in an expression.""" - function: Optional[FunctionMapping] = None - args: list["Expression"] = ir_data_fields.list_field(lambda: Expression) - function_name: Optional[Word] = None - source_location: Optional[Location] = None + function: Optional[FunctionMapping] = None + args: list["Expression"] = ir_data_fields.list_field(lambda: Expression) + function_name: Optional[Word] = None + source_location: Optional[Location] = None @dataclasses.dataclass class CanonicalName(Message): - """CanonicalName is the unique, absolute name for some object. + """CanonicalName is the unique, absolute name for some object. - A CanonicalName is the unique, absolute name for some object (Type, field, - etc.) in the IR. It is used both in the definitions of objects ("struct - Foo"), and in references to objects (a field of type "Foo"). - """ + A CanonicalName is the unique, absolute name for some object (Type, field, + etc.) in the IR. It is used both in the definitions of objects ("struct + Foo"), and in references to objects (a field of type "Foo"). + """ - module_file: str = ir_data_fields.str_field() - """The module_file is the Module.source_file_name of the Module in which this + module_file: str = ir_data_fields.str_field() + """The module_file is the Module.source_file_name of the Module in which this object's definition appears. Note that the Prelude always has a Module.source_file_name of "", and thus references to Prelude names will have module_file == "". """ - object_path: list[str] = ir_data_fields.list_field(str) - """The object_path is the canonical path to the object definition within its + object_path: list[str] = ir_data_fields.list_field(str) + """The object_path is the canonical path to the object definition within its module file. For example, the field "bar" would have an object path of @@ -279,160 +295,160 @@ class CanonicalName(Message): @dataclasses.dataclass class NameDefinition(Message): - """NameDefinition is IR for the name of an object, within the object. + """NameDefinition is IR for the name of an object, within the object. - That is, a TypeDefinition or Field will hold a NameDefinition as its - name. - """ + That is, a TypeDefinition or Field will hold a NameDefinition as its + name. + """ - name: Optional[Word] = None - """The name, as directly generated from the source text. + name: Optional[Word] = None + """The name, as directly generated from the source text. name.text will match the last element of canonical_name.object_path. Note that in some cases, the exact string in name.text may not appear in the source text. """ - canonical_name: Optional[CanonicalName] = None - """The CanonicalName that will appear in References. + canonical_name: Optional[CanonicalName] = None + """The CanonicalName that will appear in References. This field is technically redundant: canonical_name.module_file should always match the source_file_name of the enclosing Module, and canonical_name.object_path should always match the names of parent nodes. """ - is_anonymous: Optional[bool] = None - """If true, indicates that this is an automatically-generated name, which + is_anonymous: Optional[bool] = None + """If true, indicates that this is an automatically-generated name, which should not be visible outside of its immediate namespace. """ - source_location: Optional[Location] = None - """The location of this NameDefinition in source code.""" + source_location: Optional[Location] = None + """The location of this NameDefinition in source code.""" @dataclasses.dataclass class Reference(Message): - """A Reference holds the canonical name of something defined elsewhere. + """A Reference holds the canonical name of something defined elsewhere. - For example, take this fragment: + For example, take this fragment: - struct Foo: - 0:3 UInt size (s) - 4:s Int:8[] payload + struct Foo: + 0:3 UInt size (s) + 4:s Int:8[] payload - "Foo", "size", and "payload" will become NameDefinitions in their - corresponding Field and Message IR objects, while "UInt", the second "s", - and "Int" are References. Note that the second "s" will have a - canonical_name.object_path of ["Foo", "size"], not ["Foo", "s"]: the - Reference always holds the single "true" name of the object, regardless of - what appears in the .emb. - """ + "Foo", "size", and "payload" will become NameDefinitions in their + corresponding Field and Message IR objects, while "UInt", the second "s", + and "Int" are References. Note that the second "s" will have a + canonical_name.object_path of ["Foo", "size"], not ["Foo", "s"]: the + Reference always holds the single "true" name of the object, regardless of + what appears in the .emb. + """ - canonical_name: Optional[CanonicalName] = None - """The canonical name of the object being referred to. + canonical_name: Optional[CanonicalName] = None + """The canonical name of the object being referred to. This name should be used to find the object in the IR. """ - source_name: list[Word] = ir_data_fields.list_field(Word) - """The source_name is the name the user entered in the source file. + source_name: list[Word] = ir_data_fields.list_field(Word) + """The source_name is the name the user entered in the source file. The source_name could be either relative or absolute, and may be an alias (and thus not match any part of the canonical_name). Back ends should use canonical_name for name lookup, and reserve source_name for error messages. """ - is_local_name: Optional[bool] = None - """If true, then symbol resolution should only look at local names when + is_local_name: Optional[bool] = None + """If true, then symbol resolution should only look at local names when resolving source_name. This is used so that the names of inline types aren't "ambiguous" if there happens to be another type with the same name at a parent scope. """ - # TODO(bolms): Allow absolute paths starting with ".". + # TODO(bolms): Allow absolute paths starting with ".". - source_location: Optional[Location] = None - """Note that this is the source_location of the *Reference*, not of the + source_location: Optional[Location] = None + """Note that this is the source_location of the *Reference*, not of the object to which it refers. """ @dataclasses.dataclass class FieldReference(Message): - """IR for a "field" or "field.sub.subsub" reference in an expression. - - The first element of "path" is the "base" field, which should be directly - readable in the (runtime) context of the expression. For example: - - struct Foo: - 0:1 UInt header_size (h) - 0:h UInt:8[] header_bytes - - The "h" will translate to ["Foo", "header_size"], which will be the first - (and in this case only) element of "path". - - Subsequent path elements should be treated as subfields. For example, in: - - struct Foo: - struct Sizes: - 0:1 UInt header_size - 1:2 UInt body_size - 0 [+2] Sizes sizes - 0 [+sizes.header_size] UInt:8[] header - sizes.header_size [+sizes.body_size] UInt:8[] body - - The references to "sizes.header_size" will have a path of [["Foo", - "sizes"], ["Foo", "Sizes", "header_size"]]. Note that each path element is - a fully-qualified reference; some back ends (C++, Python) may only use the - last element, while others (C) may use the complete path. - - This representation is a bit awkward, and is fundamentally limited to a - dotted list of static field names. It does not allow an expression like - `array[n]` on the left side of a `.`. At this point, it is an artifact of - the era during which I (bolms@) thought I could get away with skipping - compiler-y things. - """ + """IR for a "field" or "field.sub.subsub" reference in an expression. + + The first element of "path" is the "base" field, which should be directly + readable in the (runtime) context of the expression. For example: - # TODO(bolms): Add composite types to the expression type system, and - # replace FieldReference with a "member access" Expression kind. Further, - # move the symbol resolution for FieldReferences that is currently in - # symbol_resolver.py into type_check.py. + struct Foo: + 0:1 UInt header_size (h) + 0:h UInt:8[] header_bytes + + The "h" will translate to ["Foo", "header_size"], which will be the first + (and in this case only) element of "path". + + Subsequent path elements should be treated as subfields. For example, in: + + struct Foo: + struct Sizes: + 0:1 UInt header_size + 1:2 UInt body_size + 0 [+2] Sizes sizes + 0 [+sizes.header_size] UInt:8[] header + sizes.header_size [+sizes.body_size] UInt:8[] body + + The references to "sizes.header_size" will have a path of [["Foo", + "sizes"], ["Foo", "Sizes", "header_size"]]. Note that each path element is + a fully-qualified reference; some back ends (C++, Python) may only use the + last element, while others (C) may use the complete path. + + This representation is a bit awkward, and is fundamentally limited to a + dotted list of static field names. It does not allow an expression like + `array[n]` on the left side of a `.`. At this point, it is an artifact of + the era during which I (bolms@) thought I could get away with skipping + compiler-y things. + """ - # TODO(bolms): Make the above change before declaring the IR to be "stable". + # TODO(bolms): Add composite types to the expression type system, and + # replace FieldReference with a "member access" Expression kind. Further, + # move the symbol resolution for FieldReferences that is currently in + # symbol_resolver.py into type_check.py. - path: list[Reference] = ir_data_fields.list_field(Reference) - source_location: Optional[Location] = None + # TODO(bolms): Make the above change before declaring the IR to be "stable". + + path: list[Reference] = ir_data_fields.list_field(Reference) + source_location: Optional[Location] = None @dataclasses.dataclass class OpaqueType(Message): - pass + pass @dataclasses.dataclass class IntegerType(Message): - """Type of an integer expression.""" - - # For optimization, the modular congruence of an integer expression is - # tracked. This consists of a modulus and a modular_value, such that for - # all possible values of expression, expression MOD modulus == - # modular_value. - # - # The modulus may be the special value "infinity" to indicate that the - # expression's value is exactly modular_value; otherwise, it should be a - # positive integer. - # - # A modulus of 1 places no constraints on the value. - # - # The modular_value should always be a nonnegative integer that is smaller - # than the modulus. - # - # Note that this is specifically the *modulus*, which is not equivalent to - # the value from C's '%' operator when the dividend is negative: in C, -7 % - # 4 == -3, but the modular_value here would be 1. Python uses modulus: in - # Python, -7 % 4 == 1. - modulus: Optional[str] = None - """The modulus portion of the modular congruence of an integer expression. + """Type of an integer expression.""" + + # For optimization, the modular congruence of an integer expression is + # tracked. This consists of a modulus and a modular_value, such that for + # all possible values of expression, expression MOD modulus == + # modular_value. + # + # The modulus may be the special value "infinity" to indicate that the + # expression's value is exactly modular_value; otherwise, it should be a + # positive integer. + # + # A modulus of 1 places no constraints on the value. + # + # The modular_value should always be a nonnegative integer that is smaller + # than the modulus. + # + # Note that this is specifically the *modulus*, which is not equivalent to + # the value from C's '%' operator when the dividend is negative: in C, -7 % + # 4 == -3, but the modular_value here would be 1. Python uses modulus: in + # Python, -7 % 4 == 1. + modulus: Optional[str] = None + """The modulus portion of the modular congruence of an integer expression. The modulus may be the special value "infinity" to indicate that the expression's value is exactly modular_value; otherwise, it should be a @@ -440,171 +456,165 @@ class IntegerType(Message): A modulus of 1 places no constraints on the value. """ - modular_value: Optional[str] = None - """ The modular_value portion of the modular congruence of an integer expression. + modular_value: Optional[str] = None + """ The modular_value portion of the modular congruence of an integer expression. The modular_value should always be a nonnegative integer that is smaller than the modulus. """ - # The minimum and maximum values of an integer are tracked and checked so - # that Emboss can implement reliable arithmetic with no operations - # overflowing either 64-bit unsigned or 64-bit signed 2's-complement - # integers. - # - # Note that constant subexpressions are allowed to overflow, as long as the - # final, computed constant value of the subexpression fits in a 64-bit - # value. - # - # The minimum_value may take the value "-infinity", and the maximum_value - # may take the value "infinity". These sentinel values indicate that - # Emboss has no bound information for the Expression, and therefore the - # Expression may only be evaluated during compilation; the back end should - # never need to compile such an expression into the target language (e.g., - # C++). - minimum_value: Optional[str] = None - maximum_value: Optional[str] = None + # The minimum and maximum values of an integer are tracked and checked so + # that Emboss can implement reliable arithmetic with no operations + # overflowing either 64-bit unsigned or 64-bit signed 2's-complement + # integers. + # + # Note that constant subexpressions are allowed to overflow, as long as the + # final, computed constant value of the subexpression fits in a 64-bit + # value. + # + # The minimum_value may take the value "-infinity", and the maximum_value + # may take the value "infinity". These sentinel values indicate that + # Emboss has no bound information for the Expression, and therefore the + # Expression may only be evaluated during compilation; the back end should + # never need to compile such an expression into the target language (e.g., + # C++). + minimum_value: Optional[str] = None + maximum_value: Optional[str] = None @dataclasses.dataclass class BooleanType(Message): - value: Optional[bool] = None + value: Optional[bool] = None @dataclasses.dataclass class EnumType(Message): - name: Optional[Reference] = None - value: Optional[str] = None + name: Optional[Reference] = None + value: Optional[str] = None @dataclasses.dataclass class ExpressionType(Message): - opaque: Optional[OpaqueType] = ir_data_fields.oneof_field("type") - integer: Optional[IntegerType] = ir_data_fields.oneof_field("type") - boolean: Optional[BooleanType] = ir_data_fields.oneof_field("type") - enumeration: Optional[EnumType] = ir_data_fields.oneof_field("type") + opaque: Optional[OpaqueType] = ir_data_fields.oneof_field("type") + integer: Optional[IntegerType] = ir_data_fields.oneof_field("type") + boolean: Optional[BooleanType] = ir_data_fields.oneof_field("type") + enumeration: Optional[EnumType] = ir_data_fields.oneof_field("type") @dataclasses.dataclass class Expression(Message): - """IR for an expression. + """IR for an expression. - An Expression is a potentially-recursive data structure. It can either - represent a leaf node (constant or reference) or an operation combining - other Expressions (function). - """ + An Expression is a potentially-recursive data structure. It can either + represent a leaf node (constant or reference) or an operation combining + other Expressions (function). + """ - constant: Optional[NumericConstant] = ir_data_fields.oneof_field("expression") - constant_reference: Optional[Reference] = ir_data_fields.oneof_field( - "expression" - ) - function: Optional[Function] = ir_data_fields.oneof_field("expression") - field_reference: Optional[FieldReference] = ir_data_fields.oneof_field( - "expression" - ) - boolean_constant: Optional[BooleanConstant] = ir_data_fields.oneof_field( - "expression" - ) - builtin_reference: Optional[Reference] = ir_data_fields.oneof_field( - "expression" - ) + constant: Optional[NumericConstant] = ir_data_fields.oneof_field("expression") + constant_reference: Optional[Reference] = ir_data_fields.oneof_field("expression") + function: Optional[Function] = ir_data_fields.oneof_field("expression") + field_reference: Optional[FieldReference] = ir_data_fields.oneof_field("expression") + boolean_constant: Optional[BooleanConstant] = ir_data_fields.oneof_field( + "expression" + ) + builtin_reference: Optional[Reference] = ir_data_fields.oneof_field("expression") - type: Optional[ExpressionType] = None - source_location: Optional[Location] = None + type: Optional[ExpressionType] = None + source_location: Optional[Location] = None @dataclasses.dataclass class ArrayType(Message): - """IR for an array type ("Int:8[12]" or "Message[2]" or "UInt[3][2]").""" + """IR for an array type ("Int:8[12]" or "Message[2]" or "UInt[3][2]").""" - base_type: Optional["Type"] = None + base_type: Optional["Type"] = None - element_count: Optional[Expression] = ir_data_fields.oneof_field("size") - automatic: Optional[Empty] = ir_data_fields.oneof_field("size") + element_count: Optional[Expression] = ir_data_fields.oneof_field("size") + automatic: Optional[Empty] = ir_data_fields.oneof_field("size") - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class AtomicType(Message): - """IR for a non-array type ("UInt" or "Foo(Version.SIX)").""" + """IR for a non-array type ("UInt" or "Foo(Version.SIX)").""" - reference: Optional[Reference] = None - runtime_parameter: list[Expression] = ir_data_fields.list_field(Expression) - source_location: Optional[Location] = None + reference: Optional[Reference] = None + runtime_parameter: list[Expression] = ir_data_fields.list_field(Expression) + source_location: Optional[Location] = None @dataclasses.dataclass class Type(Message): - """IR for a type reference ("UInt", "Int:8[12]", etc.).""" + """IR for a type reference ("UInt", "Int:8[12]", etc.).""" - atomic_type: Optional[AtomicType] = ir_data_fields.oneof_field("type") - array_type: Optional[ArrayType] = ir_data_fields.oneof_field("type") + atomic_type: Optional[AtomicType] = ir_data_fields.oneof_field("type") + array_type: Optional[ArrayType] = ir_data_fields.oneof_field("type") - size_in_bits: Optional[Expression] = None - source_location: Optional[Location] = None + size_in_bits: Optional[Expression] = None + source_location: Optional[Location] = None @dataclasses.dataclass class AttributeValue(Message): - """IR for a attribute value.""" + """IR for a attribute value.""" - # TODO(bolms): Make String a type of Expression, and replace - # AttributeValue with Expression. - expression: Optional[Expression] = ir_data_fields.oneof_field("value") - string_constant: Optional[String] = ir_data_fields.oneof_field("value") + # TODO(bolms): Make String a type of Expression, and replace + # AttributeValue with Expression. + expression: Optional[Expression] = ir_data_fields.oneof_field("value") + string_constant: Optional[String] = ir_data_fields.oneof_field("value") - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class Attribute(Message): - """IR for a [name = value] attribute.""" + """IR for a [name = value] attribute.""" - name: Optional[Word] = None - value: Optional[AttributeValue] = None - back_end: Optional[Word] = None - is_default: Optional[bool] = None - source_location: Optional[Location] = None + name: Optional[Word] = None + value: Optional[AttributeValue] = None + back_end: Optional[Word] = None + is_default: Optional[bool] = None + source_location: Optional[Location] = None @dataclasses.dataclass class WriteTransform(Message): - """IR which defines an expression-based virtual field write scheme. + """IR which defines an expression-based virtual field write scheme. - E.g., for a virtual field like `x_plus_one`: + E.g., for a virtual field like `x_plus_one`: - struct Foo: - 0 [+1] UInt x - let x_plus_one = x + 1 + struct Foo: + 0 [+1] UInt x + let x_plus_one = x + 1 - ... the `WriteMethod` would be `transform`, with `$logical_value - 1` for - `function_body` and `x` for `destination`. - """ + ... the `WriteMethod` would be `transform`, with `$logical_value - 1` for + `function_body` and `x` for `destination`. + """ - function_body: Optional[Expression] = None - destination: Optional[FieldReference] = None + function_body: Optional[Expression] = None + destination: Optional[FieldReference] = None @dataclasses.dataclass class WriteMethod(Message): - """IR which defines the method used for writing to a virtual field.""" + """IR which defines the method used for writing to a virtual field.""" - physical: Optional[bool] = ir_data_fields.oneof_field("method") - """A physical Field can be written directly.""" + physical: Optional[bool] = ir_data_fields.oneof_field("method") + """A physical Field can be written directly.""" - read_only: Optional[bool] = ir_data_fields.oneof_field("method") - """A read_only Field cannot be written.""" + read_only: Optional[bool] = ir_data_fields.oneof_field("method") + """A read_only Field cannot be written.""" - alias: Optional[FieldReference] = ir_data_fields.oneof_field("method") - """An alias is a direct, untransformed forward of another field; it can be + alias: Optional[FieldReference] = ir_data_fields.oneof_field("method") + """An alias is a direct, untransformed forward of another field; it can be implemented by directly returning a reference to the aliased field. Aliases are the only kind of virtual field that may have an opaque type. """ - transform: Optional[WriteTransform] = ir_data_fields.oneof_field("method") - """A transform is a way of turning a logical value into a value which should + transform: Optional[WriteTransform] = ir_data_fields.oneof_field("method") + """A transform is a way of turning a logical value into a value which should be written to another field. A virtual field like `let y = x + 1` would @@ -615,47 +625,47 @@ class WriteMethod(Message): @dataclasses.dataclass class FieldLocation(Message): - """IR for a field location.""" + """IR for a field location.""" - start: Optional[Expression] = None - size: Optional[Expression] = None - source_location: Optional[Location] = None + start: Optional[Expression] = None + size: Optional[Expression] = None + source_location: Optional[Location] = None @dataclasses.dataclass class Field(Message): # pylint:disable=too-many-instance-attributes - """IR for a field in a struct definition. + """IR for a field in a struct definition. - There are two kinds of Field: physical fields have location and (physical) - type; virtual fields have read_transform. Although there are differences, - in many situations physical and virtual fields are treated the same way, - and they can be freely intermingled in the source file. - """ + There are two kinds of Field: physical fields have location and (physical) + type; virtual fields have read_transform. Although there are differences, + in many situations physical and virtual fields are treated the same way, + and they can be freely intermingled in the source file. + """ - location: Optional[FieldLocation] = None - """The physical location of the field.""" - type: Optional[Type] = None - """The physical type of the field.""" + location: Optional[FieldLocation] = None + """The physical location of the field.""" + type: Optional[Type] = None + """The physical type of the field.""" - read_transform: Optional[Expression] = None - """The value of a virtual field.""" + read_transform: Optional[Expression] = None + """The value of a virtual field.""" - write_method: Optional[WriteMethod] = None - """How this virtual field should be written.""" + write_method: Optional[WriteMethod] = None + """How this virtual field should be written.""" - name: Optional[NameDefinition] = None - """The name of the field.""" - abbreviation: Optional[Word] = None - """An optional short name for the field, only visible inside the enclosing bits/struct.""" - attribute: list[Attribute] = ir_data_fields.list_field(Attribute) - """Field-specific attributes.""" - documentation: list[Documentation] = ir_data_fields.list_field(Documentation) - """Field-specific documentation.""" + name: Optional[NameDefinition] = None + """The name of the field.""" + abbreviation: Optional[Word] = None + """An optional short name for the field, only visible inside the enclosing bits/struct.""" + attribute: list[Attribute] = ir_data_fields.list_field(Attribute) + """Field-specific attributes.""" + documentation: list[Documentation] = ir_data_fields.list_field(Documentation) + """Field-specific documentation.""" - # TODO(bolms): Document conditional fields better, and replace some of this - # explanation with a reference to the documentation. - existence_condition: Optional[Expression] = None - """The field only exists when existence_condition evaluates to true. + # TODO(bolms): Document conditional fields better, and replace some of this + # explanation with a reference to the documentation. + existence_condition: Optional[Expression] = None + """The field only exists when existence_condition evaluates to true. For example: ``` @@ -690,17 +700,17 @@ class Field(Message): # pylint:disable=too-many-instance-attributes `bar`: those fields only conditionally exist in the structure. """ - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class Structure(Message): - """IR for a bits or struct definition.""" + """IR for a bits or struct definition.""" - field: list[Field] = ir_data_fields.list_field(Field) + field: list[Field] = ir_data_fields.list_field(Field) - fields_in_dependency_order: list[int] = ir_data_fields.list_field(int) - """The fields in `field` are listed in the order they appear in the original + fields_in_dependency_order: list[int] = ir_data_fields.list_field(int) + """The fields in `field` are listed in the order they appear in the original .emb. For text format output, this can lead to poor results. Take the following @@ -734,66 +744,66 @@ class Structure(Message): be `{ 0, 1, 2, 3, ... }`. """ - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class External(Message): - """IR for an external type declaration.""" + """IR for an external type declaration.""" - # Externals have no values other than name and attribute list, which are - # common to all type definitions. + # Externals have no values other than name and attribute list, which are + # common to all type definitions. - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class EnumValue(Message): - """IR for a single value within an enumerated type.""" + """IR for a single value within an enumerated type.""" - name: Optional[NameDefinition] = None - """The name of the enum value.""" - value: Optional[Expression] = None - """The value of the enum value.""" - documentation: list[Documentation] = ir_data_fields.list_field(Documentation) - """Value-specific documentation.""" - attribute: list[Attribute] = ir_data_fields.list_field(Attribute) - """Value-specific attributes.""" + name: Optional[NameDefinition] = None + """The name of the enum value.""" + value: Optional[Expression] = None + """The value of the enum value.""" + documentation: list[Documentation] = ir_data_fields.list_field(Documentation) + """Value-specific documentation.""" + attribute: list[Attribute] = ir_data_fields.list_field(Attribute) + """Value-specific attributes.""" - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class Enum(Message): - """IR for an enumerated type definition.""" + """IR for an enumerated type definition.""" - value: list[EnumValue] = ir_data_fields.list_field(EnumValue) - source_location: Optional[Location] = None + value: list[EnumValue] = ir_data_fields.list_field(EnumValue) + source_location: Optional[Location] = None @dataclasses.dataclass class Import(Message): - """IR for an import statement in a module.""" + """IR for an import statement in a module.""" - file_name: Optional[String] = None - """The file to import.""" - local_name: Optional[Word] = None - """The name to use within this module.""" - source_location: Optional[Location] = None + file_name: Optional[String] = None + """The file to import.""" + local_name: Optional[Word] = None + """The name to use within this module.""" + source_location: Optional[Location] = None @dataclasses.dataclass class RuntimeParameter(Message): - """IR for a runtime parameter definition.""" + """IR for a runtime parameter definition.""" - name: Optional[NameDefinition] = None - """The name of the parameter.""" - type: Optional[ExpressionType] = None - """The type of the parameter.""" + name: Optional[NameDefinition] = None + """The name of the parameter.""" + type: Optional[ExpressionType] = None + """The type of the parameter.""" - # TODO(bolms): Actually implement the set builder type notation. - physical_type_alias: Optional[Type] = None - """For convenience and readability, physical types may be used in the .emb + # TODO(bolms): Actually implement the set builder type notation. + physical_type_alias: Optional[Type] = None + """For convenience and readability, physical types may be used in the .emb source instead of a full expression type. That way, users can write @@ -809,80 +819,78 @@ class RuntimeParameter(Message): is filled in after initial parsing is finished. """ - source_location: Optional[Location] = None + source_location: Optional[Location] = None class AddressableUnit(int, enum.Enum): - """The "addressable unit" is the size of the smallest unit that can be read + """The "addressable unit" is the size of the smallest unit that can be read - from the backing store that this type expects. For `struct`s, this is - BYTE; for `enum`s and `bits`, this is BIT, and for `external`s it depends - on the specific type - """ + from the backing store that this type expects. For `struct`s, this is + BYTE; for `enum`s and `bits`, this is BIT, and for `external`s it depends + on the specific type + """ - NONE = 0 - BIT = 1 - BYTE = 8 + NONE = 0 + BIT = 1 + BYTE = 8 @dataclasses.dataclass class TypeDefinition(Message): - """Container IR for a type definition (struct, union, etc.)""" - - external: Optional[External] = ir_data_fields.oneof_field("type") - enumeration: Optional[Enum] = ir_data_fields.oneof_field("type") - structure: Optional[Structure] = ir_data_fields.oneof_field("type") - - name: Optional[NameDefinition] = None - """The name of the type.""" - attribute: list[Attribute] = ir_data_fields.list_field(Attribute) - """All attributes attached to the type.""" - documentation: list[Documentation] = ir_data_fields.list_field(Documentation) - """Docs for the type.""" - # pylint:disable=undefined-variable - subtype: list["TypeDefinition"] = ir_data_fields.list_field( - lambda: TypeDefinition - ) - """Subtypes of this type.""" - addressable_unit: Optional[AddressableUnit] = None - - runtime_parameter: list[RuntimeParameter] = ir_data_fields.list_field( - RuntimeParameter - ) - """If the type requires parameters at runtime, these are its parameters. + """Container IR for a type definition (struct, union, etc.)""" + + external: Optional[External] = ir_data_fields.oneof_field("type") + enumeration: Optional[Enum] = ir_data_fields.oneof_field("type") + structure: Optional[Structure] = ir_data_fields.oneof_field("type") + + name: Optional[NameDefinition] = None + """The name of the type.""" + attribute: list[Attribute] = ir_data_fields.list_field(Attribute) + """All attributes attached to the type.""" + documentation: list[Documentation] = ir_data_fields.list_field(Documentation) + """Docs for the type.""" + # pylint:disable=undefined-variable + subtype: list["TypeDefinition"] = ir_data_fields.list_field(lambda: TypeDefinition) + """Subtypes of this type.""" + addressable_unit: Optional[AddressableUnit] = None + + runtime_parameter: list[RuntimeParameter] = ir_data_fields.list_field( + RuntimeParameter + ) + """If the type requires parameters at runtime, these are its parameters. These are currently only allowed on structures, but in the future they should be allowed on externals. """ - source_location: Optional[Location] = None + source_location: Optional[Location] = None @dataclasses.dataclass class Module(Message): - """The IR for an individual Emboss module (file).""" + """The IR for an individual Emboss module (file).""" - attribute: list[Attribute] = ir_data_fields.list_field(Attribute) - """Module-level attributes.""" - type: list[TypeDefinition] = ir_data_fields.list_field(TypeDefinition) - """Module-level type definitions.""" - documentation: list[Documentation] = ir_data_fields.list_field(Documentation) - """Module-level docs.""" - foreign_import: list[Import] = ir_data_fields.list_field(Import) - """Other modules imported.""" - source_text: Optional[str] = None - """The original source code.""" - source_location: Optional[Location] = None - """Source code covered by this IR.""" - source_file_name: Optional[str] = None - """Name of the source file.""" + attribute: list[Attribute] = ir_data_fields.list_field(Attribute) + """Module-level attributes.""" + type: list[TypeDefinition] = ir_data_fields.list_field(TypeDefinition) + """Module-level type definitions.""" + documentation: list[Documentation] = ir_data_fields.list_field(Documentation) + """Module-level docs.""" + foreign_import: list[Import] = ir_data_fields.list_field(Import) + """Other modules imported.""" + source_text: Optional[str] = None + """The original source code.""" + source_location: Optional[Location] = None + """Source code covered by this IR.""" + source_file_name: Optional[str] = None + """Name of the source file.""" @dataclasses.dataclass class EmbossIr(Message): - """The top-level IR for an Emboss module and all of its dependencies.""" + """The top-level IR for an Emboss module and all of its dependencies.""" - module: list[Module] = ir_data_fields.list_field(Module) - """All modules. + module: list[Module] = ir_data_fields.list_field(Module) + """All modules. The first entry will be the main module; back ends should generate code corresponding to that module. The second entry will be the diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py index 52f3a8a..002ea3b 100644 --- a/compiler/util/ir_data_fields.py +++ b/compiler/util/ir_data_fields.py @@ -57,11 +57,11 @@ class IrDataclassInstance(Protocol): - """Type bound for an IR dataclass instance.""" + """Type bound for an IR dataclass instance.""" - __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] - IR_DATACLASS: ClassVar[object] - field_specs: ClassVar["FilteredIrFieldSpecs"] + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + IR_DATACLASS: ClassVar[object] + field_specs: ClassVar["FilteredIrFieldSpecs"] IrDataT = TypeVar("IrDataT", bound=IrDataclassInstance) @@ -72,249 +72,245 @@ class IrDataclassInstance(Protocol): def _is_ir_dataclass(obj): - return hasattr(obj, _IR_DATACLASSS_ATTR) + return hasattr(obj, _IR_DATACLASSS_ATTR) class CopyValuesList(list[CopyValuesListT]): - """A list that makes copies of any value that is inserted""" + """A list that makes copies of any value that is inserted""" - def __init__( - self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None - ): - if iterable: - super().__init__(iterable) - else: - super().__init__() - self.value_type = value_type + def __init__( + self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None + ): + if iterable: + super().__init__(iterable) + else: + super().__init__() + self.value_type = value_type - def _copy(self, obj: Any): - if _is_ir_dataclass(obj): - return copy(obj) - return self.value_type(obj) + def _copy(self, obj: Any): + if _is_ir_dataclass(obj): + return copy(obj) + return self.value_type(obj) - def extend(self, iterable: Iterable) -> None: - return super().extend([self._copy(i) for i in iterable]) + def extend(self, iterable: Iterable) -> None: + return super().extend([self._copy(i) for i in iterable]) - def shallow_copy(self, iterable: Iterable) -> None: - """Explicitly performs a shallow copy of the provided list""" - return super().extend(iterable) + def shallow_copy(self, iterable: Iterable) -> None: + """Explicitly performs a shallow copy of the provided list""" + return super().extend(iterable) - def append(self, obj: Any) -> None: - return super().append(self._copy(obj)) + def append(self, obj: Any) -> None: + return super().append(self._copy(obj)) - def insert(self, index: SupportsIndex, obj: Any) -> None: - return super().insert(index, self._copy(obj)) + def insert(self, index: SupportsIndex, obj: Any) -> None: + return super().insert(index, self._copy(obj)) class TemporaryCopyValuesList(NamedTuple): - """Class used to temporarily hold a CopyValuesList while copying and - constructing an IR dataclass. - """ + """Class used to temporarily hold a CopyValuesList while copying and + constructing an IR dataclass. + """ - temp_list: CopyValuesList + temp_list: CopyValuesList class FieldContainer(enum.Enum): - """Indicates a fields container type""" + """Indicates a fields container type""" - NONE = 0 - OPTIONAL = 1 - LIST = 2 + NONE = 0 + OPTIONAL = 1 + LIST = 2 class FieldSpec(NamedTuple): - """Indicates the container and type of a field. + """Indicates the container and type of a field. - `FieldSpec` objects are accessed millions of times during runs so we cache as - many operations as possible. - - `is_dataclass`: `dataclasses.is_dataclass(data_type)` - - `is_sequence`: `container is FieldContainer.LIST` - - `is_enum`: `issubclass(data_type, enum.Enum)` - - `is_oneof`: `oneof is not None` + `FieldSpec` objects are accessed millions of times during runs so we cache as + many operations as possible. + - `is_dataclass`: `dataclasses.is_dataclass(data_type)` + - `is_sequence`: `container is FieldContainer.LIST` + - `is_enum`: `issubclass(data_type, enum.Enum)` + - `is_oneof`: `oneof is not None` - Use `make_field_spec` to automatically fill in the cached operations. - """ + Use `make_field_spec` to automatically fill in the cached operations. + """ - name: str - data_type: type - container: FieldContainer - oneof: Optional[str] - is_dataclass: bool - is_sequence: bool - is_enum: bool - is_oneof: bool + name: str + data_type: type + container: FieldContainer + oneof: Optional[str] + is_dataclass: bool + is_sequence: bool + is_enum: bool + is_oneof: bool def make_field_spec( name: str, data_type: type, container: FieldContainer, oneof: Optional[str] ): - """Builds a field spec with cached type queries.""" - return FieldSpec( - name, - data_type, - container, - oneof, - is_dataclass=_is_ir_dataclass(data_type), - is_sequence=container is FieldContainer.LIST, - is_enum=issubclass(data_type, enum.Enum), - is_oneof=oneof is not None, - ) + """Builds a field spec with cached type queries.""" + return FieldSpec( + name, + data_type, + container, + oneof, + is_dataclass=_is_ir_dataclass(data_type), + is_sequence=container is FieldContainer.LIST, + is_enum=issubclass(data_type, enum.Enum), + is_oneof=oneof is not None, + ) def build_default(field_spec: FieldSpec): - """Builds a default instance of the given field""" - if field_spec.is_sequence: - return CopyValuesList(field_spec.data_type) - if field_spec.is_enum: - return field_spec.data_type(int()) - return field_spec.data_type() + """Builds a default instance of the given field""" + if field_spec.is_sequence: + return CopyValuesList(field_spec.data_type) + if field_spec.is_enum: + return field_spec.data_type(int()) + return field_spec.data_type() class FilteredIrFieldSpecs: - """Provides cached views of an IR dataclass' fields.""" + """Provides cached views of an IR dataclass' fields.""" - def __init__(self, specs: Mapping[str, FieldSpec]): - self.all_field_specs = specs - self.field_specs = tuple(specs.values()) - self.dataclass_field_specs = { - k: v for k, v in specs.items() if v.is_dataclass - } - self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof} - self.sequence_field_specs = tuple( - v for v in specs.values() if v.is_sequence - ) - self.oneof_mappings = tuple( - (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof - ) + def __init__(self, specs: Mapping[str, FieldSpec]): + self.all_field_specs = specs + self.field_specs = tuple(specs.values()) + self.dataclass_field_specs = {k: v for k, v in specs.items() if v.is_dataclass} + self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof} + self.sequence_field_specs = tuple(v for v in specs.values() if v.is_sequence) + self.oneof_mappings = tuple( + (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof + ) def all_ir_classes(mod): - """Retrieves a list of all IR dataclass definitions in the given module.""" - return ( - v - for v in mod.__dict__.values() - if isinstance(type, v.__class__) and _is_ir_dataclass(v) - ) + """Retrieves a list of all IR dataclass definitions in the given module.""" + return ( + v + for v in mod.__dict__.values() + if isinstance(type, v.__class__) and _is_ir_dataclass(v) + ) class IrDataclassSpecs: - """Maintains a cache of all IR dataclass specs.""" + """Maintains a cache of all IR dataclass specs.""" - spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {} + spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {} - @classmethod - def get_mod_specs(cls, mod): - """Gets the IR dataclass specs for the given module.""" - return { - ir_class: FilteredIrFieldSpecs(_field_specs(ir_class)) - for ir_class in all_ir_classes(mod) - } + @classmethod + def get_mod_specs(cls, mod): + """Gets the IR dataclass specs for the given module.""" + return { + ir_class: FilteredIrFieldSpecs(_field_specs(ir_class)) + for ir_class in all_ir_classes(mod) + } - @classmethod - def get_specs(cls, data_class): - """Gets the field specs for the given class. The specs will be cached.""" - if data_class not in cls.spec_cache: - mod = sys.modules[data_class.__module__] - cls.spec_cache.update(cls.get_mod_specs(mod)) - return cls.spec_cache[data_class] + @classmethod + def get_specs(cls, data_class): + """Gets the field specs for the given class. The specs will be cached.""" + if data_class not in cls.spec_cache: + mod = sys.modules[data_class.__module__] + cls.spec_cache.update(cls.get_mod_specs(mod)) + return cls.spec_cache[data_class] def cache_message_specs(mod, cls): - """Adds a cached `field_specs` attribute to IR dataclasses in `mod` - excluding the given base `cls`. + """Adds a cached `field_specs` attribute to IR dataclasses in `mod` + excluding the given base `cls`. - This needs to be done after the dataclass decorators run and create the - wrapped classes. - """ - for data_class in all_ir_classes(mod): - if data_class is not cls: - data_class.field_specs = IrDataclassSpecs.get_specs(data_class) + This needs to be done after the dataclass decorators run and create the + wrapped classes. + """ + for data_class in all_ir_classes(mod): + if data_class is not cls: + data_class.field_specs = IrDataclassSpecs.get_specs(data_class) def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]: - """Gets the IR data field names and types for the given IR data class""" - # Get the dataclass fields - class_fields = dataclasses.fields(cast(Any, cls)) - - # Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute - # `builtins.Expression` for 'Expression' rather than `ir_data.Expression`. - # Instead we manually subsitute the type by extracting the list of classes - # from the class' module and manually substituting. - mod_ns = { - k: v - for k, v in sys.modules[cls.__module__].__dict__.items() - if isinstance(type, v.__class__) - } - - # Now extract the concrete type out of optionals - result: MutableMapping[str, FieldSpec] = {} - for class_field in class_fields: - if class_field.name.startswith("_"): - continue - container_type = FieldContainer.NONE - type_hint = class_field.type - oneof = class_field.metadata.get("oneof") - - # Check if this type is wrapped - origin = get_origin(type_hint) - # Get the wrapped types if there are any - args = get_args(type_hint) - if origin is not None: - # Extract the type. - type_hint = args[0] - - # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we - # have to check if it's a `Union` instead of just using `Optional`. - if origin == Union: - # Make sure this is an `Optional` and not another `Union` type. - assert len(args) == 2 and args[1] == type(None) - container_type = FieldContainer.OPTIONAL - elif origin == list: - container_type = FieldContainer.LIST - else: - raise TypeError(f"Field has invalid container type: {origin}") - - # Resolve any forward references. - if isinstance(type_hint, str): - type_hint = mod_ns[type_hint] - if isinstance(type_hint, ForwardRef): - type_hint = mod_ns[type_hint.__forward_arg__] - - result[class_field.name] = make_field_spec( - class_field.name, type_hint, container_type, oneof - ) + """Gets the IR data field names and types for the given IR data class""" + # Get the dataclass fields + class_fields = dataclasses.fields(cast(Any, cls)) + + # Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute + # `builtins.Expression` for 'Expression' rather than `ir_data.Expression`. + # Instead we manually subsitute the type by extracting the list of classes + # from the class' module and manually substituting. + mod_ns = { + k: v + for k, v in sys.modules[cls.__module__].__dict__.items() + if isinstance(type, v.__class__) + } - return result + # Now extract the concrete type out of optionals + result: MutableMapping[str, FieldSpec] = {} + for class_field in class_fields: + if class_field.name.startswith("_"): + continue + container_type = FieldContainer.NONE + type_hint = class_field.type + oneof = class_field.metadata.get("oneof") + + # Check if this type is wrapped + origin = get_origin(type_hint) + # Get the wrapped types if there are any + args = get_args(type_hint) + if origin is not None: + # Extract the type. + type_hint = args[0] + + # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we + # have to check if it's a `Union` instead of just using `Optional`. + if origin == Union: + # Make sure this is an `Optional` and not another `Union` type. + assert len(args) == 2 and args[1] == type(None) + container_type = FieldContainer.OPTIONAL + elif origin == list: + container_type = FieldContainer.LIST + else: + raise TypeError(f"Field has invalid container type: {origin}") + + # Resolve any forward references. + if isinstance(type_hint, str): + type_hint = mod_ns[type_hint] + if isinstance(type_hint, ForwardRef): + type_hint = mod_ns[type_hint.__forward_arg__] + + result[class_field.name] = make_field_spec( + class_field.name, type_hint, container_type, oneof + ) + + return result def field_specs(obj: Union[IrDataT, type[IrDataT]]) -> Mapping[str, FieldSpec]: - """Retrieves the fields specs for the the give data type. + """Retrieves the fields specs for the the give data type. - The results of this method are cached to reduce lookup overhead. - """ - cls = obj if isinstance(obj, type) else type(obj) - if cls is type(None): - raise TypeError("field_specs called with invalid type: NoneType") - return IrDataclassSpecs.get_specs(cls).all_field_specs + The results of this method are cached to reduce lookup overhead. + """ + cls = obj if isinstance(obj, type) else type(obj) + if cls is type(None): + raise TypeError("field_specs called with invalid type: NoneType") + return IrDataclassSpecs.get_specs(cls).all_field_specs def fields_and_values( ir: IrDataT, value_filt: Optional[Callable[[Any], bool]] = None, ): - """Retrieves the fields and their values for a given IR data class. + """Retrieves the fields and their values for a given IR data class. - Args: - ir: The IR data class or a read-only wrapper of an IR data class. - value_filt: Optional filter used to exclude values. - """ - set_fields: list[Tuple[FieldSpec, Any]] = [] - specs: FilteredIrFieldSpecs = ir.field_specs - for spec in specs.field_specs: - value = getattr(ir, spec.name) - if not value_filt or value_filt(value): - set_fields.append((spec, value)) - return set_fields + Args: + ir: The IR data class or a read-only wrapper of an IR data class. + value_filt: Optional filter used to exclude values. + """ + set_fields: list[Tuple[FieldSpec, Any]] = [] + specs: FilteredIrFieldSpecs = ir.field_specs + for spec in specs.field_specs: + value = getattr(ir, spec.name) + if not value_filt or value_filt(value): + set_fields.append((spec, value)) + return set_fields # `copy` is one of the hottest paths of embossc. We've taken steps to @@ -333,117 +329,119 @@ def fields_and_values( # 5. None checks are only done in `copy()`, `_copy_set_fields` only # references `_copy()` to avoid this step. def _copy_set_fields(ir: IrDataT): - values: MutableMapping[str, Any] = {} - - specs: FilteredIrFieldSpecs = ir.field_specs - for spec in specs.field_specs: - value = getattr(ir, spec.name) - if value is not None: - if spec.is_sequence: - if spec.is_dataclass: - copy_value = CopyValuesList(spec.data_type, (_copy(v) for v in value)) - value = TemporaryCopyValuesList(copy_value) - else: - copy_value = CopyValuesList(spec.data_type, value) - value = TemporaryCopyValuesList(copy_value) - elif spec.is_dataclass: - value = _copy(value) - values[spec.name] = value - return values + values: MutableMapping[str, Any] = {} + + specs: FilteredIrFieldSpecs = ir.field_specs + for spec in specs.field_specs: + value = getattr(ir, spec.name) + if value is not None: + if spec.is_sequence: + if spec.is_dataclass: + copy_value = CopyValuesList( + spec.data_type, (_copy(v) for v in value) + ) + value = TemporaryCopyValuesList(copy_value) + else: + copy_value = CopyValuesList(spec.data_type, value) + value = TemporaryCopyValuesList(copy_value) + elif spec.is_dataclass: + value = _copy(value) + values[spec.name] = value + return values def _copy(ir: IrDataT) -> IrDataT: - return type(ir)(**_copy_set_fields(ir)) # type: ignore[misc] + return type(ir)(**_copy_set_fields(ir)) # type: ignore[misc] def copy(ir: IrDataT) -> Optional[IrDataT]: - """Creates a copy of the given IR data class""" - if not ir: - return None - return _copy(ir) + """Creates a copy of the given IR data class""" + if not ir: + return None + return _copy(ir) def update(ir: IrDataT, template: IrDataT): - """Updates `ir`s fields with all set fields in the template.""" - for k, v in _copy_set_fields(template).items(): - if isinstance(v, TemporaryCopyValuesList): - v = v.temp_list - setattr(ir, k, v) + """Updates `ir`s fields with all set fields in the template.""" + for k, v in _copy_set_fields(template).items(): + if isinstance(v, TemporaryCopyValuesList): + v = v.temp_list + setattr(ir, k, v) class OneOfField: - """Decorator for a "oneof" field. + """Decorator for a "oneof" field. - Tracks when the field is set and will unset othe fields in the associated - oneof group. + Tracks when the field is set and will unset othe fields in the associated + oneof group. - Note: Decorators only work if dataclass slots aren't used. - """ + Note: Decorators only work if dataclass slots aren't used. + """ - def __init__(self, oneof: str) -> None: - super().__init__() - self.oneof = oneof - self.owner_type = None - self.proxy_name: str = "" - self.name: str = "" + def __init__(self, oneof: str) -> None: + super().__init__() + self.oneof = oneof + self.owner_type = None + self.proxy_name: str = "" + self.name: str = "" - def __set_name__(self, owner, name): - self.name = name - self.proxy_name = f"_{name}" - self.owner_type = owner - # Add our empty proxy field to the class. - setattr(owner, self.proxy_name, None) + def __set_name__(self, owner, name): + self.name = name + self.proxy_name = f"_{name}" + self.owner_type = owner + # Add our empty proxy field to the class. + setattr(owner, self.proxy_name, None) - def __get__(self, obj, objtype=None): - return getattr(obj, self.proxy_name) + def __get__(self, obj, objtype=None): + return getattr(obj, self.proxy_name) - def __set__(self, obj, value): - if value is self: - # This will happen if the dataclass uses the default value, we just - # default to None. - value = None + def __set__(self, obj, value): + if value is self: + # This will happen if the dataclass uses the default value, we just + # default to None. + value = None - if value is not None: - # Clear the others - for name, oneof in IrDataclassSpecs.get_specs( - self.owner_type - ).oneof_mappings: - if oneof == self.oneof and name != self.name: - setattr(obj, name, None) + if value is not None: + # Clear the others + for name, oneof in IrDataclassSpecs.get_specs( + self.owner_type + ).oneof_mappings: + if oneof == self.oneof and name != self.name: + setattr(obj, name, None) - setattr(obj, self.proxy_name, value) + setattr(obj, self.proxy_name, value) def oneof_field(name: str): - """Alternative for `datclasses.field` that sets up a oneof variable""" - return dataclasses.field( # pylint:disable=invalid-field-call - default=OneOfField(name), metadata={"oneof": name}, init=True - ) + """Alternative for `datclasses.field` that sets up a oneof variable""" + return dataclasses.field( # pylint:disable=invalid-field-call + default=OneOfField(name), metadata={"oneof": name}, init=True + ) def str_field(): - """Helper used to define a defaulted str field""" - return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call + """Helper used to define a defaulted str field""" + return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call def list_field(cls_or_fn): - """Helper used to define a defaulted list field. - - A lambda can be used to defer resolution of a field type that references its - container type, for example: - ``` - class Foo: - subtypes: list['Foo'] = list_field(lambda: Foo) - names: list[str] = list_field(str) - ``` - - Args: - cls_or_fn: The class type or a function that resolves to the class type. - """ - - def list_factory(c): - return CopyValuesList(c if isinstance(c, type) else c()) - - return dataclasses.field( # pylint:disable=invalid-field-call - default_factory=lambda: list_factory(cls_or_fn) - ) + """Helper used to define a defaulted list field. + + A lambda can be used to defer resolution of a field type that references its + container type, for example: + ``` + class Foo: + subtypes: list['Foo'] = list_field(lambda: Foo) + names: list[str] = list_field(str) + ``` + + Args: + cls_or_fn: The class type or a function that resolves to the class type. + """ + + def list_factory(c): + return CopyValuesList(c if isinstance(c, type) else c()) + + return dataclasses.field( # pylint:disable=invalid-field-call + default_factory=lambda: list_factory(cls_or_fn) + ) diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py index fa99943..2853fde 100644 --- a/compiler/util/ir_data_fields_test.py +++ b/compiler/util/ir_data_fields_test.py @@ -25,230 +25,229 @@ class TestEnum(enum.Enum): - """Used to test python Enum handling.""" + """Used to test python Enum handling.""" - UNKNOWN = 0 - VALUE_1 = 1 - VALUE_2 = 2 + UNKNOWN = 0 + VALUE_1 = 1 + VALUE_2 = 2 @dataclasses.dataclass class Opaque(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers""" @dataclasses.dataclass class ClassWithUnion(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers""" - opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") - integer: Optional[int] = ir_data_fields.oneof_field("type") - boolean: Optional[bool] = ir_data_fields.oneof_field("type") - enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type") - non_union_field: int = 0 + opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") + integer: Optional[int] = ir_data_fields.oneof_field("type") + boolean: Optional[bool] = ir_data_fields.oneof_field("type") + enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type") + non_union_field: int = 0 @dataclasses.dataclass class ClassWithTwoUnions(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers""" - opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") - integer: Optional[int] = ir_data_fields.oneof_field("type_1") - boolean: Optional[bool] = ir_data_fields.oneof_field("type_2") - enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2") - non_union_field: int = 0 - seq_field: list[int] = ir_data_fields.list_field(int) + opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") + integer: Optional[int] = ir_data_fields.oneof_field("type_1") + boolean: Optional[bool] = ir_data_fields.oneof_field("type_2") + enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2") + non_union_field: int = 0 + seq_field: list[int] = ir_data_fields.list_field(int) @dataclasses.dataclass class NestedClass(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers""" - one_union_class: Optional[ClassWithUnion] = None - two_union_class: Optional[ClassWithTwoUnions] = None + one_union_class: Optional[ClassWithUnion] = None + two_union_class: Optional[ClassWithTwoUnions] = None @dataclasses.dataclass class ListCopyTestClass(ir_data.Message): - """Used to test behavior or extending a sequence.""" + """Used to test behavior or extending a sequence.""" - non_union_field: int = 0 - seq_field: list[int] = ir_data_fields.list_field(int) + non_union_field: int = 0 + seq_field: list[int] = ir_data_fields.list_field(int) @dataclasses.dataclass class OneofFieldTest(ir_data.Message): - """Basic test class for oneof fields""" + """Basic test class for oneof fields""" - int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1") - int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1") - normal_field: bool = True + int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1") + int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1") + normal_field: bool = True class OneOfTest(unittest.TestCase): - """Tests for the the various oneof field helpers""" - - def test_field_attribute(self): - """Test the `oneof_field` helper.""" - test_field = ir_data_fields.oneof_field("type_1") - self.assertIsNotNone(test_field) - self.assertTrue(test_field.init) - self.assertIsInstance(test_field.default, ir_data_fields.OneOfField) - self.assertEqual(test_field.metadata.get("oneof"), "type_1") - - def test_init_default(self): - """Test creating an instance with default fields""" - one_of_field_test = OneofFieldTest() - self.assertIsNone(one_of_field_test.int_field_1) - self.assertIsNone(one_of_field_test.int_field_2) - self.assertTrue(one_of_field_test.normal_field) - - def test_init(self): - """Test creating an instance with non-default fields""" - one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) - self.assertEqual(one_of_field_test.int_field_1, 10) - self.assertIsNone(one_of_field_test.int_field_2) - self.assertFalse(one_of_field_test.normal_field) - - def test_set_oneof_field(self): - """Tests setting oneof fields causes others in the group to be unset""" - one_of_field_test = OneofFieldTest() - one_of_field_test.int_field_1 = 10 - self.assertEqual(one_of_field_test.int_field_1, 10) - self.assertEqual(one_of_field_test.int_field_2, None) - one_of_field_test.int_field_2 = 20 - self.assertEqual(one_of_field_test.int_field_1, None) - self.assertEqual(one_of_field_test.int_field_2, 20) - - # Do it again - one_of_field_test.int_field_1 = 10 - self.assertEqual(one_of_field_test.int_field_1, 10) - self.assertEqual(one_of_field_test.int_field_2, None) - one_of_field_test.int_field_2 = 20 - self.assertEqual(one_of_field_test.int_field_1, None) - self.assertEqual(one_of_field_test.int_field_2, 20) - - # Now create a new instance and make sure changes to it are not reflected - # on the original object. - one_of_field_test_2 = OneofFieldTest() - one_of_field_test_2.int_field_1 = 1000 - self.assertEqual(one_of_field_test_2.int_field_1, 1000) - self.assertEqual(one_of_field_test_2.int_field_2, None) - self.assertEqual(one_of_field_test.int_field_1, None) - self.assertEqual(one_of_field_test.int_field_2, 20) - - def test_set_to_none(self): - """Tests explicitly setting a oneof field to None""" - one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) - self.assertEqual(one_of_field_test.int_field_1, 10) - self.assertIsNone(one_of_field_test.int_field_2) - self.assertFalse(one_of_field_test.normal_field) - - # Clear the set fields - one_of_field_test.int_field_1 = None - self.assertIsNone(one_of_field_test.int_field_1) - self.assertIsNone(one_of_field_test.int_field_2) - self.assertFalse(one_of_field_test.normal_field) - - # Set another field - one_of_field_test.int_field_2 = 200 - self.assertIsNone(one_of_field_test.int_field_1) - self.assertEqual(one_of_field_test.int_field_2, 200) - self.assertFalse(one_of_field_test.normal_field) - - # Clear the already unset field - one_of_field_test.int_field_1 = None - self.assertIsNone(one_of_field_test.int_field_1) - self.assertEqual(one_of_field_test.int_field_2, 200) - self.assertFalse(one_of_field_test.normal_field) - - def test_oneof_specs(self): - """Tests the `oneof_field_specs` filter""" - expected = { - "int_field_1": ir_data_fields.make_field_spec( - "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" - ), - "int_field_2": ir_data_fields.make_field_spec( - "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" - ), - } - actual = ir_data_fields.IrDataclassSpecs.get_specs( - OneofFieldTest - ).oneof_field_specs - self.assertDictEqual(actual, expected) - - def test_oneof_mappings(self): - """Tests the `oneof_mappings` function""" - expected = (("int_field_1", "type_1"), ("int_field_2", "type_1")) - actual = ir_data_fields.IrDataclassSpecs.get_specs( - OneofFieldTest - ).oneof_mappings - self.assertTupleEqual(actual, expected) + """Tests for the the various oneof field helpers""" + + def test_field_attribute(self): + """Test the `oneof_field` helper.""" + test_field = ir_data_fields.oneof_field("type_1") + self.assertIsNotNone(test_field) + self.assertTrue(test_field.init) + self.assertIsInstance(test_field.default, ir_data_fields.OneOfField) + self.assertEqual(test_field.metadata.get("oneof"), "type_1") + + def test_init_default(self): + """Test creating an instance with default fields""" + one_of_field_test = OneofFieldTest() + self.assertIsNone(one_of_field_test.int_field_1) + self.assertIsNone(one_of_field_test.int_field_2) + self.assertTrue(one_of_field_test.normal_field) + + def test_init(self): + """Test creating an instance with non-default fields""" + one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) + self.assertEqual(one_of_field_test.int_field_1, 10) + self.assertIsNone(one_of_field_test.int_field_2) + self.assertFalse(one_of_field_test.normal_field) + + def test_set_oneof_field(self): + """Tests setting oneof fields causes others in the group to be unset""" + one_of_field_test = OneofFieldTest() + one_of_field_test.int_field_1 = 10 + self.assertEqual(one_of_field_test.int_field_1, 10) + self.assertEqual(one_of_field_test.int_field_2, None) + one_of_field_test.int_field_2 = 20 + self.assertEqual(one_of_field_test.int_field_1, None) + self.assertEqual(one_of_field_test.int_field_2, 20) + + # Do it again + one_of_field_test.int_field_1 = 10 + self.assertEqual(one_of_field_test.int_field_1, 10) + self.assertEqual(one_of_field_test.int_field_2, None) + one_of_field_test.int_field_2 = 20 + self.assertEqual(one_of_field_test.int_field_1, None) + self.assertEqual(one_of_field_test.int_field_2, 20) + + # Now create a new instance and make sure changes to it are not reflected + # on the original object. + one_of_field_test_2 = OneofFieldTest() + one_of_field_test_2.int_field_1 = 1000 + self.assertEqual(one_of_field_test_2.int_field_1, 1000) + self.assertEqual(one_of_field_test_2.int_field_2, None) + self.assertEqual(one_of_field_test.int_field_1, None) + self.assertEqual(one_of_field_test.int_field_2, 20) + + def test_set_to_none(self): + """Tests explicitly setting a oneof field to None""" + one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) + self.assertEqual(one_of_field_test.int_field_1, 10) + self.assertIsNone(one_of_field_test.int_field_2) + self.assertFalse(one_of_field_test.normal_field) + + # Clear the set fields + one_of_field_test.int_field_1 = None + self.assertIsNone(one_of_field_test.int_field_1) + self.assertIsNone(one_of_field_test.int_field_2) + self.assertFalse(one_of_field_test.normal_field) + + # Set another field + one_of_field_test.int_field_2 = 200 + self.assertIsNone(one_of_field_test.int_field_1) + self.assertEqual(one_of_field_test.int_field_2, 200) + self.assertFalse(one_of_field_test.normal_field) + + # Clear the already unset field + one_of_field_test.int_field_1 = None + self.assertIsNone(one_of_field_test.int_field_1) + self.assertEqual(one_of_field_test.int_field_2, 200) + self.assertFalse(one_of_field_test.normal_field) + + def test_oneof_specs(self): + """Tests the `oneof_field_specs` filter""" + expected = { + "int_field_1": ir_data_fields.make_field_spec( + "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" + ), + "int_field_2": ir_data_fields.make_field_spec( + "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" + ), + } + actual = ir_data_fields.IrDataclassSpecs.get_specs( + OneofFieldTest + ).oneof_field_specs + self.assertDictEqual(actual, expected) + + def test_oneof_mappings(self): + """Tests the `oneof_mappings` function""" + expected = (("int_field_1", "type_1"), ("int_field_2", "type_1")) + actual = ir_data_fields.IrDataclassSpecs.get_specs( + OneofFieldTest + ).oneof_mappings + self.assertTupleEqual(actual, expected) class IrDataFieldsTest(unittest.TestCase): - """Tests misc methods in ir_data_fields""" - - def test_copy(self): - """Tests copying a data class works as expected""" - union = ClassWithTwoUnions( - opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3] - ) - nested_class = NestedClass(two_union_class=union) - nested_class_copy = ir_data_fields.copy(nested_class) - self.assertIsNotNone(nested_class_copy) - self.assertIsNot(nested_class, nested_class_copy) - self.assertEqual(nested_class_copy, nested_class) - - empty_copy = ir_data_fields.copy(None) - self.assertIsNone(empty_copy) - - def test_copy_values_list(self): - """Tests that CopyValuesList copies values""" - data_list = ir_data_fields.CopyValuesList(ListCopyTestClass) - self.assertEqual(len(data_list), 0) - - list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7]) - list_tests = [ir_data_fields.copy(list_test) for _ in range(4)] - data_list.extend(list_tests) - self.assertEqual(len(data_list), 4) - for i in data_list: - self.assertEqual(i, list_test) - - def test_list_param_is_copied(self): - """Test that lists passed to constructors are converted to CopyValuesList""" - seq_field = [5, 6, 7] - list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field) - self.assertEqual(len(list_test.seq_field), len(seq_field)) - self.assertIsNot(list_test.seq_field, seq_field) - self.assertEqual(list_test.seq_field, seq_field) - self.assertTrue( - isinstance(list_test.seq_field, ir_data_fields.CopyValuesList) - ) - - def test_copy_oneof(self): - """Tests copying an IR data class that has oneof fields.""" - oneof_test = OneofFieldTest() - oneof_test.int_field_1 = 10 - oneof_test.normal_field = False - self.assertEqual(oneof_test.int_field_1, 10) - self.assertEqual(oneof_test.normal_field, False) - - oneof_copy = ir_data_fields.copy(oneof_test) - self.assertIsNotNone(oneof_copy) - self.assertEqual(oneof_copy.int_field_1, 10) - self.assertIsNone(oneof_copy.int_field_2) - self.assertEqual(oneof_copy.normal_field, False) - - oneof_copy.int_field_2 = 100 - self.assertEqual(oneof_copy.int_field_2, 100) - self.assertIsNone(oneof_copy.int_field_1) - self.assertEqual(oneof_test.int_field_1, 10) - self.assertEqual(oneof_test.normal_field, False) + """Tests misc methods in ir_data_fields""" + + def test_copy(self): + """Tests copying a data class works as expected""" + union = ClassWithTwoUnions( + opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3] + ) + nested_class = NestedClass(two_union_class=union) + nested_class_copy = ir_data_fields.copy(nested_class) + self.assertIsNotNone(nested_class_copy) + self.assertIsNot(nested_class, nested_class_copy) + self.assertEqual(nested_class_copy, nested_class) + + empty_copy = ir_data_fields.copy(None) + self.assertIsNone(empty_copy) + + def test_copy_values_list(self): + """Tests that CopyValuesList copies values""" + data_list = ir_data_fields.CopyValuesList(ListCopyTestClass) + self.assertEqual(len(data_list), 0) + + list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7]) + list_tests = [ir_data_fields.copy(list_test) for _ in range(4)] + data_list.extend(list_tests) + self.assertEqual(len(data_list), 4) + for i in data_list: + self.assertEqual(i, list_test) + + def test_list_param_is_copied(self): + """Test that lists passed to constructors are converted to CopyValuesList""" + seq_field = [5, 6, 7] + list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field) + self.assertEqual(len(list_test.seq_field), len(seq_field)) + self.assertIsNot(list_test.seq_field, seq_field) + self.assertEqual(list_test.seq_field, seq_field) + self.assertTrue(isinstance(list_test.seq_field, ir_data_fields.CopyValuesList)) + + def test_copy_oneof(self): + """Tests copying an IR data class that has oneof fields.""" + oneof_test = OneofFieldTest() + oneof_test.int_field_1 = 10 + oneof_test.normal_field = False + self.assertEqual(oneof_test.int_field_1, 10) + self.assertEqual(oneof_test.normal_field, False) + + oneof_copy = ir_data_fields.copy(oneof_test) + self.assertIsNotNone(oneof_copy) + self.assertEqual(oneof_copy.int_field_1, 10) + self.assertIsNone(oneof_copy.int_field_2) + self.assertEqual(oneof_copy.normal_field, False) + + oneof_copy.int_field_2 = 100 + self.assertEqual(oneof_copy.int_field_2, 100) + self.assertIsNone(oneof_copy.int_field_1) + self.assertEqual(oneof_test.int_field_1, 10) + self.assertEqual(oneof_test.normal_field, False) ir_data_fields.cache_message_specs( - sys.modules[OneofFieldTest.__module__], ir_data.Message) + sys.modules[OneofFieldTest.__module__], ir_data.Message +) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py index 37cc9a3..2d38eac 100644 --- a/compiler/util/ir_data_utils.py +++ b/compiler/util/ir_data_utils.py @@ -85,383 +85,382 @@ def is_leaf_synthetic(data): def field_specs(ir: Union[MessageT, type[MessageT]]): - """Retrieves the field specs for the IR data class""" - data_type = ir if isinstance(ir, type) else type(ir) - return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs + """Retrieves the field specs for the IR data class""" + data_type = ir if isinstance(ir, type) else type(ir) + return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs class IrDataSerializer: - """Provides methods for serializing IR data objects""" - - def __init__(self, ir: MessageT): - assert ir is not None - self.ir = ir - - def _to_dict( - self, - ir: MessageT, - field_func: Callable[ - [MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]] - ], - ) -> MutableMapping[str, Any]: - assert ir is not None - values: MutableMapping[str, Any] = {} - for spec, value in field_func(ir): - if value is not None and spec.is_dataclass: - if spec.is_sequence: - value = [self._to_dict(v, field_func) for v in value] - else: - value = self._to_dict(value, field_func) - values[spec.name] = value - return values - - def to_dict(self, exclude_none: bool = False): - """Converts the IR data class to a dictionary.""" - - def non_empty(ir): - return fields_and_values( - ir, lambda v: v is not None and (not isinstance(v, list) or len(v)) - ) - - def all_fields(ir): - return fields_and_values(ir) - - # It's tempting to use `dataclasses.asdict` here, but that does a deep - # copy which is overkill for the current usage; mainly as an intermediary - # for `to_json` and `repr`. - return self._to_dict(self.ir, non_empty if exclude_none else all_fields) - - def to_json(self, *args, **kwargs): - """Converts the IR data class to a JSON string""" - return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs) - - @staticmethod - def from_json(data_cls, data): - """Constructs an IR data class from the given JSON string""" - as_dict = json.loads(data) - return IrDataSerializer.from_dict(data_cls, as_dict) - - def copy_from_dict(self, data): - """Deserializes the data and overwrites the IR data class with it""" - cls = type(self.ir) - data_copy = IrDataSerializer.from_dict(cls, data) - for k in field_specs(cls): - setattr(self.ir, k, getattr(data_copy, k)) - - @staticmethod - def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum: - if isinstance(val, str): - return getattr(enum_cls, val) - return enum_cls(val) - - @staticmethod - def _enum_type_hook(enum_cls: type[enum.Enum]): - return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val) - - @staticmethod - def _from_dict(data_cls: type[MessageT], data): - class_fields: MutableMapping[str, Any] = {} - for name, spec in ir_data_fields.field_specs(data_cls).items(): - if (value := data.get(name)) is not None: - if spec.is_dataclass: - if spec.is_sequence: - class_fields[name] = [ - IrDataSerializer._from_dict(spec.data_type, v) for v in value - ] - else: - class_fields[name] = IrDataSerializer._from_dict( - spec.data_type, value - ) - else: - if spec.data_type in ( - ir_data.FunctionMapping, - ir_data.AddressableUnit, - ): - class_fields[name] = IrDataSerializer._enum_type_converter( - spec.data_type, value + """Provides methods for serializing IR data objects""" + + def __init__(self, ir: MessageT): + assert ir is not None + self.ir = ir + + def _to_dict( + self, + ir: MessageT, + field_func: Callable[[MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]], + ) -> MutableMapping[str, Any]: + assert ir is not None + values: MutableMapping[str, Any] = {} + for spec, value in field_func(ir): + if value is not None and spec.is_dataclass: + if spec.is_sequence: + value = [self._to_dict(v, field_func) for v in value] + else: + value = self._to_dict(value, field_func) + values[spec.name] = value + return values + + def to_dict(self, exclude_none: bool = False): + """Converts the IR data class to a dictionary.""" + + def non_empty(ir): + return fields_and_values( + ir, lambda v: v is not None and (not isinstance(v, list) or len(v)) ) - else: - if spec.is_sequence: - class_fields[name] = value - else: - class_fields[name] = spec.data_type(value) - return data_cls(**class_fields) - @staticmethod - def from_dict(data_cls: type[MessageT], data): - """Creates a new IR data instance from a serialized dict""" - return IrDataSerializer._from_dict(data_cls, data) + def all_fields(ir): + return fields_and_values(ir) + + # It's tempting to use `dataclasses.asdict` here, but that does a deep + # copy which is overkill for the current usage; mainly as an intermediary + # for `to_json` and `repr`. + return self._to_dict(self.ir, non_empty if exclude_none else all_fields) + + def to_json(self, *args, **kwargs): + """Converts the IR data class to a JSON string""" + return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs) + + @staticmethod + def from_json(data_cls, data): + """Constructs an IR data class from the given JSON string""" + as_dict = json.loads(data) + return IrDataSerializer.from_dict(data_cls, as_dict) + + def copy_from_dict(self, data): + """Deserializes the data and overwrites the IR data class with it""" + cls = type(self.ir) + data_copy = IrDataSerializer.from_dict(cls, data) + for k in field_specs(cls): + setattr(self.ir, k, getattr(data_copy, k)) + + @staticmethod + def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum: + if isinstance(val, str): + return getattr(enum_cls, val) + return enum_cls(val) + + @staticmethod + def _enum_type_hook(enum_cls: type[enum.Enum]): + return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val) + + @staticmethod + def _from_dict(data_cls: type[MessageT], data): + class_fields: MutableMapping[str, Any] = {} + for name, spec in ir_data_fields.field_specs(data_cls).items(): + if (value := data.get(name)) is not None: + if spec.is_dataclass: + if spec.is_sequence: + class_fields[name] = [ + IrDataSerializer._from_dict(spec.data_type, v) + for v in value + ] + else: + class_fields[name] = IrDataSerializer._from_dict( + spec.data_type, value + ) + else: + if spec.data_type in ( + ir_data.FunctionMapping, + ir_data.AddressableUnit, + ): + class_fields[name] = IrDataSerializer._enum_type_converter( + spec.data_type, value + ) + else: + if spec.is_sequence: + class_fields[name] = value + else: + class_fields[name] = spec.data_type(value) + return data_cls(**class_fields) + + @staticmethod + def from_dict(data_cls: type[MessageT], data): + """Creates a new IR data instance from a serialized dict""" + return IrDataSerializer._from_dict(data_cls, data) class _IrDataSequenceBuilder(MutableSequence[MessageT]): - """Wrapper for a list of IR elements + """Wrapper for a list of IR elements - Simply wraps the returned values during indexed access and iteration with - IrDataBuilders. - """ + Simply wraps the returned values during indexed access and iteration with + IrDataBuilders. + """ - def __init__(self, target: MutableSequence[MessageT]): - self._target = target + def __init__(self, target: MutableSequence[MessageT]): + self._target = target - def __delitem__(self, key): - del self._target[key] + def __delitem__(self, key): + del self._target[key] - def __getitem__(self, key): - return _IrDataBuilder(self._target.__getitem__(key)) + def __getitem__(self, key): + return _IrDataBuilder(self._target.__getitem__(key)) - def __setitem__(self, key, value): - self._target[key] = value + def __setitem__(self, key, value): + self._target[key] = value - def __iter__(self): - itr = iter(self._target) - for i in itr: - yield _IrDataBuilder(i) + def __iter__(self): + itr = iter(self._target) + for i in itr: + yield _IrDataBuilder(i) - def __repr__(self): - return repr(self._target) + def __repr__(self): + return repr(self._target) - def __len__(self): - return len(self._target) + def __len__(self): + return len(self._target) - def __eq__(self, other): - return self._target == other + def __eq__(self, other): + return self._target == other - def __ne__(self, other): - return self._target != other + def __ne__(self, other): + return self._target != other - def insert(self, index, value): - self._target.insert(index, value) + def insert(self, index, value): + self._target.insert(index, value) - def extend(self, values): - self._target.extend(values) + def extend(self, values): + self._target.extend(values) class _IrDataBuilder(Generic[MessageT]): - """Wrapper for an IR element""" - - def __init__(self, ir: MessageT) -> None: - assert ir is not None - self.ir: MessageT = ir - - def __setattr__(self, __name: str, __value: Any) -> None: - if __name == "ir": - # This our proxy object - object.__setattr__(self, __name, __value) - else: - # Passthrough to the proxy object - ir: MessageT = object.__getattribute__(self, "ir") - setattr(ir, __name, __value) - - def __getattribute__(self, name: str) -> Any: - """Hook for `getattr` that handles adding missing fields. - - If the field is missing inserts it, and then returns either the raw value - for basic types - or a new IrBuilder wrapping the field to handle the next field access in a - longer chain. - """ - - # Check if getting one of the builder attributes - if name in ("CopyFrom", "ir"): - return object.__getattribute__(self, name) + """Wrapper for an IR element""" - # Get our target object by bypassing our getattr hook - ir: MessageT = object.__getattribute__(self, "ir") - if ir is None: - return object.__getattribute__(self, name) + def __init__(self, ir: MessageT) -> None: + assert ir is not None + self.ir: MessageT = ir - if name in ("HasField", "WhichOneof"): - return getattr(ir, name) - - field_spec = field_specs(ir).get(name) - if field_spec is None: - raise AttributeError( - f"No field {name} on {type(ir).__module__}.{type(ir).__name__}." - ) - - obj = getattr(ir, name, None) - if obj is None: - # Create a default and store it - obj = ir_data_fields.build_default(field_spec) - setattr(ir, name, obj) + def __setattr__(self, __name: str, __value: Any) -> None: + if __name == "ir": + # This our proxy object + object.__setattr__(self, __name, __value) + else: + # Passthrough to the proxy object + ir: MessageT = object.__getattribute__(self, "ir") + setattr(ir, __name, __value) + + def __getattribute__(self, name: str) -> Any: + """Hook for `getattr` that handles adding missing fields. + + If the field is missing inserts it, and then returns either the raw value + for basic types + or a new IrBuilder wrapping the field to handle the next field access in a + longer chain. + """ + + # Check if getting one of the builder attributes + if name in ("CopyFrom", "ir"): + return object.__getattribute__(self, name) + + # Get our target object by bypassing our getattr hook + ir: MessageT = object.__getattribute__(self, "ir") + if ir is None: + return object.__getattribute__(self, name) + + if name in ("HasField", "WhichOneof"): + return getattr(ir, name) + + field_spec = field_specs(ir).get(name) + if field_spec is None: + raise AttributeError( + f"No field {name} on {type(ir).__module__}.{type(ir).__name__}." + ) - if field_spec.is_dataclass: - obj = ( - _IrDataSequenceBuilder(obj) - if field_spec.is_sequence - else _IrDataBuilder(obj) - ) + obj = getattr(ir, name, None) + if obj is None: + # Create a default and store it + obj = ir_data_fields.build_default(field_spec) + setattr(ir, name, obj) + + if field_spec.is_dataclass: + obj = ( + _IrDataSequenceBuilder(obj) + if field_spec.is_sequence + else _IrDataBuilder(obj) + ) - return obj + return obj - def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name - """Updates the fields of this class with values set in the template""" - update(cast(type[MessageT], self), template) + def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name + """Updates the fields of this class with values set in the template""" + update(cast(type[MessageT], self), template) def builder(target: MessageT) -> MessageT: - """Create a wrapper around the target to help build an IR Data structure""" - # Check if the target is already a builder. - if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)): - return target + """Create a wrapper around the target to help build an IR Data structure""" + # Check if the target is already a builder. + if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)): + return target - # Builders are only valid for IR data classes. - if not hasattr(type(target), "IR_DATACLASS"): - raise TypeError(f"Builder target {type(target)} is not an ir_data.message") + # Builders are only valid for IR data classes. + if not hasattr(type(target), "IR_DATACLASS"): + raise TypeError(f"Builder target {type(target)} is not an ir_data.message") - # Create a builder and cast it to the target type to expose type hinting for - # the wrapped type. - return cast(MessageT, _IrDataBuilder(target)) + # Create a builder and cast it to the target type to expose type hinting for + # the wrapped type. + return cast(MessageT, _IrDataBuilder(target)) def _field_checker_from_spec(spec: ir_data_fields.FieldSpec): - """Helper that builds an FieldChecker that pretends to be an IR class""" - if spec.is_sequence: - return [] - if spec.is_dataclass: - return _ReadOnlyFieldChecker(spec) - return ir_data_fields.build_default(spec) + """Helper that builds an FieldChecker that pretends to be an IR class""" + if spec.is_sequence: + return [] + if spec.is_dataclass: + return _ReadOnlyFieldChecker(spec) + return ir_data_fields.build_default(spec) def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type: - if isinstance(ir_or_spec, ir_data_fields.FieldSpec): - return ir_or_spec.data_type - return type(ir_or_spec) + if isinstance(ir_or_spec, ir_data_fields.FieldSpec): + return ir_or_spec.data_type + return type(ir_or_spec) class _ReadOnlyFieldChecker: - """Class used the chain calls to fields that aren't set""" - - def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None: - self.ir_or_spec = ir_or_spec + """Class used the chain calls to fields that aren't set""" + + def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None: + self.ir_or_spec = ir_or_spec + + def __setattr__(self, name: str, value: Any) -> None: + if name == "ir_or_spec": + return object.__setattr__(self, name, value) + + raise AttributeError(f"Cannot set {name} on read-only wrapper") + + def __getattribute__( + self, name: str + ) -> Any: # pylint:disable=too-many-return-statements + ir_or_spec = object.__getattribute__(self, "ir_or_spec") + if name == "ir_or_spec": + return ir_or_spec + + field_type = _field_type(ir_or_spec) + spec = field_specs(field_type).get(name) + if not spec: + if isinstance(ir_or_spec, ir_data_fields.FieldSpec): + if name == "HasField": + return lambda x: False + if name == "WhichOneof": + return lambda x: None + return object.__getattribute__(ir_or_spec, name) + + if isinstance(ir_or_spec, ir_data_fields.FieldSpec): + # Just pretending + return _field_checker_from_spec(spec) + + value = getattr(ir_or_spec, name) + if value is None: + return _field_checker_from_spec(spec) - def __setattr__(self, name: str, value: Any) -> None: - if name == "ir_or_spec": - return object.__setattr__(self, name, value) - - raise AttributeError(f"Cannot set {name} on read-only wrapper") - - def __getattribute__(self, name: str) -> Any: # pylint:disable=too-many-return-statements - ir_or_spec = object.__getattribute__(self, "ir_or_spec") - if name == "ir_or_spec": - return ir_or_spec - - field_type = _field_type(ir_or_spec) - spec = field_specs(field_type).get(name) - if not spec: - if isinstance(ir_or_spec, ir_data_fields.FieldSpec): - if name == "HasField": - return lambda x: False - if name == "WhichOneof": - return lambda x: None - return object.__getattribute__(ir_or_spec, name) - - if isinstance(ir_or_spec, ir_data_fields.FieldSpec): - # Just pretending - return _field_checker_from_spec(spec) - - value = getattr(ir_or_spec, name) - if value is None: - return _field_checker_from_spec(spec) - - if spec.is_dataclass: - if spec.is_sequence: - return [_ReadOnlyFieldChecker(i) for i in value] - return _ReadOnlyFieldChecker(value) + if spec.is_dataclass: + if spec.is_sequence: + return [_ReadOnlyFieldChecker(i) for i in value] + return _ReadOnlyFieldChecker(value) - return value + return value - def __eq__(self, other): - if isinstance(other, _ReadOnlyFieldChecker): - other = other.ir_or_spec - return self.ir_or_spec == other + def __eq__(self, other): + if isinstance(other, _ReadOnlyFieldChecker): + other = other.ir_or_spec + return self.ir_or_spec == other - def __ne__(self, other): - return not self == other + def __ne__(self, other): + return not self == other def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT: - """Builds a read-only wrapper that can be used to check chains of possibly - unset fields. - - This wrapper explicitly does not alter the wrapped object and is only - intended for reading contents. - - For example, a `reader` lets you do: - ``` - def get_function_name_end_column(function: ir_data.Function): - return reader(function).function_name.source_location.end.column - ``` - - Instead of: - ``` - def get_function_name_end_column(function: ir_data.Function): - if function.function_name: - if function.function_name.source_location: - if function.function_name.source_location.end: - return function.function_name.source_location.end.column - return 0 - ``` - """ - # Create a read-only wrapper if it's not already one. - if not isinstance(obj, _ReadOnlyFieldChecker): - obj = _ReadOnlyFieldChecker(obj) - - # Cast it back to the original type. - return cast(MessageT, obj) + """Builds a read-only wrapper that can be used to check chains of possibly + unset fields. + + This wrapper explicitly does not alter the wrapped object and is only + intended for reading contents. + + For example, a `reader` lets you do: + ``` + def get_function_name_end_column(function: ir_data.Function): + return reader(function).function_name.source_location.end.column + ``` + + Instead of: + ``` + def get_function_name_end_column(function: ir_data.Function): + if function.function_name: + if function.function_name.source_location: + if function.function_name.source_location.end: + return function.function_name.source_location.end.column + return 0 + ``` + """ + # Create a read-only wrapper if it's not already one. + if not isinstance(obj, _ReadOnlyFieldChecker): + obj = _ReadOnlyFieldChecker(obj) + + # Cast it back to the original type. + return cast(MessageT, obj) def _extract_ir( ir_or_wrapper: Union[MessageT, _ReadOnlyFieldChecker, _IrDataBuilder, None], ) -> Optional[ir_data_fields.IrDataclassInstance]: - if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker): - ir_or_spec = ir_or_wrapper.ir_or_spec - if isinstance(ir_or_spec, ir_data_fields.FieldSpec): - # This is a placeholder entry, no fields are set. - return None - ir_or_wrapper = ir_or_spec - elif isinstance(ir_or_wrapper, _IrDataBuilder): - ir_or_wrapper = ir_or_wrapper.ir - return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper) + if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker): + ir_or_spec = ir_or_wrapper.ir_or_spec + if isinstance(ir_or_spec, ir_data_fields.FieldSpec): + # This is a placeholder entry, no fields are set. + return None + ir_or_wrapper = ir_or_spec + elif isinstance(ir_or_wrapper, _IrDataBuilder): + ir_or_wrapper = ir_or_wrapper.ir + return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper) def fields_and_values( ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker], value_filt: Optional[Callable[[Any], bool]] = None, ) -> list[Tuple[ir_data_fields.FieldSpec, Any]]: - """Retrieves the fields and their values for a given IR data class. + """Retrieves the fields and their values for a given IR data class. - Args: - ir: The IR data class or a read-only wrapper of an IR data class. - value_filt: Optional filter used to exclude values. - """ - if (ir := _extract_ir(ir_wrapper)) is None: - return [] + Args: + ir: The IR data class or a read-only wrapper of an IR data class. + value_filt: Optional filter used to exclude values. + """ + if (ir := _extract_ir(ir_wrapper)) is None: + return [] - return ir_data_fields.fields_and_values(ir, value_filt) + return ir_data_fields.fields_and_values(ir, value_filt) def get_set_fields(ir: MessageT): - """Retrieves the field spec and value of fields that are set in the given IR data class. + """Retrieves the field spec and value of fields that are set in the given IR data class. - A value is considered "set" if it is not None. - """ - return fields_and_values(ir, lambda v: v is not None) + A value is considered "set" if it is not None. + """ + return fields_and_values(ir, lambda v: v is not None) def copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]: - """Creates a copy of the given IR data class""" - if (ir := _extract_ir(ir_wrapper)) is None: - return None - ir_copy = ir_data_fields.copy(ir) - return cast(MessageT, ir_copy) + """Creates a copy of the given IR data class""" + if (ir := _extract_ir(ir_wrapper)) is None: + return None + ir_copy = ir_data_fields.copy(ir) + return cast(MessageT, ir_copy) def update(ir: MessageT, template: MessageT): - """Updates `ir`s fields with all set fields in the template.""" - if not (template_ir := _extract_ir(template)): - return + """Updates `ir`s fields with all set fields in the template.""" + if not (template_ir := _extract_ir(template)): + return - ir_data_fields.update( - cast(ir_data_fields.IrDataclassInstance, ir), template_ir - ) + ir_data_fields.update(cast(ir_data_fields.IrDataclassInstance, ir), template_ir) diff --git a/compiler/util/ir_data_utils_test.py b/compiler/util/ir_data_utils_test.py index 1c000ec..82baac6 100644 --- a/compiler/util/ir_data_utils_test.py +++ b/compiler/util/ir_data_utils_test.py @@ -26,625 +26,608 @@ class TestEnum(enum.Enum): - """Used to test python Enum handling.""" + """Used to test python Enum handling.""" - UNKNOWN = 0 - VALUE_1 = 1 - VALUE_2 = 2 + UNKNOWN = 0 + VALUE_1 = 1 + VALUE_2 = 2 @dataclasses.dataclass class Opaque(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers""" @dataclasses.dataclass class ClassWithUnion(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers""" - opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") - integer: Optional[int] = ir_data_fields.oneof_field("type") - boolean: Optional[bool] = ir_data_fields.oneof_field("type") - enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type") - non_union_field: int = 0 + opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") + integer: Optional[int] = ir_data_fields.oneof_field("type") + boolean: Optional[bool] = ir_data_fields.oneof_field("type") + enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type") + non_union_field: int = 0 @dataclasses.dataclass class ClassWithTwoUnions(ir_data.Message): - """Used for testing data field helpers""" + """Used for testing data field helpers""" - opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") - integer: Optional[int] = ir_data_fields.oneof_field("type_1") - boolean: Optional[bool] = ir_data_fields.oneof_field("type_2") - enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2") - non_union_field: int = 0 - seq_field: list[int] = ir_data_fields.list_field(int) + opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") + integer: Optional[int] = ir_data_fields.oneof_field("type_1") + boolean: Optional[bool] = ir_data_fields.oneof_field("type_2") + enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2") + non_union_field: int = 0 + seq_field: list[int] = ir_data_fields.list_field(int) class IrDataUtilsTest(unittest.TestCase): - """Tests for the miscellaneous utility functions in ir_data_utils.py.""" - - def test_field_specs(self): - """Tests the `field_specs` method""" - fields = ir_data_utils.field_specs(ir_data.TypeDefinition) - self.assertIsNotNone(fields) - expected_fields = ( - "external", - "enumeration", - "structure", - "name", - "attribute", - "documentation", - "subtype", - "addressable_unit", - "runtime_parameter", - "source_location", - ) - self.assertEqual(len(fields), len(expected_fields)) - field_names = fields.keys() - for k in expected_fields: - self.assertIn(k, field_names) - - # Try a sequence - expected_field = ir_data_fields.make_field_spec( - "attribute", ir_data.Attribute, ir_data_fields.FieldContainer.LIST, None - ) - self.assertEqual(fields["attribute"], expected_field) - - # Try a scalar - expected_field = ir_data_fields.make_field_spec( - "addressable_unit", - ir_data.AddressableUnit, - ir_data_fields.FieldContainer.OPTIONAL, - None, - ) - self.assertEqual(fields["addressable_unit"], expected_field) - - # Try a IR data class - expected_field = ir_data_fields.make_field_spec( - "source_location", - ir_data.Location, - ir_data_fields.FieldContainer.OPTIONAL, - None, - ) - self.assertEqual(fields["source_location"], expected_field) - - # Try an oneof field - expected_field = ir_data_fields.make_field_spec( - "external", - ir_data.External, - ir_data_fields.FieldContainer.OPTIONAL, - oneof="type", - ) - self.assertEqual(fields["external"], expected_field) - - # Try non-optional scalar - fields = ir_data_utils.field_specs(ir_data.Position) - expected_field = ir_data_fields.make_field_spec( - "line", int, ir_data_fields.FieldContainer.NONE, None - ) - self.assertEqual(fields["line"], expected_field) - - fields = ir_data_utils.field_specs(ir_data.ArrayType) - expected_field = ir_data_fields.make_field_spec( - "base_type", ir_data.Type, ir_data_fields.FieldContainer.OPTIONAL, None - ) - self.assertEqual(fields["base_type"], expected_field) - - def test_is_sequence(self): - """Tests for the `FieldSpec.is_sequence` helper""" - type_def = ir_data.TypeDefinition( - attribute=[ - ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ), - ] - ) - fields = ir_data_utils.field_specs(ir_data.TypeDefinition) - # Test against a repeated field - self.assertTrue(fields["attribute"].is_sequence) - # Test against a nested IR data type - self.assertFalse(fields["name"].is_sequence) - # Test against a plain scalar type - fields = ir_data_utils.field_specs(type_def.attribute[0]) - self.assertFalse(fields["is_default"].is_sequence) - - def test_is_dataclass(self): - """Tests FieldSpec.is_dataclass against ir_data""" - type_def = ir_data.TypeDefinition( - attribute=[ - ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ), - ] - ) - fields = ir_data_utils.field_specs(ir_data.TypeDefinition) - # Test against a repeated field that holds IR data structs - self.assertTrue(fields["attribute"].is_dataclass) - # Test against a nested IR data type - self.assertTrue(fields["name"].is_dataclass) - # Test against a plain scalar type - fields = ir_data_utils.field_specs(type_def.attribute[0]) - self.assertFalse(fields["is_default"].is_dataclass) - # Test against a repeated field that holds scalars - fields = ir_data_utils.field_specs(ir_data.Structure) - self.assertFalse(fields["fields_in_dependency_order"].is_dataclass) - - def test_get_set_fields(self): - """Tests that get set fields works""" - type_def = ir_data.TypeDefinition( - attribute=[ - ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ), - ] - ) - set_fields = ir_data_utils.get_set_fields(type_def) - expected_fields = set( - ["attribute", "documentation", "subtype", "runtime_parameter"] - ) - self.assertEqual(len(set_fields), len(expected_fields)) - found_fields = set() - for k, v in set_fields: - self.assertIn(k.name, expected_fields) - found_fields.add(k.name) - self.assertEqual(v, getattr(type_def, k.name)) - - self.assertSetEqual(found_fields, expected_fields) - - def test_copy(self): - """Tests the `copy` helper""" - attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - attribute_copy = ir_data_utils.copy(attribute) - - # Should be equivalent - self.assertEqual(attribute, attribute_copy) - # But not the same instance - self.assertIsNot(attribute, attribute_copy) - - # Let's do a sequence - type_def = ir_data.TypeDefinition(attribute=[attribute]) - type_def_copy = ir_data_utils.copy(type_def) - - # Should be equivalent - self.assertEqual(type_def, type_def_copy) - # But not the same instance - self.assertIsNot(type_def, type_def_copy) - self.assertIsNot(type_def.attribute, type_def_copy.attribute) - - def test_update(self): - """Tests the `update` helper""" - attribute_template = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - attribute = ir_data.Attribute(is_default=True) - ir_data_utils.update(attribute, attribute_template) - self.assertIsNotNone(attribute.value) - self.assertIsNot(attribute.value, attribute_template.value) - self.assertIsNotNone(attribute.name) - self.assertIsNot(attribute.name, attribute_template.name) - - # Value not present in template should be untouched - self.assertTrue(attribute.is_default) + """Tests for the miscellaneous utility functions in ir_data_utils.py.""" + + def test_field_specs(self): + """Tests the `field_specs` method""" + fields = ir_data_utils.field_specs(ir_data.TypeDefinition) + self.assertIsNotNone(fields) + expected_fields = ( + "external", + "enumeration", + "structure", + "name", + "attribute", + "documentation", + "subtype", + "addressable_unit", + "runtime_parameter", + "source_location", + ) + self.assertEqual(len(fields), len(expected_fields)) + field_names = fields.keys() + for k in expected_fields: + self.assertIn(k, field_names) + + # Try a sequence + expected_field = ir_data_fields.make_field_spec( + "attribute", ir_data.Attribute, ir_data_fields.FieldContainer.LIST, None + ) + self.assertEqual(fields["attribute"], expected_field) + + # Try a scalar + expected_field = ir_data_fields.make_field_spec( + "addressable_unit", + ir_data.AddressableUnit, + ir_data_fields.FieldContainer.OPTIONAL, + None, + ) + self.assertEqual(fields["addressable_unit"], expected_field) + + # Try a IR data class + expected_field = ir_data_fields.make_field_spec( + "source_location", + ir_data.Location, + ir_data_fields.FieldContainer.OPTIONAL, + None, + ) + self.assertEqual(fields["source_location"], expected_field) + + # Try an oneof field + expected_field = ir_data_fields.make_field_spec( + "external", + ir_data.External, + ir_data_fields.FieldContainer.OPTIONAL, + oneof="type", + ) + self.assertEqual(fields["external"], expected_field) + + # Try non-optional scalar + fields = ir_data_utils.field_specs(ir_data.Position) + expected_field = ir_data_fields.make_field_spec( + "line", int, ir_data_fields.FieldContainer.NONE, None + ) + self.assertEqual(fields["line"], expected_field) + + fields = ir_data_utils.field_specs(ir_data.ArrayType) + expected_field = ir_data_fields.make_field_spec( + "base_type", ir_data.Type, ir_data_fields.FieldContainer.OPTIONAL, None + ) + self.assertEqual(fields["base_type"], expected_field) + + def test_is_sequence(self): + """Tests for the `FieldSpec.is_sequence` helper""" + type_def = ir_data.TypeDefinition( + attribute=[ + ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ), + ] + ) + fields = ir_data_utils.field_specs(ir_data.TypeDefinition) + # Test against a repeated field + self.assertTrue(fields["attribute"].is_sequence) + # Test against a nested IR data type + self.assertFalse(fields["name"].is_sequence) + # Test against a plain scalar type + fields = ir_data_utils.field_specs(type_def.attribute[0]) + self.assertFalse(fields["is_default"].is_sequence) + + def test_is_dataclass(self): + """Tests FieldSpec.is_dataclass against ir_data""" + type_def = ir_data.TypeDefinition( + attribute=[ + ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ), + ] + ) + fields = ir_data_utils.field_specs(ir_data.TypeDefinition) + # Test against a repeated field that holds IR data structs + self.assertTrue(fields["attribute"].is_dataclass) + # Test against a nested IR data type + self.assertTrue(fields["name"].is_dataclass) + # Test against a plain scalar type + fields = ir_data_utils.field_specs(type_def.attribute[0]) + self.assertFalse(fields["is_default"].is_dataclass) + # Test against a repeated field that holds scalars + fields = ir_data_utils.field_specs(ir_data.Structure) + self.assertFalse(fields["fields_in_dependency_order"].is_dataclass) + + def test_get_set_fields(self): + """Tests that get set fields works""" + type_def = ir_data.TypeDefinition( + attribute=[ + ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ), + ] + ) + set_fields = ir_data_utils.get_set_fields(type_def) + expected_fields = set( + ["attribute", "documentation", "subtype", "runtime_parameter"] + ) + self.assertEqual(len(set_fields), len(expected_fields)) + found_fields = set() + for k, v in set_fields: + self.assertIn(k.name, expected_fields) + found_fields.add(k.name) + self.assertEqual(v, getattr(type_def, k.name)) + + self.assertSetEqual(found_fields, expected_fields) + + def test_copy(self): + """Tests the `copy` helper""" + attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + attribute_copy = ir_data_utils.copy(attribute) + + # Should be equivalent + self.assertEqual(attribute, attribute_copy) + # But not the same instance + self.assertIsNot(attribute, attribute_copy) + + # Let's do a sequence + type_def = ir_data.TypeDefinition(attribute=[attribute]) + type_def_copy = ir_data_utils.copy(type_def) + + # Should be equivalent + self.assertEqual(type_def, type_def_copy) + # But not the same instance + self.assertIsNot(type_def, type_def_copy) + self.assertIsNot(type_def.attribute, type_def_copy.attribute) + + def test_update(self): + """Tests the `update` helper""" + attribute_template = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + attribute = ir_data.Attribute(is_default=True) + ir_data_utils.update(attribute, attribute_template) + self.assertIsNotNone(attribute.value) + self.assertIsNot(attribute.value, attribute_template.value) + self.assertIsNotNone(attribute.name) + self.assertIsNot(attribute.name, attribute_template.name) + + # Value not present in template should be untouched + self.assertTrue(attribute.is_default) class IrDataBuilderTest(unittest.TestCase): - """Tests for IrDataBuilder""" - - def test_ir_data_builder(self): - """Tests that basic builder chains work""" - # We start with an empty type - type_def = ir_data.TypeDefinition() - self.assertFalse(type_def.HasField("name")) - self.assertIsNone(type_def.name) - - # Now setup a builder - builder = ir_data_utils.builder(type_def) - - # Assign to a sub-child - builder.name.name = ir_data.Word(text="phil") - - # Verify the wrapped struct is updated - self.assertIsNotNone(type_def.name) - self.assertIsNotNone(type_def.name.name) - self.assertIsNotNone(type_def.name.name.text) - self.assertEqual(type_def.name.name.text, "phil") - - def test_ir_data_builder_bad_field(self): - """Tests accessing an undefined field name fails""" - type_def = ir_data.TypeDefinition() - builder = ir_data_utils.builder(type_def) - self.assertRaises(AttributeError, lambda: builder.foo) - # Make sure it's not set on our IR data class either - self.assertRaises(AttributeError, getattr, type_def, "foo") - - def test_ir_data_builder_sequence(self): - """Tests that sequences are properly wrapped""" - # We start with an empty type - type_def = ir_data.TypeDefinition() - self.assertTrue(type_def.HasField("attribute")) - self.assertEqual(len(type_def.attribute), 0) - - # Now setup a builder - builder = ir_data_utils.builder(type_def) - - # Assign to a sequence - attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - - builder.attribute.append(attribute) - self.assertEqual(builder.attribute, [attribute]) - self.assertTrue(type_def.HasField("attribute")) - self.assertEqual(len(type_def.attribute), 1) - self.assertEqual(type_def.attribute[0], attribute) - - # Lets make it longer and then try iterating - builder.attribute.append(attribute) - self.assertEqual(len(type_def.attribute), 2) - for attr in builder.attribute: - # Modify the attributes - attr.name.text = "bob" - - # Make sure we can build up auto-default entries from a sequence item - builder.attribute.append(ir_data.Attribute()) - builder.attribute[-1].value.expression = ir_data.Expression() - builder.attribute[-1].name.text = "bob" - - # Create an attribute to compare against - new_attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="bob"), - ) - - self.assertEqual(len(type_def.attribute), 3) - for attr in type_def.attribute: - self.assertEqual(attr, new_attribute) - - # Make sure the list type is a CopyValuesList - self.assertIsInstance( - type_def.attribute, - ir_data_fields.CopyValuesList, - f"Instance is: {type(type_def.attribute)}", - ) - - def test_copy_from(self) -> None: - """Tests that `CopyFrom` works.""" - location = ir_data.Location( - start=ir_data.Position(line=1, column=1), - end=ir_data.Position(line=1, column=2), - ) - expression_ir = ir_data.Expression(source_location=location) - template: ir_data.Expression = expression_parser.parse("x + y") - expression = ir_data_utils.builder(expression_ir) - expression.CopyFrom(template) - self.assertIsNotNone(expression_ir.function) - self.assertIsInstance(expression.function, ir_data_utils._IrDataBuilder) - self.assertIsInstance( - expression.function.args, ir_data_utils._IrDataSequenceBuilder - ) - self.assertTrue(expression_ir.function.args) - - def test_copy_from_list(self): - specs = ir_data_utils.field_specs(ir_data.Function) - args_spec = specs["args"] - self.assertTrue(args_spec.is_dataclass) - template: ir_data.Expression = expression_parser.parse("x + y") - self.assertIsNotNone(template) - self.assertIsInstance(template, ir_data.Expression) - self.assertIsInstance(template.function, ir_data.Function) - self.assertIsInstance(template.function.args, ir_data_fields.CopyValuesList) - - location = ir_data.Location( - start=ir_data.Position(line=1, column=1), - end=ir_data.Position(line=1, column=2), - ) - expression_ir = ir_data.Expression(source_location=location) - self.assertIsInstance(expression_ir, ir_data.Expression) - self.assertIsNone(expression_ir.function) - - expression_builder = ir_data_utils.builder(expression_ir) - self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder) - expression_builder.CopyFrom(template) - self.assertIsNotNone(expression_ir.function) - self.assertIsInstance(expression_ir.function, ir_data.Function) - self.assertIsNotNone(expression_ir.function.args) - self.assertIsInstance( - expression_ir.function.args, ir_data_fields.CopyValuesList - ) - - self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder) - self.assertIsInstance( - expression_builder.function, ir_data_utils._IrDataBuilder - ) - self.assertIsInstance( - expression_builder.function.args, ir_data_utils._IrDataSequenceBuilder - ) - - def test_ir_data_builder_sequence_scalar(self): - """Tests that sequences of scalars function properly""" - # We start with an empty type - structure = ir_data.Structure() - - # Now setup a builder - builder = ir_data_utils.builder(structure) - - # Assign to a scalar sequence - builder.fields_in_dependency_order.append(12) - builder.fields_in_dependency_order.append(11) - - self.assertTrue(structure.HasField("fields_in_dependency_order")) - self.assertEqual(len(structure.fields_in_dependency_order), 2) - self.assertEqual(structure.fields_in_dependency_order[0], 12) - self.assertEqual(structure.fields_in_dependency_order[1], 11) - self.assertEqual(builder.fields_in_dependency_order, [12, 11]) - - new_structure = ir_data.Structure(fields_in_dependency_order=[12, 11]) - self.assertEqual(structure, new_structure) - - def test_ir_data_builder_oneof(self): - value = ir_data.AttributeValue( - expression=ir_data.Expression( - boolean_constant=ir_data.BooleanConstant() + """Tests for IrDataBuilder""" + + def test_ir_data_builder(self): + """Tests that basic builder chains work""" + # We start with an empty type + type_def = ir_data.TypeDefinition() + self.assertFalse(type_def.HasField("name")) + self.assertIsNone(type_def.name) + + # Now setup a builder + builder = ir_data_utils.builder(type_def) + + # Assign to a sub-child + builder.name.name = ir_data.Word(text="phil") + + # Verify the wrapped struct is updated + self.assertIsNotNone(type_def.name) + self.assertIsNotNone(type_def.name.name) + self.assertIsNotNone(type_def.name.name.text) + self.assertEqual(type_def.name.name.text, "phil") + + def test_ir_data_builder_bad_field(self): + """Tests accessing an undefined field name fails""" + type_def = ir_data.TypeDefinition() + builder = ir_data_utils.builder(type_def) + self.assertRaises(AttributeError, lambda: builder.foo) + # Make sure it's not set on our IR data class either + self.assertRaises(AttributeError, getattr, type_def, "foo") + + def test_ir_data_builder_sequence(self): + """Tests that sequences are properly wrapped""" + # We start with an empty type + type_def = ir_data.TypeDefinition() + self.assertTrue(type_def.HasField("attribute")) + self.assertEqual(len(type_def.attribute), 0) + + # Now setup a builder + builder = ir_data_utils.builder(type_def) + + # Assign to a sequence + attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + + builder.attribute.append(attribute) + self.assertEqual(builder.attribute, [attribute]) + self.assertTrue(type_def.HasField("attribute")) + self.assertEqual(len(type_def.attribute), 1) + self.assertEqual(type_def.attribute[0], attribute) + + # Lets make it longer and then try iterating + builder.attribute.append(attribute) + self.assertEqual(len(type_def.attribute), 2) + for attr in builder.attribute: + # Modify the attributes + attr.name.text = "bob" + + # Make sure we can build up auto-default entries from a sequence item + builder.attribute.append(ir_data.Attribute()) + builder.attribute[-1].value.expression = ir_data.Expression() + builder.attribute[-1].name.text = "bob" + + # Create an attribute to compare against + new_attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="bob"), + ) + + self.assertEqual(len(type_def.attribute), 3) + for attr in type_def.attribute: + self.assertEqual(attr, new_attribute) + + # Make sure the list type is a CopyValuesList + self.assertIsInstance( + type_def.attribute, + ir_data_fields.CopyValuesList, + f"Instance is: {type(type_def.attribute)}", + ) + + def test_copy_from(self) -> None: + """Tests that `CopyFrom` works.""" + location = ir_data.Location( + start=ir_data.Position(line=1, column=1), + end=ir_data.Position(line=1, column=2), + ) + expression_ir = ir_data.Expression(source_location=location) + template: ir_data.Expression = expression_parser.parse("x + y") + expression = ir_data_utils.builder(expression_ir) + expression.CopyFrom(template) + self.assertIsNotNone(expression_ir.function) + self.assertIsInstance(expression.function, ir_data_utils._IrDataBuilder) + self.assertIsInstance( + expression.function.args, ir_data_utils._IrDataSequenceBuilder + ) + self.assertTrue(expression_ir.function.args) + + def test_copy_from_list(self): + specs = ir_data_utils.field_specs(ir_data.Function) + args_spec = specs["args"] + self.assertTrue(args_spec.is_dataclass) + template: ir_data.Expression = expression_parser.parse("x + y") + self.assertIsNotNone(template) + self.assertIsInstance(template, ir_data.Expression) + self.assertIsInstance(template.function, ir_data.Function) + self.assertIsInstance(template.function.args, ir_data_fields.CopyValuesList) + + location = ir_data.Location( + start=ir_data.Position(line=1, column=1), + end=ir_data.Position(line=1, column=2), + ) + expression_ir = ir_data.Expression(source_location=location) + self.assertIsInstance(expression_ir, ir_data.Expression) + self.assertIsNone(expression_ir.function) + + expression_builder = ir_data_utils.builder(expression_ir) + self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder) + expression_builder.CopyFrom(template) + self.assertIsNotNone(expression_ir.function) + self.assertIsInstance(expression_ir.function, ir_data.Function) + self.assertIsNotNone(expression_ir.function.args) + self.assertIsInstance( + expression_ir.function.args, ir_data_fields.CopyValuesList ) - ) - builder = ir_data_utils.builder(value) - self.assertTrue(builder.HasField("expression")) - self.assertFalse(builder.expression.boolean_constant.value) - builder.expression.boolean_constant.value = True - self.assertTrue(builder.expression.boolean_constant.value) - self.assertTrue(value.expression.boolean_constant.value) - bool_constant = value.expression.boolean_constant - self.assertIsInstance(bool_constant, ir_data.BooleanConstant) + self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder) + self.assertIsInstance(expression_builder.function, ir_data_utils._IrDataBuilder) + self.assertIsInstance( + expression_builder.function.args, ir_data_utils._IrDataSequenceBuilder + ) + + def test_ir_data_builder_sequence_scalar(self): + """Tests that sequences of scalars function properly""" + # We start with an empty type + structure = ir_data.Structure() + + # Now setup a builder + builder = ir_data_utils.builder(structure) + + # Assign to a scalar sequence + builder.fields_in_dependency_order.append(12) + builder.fields_in_dependency_order.append(11) + + self.assertTrue(structure.HasField("fields_in_dependency_order")) + self.assertEqual(len(structure.fields_in_dependency_order), 2) + self.assertEqual(structure.fields_in_dependency_order[0], 12) + self.assertEqual(structure.fields_in_dependency_order[1], 11) + self.assertEqual(builder.fields_in_dependency_order, [12, 11]) + + new_structure = ir_data.Structure(fields_in_dependency_order=[12, 11]) + self.assertEqual(structure, new_structure) + + def test_ir_data_builder_oneof(self): + value = ir_data.AttributeValue( + expression=ir_data.Expression(boolean_constant=ir_data.BooleanConstant()) + ) + builder = ir_data_utils.builder(value) + self.assertTrue(builder.HasField("expression")) + self.assertFalse(builder.expression.boolean_constant.value) + builder.expression.boolean_constant.value = True + self.assertTrue(builder.expression.boolean_constant.value) + self.assertTrue(value.expression.boolean_constant.value) + + bool_constant = value.expression.boolean_constant + self.assertIsInstance(bool_constant, ir_data.BooleanConstant) class IrDataSerializerTest(unittest.TestCase): - """Tests for IrDataSerializer""" - - def test_ir_data_serializer_to_dict(self): - """Tests serialization with `IrDataSerializer.to_dict` with default settings""" - attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - - serializer = ir_data_utils.IrDataSerializer(attribute) - raw_dict = serializer.to_dict() - expected = { - "name": {"text": "phil", "source_location": None}, - "value": { - "expression": { - "constant": None, - "constant_reference": None, - "function": None, - "field_reference": None, - "boolean_constant": None, - "builtin_reference": None, - "type": None, + """Tests for IrDataSerializer""" + + def test_ir_data_serializer_to_dict(self): + """Tests serialization with `IrDataSerializer.to_dict` with default settings""" + attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + + serializer = ir_data_utils.IrDataSerializer(attribute) + raw_dict = serializer.to_dict() + expected = { + "name": {"text": "phil", "source_location": None}, + "value": { + "expression": { + "constant": None, + "constant_reference": None, + "function": None, + "field_reference": None, + "boolean_constant": None, + "builtin_reference": None, + "type": None, + "source_location": None, + }, + "string_constant": None, "source_location": None, }, - "string_constant": None, + "back_end": None, + "is_default": None, "source_location": None, - }, - "back_end": None, - "is_default": None, - "source_location": None, - } - self.assertDictEqual(raw_dict, expected) - - def test_ir_data_serializer_to_dict_exclude_none(self): - """Tests serialization with `IrDataSerializer.to_dict` when excluding None values""" - attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - serializer = ir_data_utils.IrDataSerializer(attribute) - raw_dict = serializer.to_dict(exclude_none=True) - expected = {"name": {"text": "phil"}, "value": {"expression": {}}} - self.assertDictEqual(raw_dict, expected) - - def test_ir_data_serializer_to_dict_enum(self): - """Tests that serialization of `enum.Enum` values works properly""" - type_def = ir_data.TypeDefinition( - addressable_unit=ir_data.AddressableUnit.BYTE - ) - serializer = ir_data_utils.IrDataSerializer(type_def) - raw_dict = serializer.to_dict(exclude_none=True) - expected = {"addressable_unit": ir_data.AddressableUnit.BYTE} - self.assertDictEqual(raw_dict, expected) - - def test_ir_data_serializer_from_dict(self): - """Tests deserializing IR data from a serialized dict""" - attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - serializer = ir_data_utils.IrDataSerializer(attribute) - raw_dict = serializer.to_dict(exclude_none=False) - new_attribute = serializer.from_dict(ir_data.Attribute, raw_dict) - self.assertEqual(attribute, new_attribute) - - def test_ir_data_serializer_from_dict_enum(self): - """Tests that deserializing `enum.Enum` values works properly""" - type_def = ir_data.TypeDefinition( - addressable_unit=ir_data.AddressableUnit.BYTE - ) - - serializer = ir_data_utils.IrDataSerializer(type_def) - raw_dict = serializer.to_dict(exclude_none=False) - new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict) - self.assertEqual(type_def, new_type_def) - - def test_ir_data_serializer_from_dict_enum_is_str(self): - """Tests that deserializing `enum.Enum` values works properly when string constant is used""" - type_def = ir_data.TypeDefinition( - addressable_unit=ir_data.AddressableUnit.BYTE - ) - raw_dict = {"addressable_unit": "BYTE"} - serializer = ir_data_utils.IrDataSerializer(type_def) - new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict) - self.assertEqual(type_def, new_type_def) - - def test_ir_data_serializer_from_dict_exclude_none(self): - """Tests that deserializing from a dict that excluded None values works properly""" - attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - - serializer = ir_data_utils.IrDataSerializer(attribute) - raw_dict = serializer.to_dict(exclude_none=True) - new_attribute = ir_data_utils.IrDataSerializer.from_dict( - ir_data.Attribute, raw_dict - ) - self.assertEqual(attribute, new_attribute) - - def test_from_dict_list(self): - function_args = [ - { - "constant": { - "value": "0", + } + self.assertDictEqual(raw_dict, expected) + + def test_ir_data_serializer_to_dict_exclude_none(self): + """Tests serialization with `IrDataSerializer.to_dict` when excluding None values""" + attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + serializer = ir_data_utils.IrDataSerializer(attribute) + raw_dict = serializer.to_dict(exclude_none=True) + expected = {"name": {"text": "phil"}, "value": {"expression": {}}} + self.assertDictEqual(raw_dict, expected) + + def test_ir_data_serializer_to_dict_enum(self): + """Tests that serialization of `enum.Enum` values works properly""" + type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE) + serializer = ir_data_utils.IrDataSerializer(type_def) + raw_dict = serializer.to_dict(exclude_none=True) + expected = {"addressable_unit": ir_data.AddressableUnit.BYTE} + self.assertDictEqual(raw_dict, expected) + + def test_ir_data_serializer_from_dict(self): + """Tests deserializing IR data from a serialized dict""" + attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + serializer = ir_data_utils.IrDataSerializer(attribute) + raw_dict = serializer.to_dict(exclude_none=False) + new_attribute = serializer.from_dict(ir_data.Attribute, raw_dict) + self.assertEqual(attribute, new_attribute) + + def test_ir_data_serializer_from_dict_enum(self): + """Tests that deserializing `enum.Enum` values works properly""" + type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE) + + serializer = ir_data_utils.IrDataSerializer(type_def) + raw_dict = serializer.to_dict(exclude_none=False) + new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict) + self.assertEqual(type_def, new_type_def) + + def test_ir_data_serializer_from_dict_enum_is_str(self): + """Tests that deserializing `enum.Enum` values works properly when string constant is used""" + type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE) + raw_dict = {"addressable_unit": "BYTE"} + serializer = ir_data_utils.IrDataSerializer(type_def) + new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict) + self.assertEqual(type_def, new_type_def) + + def test_ir_data_serializer_from_dict_exclude_none(self): + """Tests that deserializing from a dict that excluded None values works properly""" + attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + + serializer = ir_data_utils.IrDataSerializer(attribute) + raw_dict = serializer.to_dict(exclude_none=True) + new_attribute = ir_data_utils.IrDataSerializer.from_dict( + ir_data.Attribute, raw_dict + ) + self.assertEqual(attribute, new_attribute) + + def test_from_dict_list(self): + function_args = [ + { + "constant": { + "value": "0", + "source_location": { + "start": {"line": 421, "column": 3}, + "end": {"line": 421, "column": 4}, + "is_synthetic": False, + }, + }, + "type": { + "integer": { + "modulus": "infinity", + "modular_value": "0", + "minimum_value": "0", + "maximum_value": "0", + } + }, "source_location": { "start": {"line": 421, "column": 3}, "end": {"line": 421, "column": 4}, "is_synthetic": False, }, }, - "type": { - "integer": { - "modulus": "infinity", - "modular_value": "0", - "minimum_value": "0", - "maximum_value": "0", - } - }, - "source_location": { - "start": {"line": 421, "column": 3}, - "end": {"line": 421, "column": 4}, - "is_synthetic": False, - }, - }, - { - "constant": { - "value": "1", + { + "constant": { + "value": "1", + "source_location": { + "start": {"line": 421, "column": 11}, + "end": {"line": 421, "column": 12}, + "is_synthetic": False, + }, + }, + "type": { + "integer": { + "modulus": "infinity", + "modular_value": "1", + "minimum_value": "1", + "maximum_value": "1", + } + }, "source_location": { "start": {"line": 421, "column": 11}, "end": {"line": 421, "column": 12}, "is_synthetic": False, }, }, - "type": { - "integer": { - "modulus": "infinity", - "modular_value": "1", - "minimum_value": "1", - "maximum_value": "1", - } - }, - "source_location": { - "start": {"line": 421, "column": 11}, - "end": {"line": 421, "column": 12}, - "is_synthetic": False, - }, - }, - ] - function_data = {"args": function_args} - func = ir_data_utils.IrDataSerializer.from_dict( - ir_data.Function, function_data - ) - self.assertIsNotNone(func) - - def test_ir_data_serializer_copy_from_dict(self): - """Tests that updating an IR data struct from a dict works properly""" - attribute = ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil"), - ) - serializer = ir_data_utils.IrDataSerializer(attribute) - raw_dict = serializer.to_dict(exclude_none=False) - - new_attribute = ir_data.Attribute() - new_serializer = ir_data_utils.IrDataSerializer(new_attribute) - new_serializer.copy_from_dict(raw_dict) - self.assertEqual(attribute, new_attribute) + ] + function_data = {"args": function_args} + func = ir_data_utils.IrDataSerializer.from_dict(ir_data.Function, function_data) + self.assertIsNotNone(func) + + def test_ir_data_serializer_copy_from_dict(self): + """Tests that updating an IR data struct from a dict works properly""" + attribute = ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ) + serializer = ir_data_utils.IrDataSerializer(attribute) + raw_dict = serializer.to_dict(exclude_none=False) + + new_attribute = ir_data.Attribute() + new_serializer = ir_data_utils.IrDataSerializer(new_attribute) + new_serializer.copy_from_dict(raw_dict) + self.assertEqual(attribute, new_attribute) class ReadOnlyFieldCheckerTest(unittest.TestCase): - """Tests the ReadOnlyFieldChecker""" - - def test_basic_wrapper(self): - """Tests basic field checker actions""" - union = ClassWithTwoUnions( - opaque=Opaque(), boolean=True, non_union_field=10 - ) - field_checker = ir_data_utils.reader(union) - - # All accesses should return a wrapper object - self.assertIsNotNone(field_checker.opaque) - self.assertIsNotNone(field_checker.integer) - self.assertIsNotNone(field_checker.boolean) - self.assertIsNotNone(field_checker.enumeration) - self.assertIsNotNone(field_checker.non_union_field) - # Scalar field should pass through - self.assertEqual(field_checker.non_union_field, 10) - - # Make sure HasField works - self.assertTrue(field_checker.HasField("opaque")) - self.assertFalse(field_checker.HasField("integer")) - self.assertTrue(field_checker.HasField("boolean")) - self.assertFalse(field_checker.HasField("enumeration")) - self.assertTrue(field_checker.HasField("non_union_field")) - - def test_construct_from_field_checker(self): - """Tests that constructing from another field checker works""" - union = ClassWithTwoUnions( - opaque=Opaque(), boolean=True, non_union_field=10 - ) - field_checker_orig = ir_data_utils.reader(union) - field_checker = ir_data_utils.reader(field_checker_orig) - self.assertIsNotNone(field_checker) - self.assertEqual(field_checker.ir_or_spec, union) - - # All accesses should return a wrapper object - self.assertIsNotNone(field_checker.opaque) - self.assertIsNotNone(field_checker.integer) - self.assertIsNotNone(field_checker.boolean) - self.assertIsNotNone(field_checker.enumeration) - self.assertIsNotNone(field_checker.non_union_field) - # Scalar field should pass through - self.assertEqual(field_checker.non_union_field, 10) - - # Make sure HasField works - self.assertTrue(field_checker.HasField("opaque")) - self.assertFalse(field_checker.HasField("integer")) - self.assertTrue(field_checker.HasField("boolean")) - self.assertFalse(field_checker.HasField("enumeration")) - self.assertTrue(field_checker.HasField("non_union_field")) - - def test_read_only(self) -> None: - """Tests that the read only wrapper really is read only""" - union = ClassWithTwoUnions( - opaque=Opaque(), boolean=True, non_union_field=10 - ) - field_checker = ir_data_utils.reader(union) - - def set_field(): - field_checker.opaque = None - - self.assertRaises(AttributeError, set_field) + """Tests the ReadOnlyFieldChecker""" + + def test_basic_wrapper(self): + """Tests basic field checker actions""" + union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10) + field_checker = ir_data_utils.reader(union) + + # All accesses should return a wrapper object + self.assertIsNotNone(field_checker.opaque) + self.assertIsNotNone(field_checker.integer) + self.assertIsNotNone(field_checker.boolean) + self.assertIsNotNone(field_checker.enumeration) + self.assertIsNotNone(field_checker.non_union_field) + # Scalar field should pass through + self.assertEqual(field_checker.non_union_field, 10) + + # Make sure HasField works + self.assertTrue(field_checker.HasField("opaque")) + self.assertFalse(field_checker.HasField("integer")) + self.assertTrue(field_checker.HasField("boolean")) + self.assertFalse(field_checker.HasField("enumeration")) + self.assertTrue(field_checker.HasField("non_union_field")) + + def test_construct_from_field_checker(self): + """Tests that constructing from another field checker works""" + union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10) + field_checker_orig = ir_data_utils.reader(union) + field_checker = ir_data_utils.reader(field_checker_orig) + self.assertIsNotNone(field_checker) + self.assertEqual(field_checker.ir_or_spec, union) + + # All accesses should return a wrapper object + self.assertIsNotNone(field_checker.opaque) + self.assertIsNotNone(field_checker.integer) + self.assertIsNotNone(field_checker.boolean) + self.assertIsNotNone(field_checker.enumeration) + self.assertIsNotNone(field_checker.non_union_field) + # Scalar field should pass through + self.assertEqual(field_checker.non_union_field, 10) + + # Make sure HasField works + self.assertTrue(field_checker.HasField("opaque")) + self.assertFalse(field_checker.HasField("integer")) + self.assertTrue(field_checker.HasField("boolean")) + self.assertFalse(field_checker.HasField("enumeration")) + self.assertTrue(field_checker.HasField("non_union_field")) + + def test_read_only(self) -> None: + """Tests that the read only wrapper really is read only""" + union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10) + field_checker = ir_data_utils.reader(union) + + def set_field(): + field_checker.opaque = None + + self.assertRaises(AttributeError, set_field) ir_data_fields.cache_message_specs( - sys.modules[ReadOnlyFieldCheckerTest.__module__], ir_data.Message) + sys.modules[ReadOnlyFieldCheckerTest.__module__], ir_data.Message +) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/ir_util.py b/compiler/util/ir_util.py index 603c0c0..2d8763b 100644 --- a/compiler/util/ir_util.py +++ b/compiler/util/ir_util.py @@ -24,376 +24,384 @@ def get_attribute(attribute_list, name): - """Finds name in attribute_list and returns a AttributeValue or None.""" - if not attribute_list: - return None - attribute_value = None - for attr in attribute_list: - if attr.name.text == name and not attr.is_default: - assert attribute_value is None, 'Duplicate attribute "{}".'.format(name) - attribute_value = attr.value - return attribute_value + """Finds name in attribute_list and returns a AttributeValue or None.""" + if not attribute_list: + return None + attribute_value = None + for attr in attribute_list: + if attr.name.text == name and not attr.is_default: + assert attribute_value is None, 'Duplicate attribute "{}".'.format(name) + attribute_value = attr.value + return attribute_value def get_boolean_attribute(attribute_list, name, default_value=None): - """Returns the boolean value of an attribute, if any, or default_value. - - Arguments: - attribute_list: A list of attributes to search. - name: The name of the desired attribute. - default_value: A value to return if name is not found in attribute_list, - or the attribute does not have a boolean value. - - Returns: - The boolean value of the requested attribute, or default_value if the - requested attribute is not found or has a non-boolean value. - """ - attribute_value = get_attribute(attribute_list, name) - if (not attribute_value or - not attribute_value.expression.HasField("boolean_constant")): - return default_value - return attribute_value.expression.boolean_constant.value + """Returns the boolean value of an attribute, if any, or default_value. + + Arguments: + attribute_list: A list of attributes to search. + name: The name of the desired attribute. + default_value: A value to return if name is not found in attribute_list, + or the attribute does not have a boolean value. + + Returns: + The boolean value of the requested attribute, or default_value if the + requested attribute is not found or has a non-boolean value. + """ + attribute_value = get_attribute(attribute_list, name) + if not attribute_value or not attribute_value.expression.HasField( + "boolean_constant" + ): + return default_value + return attribute_value.expression.boolean_constant.value def get_integer_attribute(attribute_list, name, default_value=None): - """Returns the integer value of an attribute, if any, or default_value. - - Arguments: - attribute_list: A list of attributes to search. - name: The name of the desired attribute. - default_value: A value to return if name is not found in attribute_list, - or the attribute does not have an integer value. - - Returns: - The integer value of the requested attribute, or default_value if the - requested attribute is not found or has a non-integer value. - """ - attribute_value = get_attribute(attribute_list, name) - if (not attribute_value or - attribute_value.expression.type.WhichOneof("type") != "integer" or - not is_constant(attribute_value.expression)): - return default_value - return constant_value(attribute_value.expression) + """Returns the integer value of an attribute, if any, or default_value. + + Arguments: + attribute_list: A list of attributes to search. + name: The name of the desired attribute. + default_value: A value to return if name is not found in attribute_list, + or the attribute does not have an integer value. + + Returns: + The integer value of the requested attribute, or default_value if the + requested attribute is not found or has a non-integer value. + """ + attribute_value = get_attribute(attribute_list, name) + if ( + not attribute_value + or attribute_value.expression.type.WhichOneof("type") != "integer" + or not is_constant(attribute_value.expression) + ): + return default_value + return constant_value(attribute_value.expression) def is_constant(expression, bindings=None): - return constant_value(expression, bindings) is not None + return constant_value(expression, bindings) is not None def is_constant_type(expression_type): - """Returns True if expression_type is inhabited by a single value.""" - expression_type = ir_data_utils.reader(expression_type) - return (expression_type.integer.modulus == "infinity" or - expression_type.boolean.HasField("value") or - expression_type.enumeration.HasField("value")) + """Returns True if expression_type is inhabited by a single value.""" + expression_type = ir_data_utils.reader(expression_type) + return ( + expression_type.integer.modulus == "infinity" + or expression_type.boolean.HasField("value") + or expression_type.enumeration.HasField("value") + ) def constant_value(expression, bindings=None): - """Evaluates expression with the given bindings.""" - if expression is None: - return None - expression = ir_data_utils.reader(expression) - if expression.WhichOneof("expression") == "constant": - return int(expression.constant.value or 0) - elif expression.WhichOneof("expression") == "constant_reference": - # We can't look up the constant reference without the IR, but by the time - # constant_value is called, the actual values should have been propagated to - # the type information. - if expression.type.WhichOneof("type") == "integer": - assert expression.type.integer.modulus == "infinity" - return int(expression.type.integer.modular_value) - elif expression.type.WhichOneof("type") == "boolean": - assert expression.type.boolean.HasField("value") - return expression.type.boolean.value - elif expression.type.WhichOneof("type") == "enumeration": - assert expression.type.enumeration.HasField("value") - return int(expression.type.enumeration.value) - else: - assert False, "Unexpected expression type {}".format( - expression.type.WhichOneof("type")) - elif expression.WhichOneof("expression") == "function": - return _constant_value_of_function(expression.function, bindings) - elif expression.WhichOneof("expression") == "field_reference": - return None - elif expression.WhichOneof("expression") == "boolean_constant": - return expression.boolean_constant.value - elif expression.WhichOneof("expression") == "builtin_reference": - name = expression.builtin_reference.canonical_name.object_path[0] - if bindings and name in bindings: - return bindings[name] + """Evaluates expression with the given bindings.""" + if expression is None: + return None + expression = ir_data_utils.reader(expression) + if expression.WhichOneof("expression") == "constant": + return int(expression.constant.value or 0) + elif expression.WhichOneof("expression") == "constant_reference": + # We can't look up the constant reference without the IR, but by the time + # constant_value is called, the actual values should have been propagated to + # the type information. + if expression.type.WhichOneof("type") == "integer": + assert expression.type.integer.modulus == "infinity" + return int(expression.type.integer.modular_value) + elif expression.type.WhichOneof("type") == "boolean": + assert expression.type.boolean.HasField("value") + return expression.type.boolean.value + elif expression.type.WhichOneof("type") == "enumeration": + assert expression.type.enumeration.HasField("value") + return int(expression.type.enumeration.value) + else: + assert False, "Unexpected expression type {}".format( + expression.type.WhichOneof("type") + ) + elif expression.WhichOneof("expression") == "function": + return _constant_value_of_function(expression.function, bindings) + elif expression.WhichOneof("expression") == "field_reference": + return None + elif expression.WhichOneof("expression") == "boolean_constant": + return expression.boolean_constant.value + elif expression.WhichOneof("expression") == "builtin_reference": + name = expression.builtin_reference.canonical_name.object_path[0] + if bindings and name in bindings: + return bindings[name] + else: + return None + elif expression.WhichOneof("expression") is None: + return None else: - return None - elif expression.WhichOneof("expression") is None: - return None - else: - assert False, "Unexpected expression kind {}".format( - expression.WhichOneof("expression")) + assert False, "Unexpected expression kind {}".format( + expression.WhichOneof("expression") + ) def _constant_value_of_function(function, bindings): - """Returns the constant value of evaluating `function`, or None.""" - values = [constant_value(arg, bindings) for arg in function.args] - # Expressions like `$is_statically_sized && 1 <= $static_size_in_bits <= 64` - # should return False, not None, if `$is_statically_sized` is false, even - # though `$static_size_in_bits` is unknown. - # - # The easiest way to allow this is to use a three-way logic chart for each; - # specifically: - # - # AND: True False Unknown - # +-------------------------- - # True | True False Unknown - # False | False False False - # Unknown | Unknown False Unknown - # - # OR: True False Unknown - # +-------------------------- - # True | True True True - # False | True False Unknown - # Unknown | True Unknown Unknown - # - # This raises the question of just how many constant-from-nonconstant - # expressions Emboss should support. There are, after all, a vast number of - # constant expression patterns built from non-constant subexpressions, such as - # `0 * X` or `X == X` or `3 * X == X + X + X`. I (bolms@) am not implementing - # any further special cases because I do not see any practical use for them. - if function.function == ir_data.FunctionMapping.UNKNOWN: - return None - if function.function == ir_data.FunctionMapping.AND: - if any(value is False for value in values): - return False - elif any(value is None for value in values): - return None - else: - return True - elif function.function == ir_data.FunctionMapping.OR: - if any(value is True for value in values): - return True - elif any(value is None for value in values): - return None - else: - return False - elif function.function == ir_data.FunctionMapping.CHOICE: - if values[0] is None: - return None - else: - return values[1] if values[0] else values[2] - # Other than the logical operators and choice operator, the result of any - # function on an unknown value is, itself, considered unknown. - if any(value is None for value in values): - return None - functions = { - ir_data.FunctionMapping.ADDITION: operator.add, - ir_data.FunctionMapping.SUBTRACTION: operator.sub, - ir_data.FunctionMapping.MULTIPLICATION: operator.mul, - ir_data.FunctionMapping.EQUALITY: operator.eq, - ir_data.FunctionMapping.INEQUALITY: operator.ne, - ir_data.FunctionMapping.LESS: operator.lt, - ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le, - ir_data.FunctionMapping.GREATER: operator.gt, - ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge, - # Python's max([1, 2]) == 2; max(1, 2) == 2; max([1]) == 1; but max(1) - # throws a TypeError ("'int' object is not iterable"). - ir_data.FunctionMapping.MAXIMUM: lambda *x: max(x), - } - return functions[function.function](*values) + """Returns the constant value of evaluating `function`, or None.""" + values = [constant_value(arg, bindings) for arg in function.args] + # Expressions like `$is_statically_sized && 1 <= $static_size_in_bits <= 64` + # should return False, not None, if `$is_statically_sized` is false, even + # though `$static_size_in_bits` is unknown. + # + # The easiest way to allow this is to use a three-way logic chart for each; + # specifically: + # + # AND: True False Unknown + # +-------------------------- + # True | True False Unknown + # False | False False False + # Unknown | Unknown False Unknown + # + # OR: True False Unknown + # +-------------------------- + # True | True True True + # False | True False Unknown + # Unknown | True Unknown Unknown + # + # This raises the question of just how many constant-from-nonconstant + # expressions Emboss should support. There are, after all, a vast number of + # constant expression patterns built from non-constant subexpressions, such as + # `0 * X` or `X == X` or `3 * X == X + X + X`. I (bolms@) am not implementing + # any further special cases because I do not see any practical use for them. + if function.function == ir_data.FunctionMapping.UNKNOWN: + return None + if function.function == ir_data.FunctionMapping.AND: + if any(value is False for value in values): + return False + elif any(value is None for value in values): + return None + else: + return True + elif function.function == ir_data.FunctionMapping.OR: + if any(value is True for value in values): + return True + elif any(value is None for value in values): + return None + else: + return False + elif function.function == ir_data.FunctionMapping.CHOICE: + if values[0] is None: + return None + else: + return values[1] if values[0] else values[2] + # Other than the logical operators and choice operator, the result of any + # function on an unknown value is, itself, considered unknown. + if any(value is None for value in values): + return None + functions = { + ir_data.FunctionMapping.ADDITION: operator.add, + ir_data.FunctionMapping.SUBTRACTION: operator.sub, + ir_data.FunctionMapping.MULTIPLICATION: operator.mul, + ir_data.FunctionMapping.EQUALITY: operator.eq, + ir_data.FunctionMapping.INEQUALITY: operator.ne, + ir_data.FunctionMapping.LESS: operator.lt, + ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le, + ir_data.FunctionMapping.GREATER: operator.gt, + ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge, + # Python's max([1, 2]) == 2; max(1, 2) == 2; max([1]) == 1; but max(1) + # throws a TypeError ("'int' object is not iterable"). + ir_data.FunctionMapping.MAXIMUM: lambda *x: max(x), + } + return functions[function.function](*values) def _hashable_form_of_name(name): - return (name.module_file,) + tuple(name.object_path) + return (name.module_file,) + tuple(name.object_path) def hashable_form_of_reference(reference): - """Returns a representation of reference that can be used as a dict key. + """Returns a representation of reference that can be used as a dict key. - Arguments: - reference: An ir_data.Reference or ir_data.NameDefinition. + Arguments: + reference: An ir_data.Reference or ir_data.NameDefinition. - Returns: - A tuple of the module_file and object_path. - """ - return _hashable_form_of_name(reference.canonical_name) + Returns: + A tuple of the module_file and object_path. + """ + return _hashable_form_of_name(reference.canonical_name) def hashable_form_of_field_reference(field_reference): - """Returns a representation of field_reference that can be used as a dict key. + """Returns a representation of field_reference that can be used as a dict key. - Arguments: - field_reference: An ir_data.FieldReference + Arguments: + field_reference: An ir_data.FieldReference - Returns: - A tuple of tuples of the module_files and object_paths. - """ - return tuple(_hashable_form_of_name(reference.canonical_name) - for reference in field_reference.path) + Returns: + A tuple of tuples of the module_files and object_paths. + """ + return tuple( + _hashable_form_of_name(reference.canonical_name) + for reference in field_reference.path + ) def is_array(type_ir): - """Returns true if type_ir is an array type.""" - return type_ir.HasField("array_type") + """Returns true if type_ir is an array type.""" + return type_ir.HasField("array_type") def _find_path_in_structure_field(path, field): - if not path: - return field - return None + if not path: + return field + return None def _find_path_in_structure(path, type_definition): - for field in type_definition.structure.field: - if field.name.name.text == path[0]: - return _find_path_in_structure_field(path[1:], field) - return None + for field in type_definition.structure.field: + if field.name.name.text == path[0]: + return _find_path_in_structure_field(path[1:], field) + return None def _find_path_in_enumeration(path, type_definition): - if len(path) != 1: + if len(path) != 1: + return None + for value in type_definition.enumeration.value: + if value.name.name.text == path[0]: + return value return None - for value in type_definition.enumeration.value: - if value.name.name.text == path[0]: - return value - return None def _find_path_in_parameters(path, type_definition): - if len(path) > 1 or not type_definition.HasField("runtime_parameter"): + if len(path) > 1 or not type_definition.HasField("runtime_parameter"): + return None + for parameter in type_definition.runtime_parameter: + if ir_data_utils.reader(parameter).name.name.text == path[0]: + return parameter return None - for parameter in type_definition.runtime_parameter: - if ir_data_utils.reader(parameter).name.name.text == path[0]: - return parameter - return None def _find_path_in_type_definition(path, type_definition): - """Finds the object with the given path in the given type_definition.""" - if not path: - return type_definition - obj = _find_path_in_parameters(path, type_definition) - if obj: - return obj - if type_definition.HasField("structure"): - obj = _find_path_in_structure(path, type_definition) - elif type_definition.HasField("enumeration"): - obj = _find_path_in_enumeration(path, type_definition) - if obj: - return obj - else: - return _find_path_in_type_list(path, type_definition.subtype or []) + """Finds the object with the given path in the given type_definition.""" + if not path: + return type_definition + obj = _find_path_in_parameters(path, type_definition) + if obj: + return obj + if type_definition.HasField("structure"): + obj = _find_path_in_structure(path, type_definition) + elif type_definition.HasField("enumeration"): + obj = _find_path_in_enumeration(path, type_definition) + if obj: + return obj + else: + return _find_path_in_type_list(path, type_definition.subtype or []) def _find_path_in_type_list(path, type_list): - for type_definition in type_list: - if type_definition.name.name.text == path[0]: - return _find_path_in_type_definition(path[1:], type_definition) - return None + for type_definition in type_list: + if type_definition.name.name.text == path[0]: + return _find_path_in_type_definition(path[1:], type_definition) + return None def _find_path_in_module(path, module_ir): - if not path: - return module_ir - return _find_path_in_type_list(path, module_ir.type) + if not path: + return module_ir + return _find_path_in_type_list(path, module_ir.type) def find_object_or_none(name, ir): - """Finds the object with the given canonical name, if it exists..""" - if (isinstance(name, ir_data.Reference) or - isinstance(name, ir_data.NameDefinition)): - path = _hashable_form_of_name(name.canonical_name) - elif isinstance(name, ir_data.CanonicalName): - path = _hashable_form_of_name(name) - else: - path = name + """Finds the object with the given canonical name, if it exists..""" + if isinstance(name, ir_data.Reference) or isinstance(name, ir_data.NameDefinition): + path = _hashable_form_of_name(name.canonical_name) + elif isinstance(name, ir_data.CanonicalName): + path = _hashable_form_of_name(name) + else: + path = name - for module in ir.module: - if module.source_file_name == path[0]: - return _find_path_in_module(path[1:], module) + for module in ir.module: + if module.source_file_name == path[0]: + return _find_path_in_module(path[1:], module) - return None + return None def find_object(name, ir): - """Finds the IR of the type, field, or value with the given canonical name.""" - result = find_object_or_none(name, ir) - assert result is not None, "Bad reference {}".format(name) - return result + """Finds the IR of the type, field, or value with the given canonical name.""" + result = find_object_or_none(name, ir) + assert result is not None, "Bad reference {}".format(name) + return result def find_parent_object(name, ir): - """Finds the parent object of the object with the given canonical name.""" - if (isinstance(name, ir_data.Reference) or - isinstance(name, ir_data.NameDefinition)): - path = _hashable_form_of_name(name.canonical_name) - elif isinstance(name, ir_data.CanonicalName): - path = _hashable_form_of_name(name) - else: - path = name - return find_object(path[:-1], ir) + """Finds the parent object of the object with the given canonical name.""" + if isinstance(name, ir_data.Reference) or isinstance(name, ir_data.NameDefinition): + path = _hashable_form_of_name(name.canonical_name) + elif isinstance(name, ir_data.CanonicalName): + path = _hashable_form_of_name(name) + else: + path = name + return find_object(path[:-1], ir) def get_base_type(type_ir): - """Returns the base type of the given type. + """Returns the base type of the given type. - Arguments: - type_ir: IR of a type reference. + Arguments: + type_ir: IR of a type reference. - Returns: - If type_ir corresponds to an atomic type (like "UInt"), returns type_ir. If - type_ir corresponds to an array type (like "UInt:8[12]" or "Square[8][8]"), - returns the type after stripping off the array types ("UInt" or "Square"). - """ - while type_ir.HasField("array_type"): - type_ir = type_ir.array_type.base_type - assert type_ir.HasField("atomic_type"), ( - "Unknown kind of type {}".format(type_ir)) - return type_ir + Returns: + If type_ir corresponds to an atomic type (like "UInt"), returns type_ir. If + type_ir corresponds to an array type (like "UInt:8[12]" or "Square[8][8]"), + returns the type after stripping off the array types ("UInt" or "Square"). + """ + while type_ir.HasField("array_type"): + type_ir = type_ir.array_type.base_type + assert type_ir.HasField("atomic_type"), "Unknown kind of type {}".format(type_ir) + return type_ir def fixed_size_of_type_in_bits(type_ir, ir): - """Returns the fixed, known size for the given type, in bits, or None. - - Arguments: - type_ir: The IR of a type. - ir: A complete IR, used to resolve references to types. - - Returns: - size if the size of the type can be determined, otherwise None. - """ - array_multiplier = 1 - while type_ir.HasField("array_type"): - if type_ir.array_type.WhichOneof("size") == "automatic": - return None - else: - assert type_ir.array_type.WhichOneof("size") == "element_count", ( - 'Expected array size to be "automatic" or "element_count".') - element_count = type_ir.array_type.element_count - if not is_constant(element_count): - return None + """Returns the fixed, known size for the given type, in bits, or None. + + Arguments: + type_ir: The IR of a type. + ir: A complete IR, used to resolve references to types. + + Returns: + size if the size of the type can be determined, otherwise None. + """ + array_multiplier = 1 + while type_ir.HasField("array_type"): + if type_ir.array_type.WhichOneof("size") == "automatic": + return None + else: + assert ( + type_ir.array_type.WhichOneof("size") == "element_count" + ), 'Expected array size to be "automatic" or "element_count".' + element_count = type_ir.array_type.element_count + if not is_constant(element_count): + return None + else: + array_multiplier *= constant_value(element_count) + assert not type_ir.HasField( + "size_in_bits" + ), "TODO(bolms): implement explicitly-sized arrays" + type_ir = type_ir.array_type.base_type + assert type_ir.HasField("atomic_type"), "Unexpected type!" + if type_ir.HasField("size_in_bits"): + size = constant_value(type_ir.size_in_bits) else: - array_multiplier *= constant_value(element_count) - assert not type_ir.HasField("size_in_bits"), ( - "TODO(bolms): implement explicitly-sized arrays") - type_ir = type_ir.array_type.base_type - assert type_ir.HasField("atomic_type"), "Unexpected type!" - if type_ir.HasField("size_in_bits"): - size = constant_value(type_ir.size_in_bits) - else: - type_definition = find_object(type_ir.atomic_type.reference, ir) - size_attr = get_attribute(type_definition.attribute, _FIXED_SIZE_ATTRIBUTE) - if not size_attr: - return None - size = constant_value(size_attr.expression) - return size * array_multiplier + type_definition = find_object(type_ir.atomic_type.reference, ir) + size_attr = get_attribute(type_definition.attribute, _FIXED_SIZE_ATTRIBUTE) + if not size_attr: + return None + size = constant_value(size_attr.expression) + return size * array_multiplier def field_is_virtual(field_ir): - """Returns true if the field is virtual.""" - # TODO(bolms): Should there be a more explicit indicator that a field is - # virtual? - return not field_ir.HasField("location") + """Returns true if the field is virtual.""" + # TODO(bolms): Should there be a more explicit indicator that a field is + # virtual? + return not field_ir.HasField("location") def field_is_read_only(field_ir): - """Returns true if the field is read-only.""" - # For now, all virtual fields are read-only, and no non-virtual fields are - # read-only. - return ir_data_utils.reader(field_ir).write_method.read_only + """Returns true if the field is read-only.""" + # For now, all virtual fields are read-only, and no non-virtual fields are + # read-only. + return ir_data_utils.reader(field_ir).write_method.read_only diff --git a/compiler/util/ir_util_test.py b/compiler/util/ir_util_test.py index 8e0b37a..07cf4dc 100644 --- a/compiler/util/ir_util_test.py +++ b/compiler/util/ir_util_test.py @@ -22,397 +22,559 @@ def _parse_expression(text): - return expression_parser.parse(text) + return expression_parser.parse(text) class IrUtilTest(unittest.TestCase): - """Tests for the miscellaneous utility functions in ir_util.py.""" - - def test_is_constant_integer(self): - self.assertTrue(ir_util.is_constant(_parse_expression("6"))) - expression = _parse_expression("12") - # The type information should be ignored for constants like this one. - ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType()) - self.assertTrue(ir_util.is_constant(expression)) - - def test_is_constant_boolean(self): - self.assertTrue(ir_util.is_constant(_parse_expression("true"))) - expression = _parse_expression("true") - # The type information should be ignored for constants like this one. - ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType()) - self.assertTrue(ir_util.is_constant(expression)) - - def test_is_constant_enum(self): - self.assertTrue(ir_util.is_constant(ir_data.Expression( - constant_reference=ir_data.Reference(), - type=ir_data.ExpressionType(enumeration=ir_data.EnumType(value="12"))))) - - def test_is_constant_integer_type(self): - self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType( - integer=ir_data.IntegerType( - modulus="10", - modular_value="5", - minimum_value="-5", - maximum_value="15")))) - self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType( - integer=ir_data.IntegerType( - modulus="infinity", - modular_value="5", - minimum_value="5", - maximum_value="5")))) - - def test_is_constant_boolean_type(self): - self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType( - boolean=ir_data.BooleanType()))) - self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType( - boolean=ir_data.BooleanType(value=True)))) - self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType( - boolean=ir_data.BooleanType(value=False)))) - - def test_is_constant_enumeration_type(self): - self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType( - enumeration=ir_data.EnumType()))) - self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType( - enumeration=ir_data.EnumType(value="0")))) - - def test_is_constant_opaque_type(self): - self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType( - opaque=ir_data.OpaqueType()))) - - def test_constant_value_of_integer(self): - self.assertEqual(6, ir_util.constant_value(_parse_expression("6"))) - - def test_constant_value_of_none(self): - self.assertIsNone(ir_util.constant_value(ir_data.Expression())) - - def test_constant_value_of_addition(self): - self.assertEqual(6, ir_util.constant_value(_parse_expression("2+4"))) - - def test_constant_value_of_subtraction(self): - self.assertEqual(-2, ir_util.constant_value(_parse_expression("2-4"))) - - def test_constant_value_of_multiplication(self): - self.assertEqual(8, ir_util.constant_value(_parse_expression("2*4"))) - - def test_constant_value_of_equality(self): - self.assertFalse(ir_util.constant_value(_parse_expression("2 == 4"))) - - def test_constant_value_of_inequality(self): - self.assertTrue(ir_util.constant_value(_parse_expression("2 != 4"))) - - def test_constant_value_of_less(self): - self.assertTrue(ir_util.constant_value(_parse_expression("2 < 4"))) - - def test_constant_value_of_less_or_equal(self): - self.assertTrue(ir_util.constant_value(_parse_expression("2 <= 4"))) - - def test_constant_value_of_greater(self): - self.assertFalse(ir_util.constant_value(_parse_expression("2 > 4"))) - - def test_constant_value_of_greater_or_equal(self): - self.assertFalse(ir_util.constant_value(_parse_expression("2 >= 4"))) - - def test_constant_value_of_and(self): - self.assertFalse(ir_util.constant_value(_parse_expression("true && false"))) - self.assertTrue(ir_util.constant_value(_parse_expression("true && true"))) - - def test_constant_value_of_or(self): - self.assertTrue(ir_util.constant_value(_parse_expression("true || false"))) - self.assertFalse( - ir_util.constant_value(_parse_expression("false || false"))) - - def test_constant_value_of_choice(self): - self.assertEqual( - 10, ir_util.constant_value(_parse_expression("false ? 20 : 10"))) - self.assertEqual( - 20, ir_util.constant_value(_parse_expression("true ? 20 : 10"))) - - def test_constant_value_of_choice_with_unknown_other_branch(self): - self.assertEqual( - 10, ir_util.constant_value(_parse_expression("false ? foo : 10"))) - self.assertEqual( - 20, ir_util.constant_value(_parse_expression("true ? 20 : foo"))) - - def test_constant_value_of_maximum(self): - self.assertEqual(10, - ir_util.constant_value(_parse_expression("$max(5, 10)"))) - self.assertEqual(10, - ir_util.constant_value(_parse_expression("$max(10)"))) - self.assertEqual( - 10, - ir_util.constant_value(_parse_expression("$max(5, 9, 7, 10, 6, -100)"))) - - def test_constant_value_of_boolean(self): - self.assertTrue(ir_util.constant_value(_parse_expression("true"))) - self.assertFalse(ir_util.constant_value(_parse_expression("false"))) - - def test_constant_value_of_enum(self): - self.assertEqual(12, ir_util.constant_value(ir_data.Expression( - constant_reference=ir_data.Reference(), - type=ir_data.ExpressionType(enumeration=ir_data.EnumType(value="12"))))) - - def test_constant_value_of_integer_reference(self): - self.assertEqual(12, ir_util.constant_value(ir_data.Expression( - constant_reference=ir_data.Reference(), - type=ir_data.ExpressionType( - integer=ir_data.IntegerType(modulus="infinity", - modular_value="12"))))) - - def test_constant_value_of_boolean_reference(self): - self.assertTrue(ir_util.constant_value(ir_data.Expression( - constant_reference=ir_data.Reference(), - type=ir_data.ExpressionType(boolean=ir_data.BooleanType(value=True))))) - - def test_constant_value_of_builtin_reference(self): - self.assertEqual(12, ir_util.constant_value( - ir_data.Expression( - builtin_reference=ir_data.Reference( - canonical_name=ir_data.CanonicalName(object_path=["$foo"]))), - {"$foo": 12})) - - def test_constant_value_of_field_reference(self): - self.assertIsNone(ir_util.constant_value(_parse_expression("foo"))) - - def test_constant_value_of_missing_builtin_reference(self): - self.assertIsNone(ir_util.constant_value( - _parse_expression("$static_size_in_bits"), {"$bar": 12})) - - def test_constant_value_of_present_builtin_reference(self): - self.assertEqual(12, ir_util.constant_value( - _parse_expression("$static_size_in_bits"), - {"$static_size_in_bits": 12})) - - def test_constant_false_value_of_operator_and_with_missing_value(self): - self.assertIs(False, ir_util.constant_value( - _parse_expression("false && foo"), {"bar": 12})) - self.assertIs(False, ir_util.constant_value( - _parse_expression("foo && false"), {"bar": 12})) - - def test_constant_false_value_of_operator_and_known_value(self): - self.assertTrue(ir_util.constant_value( - _parse_expression("true && $is_statically_sized"), - {"$is_statically_sized": True})) - - def test_constant_none_value_of_operator_and_with_missing_value(self): - self.assertIsNone(ir_util.constant_value( - _parse_expression("true && foo"), {"bar": 12})) - self.assertIsNone(ir_util.constant_value( - _parse_expression("foo && true"), {"bar": 12})) - - def test_constant_false_value_of_operator_or_with_missing_value(self): - self.assertTrue(ir_util.constant_value( - _parse_expression("true || foo"), {"bar": 12})) - self.assertTrue(ir_util.constant_value( - _parse_expression("foo || true"), {"bar": 12})) - - def test_constant_none_value_of_operator_or_with_missing_value(self): - self.assertIsNone(ir_util.constant_value( - _parse_expression("foo || false"), {"bar": 12})) - self.assertIsNone(ir_util.constant_value( - _parse_expression("false || foo"), {"bar": 12})) - - def test_constant_value_of_operator_plus_with_missing_value(self): - self.assertIsNone(ir_util.constant_value( - _parse_expression("12 + foo"), {"bar": 12})) - - def test_is_array(self): - self.assertTrue( - ir_util.is_array(ir_data.Type(array_type=ir_data.ArrayType()))) - self.assertFalse( - ir_util.is_array(ir_data.Type(atomic_type=ir_data.AtomicType()))) - - def test_get_attribute(self): - type_def = ir_data.TypeDefinition(attribute=[ - ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word(text="bob"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("true")), - name=ir_data.Word(text="bob")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word(text="bob2")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("true")), - name=ir_data.Word(text="bob2"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word(text="bob3"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word()), - ]) - self.assertEqual( - ir_data.AttributeValue(expression=_parse_expression("true")), - ir_util.get_attribute(type_def.attribute, "bob")) - self.assertEqual( - ir_data.AttributeValue(expression=_parse_expression("false")), - ir_util.get_attribute(type_def.attribute, "bob2")) - self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "Bob")) - self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "bob3")) - - def test_get_boolean_attribute(self): - type_def = ir_data.TypeDefinition(attribute=[ - ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word(text="bob"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("true")), - name=ir_data.Word(text="bob")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word(text="bob2")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("true")), - name=ir_data.Word(text="bob2"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word(text="bob3"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word()), - ]) - self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute, "bob")) - self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute, - "bob", - default_value=False)) - self.assertFalse(ir_util.get_boolean_attribute(type_def.attribute, "bob2")) - self.assertFalse(ir_util.get_boolean_attribute(type_def.attribute, - "bob2", - default_value=True)) - self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "Bob")) - self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute, - "Bob", - default_value=True)) - self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "bob3")) - - def test_get_integer_attribute(self): - type_def = ir_data.TypeDefinition(attribute=[ - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - type=ir_data.ExpressionType(integer=ir_data.IntegerType()))), - name=ir_data.Word(text="phil")), - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - constant=ir_data.NumericConstant(value="20"), - type=ir_data.ExpressionType(integer=ir_data.IntegerType( - modular_value="20", - modulus="infinity")))), - name=ir_data.Word(text="bob"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - constant=ir_data.NumericConstant(value="10"), - type=ir_data.ExpressionType(integer=ir_data.IntegerType( - modular_value="10", - modulus="infinity")))), - name=ir_data.Word(text="bob")), - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - constant=ir_data.NumericConstant(value="5"), - type=ir_data.ExpressionType(integer=ir_data.IntegerType( + """Tests for the miscellaneous utility functions in ir_util.py.""" + + def test_is_constant_integer(self): + self.assertTrue(ir_util.is_constant(_parse_expression("6"))) + expression = _parse_expression("12") + # The type information should be ignored for constants like this one. + ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType()) + self.assertTrue(ir_util.is_constant(expression)) + + def test_is_constant_boolean(self): + self.assertTrue(ir_util.is_constant(_parse_expression("true"))) + expression = _parse_expression("true") + # The type information should be ignored for constants like this one. + ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType()) + self.assertTrue(ir_util.is_constant(expression)) + + def test_is_constant_enum(self): + self.assertTrue( + ir_util.is_constant( + ir_data.Expression( + constant_reference=ir_data.Reference(), + type=ir_data.ExpressionType( + enumeration=ir_data.EnumType(value="12") + ), + ) + ) + ) + + def test_is_constant_integer_type(self): + self.assertFalse( + ir_util.is_constant_type( + ir_data.ExpressionType( + integer=ir_data.IntegerType( + modulus="10", modular_value="5", - modulus="infinity")))), - name=ir_data.Word(text="bob2")), - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - constant=ir_data.NumericConstant(value="0"), - type=ir_data.ExpressionType(integer=ir_data.IntegerType( - modular_value="0", - modulus="infinity")))), - name=ir_data.Word(text="bob2"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - constant=ir_data.NumericConstant(value="30"), - type=ir_data.ExpressionType(integer=ir_data.IntegerType( - modular_value="30", - modulus="infinity")))), - name=ir_data.Word(text="bob3"), - is_default=True), - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - function=ir_data.Function( - function=ir_data.FunctionMapping.ADDITION, - args=[ - ir_data.Expression( - constant=ir_data.NumericConstant(value="100"), - type=ir_data.ExpressionType( - integer=ir_data.IntegerType( - modular_value="100", - modulus="infinity"))), - ir_data.Expression( - constant=ir_data.NumericConstant(value="100"), - type=ir_data.ExpressionType( - integer=ir_data.IntegerType( - modular_value="100", - modulus="infinity"))) - ]), - type=ir_data.ExpressionType(integer=ir_data.IntegerType( - modular_value="200", - modulus="infinity")))), - name=ir_data.Word(text="bob4")), - ir_data.Attribute( - value=ir_data.AttributeValue( - expression=ir_data.Expression( - constant=ir_data.NumericConstant(value="40"), - type=ir_data.ExpressionType(integer=ir_data.IntegerType( - modular_value="40", - modulus="infinity")))), - name=ir_data.Word()), - ]) - self.assertEqual(10, - ir_util.get_integer_attribute(type_def.attribute, "bob")) - self.assertEqual(5, - ir_util.get_integer_attribute(type_def.attribute, "bob2")) - self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "Bob")) - self.assertEqual(10, ir_util.get_integer_attribute(type_def.attribute, - "Bob", - default_value=10)) - self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "bob3")) - self.assertEqual(200, ir_util.get_integer_attribute(type_def.attribute, - "bob4")) - - def test_get_duplicate_attribute(self): - type_def = ir_data.TypeDefinition(attribute=[ - ir_data.Attribute( - value=ir_data.AttributeValue(expression=ir_data.Expression()), - name=ir_data.Word(text="phil")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("true")), - name=ir_data.Word(text="bob")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word(text="bob")), - ir_data.Attribute( - value=ir_data.AttributeValue(expression=_parse_expression("false")), - name=ir_data.Word()), - ]) - self.assertRaises(AssertionError, ir_util.get_attribute, type_def.attribute, - "bob") - - def test_find_object(self): - ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, - """{ + minimum_value="-5", + maximum_value="15", + ) + ) + ) + ) + self.assertTrue( + ir_util.is_constant_type( + ir_data.ExpressionType( + integer=ir_data.IntegerType( + modulus="infinity", + modular_value="5", + minimum_value="5", + maximum_value="5", + ) + ) + ) + ) + + def test_is_constant_boolean_type(self): + self.assertFalse( + ir_util.is_constant_type( + ir_data.ExpressionType(boolean=ir_data.BooleanType()) + ) + ) + self.assertTrue( + ir_util.is_constant_type( + ir_data.ExpressionType(boolean=ir_data.BooleanType(value=True)) + ) + ) + self.assertTrue( + ir_util.is_constant_type( + ir_data.ExpressionType(boolean=ir_data.BooleanType(value=False)) + ) + ) + + def test_is_constant_enumeration_type(self): + self.assertFalse( + ir_util.is_constant_type( + ir_data.ExpressionType(enumeration=ir_data.EnumType()) + ) + ) + self.assertTrue( + ir_util.is_constant_type( + ir_data.ExpressionType(enumeration=ir_data.EnumType(value="0")) + ) + ) + + def test_is_constant_opaque_type(self): + self.assertFalse( + ir_util.is_constant_type( + ir_data.ExpressionType(opaque=ir_data.OpaqueType()) + ) + ) + + def test_constant_value_of_integer(self): + self.assertEqual(6, ir_util.constant_value(_parse_expression("6"))) + + def test_constant_value_of_none(self): + self.assertIsNone(ir_util.constant_value(ir_data.Expression())) + + def test_constant_value_of_addition(self): + self.assertEqual(6, ir_util.constant_value(_parse_expression("2+4"))) + + def test_constant_value_of_subtraction(self): + self.assertEqual(-2, ir_util.constant_value(_parse_expression("2-4"))) + + def test_constant_value_of_multiplication(self): + self.assertEqual(8, ir_util.constant_value(_parse_expression("2*4"))) + + def test_constant_value_of_equality(self): + self.assertFalse(ir_util.constant_value(_parse_expression("2 == 4"))) + + def test_constant_value_of_inequality(self): + self.assertTrue(ir_util.constant_value(_parse_expression("2 != 4"))) + + def test_constant_value_of_less(self): + self.assertTrue(ir_util.constant_value(_parse_expression("2 < 4"))) + + def test_constant_value_of_less_or_equal(self): + self.assertTrue(ir_util.constant_value(_parse_expression("2 <= 4"))) + + def test_constant_value_of_greater(self): + self.assertFalse(ir_util.constant_value(_parse_expression("2 > 4"))) + + def test_constant_value_of_greater_or_equal(self): + self.assertFalse(ir_util.constant_value(_parse_expression("2 >= 4"))) + + def test_constant_value_of_and(self): + self.assertFalse(ir_util.constant_value(_parse_expression("true && false"))) + self.assertTrue(ir_util.constant_value(_parse_expression("true && true"))) + + def test_constant_value_of_or(self): + self.assertTrue(ir_util.constant_value(_parse_expression("true || false"))) + self.assertFalse(ir_util.constant_value(_parse_expression("false || false"))) + + def test_constant_value_of_choice(self): + self.assertEqual( + 10, ir_util.constant_value(_parse_expression("false ? 20 : 10")) + ) + self.assertEqual( + 20, ir_util.constant_value(_parse_expression("true ? 20 : 10")) + ) + + def test_constant_value_of_choice_with_unknown_other_branch(self): + self.assertEqual( + 10, ir_util.constant_value(_parse_expression("false ? foo : 10")) + ) + self.assertEqual( + 20, ir_util.constant_value(_parse_expression("true ? 20 : foo")) + ) + + def test_constant_value_of_maximum(self): + self.assertEqual(10, ir_util.constant_value(_parse_expression("$max(5, 10)"))) + self.assertEqual(10, ir_util.constant_value(_parse_expression("$max(10)"))) + self.assertEqual( + 10, ir_util.constant_value(_parse_expression("$max(5, 9, 7, 10, 6, -100)")) + ) + + def test_constant_value_of_boolean(self): + self.assertTrue(ir_util.constant_value(_parse_expression("true"))) + self.assertFalse(ir_util.constant_value(_parse_expression("false"))) + + def test_constant_value_of_enum(self): + self.assertEqual( + 12, + ir_util.constant_value( + ir_data.Expression( + constant_reference=ir_data.Reference(), + type=ir_data.ExpressionType( + enumeration=ir_data.EnumType(value="12") + ), + ) + ), + ) + + def test_constant_value_of_integer_reference(self): + self.assertEqual( + 12, + ir_util.constant_value( + ir_data.Expression( + constant_reference=ir_data.Reference(), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modulus="infinity", modular_value="12" + ) + ), + ) + ), + ) + + def test_constant_value_of_boolean_reference(self): + self.assertTrue( + ir_util.constant_value( + ir_data.Expression( + constant_reference=ir_data.Reference(), + type=ir_data.ExpressionType( + boolean=ir_data.BooleanType(value=True) + ), + ) + ) + ) + + def test_constant_value_of_builtin_reference(self): + self.assertEqual( + 12, + ir_util.constant_value( + ir_data.Expression( + builtin_reference=ir_data.Reference( + canonical_name=ir_data.CanonicalName(object_path=["$foo"]) + ) + ), + {"$foo": 12}, + ), + ) + + def test_constant_value_of_field_reference(self): + self.assertIsNone(ir_util.constant_value(_parse_expression("foo"))) + + def test_constant_value_of_missing_builtin_reference(self): + self.assertIsNone( + ir_util.constant_value( + _parse_expression("$static_size_in_bits"), {"$bar": 12} + ) + ) + + def test_constant_value_of_present_builtin_reference(self): + self.assertEqual( + 12, + ir_util.constant_value( + _parse_expression("$static_size_in_bits"), {"$static_size_in_bits": 12} + ), + ) + + def test_constant_false_value_of_operator_and_with_missing_value(self): + self.assertIs( + False, + ir_util.constant_value(_parse_expression("false && foo"), {"bar": 12}), + ) + self.assertIs( + False, + ir_util.constant_value(_parse_expression("foo && false"), {"bar": 12}), + ) + + def test_constant_false_value_of_operator_and_known_value(self): + self.assertTrue( + ir_util.constant_value( + _parse_expression("true && $is_statically_sized"), + {"$is_statically_sized": True}, + ) + ) + + def test_constant_none_value_of_operator_and_with_missing_value(self): + self.assertIsNone( + ir_util.constant_value(_parse_expression("true && foo"), {"bar": 12}) + ) + self.assertIsNone( + ir_util.constant_value(_parse_expression("foo && true"), {"bar": 12}) + ) + + def test_constant_false_value_of_operator_or_with_missing_value(self): + self.assertTrue( + ir_util.constant_value(_parse_expression("true || foo"), {"bar": 12}) + ) + self.assertTrue( + ir_util.constant_value(_parse_expression("foo || true"), {"bar": 12}) + ) + + def test_constant_none_value_of_operator_or_with_missing_value(self): + self.assertIsNone( + ir_util.constant_value(_parse_expression("foo || false"), {"bar": 12}) + ) + self.assertIsNone( + ir_util.constant_value(_parse_expression("false || foo"), {"bar": 12}) + ) + + def test_constant_value_of_operator_plus_with_missing_value(self): + self.assertIsNone( + ir_util.constant_value(_parse_expression("12 + foo"), {"bar": 12}) + ) + + def test_is_array(self): + self.assertTrue(ir_util.is_array(ir_data.Type(array_type=ir_data.ArrayType()))) + self.assertFalse( + ir_util.is_array(ir_data.Type(atomic_type=ir_data.AtomicType())) + ) + + def test_get_attribute(self): + type_def = ir_data.TypeDefinition( + attribute=[ + ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(text="bob"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("true")), + name=ir_data.Word(text="bob"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(text="bob2"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("true")), + name=ir_data.Word(text="bob2"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(text="bob3"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(), + ), + ] + ) + self.assertEqual( + ir_data.AttributeValue(expression=_parse_expression("true")), + ir_util.get_attribute(type_def.attribute, "bob"), + ) + self.assertEqual( + ir_data.AttributeValue(expression=_parse_expression("false")), + ir_util.get_attribute(type_def.attribute, "bob2"), + ) + self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "Bob")) + self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "bob3")) + + def test_get_boolean_attribute(self): + type_def = ir_data.TypeDefinition( + attribute=[ + ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(text="bob"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("true")), + name=ir_data.Word(text="bob"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(text="bob2"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("true")), + name=ir_data.Word(text="bob2"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(text="bob3"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(), + ), + ] + ) + self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute, "bob")) + self.assertTrue( + ir_util.get_boolean_attribute( + type_def.attribute, "bob", default_value=False + ) + ) + self.assertFalse(ir_util.get_boolean_attribute(type_def.attribute, "bob2")) + self.assertFalse( + ir_util.get_boolean_attribute( + type_def.attribute, "bob2", default_value=True + ) + ) + self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "Bob")) + self.assertTrue( + ir_util.get_boolean_attribute(type_def.attribute, "Bob", default_value=True) + ) + self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "bob3")) + + def test_get_integer_attribute(self): + type_def = ir_data.TypeDefinition( + attribute=[ + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + type=ir_data.ExpressionType(integer=ir_data.IntegerType()) + ) + ), + name=ir_data.Word(text="phil"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + constant=ir_data.NumericConstant(value="20"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="20", modulus="infinity" + ) + ), + ) + ), + name=ir_data.Word(text="bob"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + constant=ir_data.NumericConstant(value="10"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="10", modulus="infinity" + ) + ), + ) + ), + name=ir_data.Word(text="bob"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + constant=ir_data.NumericConstant(value="5"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="5", modulus="infinity" + ) + ), + ) + ), + name=ir_data.Word(text="bob2"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + constant=ir_data.NumericConstant(value="0"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="0", modulus="infinity" + ) + ), + ) + ), + name=ir_data.Word(text="bob2"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + constant=ir_data.NumericConstant(value="30"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="30", modulus="infinity" + ) + ), + ) + ), + name=ir_data.Word(text="bob3"), + is_default=True, + ), + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + function=ir_data.Function( + function=ir_data.FunctionMapping.ADDITION, + args=[ + ir_data.Expression( + constant=ir_data.NumericConstant(value="100"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="100", modulus="infinity" + ) + ), + ), + ir_data.Expression( + constant=ir_data.NumericConstant(value="100"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="100", modulus="infinity" + ) + ), + ), + ], + ), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="200", modulus="infinity" + ) + ), + ) + ), + name=ir_data.Word(text="bob4"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue( + expression=ir_data.Expression( + constant=ir_data.NumericConstant(value="40"), + type=ir_data.ExpressionType( + integer=ir_data.IntegerType( + modular_value="40", modulus="infinity" + ) + ), + ) + ), + name=ir_data.Word(), + ), + ] + ) + self.assertEqual(10, ir_util.get_integer_attribute(type_def.attribute, "bob")) + self.assertEqual(5, ir_util.get_integer_attribute(type_def.attribute, "bob2")) + self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "Bob")) + self.assertEqual( + 10, + ir_util.get_integer_attribute(type_def.attribute, "Bob", default_value=10), + ) + self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "bob3")) + self.assertEqual(200, ir_util.get_integer_attribute(type_def.attribute, "bob4")) + + def test_get_duplicate_attribute(self): + type_def = ir_data.TypeDefinition( + attribute=[ + ir_data.Attribute( + value=ir_data.AttributeValue(expression=ir_data.Expression()), + name=ir_data.Word(text="phil"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("true")), + name=ir_data.Word(text="bob"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(text="bob"), + ), + ir_data.Attribute( + value=ir_data.AttributeValue(expression=_parse_expression("false")), + name=ir_data.Word(), + ), + ] + ) + self.assertRaises( + AssertionError, ir_util.get_attribute, type_def.attribute, "bob" + ) + + def test_find_object(self): + ir = ir_data_utils.IrDataSerializer.from_json( + ir_data.EmbossIr, + """{ "module": [ { "type": [ @@ -490,83 +652,128 @@ def test_find_object(self): "source_file_name": "" } ] - }""") - - # Test that find_object works with any of its four "name" types. - canonical_name_of_foo = ir_data.CanonicalName(module_file="test.emb", - object_path=["Foo"]) - self.assertEqual(ir.module[0].type[0], ir_util.find_object( - ir_data.Reference(canonical_name=canonical_name_of_foo), ir)) - self.assertEqual(ir.module[0].type[0], ir_util.find_object( - ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir)) - self.assertEqual(ir.module[0].type[0], - ir_util.find_object(canonical_name_of_foo, ir)) - self.assertEqual(ir.module[0].type[0], - ir_util.find_object(("test.emb", "Foo"), ir)) - - # Test that find_object works with objects other than structs. - self.assertEqual(ir.module[0].type[1], - ir_util.find_object(("test.emb", "Bar"), ir)) - self.assertEqual(ir.module[1].type[0], - ir_util.find_object(("", "UInt"), ir)) - self.assertEqual(ir.module[0].type[0].structure.field[0], - ir_util.find_object(("test.emb", "Foo", "field"), ir)) - self.assertEqual(ir.module[0].type[0].runtime_parameter[0], - ir_util.find_object(("test.emb", "Foo", "parameter"), ir)) - self.assertEqual(ir.module[0].type[1].enumeration.value[0], - ir_util.find_object(("test.emb", "Bar", "QUX"), ir)) - self.assertEqual(ir.module[0], ir_util.find_object(("test.emb",), ir)) - self.assertEqual(ir.module[1], ir_util.find_object(("",), ir)) - - # Test searching for non-present objects. - self.assertIsNone(ir_util.find_object_or_none(("not_test.emb",), ir)) - self.assertIsNone(ir_util.find_object_or_none(("test.emb", "NotFoo"), ir)) - self.assertIsNone( - ir_util.find_object_or_none(("test.emb", "Foo", "not_field"), ir)) - self.assertIsNone( - ir_util.find_object_or_none(("test.emb", "Foo", "field", "no_subfield"), - ir)) - self.assertIsNone( - ir_util.find_object_or_none(("test.emb", "Bar", "NOT_QUX"), ir)) - self.assertIsNone( - ir_util.find_object_or_none(("test.emb", "Bar", "QUX", "no_subenum"), - ir)) - - # Test that find_parent_object works with any of its four "name" types. - self.assertEqual(ir.module[0], ir_util.find_parent_object( - ir_data.Reference(canonical_name=canonical_name_of_foo), ir)) - self.assertEqual(ir.module[0], ir_util.find_parent_object( - ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir)) - self.assertEqual(ir.module[0], - ir_util.find_parent_object(canonical_name_of_foo, ir)) - self.assertEqual(ir.module[0], - ir_util.find_parent_object(("test.emb", "Foo"), ir)) - - # Test that find_parent_object works with objects other than structs. - self.assertEqual(ir.module[0], - ir_util.find_parent_object(("test.emb", "Bar"), ir)) - self.assertEqual(ir.module[1], ir_util.find_parent_object(("", "UInt"), ir)) - self.assertEqual(ir.module[0].type[0], - ir_util.find_parent_object(("test.emb", "Foo", "field"), - ir)) - self.assertEqual(ir.module[0].type[1], - ir_util.find_parent_object(("test.emb", "Bar", "QUX"), ir)) - - def test_hashable_form_of_reference(self): - self.assertEqual( - ("t.emb", "Foo", "Bar"), - ir_util.hashable_form_of_reference(ir_data.Reference( - canonical_name=ir_data.CanonicalName(module_file="t.emb", - object_path=["Foo", "Bar"])))) - self.assertEqual( - ("t.emb", "Foo", "Bar"), - ir_util.hashable_form_of_reference(ir_data.NameDefinition( - canonical_name=ir_data.CanonicalName(module_file="t.emb", - object_path=["Foo", "Bar"])))) - - def test_get_base_type(self): - array_type_ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + }""", + ) + + # Test that find_object works with any of its four "name" types. + canonical_name_of_foo = ir_data.CanonicalName( + module_file="test.emb", object_path=["Foo"] + ) + self.assertEqual( + ir.module[0].type[0], + ir_util.find_object( + ir_data.Reference(canonical_name=canonical_name_of_foo), ir + ), + ) + self.assertEqual( + ir.module[0].type[0], + ir_util.find_object( + ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir + ), + ) + self.assertEqual( + ir.module[0].type[0], ir_util.find_object(canonical_name_of_foo, ir) + ) + self.assertEqual( + ir.module[0].type[0], ir_util.find_object(("test.emb", "Foo"), ir) + ) + + # Test that find_object works with objects other than structs. + self.assertEqual( + ir.module[0].type[1], ir_util.find_object(("test.emb", "Bar"), ir) + ) + self.assertEqual(ir.module[1].type[0], ir_util.find_object(("", "UInt"), ir)) + self.assertEqual( + ir.module[0].type[0].structure.field[0], + ir_util.find_object(("test.emb", "Foo", "field"), ir), + ) + self.assertEqual( + ir.module[0].type[0].runtime_parameter[0], + ir_util.find_object(("test.emb", "Foo", "parameter"), ir), + ) + self.assertEqual( + ir.module[0].type[1].enumeration.value[0], + ir_util.find_object(("test.emb", "Bar", "QUX"), ir), + ) + self.assertEqual(ir.module[0], ir_util.find_object(("test.emb",), ir)) + self.assertEqual(ir.module[1], ir_util.find_object(("",), ir)) + + # Test searching for non-present objects. + self.assertIsNone(ir_util.find_object_or_none(("not_test.emb",), ir)) + self.assertIsNone(ir_util.find_object_or_none(("test.emb", "NotFoo"), ir)) + self.assertIsNone( + ir_util.find_object_or_none(("test.emb", "Foo", "not_field"), ir) + ) + self.assertIsNone( + ir_util.find_object_or_none(("test.emb", "Foo", "field", "no_subfield"), ir) + ) + self.assertIsNone( + ir_util.find_object_or_none(("test.emb", "Bar", "NOT_QUX"), ir) + ) + self.assertIsNone( + ir_util.find_object_or_none(("test.emb", "Bar", "QUX", "no_subenum"), ir) + ) + + # Test that find_parent_object works with any of its four "name" types. + self.assertEqual( + ir.module[0], + ir_util.find_parent_object( + ir_data.Reference(canonical_name=canonical_name_of_foo), ir + ), + ) + self.assertEqual( + ir.module[0], + ir_util.find_parent_object( + ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir + ), + ) + self.assertEqual( + ir.module[0], ir_util.find_parent_object(canonical_name_of_foo, ir) + ) + self.assertEqual( + ir.module[0], ir_util.find_parent_object(("test.emb", "Foo"), ir) + ) + + # Test that find_parent_object works with objects other than structs. + self.assertEqual( + ir.module[0], ir_util.find_parent_object(("test.emb", "Bar"), ir) + ) + self.assertEqual(ir.module[1], ir_util.find_parent_object(("", "UInt"), ir)) + self.assertEqual( + ir.module[0].type[0], + ir_util.find_parent_object(("test.emb", "Foo", "field"), ir), + ) + self.assertEqual( + ir.module[0].type[1], + ir_util.find_parent_object(("test.emb", "Bar", "QUX"), ir), + ) + + def test_hashable_form_of_reference(self): + self.assertEqual( + ("t.emb", "Foo", "Bar"), + ir_util.hashable_form_of_reference( + ir_data.Reference( + canonical_name=ir_data.CanonicalName( + module_file="t.emb", object_path=["Foo", "Bar"] + ) + ) + ), + ) + self.assertEqual( + ("t.emb", "Foo", "Bar"), + ir_util.hashable_form_of_reference( + ir_data.NameDefinition( + canonical_name=ir_data.CanonicalName( + module_file="t.emb", object_path=["Foo", "Bar"] + ) + ) + ), + ) + + def test_get_base_type(self): + array_type_ir = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "array_type": { "element_count": { "constant": { "value": "20" } }, "base_type": { @@ -583,16 +790,19 @@ def test_get_base_type(self): }, "source_location": { "start": { "line": 3 } } } - }""") - base_type_ir = array_type_ir.array_type.base_type.array_type.base_type - self.assertEqual(base_type_ir, ir_util.get_base_type(array_type_ir)) - self.assertEqual(base_type_ir, ir_util.get_base_type( - array_type_ir.array_type.base_type)) - self.assertEqual(base_type_ir, ir_util.get_base_type(base_type_ir)) - - def test_size_of_type_in_bits(self): - ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, - """{ + }""", + ) + base_type_ir = array_type_ir.array_type.base_type.array_type.base_type + self.assertEqual(base_type_ir, ir_util.get_base_type(array_type_ir)) + self.assertEqual( + base_type_ir, ir_util.get_base_type(array_type_ir.array_type.base_type) + ) + self.assertEqual(base_type_ir, ir_util.get_base_type(base_type_ir)) + + def test_size_of_type_in_bits(self): + ir = ir_data_utils.IrDataSerializer.from_json( + ir_data.EmbossIr, + """{ "module": [{ "type": [{ "name": { @@ -637,20 +847,24 @@ def test_size_of_type_in_bits(self): }], "source_file_name": "" }] - }""") + }""", + ) - fixed_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + fixed_size_type = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "atomic_type": { "reference": { "canonical_name": { "module_file": "", "object_path": ["Byte"] } } } - }""") - self.assertEqual(8, ir_util.fixed_size_of_type_in_bits(fixed_size_type, ir)) + }""", + ) + self.assertEqual(8, ir_util.fixed_size_of_type_in_bits(fixed_size_type, ir)) - explicit_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + explicit_size_type = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "atomic_type": { "reference": { "canonical_name": { "module_file": "", "object_path": ["UInt"] } @@ -662,12 +876,13 @@ def test_size_of_type_in_bits(self): "integer": { "modular_value": "32", "modulus": "infinity" } } } - }""") - self.assertEqual(32, - ir_util.fixed_size_of_type_in_bits(explicit_size_type, ir)) + }""", + ) + self.assertEqual(32, ir_util.fixed_size_of_type_in_bits(explicit_size_type, ir)) - fixed_size_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + fixed_size_array = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "array_type": { "base_type": { "atomic_type": { @@ -683,12 +898,13 @@ def test_size_of_type_in_bits(self): } } } - }""") - self.assertEqual(40, - ir_util.fixed_size_of_type_in_bits(fixed_size_array, ir)) + }""", + ) + self.assertEqual(40, ir_util.fixed_size_of_type_in_bits(fixed_size_array, ir)) - fixed_size_2d_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + fixed_size_2d_array = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "array_type": { "base_type": { "array_type": { @@ -717,12 +933,15 @@ def test_size_of_type_in_bits(self): } } } - }""") - self.assertEqual( - 80, ir_util.fixed_size_of_type_in_bits(fixed_size_2d_array, ir)) - - automatic_size_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + }""", + ) + self.assertEqual( + 80, ir_util.fixed_size_of_type_in_bits(fixed_size_2d_array, ir) + ) + + automatic_size_array = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "array_type": { "base_type": { "array_type": { @@ -746,23 +965,25 @@ def test_size_of_type_in_bits(self): }, "automatic": { } } - }""") - self.assertIsNone( - ir_util.fixed_size_of_type_in_bits(automatic_size_array, ir)) + }""", + ) + self.assertIsNone(ir_util.fixed_size_of_type_in_bits(automatic_size_array, ir)) - variable_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + variable_size_type = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "atomic_type": { "reference": { "canonical_name": { "module_file": "", "object_path": ["UInt"] } } } - }""") - self.assertIsNone( - ir_util.fixed_size_of_type_in_bits(variable_size_type, ir)) + }""", + ) + self.assertIsNone(ir_util.fixed_size_of_type_in_bits(variable_size_type, ir)) - no_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type, - """{ + no_size_type = ir_data_utils.IrDataSerializer.from_json( + ir_data.Type, + """{ "atomic_type": { "reference": { "canonical_name": { @@ -771,26 +992,35 @@ def test_size_of_type_in_bits(self): } } } - }""") - self.assertIsNone(ir_util.fixed_size_of_type_in_bits(no_size_type, ir)) - - def test_field_is_virtual(self): - self.assertTrue(ir_util.field_is_virtual(ir_data.Field())) - - def test_field_is_not_virtual(self): - self.assertFalse(ir_util.field_is_virtual( - ir_data.Field(location=ir_data.FieldLocation()))) - - def test_field_is_read_only(self): - self.assertTrue(ir_util.field_is_read_only(ir_data.Field( - write_method=ir_data.WriteMethod(read_only=True)))) - - def test_field_is_not_read_only(self): - self.assertFalse(ir_util.field_is_read_only( - ir_data.Field(location=ir_data.FieldLocation()))) - self.assertFalse(ir_util.field_is_read_only(ir_data.Field( - write_method=ir_data.WriteMethod()))) + }""", + ) + self.assertIsNone(ir_util.fixed_size_of_type_in_bits(no_size_type, ir)) + + def test_field_is_virtual(self): + self.assertTrue(ir_util.field_is_virtual(ir_data.Field())) + + def test_field_is_not_virtual(self): + self.assertFalse( + ir_util.field_is_virtual(ir_data.Field(location=ir_data.FieldLocation())) + ) + + def test_field_is_read_only(self): + self.assertTrue( + ir_util.field_is_read_only( + ir_data.Field(write_method=ir_data.WriteMethod(read_only=True)) + ) + ) + + def test_field_is_not_read_only(self): + self.assertFalse( + ir_util.field_is_read_only(ir_data.Field(location=ir_data.FieldLocation())) + ) + self.assertFalse( + ir_util.field_is_read_only( + ir_data.Field(write_method=ir_data.WriteMethod()) + ) + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/name_conversion.py b/compiler/util/name_conversion.py index 3f32809..b13d667 100644 --- a/compiler/util/name_conversion.py +++ b/compiler/util/name_conversion.py @@ -18,10 +18,10 @@ class Case(str, Enum): - SNAKE = "snake_case" - SHOUTY = "SHOUTY_CASE" - CAMEL = "CamelCase" - K_CAMEL = "kCamelCase" + SNAKE = "snake_case" + SHOUTY = "SHOUTY_CASE" + CAMEL = "CamelCase" + K_CAMEL = "kCamelCase" # Map of (from, to) cases to their conversion function. Initially only contains @@ -31,41 +31,42 @@ class Case(str, Enum): def _case_conversion(case_from, case_to): - """Decorator to dynamically dispatch case conversions at runtime.""" - def _func(f): - _case_conversions[case_from, case_to] = f - return f + """Decorator to dynamically dispatch case conversions at runtime.""" - return _func + def _func(f): + _case_conversions[case_from, case_to] = f + return f + + return _func @_case_conversion(Case.SNAKE, Case.CAMEL) @_case_conversion(Case.SHOUTY, Case.CAMEL) def snake_to_camel(name): - """Convert from snake_case to CamelCase. Also works from SHOUTY_CASE.""" - return "".join(word.capitalize() for word in name.split("_")) + """Convert from snake_case to CamelCase. Also works from SHOUTY_CASE.""" + return "".join(word.capitalize() for word in name.split("_")) @_case_conversion(Case.CAMEL, Case.K_CAMEL) def camel_to_k_camel(name): - """Convert from CamelCase to kCamelCase.""" - return "k" + name + """Convert from CamelCase to kCamelCase.""" + return "k" + name @_case_conversion(Case.SNAKE, Case.K_CAMEL) @_case_conversion(Case.SHOUTY, Case.K_CAMEL) def snake_to_k_camel(name): - """Converts from snake_case to kCamelCase. Also works from SHOUTY_CASE.""" - return camel_to_k_camel(snake_to_camel(name)) + """Converts from snake_case to kCamelCase. Also works from SHOUTY_CASE.""" + return camel_to_k_camel(snake_to_camel(name)) def convert_case(case_from, case_to, value): - """Converts cases based on runtime case values. + """Converts cases based on runtime case values. - Note: Cases can be strings or enum values.""" - return _case_conversions[case_from, case_to](value) + Note: Cases can be strings or enum values.""" + return _case_conversions[case_from, case_to](value) def is_case_conversion_supported(case_from, case_to): - """Determines if a case conversion would be supported""" - return (case_from, case_to) in _case_conversions + """Determines if a case conversion would be supported""" + return (case_from, case_to) in _case_conversions diff --git a/compiler/util/name_conversion_test.py b/compiler/util/name_conversion_test.py index 9e41ecc..91309ea 100644 --- a/compiler/util/name_conversion_test.py +++ b/compiler/util/name_conversion_test.py @@ -20,70 +20,77 @@ class NameConversionTest(unittest.TestCase): - def test_snake_to_camel(self): - self.assertEqual("", name_conversion.snake_to_camel("")) - self.assertEqual("Abc", name_conversion.snake_to_camel("abc")) - self.assertEqual("AbcDef", name_conversion.snake_to_camel("abc_def")) - self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def89")) - self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def_89")) - self.assertEqual("Abc89Def", name_conversion.snake_to_camel("abc_89_def")) - self.assertEqual("Abc89def", name_conversion.snake_to_camel("abc_89def")) + def test_snake_to_camel(self): + self.assertEqual("", name_conversion.snake_to_camel("")) + self.assertEqual("Abc", name_conversion.snake_to_camel("abc")) + self.assertEqual("AbcDef", name_conversion.snake_to_camel("abc_def")) + self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def89")) + self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def_89")) + self.assertEqual("Abc89Def", name_conversion.snake_to_camel("abc_89_def")) + self.assertEqual("Abc89def", name_conversion.snake_to_camel("abc_89def")) - def test_shouty_to_camel(self): - self.assertEqual("Abc", name_conversion.snake_to_camel("ABC")) - self.assertEqual("AbcDef", name_conversion.snake_to_camel("ABC_DEF")) - self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF89")) - self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF_89")) - self.assertEqual("Abc89Def", name_conversion.snake_to_camel("ABC_89_DEF")) - self.assertEqual("Abc89def", name_conversion.snake_to_camel("ABC_89DEF")) + def test_shouty_to_camel(self): + self.assertEqual("Abc", name_conversion.snake_to_camel("ABC")) + self.assertEqual("AbcDef", name_conversion.snake_to_camel("ABC_DEF")) + self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF89")) + self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF_89")) + self.assertEqual("Abc89Def", name_conversion.snake_to_camel("ABC_89_DEF")) + self.assertEqual("Abc89def", name_conversion.snake_to_camel("ABC_89DEF")) - def test_camel_to_k_camel(self): - self.assertEqual("kFoo", name_conversion.camel_to_k_camel("Foo")) - self.assertEqual("kFooBar", name_conversion.camel_to_k_camel("FooBar")) - self.assertEqual("kAbc123", name_conversion.camel_to_k_camel("Abc123")) + def test_camel_to_k_camel(self): + self.assertEqual("kFoo", name_conversion.camel_to_k_camel("Foo")) + self.assertEqual("kFooBar", name_conversion.camel_to_k_camel("FooBar")) + self.assertEqual("kAbc123", name_conversion.camel_to_k_camel("Abc123")) - def test_snake_to_k_camel(self): - self.assertEqual("kAbc", name_conversion.snake_to_k_camel("abc")) - self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("abc_def")) - self.assertEqual("kAbcDef89", - name_conversion.snake_to_k_camel("abc_def89")) - self.assertEqual("kAbcDef89", - name_conversion.snake_to_k_camel("abc_def_89")) - self.assertEqual("kAbc89Def", - name_conversion.snake_to_k_camel("abc_89_def")) - self.assertEqual("kAbc89def", - name_conversion.snake_to_k_camel("abc_89def")) + def test_snake_to_k_camel(self): + self.assertEqual("kAbc", name_conversion.snake_to_k_camel("abc")) + self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("abc_def")) + self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("abc_def89")) + self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("abc_def_89")) + self.assertEqual("kAbc89Def", name_conversion.snake_to_k_camel("abc_89_def")) + self.assertEqual("kAbc89def", name_conversion.snake_to_k_camel("abc_89def")) - def test_shouty_to_k_camel(self): - self.assertEqual("kAbc", name_conversion.snake_to_k_camel("ABC")) - self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("ABC_DEF")) - self.assertEqual("kAbcDef89", - name_conversion.snake_to_k_camel("ABC_DEF89")) - self.assertEqual("kAbcDef89", - name_conversion.snake_to_k_camel("ABC_DEF_89")) - self.assertEqual("kAbc89Def", - name_conversion.snake_to_k_camel("ABC_89_DEF")) - self.assertEqual("kAbc89def", - name_conversion.snake_to_k_camel("ABC_89DEF")) + def test_shouty_to_k_camel(self): + self.assertEqual("kAbc", name_conversion.snake_to_k_camel("ABC")) + self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("ABC_DEF")) + self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("ABC_DEF89")) + self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("ABC_DEF_89")) + self.assertEqual("kAbc89Def", name_conversion.snake_to_k_camel("ABC_89_DEF")) + self.assertEqual("kAbc89def", name_conversion.snake_to_k_camel("ABC_89DEF")) - def test_convert_case(self): - self.assertEqual("foo_bar_123", name_conversion.convert_case( - "snake_case", "snake_case", "foo_bar_123")) - self.assertEqual("FOO_BAR_123", name_conversion.convert_case( - "SHOUTY_CASE", "SHOUTY_CASE", "FOO_BAR_123")) - self.assertEqual("kFooBar123", name_conversion.convert_case( - "kCamelCase", "kCamelCase", "kFooBar123")) - self.assertEqual("FooBar123", name_conversion.convert_case( - "CamelCase", "CamelCase", "FooBar123")) - self.assertEqual("kAbcDef", name_conversion.convert_case( - "snake_case", "kCamelCase", "abc_def")) - self.assertEqual("AbcDef", name_conversion.convert_case( - "snake_case", "CamelCase", "abc_def")) - self.assertEqual("kAbcDef", name_conversion.convert_case( - "SHOUTY_CASE", "kCamelCase", "ABC_DEF")) - self.assertEqual("AbcDef", name_conversion.convert_case( - "SHOUTY_CASE", "CamelCase", "ABC_DEF")) + def test_convert_case(self): + self.assertEqual( + "foo_bar_123", + name_conversion.convert_case("snake_case", "snake_case", "foo_bar_123"), + ) + self.assertEqual( + "FOO_BAR_123", + name_conversion.convert_case("SHOUTY_CASE", "SHOUTY_CASE", "FOO_BAR_123"), + ) + self.assertEqual( + "kFooBar123", + name_conversion.convert_case("kCamelCase", "kCamelCase", "kFooBar123"), + ) + self.assertEqual( + "FooBar123", + name_conversion.convert_case("CamelCase", "CamelCase", "FooBar123"), + ) + self.assertEqual( + "kAbcDef", + name_conversion.convert_case("snake_case", "kCamelCase", "abc_def"), + ) + self.assertEqual( + "AbcDef", name_conversion.convert_case("snake_case", "CamelCase", "abc_def") + ) + self.assertEqual( + "kAbcDef", + name_conversion.convert_case("SHOUTY_CASE", "kCamelCase", "ABC_DEF"), + ) + self.assertEqual( + "AbcDef", + name_conversion.convert_case("SHOUTY_CASE", "CamelCase", "ABC_DEF"), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/parser_types.py b/compiler/util/parser_types.py index 5b63ffa..fffe086 100644 --- a/compiler/util/parser_types.py +++ b/compiler/util/parser_types.py @@ -25,92 +25,97 @@ def _make_position(line, column): - """Makes an ir_data.Position from line, column ints.""" - if not isinstance(line, int): - raise ValueError("Bad line {!r}".format(line)) - elif not isinstance(column, int): - raise ValueError("Bad column {!r}".format(column)) - return ir_data.Position(line=line, column=column) + """Makes an ir_data.Position from line, column ints.""" + if not isinstance(line, int): + raise ValueError("Bad line {!r}".format(line)) + elif not isinstance(column, int): + raise ValueError("Bad column {!r}".format(column)) + return ir_data.Position(line=line, column=column) def _parse_position(text): - """Parses an ir_data.Position from "line:column" (e.g., "1:2").""" - line, column = text.split(":") - return _make_position(int(line), int(column)) + """Parses an ir_data.Position from "line:column" (e.g., "1:2").""" + line, column = text.split(":") + return _make_position(int(line), int(column)) def format_position(position): - """formats an ir_data.Position to "line:column" form.""" - return "{}:{}".format(position.line, position.column) + """formats an ir_data.Position to "line:column" form.""" + return "{}:{}".format(position.line, position.column) def make_location(start, end, is_synthetic=False): - """Makes an ir_data.Location from (line, column) tuples or ir_data.Positions.""" - if isinstance(start, tuple): - start = _make_position(*start) - if isinstance(end, tuple): - end = _make_position(*end) - if not isinstance(start, ir_data.Position): - raise ValueError("Bad start {!r}".format(start)) - elif not isinstance(end, ir_data.Position): - raise ValueError("Bad end {!r}".format(end)) - elif start.line > end.line or ( - start.line == end.line and start.column > end.column): - raise ValueError("Start {} is after end {}".format(format_position(start), - format_position(end))) - return ir_data.Location(start=start, end=end, is_synthetic=is_synthetic) + """Makes an ir_data.Location from (line, column) tuples or ir_data.Positions.""" + if isinstance(start, tuple): + start = _make_position(*start) + if isinstance(end, tuple): + end = _make_position(*end) + if not isinstance(start, ir_data.Position): + raise ValueError("Bad start {!r}".format(start)) + elif not isinstance(end, ir_data.Position): + raise ValueError("Bad end {!r}".format(end)) + elif start.line > end.line or ( + start.line == end.line and start.column > end.column + ): + raise ValueError( + "Start {} is after end {}".format( + format_position(start), format_position(end) + ) + ) + return ir_data.Location(start=start, end=end, is_synthetic=is_synthetic) def format_location(location): - """Formats an ir_data.Location in format "1:2-3:4" ("start-end").""" - return "{}-{}".format(format_position(location.start), - format_position(location.end)) + """Formats an ir_data.Location in format "1:2-3:4" ("start-end").""" + return "{}-{}".format( + format_position(location.start), format_position(location.end) + ) def parse_location(text): - """Parses an ir_data.Location from format "1:2-3:4" ("start-end").""" - start, end = text.split("-") - return make_location(_parse_position(start), _parse_position(end)) + """Parses an ir_data.Location from format "1:2-3:4" ("start-end").""" + start, end = text.split("-") + return make_location(_parse_position(start), _parse_position(end)) -class Token( - collections.namedtuple("Token", ["symbol", "text", "source_location"])): - """A Token is a chunk of text from a source file, and a classification. +class Token(collections.namedtuple("Token", ["symbol", "text", "source_location"])): + """A Token is a chunk of text from a source file, and a classification. - Attributes: - symbol: The name of this token ("Indent", "SnakeWord", etc.) - text: The original text ("1234", "some_name", etc.) - source_location: Where this token came from in the original source file. - """ + Attributes: + symbol: The name of this token ("Indent", "SnakeWord", etc.) + text: The original text ("1234", "some_name", etc.) + source_location: Where this token came from in the original source file. + """ - def __str__(self): - return "{} {} {}".format(self.symbol, repr(self.text), - format_location(self.source_location)) + def __str__(self): + return "{} {} {}".format( + self.symbol, repr(self.text), format_location(self.source_location) + ) class Production(collections.namedtuple("Production", ["lhs", "rhs"])): - """A Production is a simple production from a context-free grammar. + """A Production is a simple production from a context-free grammar. - A Production takes the form: + A Production takes the form: - nonterminal -> symbol* + nonterminal -> symbol* - where "nonterminal" is an implicitly non-terminal symbol in the language, - and "symbol*" is zero or more terminal or non-terminal symbols which form the - non-terminal on the left. + where "nonterminal" is an implicitly non-terminal symbol in the language, + and "symbol*" is zero or more terminal or non-terminal symbols which form the + non-terminal on the left. - Attributes: - lhs: The non-terminal symbol on the left-hand-side of the production. - rhs: The sequence of symbols on the right-hand-side of the production. - """ + Attributes: + lhs: The non-terminal symbol on the left-hand-side of the production. + rhs: The sequence of symbols on the right-hand-side of the production. + """ - def __str__(self): - return str(self.lhs) + " -> " + " ".join([str(r) for r in self.rhs]) + def __str__(self): + return str(self.lhs) + " -> " + " ".join([str(r) for r in self.rhs]) - @staticmethod - def parse(production_text): - """Parses a Production from a "symbol -> symbol symbol symbol" string.""" - words = production_text.split() - if words[1] != "->": - raise SyntaxError - return Production(words[0], tuple(words[2:])) + @staticmethod + def parse(production_text): + """Parses a Production from a "symbol -> symbol symbol symbol" string.""" + words = production_text.split() + if words[1] != "->": + raise SyntaxError + return Production(words[0], tuple(words[2:])) diff --git a/compiler/util/parser_types_test.py b/compiler/util/parser_types_test.py index 5e6fddf..097ece4 100644 --- a/compiler/util/parser_types_test.py +++ b/compiler/util/parser_types_test.py @@ -20,114 +20,141 @@ class PositionTest(unittest.TestCase): - """Tests for Position-related functions in parser_types.""" + """Tests for Position-related functions in parser_types.""" - def test_format_position(self): - self.assertEqual( - "1:2", parser_types.format_position(ir_data.Position(line=1, column=2))) + def test_format_position(self): + self.assertEqual( + "1:2", parser_types.format_position(ir_data.Position(line=1, column=2)) + ) class LocationTest(unittest.TestCase): - """Tests for Location-related functions in parser_types.""" - - def test_make_location(self): - self.assertEqual(ir_data.Location(start=ir_data.Position(line=1, - column=2), - end=ir_data.Position(line=3, - column=4), - is_synthetic=False), - parser_types.make_location((1, 2), (3, 4))) - self.assertEqual( - ir_data.Location(start=ir_data.Position(line=1, - column=2), - end=ir_data.Position(line=3, - column=4), - is_synthetic=False), - parser_types.make_location(ir_data.Position(line=1, - column=2), - ir_data.Position(line=3, - column=4))) - - def test_make_synthetic_location(self): - self.assertEqual( - ir_data.Location(start=ir_data.Position(line=1, column=2), - end=ir_data.Position(line=3, column=4), - is_synthetic=True), - parser_types.make_location((1, 2), (3, 4), True)) - self.assertEqual( - ir_data.Location(start=ir_data.Position(line=1, column=2), - end=ir_data.Position(line=3, column=4), - is_synthetic=True), - parser_types.make_location(ir_data.Position(line=1, column=2), - ir_data.Position(line=3, column=4), - True)) - - def test_make_location_type_checks(self): - self.assertRaises(ValueError, parser_types.make_location, [1, 2], (1, 2)) - self.assertRaises(ValueError, parser_types.make_location, (1, 2), [1, 2]) - - def test_make_location_logic_checks(self): - self.assertRaises(ValueError, parser_types.make_location, (3, 4), (1, 2)) - self.assertRaises(ValueError, parser_types.make_location, (1, 3), (1, 2)) - self.assertTrue(parser_types.make_location((1, 2), (1, 2))) - - def test_format_location(self): - self.assertEqual("1:2-3:4", - parser_types.format_location(parser_types.make_location( - (1, 2), (3, 4)))) - - def test_parse_location(self): - self.assertEqual(parser_types.make_location((1, 2), (3, 4)), - parser_types.parse_location("1:2-3:4")) - self.assertEqual(parser_types.make_location((1, 2), (3, 4)), - parser_types.parse_location(" 1 : 2 - 3 : 4 ")) + """Tests for Location-related functions in parser_types.""" + + def test_make_location(self): + self.assertEqual( + ir_data.Location( + start=ir_data.Position(line=1, column=2), + end=ir_data.Position(line=3, column=4), + is_synthetic=False, + ), + parser_types.make_location((1, 2), (3, 4)), + ) + self.assertEqual( + ir_data.Location( + start=ir_data.Position(line=1, column=2), + end=ir_data.Position(line=3, column=4), + is_synthetic=False, + ), + parser_types.make_location( + ir_data.Position(line=1, column=2), ir_data.Position(line=3, column=4) + ), + ) + + def test_make_synthetic_location(self): + self.assertEqual( + ir_data.Location( + start=ir_data.Position(line=1, column=2), + end=ir_data.Position(line=3, column=4), + is_synthetic=True, + ), + parser_types.make_location((1, 2), (3, 4), True), + ) + self.assertEqual( + ir_data.Location( + start=ir_data.Position(line=1, column=2), + end=ir_data.Position(line=3, column=4), + is_synthetic=True, + ), + parser_types.make_location( + ir_data.Position(line=1, column=2), + ir_data.Position(line=3, column=4), + True, + ), + ) + + def test_make_location_type_checks(self): + self.assertRaises(ValueError, parser_types.make_location, [1, 2], (1, 2)) + self.assertRaises(ValueError, parser_types.make_location, (1, 2), [1, 2]) + + def test_make_location_logic_checks(self): + self.assertRaises(ValueError, parser_types.make_location, (3, 4), (1, 2)) + self.assertRaises(ValueError, parser_types.make_location, (1, 3), (1, 2)) + self.assertTrue(parser_types.make_location((1, 2), (1, 2))) + + def test_format_location(self): + self.assertEqual( + "1:2-3:4", + parser_types.format_location(parser_types.make_location((1, 2), (3, 4))), + ) + + def test_parse_location(self): + self.assertEqual( + parser_types.make_location((1, 2), (3, 4)), + parser_types.parse_location("1:2-3:4"), + ) + self.assertEqual( + parser_types.make_location((1, 2), (3, 4)), + parser_types.parse_location(" 1 : 2 - 3 : 4 "), + ) class TokenTest(unittest.TestCase): - """Tests for parser_types.Token.""" + """Tests for parser_types.Token.""" - def test_str(self): - self.assertEqual("FOO 'bar' 1:2-3:4", str(parser_types.Token( - "FOO", "bar", parser_types.make_location((1, 2), (3, 4))))) + def test_str(self): + self.assertEqual( + "FOO 'bar' 1:2-3:4", + str( + parser_types.Token( + "FOO", "bar", parser_types.make_location((1, 2), (3, 4)) + ) + ), + ) class ProductionTest(unittest.TestCase): - """Tests for parser_types.Production.""" - - def test_parse(self): - self.assertEqual(parser_types.Production(lhs="A", - rhs=("B", "C")), - parser_types.Production.parse("A -> B C")) - self.assertEqual(parser_types.Production(lhs="A", - rhs=("B",)), - parser_types.Production.parse("A -> B")) - self.assertEqual(parser_types.Production(lhs="A", - rhs=("B", "C")), - parser_types.Production.parse(" A -> B C ")) - self.assertEqual(parser_types.Production(lhs="A", - rhs=tuple()), - parser_types.Production.parse("A ->")) - self.assertEqual(parser_types.Production(lhs="A", - rhs=tuple()), - parser_types.Production.parse("A -> ")) - self.assertEqual(parser_types.Production(lhs="FOO", - rhs=('"B"', "x*")), - parser_types.Production.parse('FOO -> "B" x*')) - self.assertRaises(SyntaxError, parser_types.Production.parse, "F-> A B") - self.assertRaises(SyntaxError, parser_types.Production.parse, "F B -> A B") - self.assertRaises(SyntaxError, parser_types.Production.parse, "-> A B") - - def test_str(self): - self.assertEqual(str(parser_types.Production(lhs="A", - rhs=("B", "C"))), "A -> B C") - self.assertEqual(str(parser_types.Production(lhs="A", - rhs=("B",))), "A -> B") - self.assertEqual(str(parser_types.Production(lhs="A", - rhs=tuple())), "A -> ") - self.assertEqual(str(parser_types.Production(lhs="FOO", - rhs=('"B"', "x*"))), - 'FOO -> "B" x*') + """Tests for parser_types.Production.""" + + def test_parse(self): + self.assertEqual( + parser_types.Production(lhs="A", rhs=("B", "C")), + parser_types.Production.parse("A -> B C"), + ) + self.assertEqual( + parser_types.Production(lhs="A", rhs=("B",)), + parser_types.Production.parse("A -> B"), + ) + self.assertEqual( + parser_types.Production(lhs="A", rhs=("B", "C")), + parser_types.Production.parse(" A -> B C "), + ) + self.assertEqual( + parser_types.Production(lhs="A", rhs=tuple()), + parser_types.Production.parse("A ->"), + ) + self.assertEqual( + parser_types.Production(lhs="A", rhs=tuple()), + parser_types.Production.parse("A -> "), + ) + self.assertEqual( + parser_types.Production(lhs="FOO", rhs=('"B"', "x*")), + parser_types.Production.parse('FOO -> "B" x*'), + ) + self.assertRaises(SyntaxError, parser_types.Production.parse, "F-> A B") + self.assertRaises(SyntaxError, parser_types.Production.parse, "F B -> A B") + self.assertRaises(SyntaxError, parser_types.Production.parse, "-> A B") + + def test_str(self): + self.assertEqual( + str(parser_types.Production(lhs="A", rhs=("B", "C"))), "A -> B C" + ) + self.assertEqual(str(parser_types.Production(lhs="A", rhs=("B",))), "A -> B") + self.assertEqual(str(parser_types.Production(lhs="A", rhs=tuple())), "A -> ") + self.assertEqual( + str(parser_types.Production(lhs="FOO", rhs=('"B"', "x*"))), 'FOO -> "B" x*' + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/resources.py b/compiler/util/resources.py index 606b1d9..5be66af 100644 --- a/compiler/util/resources.py +++ b/compiler/util/resources.py @@ -16,8 +16,10 @@ import importlib.resources + def load(package, file, encoding="utf-8"): - """Returns the contents of `file` from the Python package loader.""" - with importlib.resources.files( - package).joinpath(file).open("r", encoding=encoding) as f: - return f.read() + """Returns the contents of `file` from the Python package loader.""" + with importlib.resources.files(package).joinpath(file).open( + "r", encoding=encoding + ) as f: + return f.read() diff --git a/compiler/util/simple_memoizer.py b/compiler/util/simple_memoizer.py index 15a1719..fbd46c4 100644 --- a/compiler/util/simple_memoizer.py +++ b/compiler/util/simple_memoizer.py @@ -16,54 +16,54 @@ def memoize(f): - """Memoizes f. + """Memoizes f. - The @memoize decorator returns a function which caches the results of f, and - returns directly from the cache instead of calling f when it is called again - with the same arguments. + The @memoize decorator returns a function which caches the results of f, and + returns directly from the cache instead of calling f when it is called again + with the same arguments. - Memoization has some caveats: + Memoization has some caveats: - Most importantly, the decorated function will not be called every time the - function is called. If the memoized function `f` performs I/O or relies on - or changes global state, it may not work correctly when memoized. + Most importantly, the decorated function will not be called every time the + function is called. If the memoized function `f` performs I/O or relies on + or changes global state, it may not work correctly when memoized. - This memoizer only works for functions taking positional arguments. It does - not handle keywork arguments. + This memoizer only works for functions taking positional arguments. It does + not handle keywork arguments. - This memoizer only works for hashable arguments -- tuples, ints, etc. It does - not work on most iterables. + This memoizer only works for hashable arguments -- tuples, ints, etc. It does + not work on most iterables. - This memoizer returns a function whose __name__ and argument list may differ - from the memoized function under reflection. + This memoizer returns a function whose __name__ and argument list may differ + from the memoized function under reflection. - This memoizer never evicts anything from its cache, so its memory usage can - grow indefinitely. + This memoizer never evicts anything from its cache, so its memory usage can + grow indefinitely. - Depending on the workload and speed of `f`, the memoized `f` can be slower - than unadorned `f`; it is important to use profiling before and after - memoization. + Depending on the workload and speed of `f`, the memoized `f` can be slower + than unadorned `f`; it is important to use profiling before and after + memoization. - Usage: - @memoize - def function(arg, arg2, arg3): - ... + Usage: + @memoize + def function(arg, arg2, arg3): + ... - Arguments: - f: The function to memoize. + Arguments: + f: The function to memoize. - Returns: - A function which acts like f, but faster when called repeatedly with the - same arguments. - """ - cache = {} + Returns: + A function which acts like f, but faster when called repeatedly with the + same arguments. + """ + cache = {} - def _memoized(*args): - assert all(arg.__hash__ for arg in args), ( - "Arguments to memoized function {} must be hashable.".format( - f.__name__)) - if args not in cache: - cache[args] = f(*args) - return cache[args] + def _memoized(*args): + assert all( + arg.__hash__ for arg in args + ), "Arguments to memoized function {} must be hashable.".format(f.__name__) + if args not in cache: + cache[args] = f(*args) + return cache[args] - return _memoized + return _memoized diff --git a/compiler/util/simple_memoizer_test.py b/compiler/util/simple_memoizer_test.py index cfc35c7..855b3f4 100644 --- a/compiler/util/simple_memoizer_test.py +++ b/compiler/util/simple_memoizer_test.py @@ -20,54 +20,55 @@ class SimpleMemoizerTest(unittest.TestCase): - def test_memoized_function_returns_same_values(self): - @simple_memoizer.memoize - def add_one(n): - return n + 1 - - for i in range(100): - self.assertEqual(i + 1, add_one(i)) - - def test_memoized_function_is_only_called_once(self): - arguments = [] - - @simple_memoizer.memoize - def add_one_and_add_argument_to_list(n): - arguments.append(n) - return n + 1 - - self.assertEqual(1, add_one_and_add_argument_to_list(0)) - self.assertEqual([0], arguments) - self.assertEqual(1, add_one_and_add_argument_to_list(0)) - self.assertEqual([0], arguments) - - def test_memoized_function_with_multiple_arguments(self): - arguments = [] - - @simple_memoizer.memoize - def sum_arguments_and_add_arguments_to_list(n, m, o): - arguments.append((n, m, o)) - return n + m + o - - self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2)) - self.assertEqual([(0, 1, 2)], arguments) - self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2)) - self.assertEqual([(0, 1, 2)], arguments) - self.assertEqual(3, sum_arguments_and_add_arguments_to_list(2, 1, 0)) - self.assertEqual([(0, 1, 2), (2, 1, 0)], arguments) - - def test_memoized_function_with_no_arguments(self): - arguments = [] - - @simple_memoizer.memoize - def return_one_and_add_empty_tuple_to_list(): - arguments.append(()) - return 1 - - self.assertEqual(1, return_one_and_add_empty_tuple_to_list()) - self.assertEqual([()], arguments) - self.assertEqual(1, return_one_and_add_empty_tuple_to_list()) - self.assertEqual([()], arguments) - -if __name__ == '__main__': - unittest.main() + def test_memoized_function_returns_same_values(self): + @simple_memoizer.memoize + def add_one(n): + return n + 1 + + for i in range(100): + self.assertEqual(i + 1, add_one(i)) + + def test_memoized_function_is_only_called_once(self): + arguments = [] + + @simple_memoizer.memoize + def add_one_and_add_argument_to_list(n): + arguments.append(n) + return n + 1 + + self.assertEqual(1, add_one_and_add_argument_to_list(0)) + self.assertEqual([0], arguments) + self.assertEqual(1, add_one_and_add_argument_to_list(0)) + self.assertEqual([0], arguments) + + def test_memoized_function_with_multiple_arguments(self): + arguments = [] + + @simple_memoizer.memoize + def sum_arguments_and_add_arguments_to_list(n, m, o): + arguments.append((n, m, o)) + return n + m + o + + self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2)) + self.assertEqual([(0, 1, 2)], arguments) + self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2)) + self.assertEqual([(0, 1, 2)], arguments) + self.assertEqual(3, sum_arguments_and_add_arguments_to_list(2, 1, 0)) + self.assertEqual([(0, 1, 2), (2, 1, 0)], arguments) + + def test_memoized_function_with_no_arguments(self): + arguments = [] + + @simple_memoizer.memoize + def return_one_and_add_empty_tuple_to_list(): + arguments.append(()) + return 1 + + self.assertEqual(1, return_one_and_add_empty_tuple_to_list()) + self.assertEqual([()], arguments) + self.assertEqual(1, return_one_and_add_empty_tuple_to_list()) + self.assertEqual([()], arguments) + + +if __name__ == "__main__": + unittest.main() diff --git a/compiler/util/test_util.py b/compiler/util/test_util.py index 02ac9a3..d54af37 100644 --- a/compiler/util/test_util.py +++ b/compiler/util/test_util.py @@ -18,89 +18,92 @@ def proto_is_superset(proto, expected_values, path=""): - """Returns true if every value in expected_values is set in proto. - - This is intended to be used in assertTrue in a unit test, like so: - - self.assertTrue(*proto_is_superset(proto, expected)) - - Arguments: - proto: The proto to check. - expected_values: The reference proto. - path: The path to the elements being compared. Clients can generally leave - this at default. - - Returns: - A tuple; the first element is True if the fields set in proto are a strict - superset of the fields set in expected_values. The second element is an - informational string specifying the path of a value found in expected_values - but not in proto. - - Every atomic field that is set in expected_values must be set to the same - value in proto; every message field set in expected_values must have a - matching field in proto, such that proto_is_superset(proto.field, - expected_values.field) is true. - - For repeated fields in expected_values, each element in the expected_values - proto must have a corresponding element at the same index in proto; proto - may have additional elements. - """ - if path: - path += "." - for spec, expected_value in ir_data_utils.get_set_fields(expected_values): - name = spec.name - field_path = "{}{}".format(path, name) - value = getattr(proto, name) - if spec.is_dataclass: - if spec.is_sequence: - if len(expected_value) > len(value): - return False, "{}[{}] missing".format(field_path, - len(getattr(proto, name))) - for i in range(len(expected_value)): - result = proto_is_superset(value[i], expected_value[i], - "{}[{}]".format(field_path, i)) - if not result[0]: - return result - else: - if (expected_values.HasField(name) and - not proto.HasField(name)): - return False, "{} missing".format(field_path) - result = proto_is_superset(value, expected_value, field_path) - if not result[0]: - return result - else: - # Zero-length repeated fields and not-there repeated fields are "the - # same." - if (expected_value != value and - (not spec.is_sequence or - len(expected_value))): - if spec.is_sequence: - return False, "{} differs: found {}, expected {}".format( - field_path, list(value), list(expected_value)) + """Returns true if every value in expected_values is set in proto. + + This is intended to be used in assertTrue in a unit test, like so: + + self.assertTrue(*proto_is_superset(proto, expected)) + + Arguments: + proto: The proto to check. + expected_values: The reference proto. + path: The path to the elements being compared. Clients can generally leave + this at default. + + Returns: + A tuple; the first element is True if the fields set in proto are a strict + superset of the fields set in expected_values. The second element is an + informational string specifying the path of a value found in expected_values + but not in proto. + + Every atomic field that is set in expected_values must be set to the same + value in proto; every message field set in expected_values must have a + matching field in proto, such that proto_is_superset(proto.field, + expected_values.field) is true. + + For repeated fields in expected_values, each element in the expected_values + proto must have a corresponding element at the same index in proto; proto + may have additional elements. + """ + if path: + path += "." + for spec, expected_value in ir_data_utils.get_set_fields(expected_values): + name = spec.name + field_path = "{}{}".format(path, name) + value = getattr(proto, name) + if spec.is_dataclass: + if spec.is_sequence: + if len(expected_value) > len(value): + return False, "{}[{}] missing".format( + field_path, len(getattr(proto, name)) + ) + for i in range(len(expected_value)): + result = proto_is_superset( + value[i], expected_value[i], "{}[{}]".format(field_path, i) + ) + if not result[0]: + return result + else: + if expected_values.HasField(name) and not proto.HasField(name): + return False, "{} missing".format(field_path) + result = proto_is_superset(value, expected_value, field_path) + if not result[0]: + return result else: - return False, "{} differs: found {}, expected {}".format( - field_path, value, expected_value) - return True, "" + # Zero-length repeated fields and not-there repeated fields are "the + # same." + if expected_value != value and ( + not spec.is_sequence or len(expected_value) + ): + if spec.is_sequence: + return False, "{} differs: found {}, expected {}".format( + field_path, list(value), list(expected_value) + ) + else: + return False, "{} differs: found {}, expected {}".format( + field_path, value, expected_value + ) + return True, "" def dict_file_reader(file_dict): - """Returns a callable that retrieves entries from file_dict as files. + """Returns a callable that retrieves entries from file_dict as files. - This can be used to call glue.parse_emboss_file with file text declared - inline. + This can be used to call glue.parse_emboss_file with file text declared + inline. - Arguments: - file_dict: A dictionary from "file names" to "contents." + Arguments: + file_dict: A dictionary from "file names" to "contents." - Returns: - A callable that can be passed to glue.parse_emboss_file in place of the - "read" builtin. - """ + Returns: + A callable that can be passed to glue.parse_emboss_file in place of the + "read" builtin. + """ - def read(file_name): - try: - return file_dict[file_name], None - except KeyError: - return None, ["File '{}' not found.".format(file_name)] + def read(file_name): + try: + return file_dict[file_name], None + except KeyError: + return None, ["File '{}' not found.".format(file_name)] - return read + return read diff --git a/compiler/util/test_util_test.py b/compiler/util/test_util_test.py index e82f3c7..58e1ad6 100644 --- a/compiler/util/test_util_test.py +++ b/compiler/util/test_util_test.py @@ -22,149 +22,196 @@ class ProtoIsSupersetTest(unittest.TestCase): - """Tests for test_util.proto_is_superset.""" - - def test_superset_extra_optional_field(self): - self.assertEqual( - (True, ""), - test_util.proto_is_superset( - ir_data.Structure( - field=[ir_data.Field()], - source_location=parser_types.parse_location("1:2-3:4")), - ir_data.Structure(field=[ir_data.Field()]))) - - def test_superset_extra_repeated_field(self): - self.assertEqual( - (True, ""), - test_util.proto_is_superset( - ir_data.Structure( - field=[ir_data.Field(), ir_data.Field()], - source_location=parser_types.parse_location("1:2-3:4")), - ir_data.Structure(field=[ir_data.Field()]))) - - def test_superset_missing_empty_repeated_field(self): - self.assertEqual( - (False, "field[0] missing"), - test_util.proto_is_superset( - ir_data.Structure( - field=[], - source_location=parser_types.parse_location("1:2-3:4")), - ir_data.Structure(field=[ir_data.Field(), ir_data.Field()]))) - - def test_superset_missing_empty_optional_field(self): - self.assertEqual((False, "source_location missing"), - test_util.proto_is_superset( - ir_data.Structure(field=[]), - ir_data.Structure(source_location=ir_data.Location()))) - - def test_array_element_differs(self): - self.assertEqual( - (False, - "field[0].source_location.start.line differs: found 1, expected 2"), - test_util.proto_is_superset( - ir_data.Structure( - field=[ir_data.Field(source_location=parser_types.parse_location( - "1:2-3:4"))]), - ir_data.Structure( - field=[ir_data.Field(source_location=parser_types.parse_location( - "2:2-3:4"))]))) - - def test_equal(self): - self.assertEqual( - (True, ""), - test_util.proto_is_superset(parser_types.parse_location("1:2-3:4"), - parser_types.parse_location("1:2-3:4"))) - - def test_superset_missing_optional_field(self): - self.assertEqual( - (False, "source_location missing"), - test_util.proto_is_superset( - ir_data.Structure(field=[ir_data.Field()]), - ir_data.Structure( - field=[ir_data.Field()], - source_location=parser_types.parse_location("1:2-3:4")))) - - def test_optional_field_differs(self): - self.assertEqual((False, "end.line differs: found 4, expected 3"), - test_util.proto_is_superset( - parser_types.parse_location("1:2-4:4"), - parser_types.parse_location("1:2-3:4"))) - - def test_non_message_repeated_field_equal(self): - self.assertEqual((True, ""), - test_util.proto_is_superset( - ir_data.CanonicalName(object_path=[]), - ir_data.CanonicalName(object_path=[]))) - - def test_non_message_repeated_field_missing_element(self): - self.assertEqual( - (False, "object_path differs: found {none!r}, expected {a!r}".format( - none=[], - a=[u"a"])), - test_util.proto_is_superset( - ir_data.CanonicalName(object_path=[]), - ir_data.CanonicalName(object_path=[u"a"]))) - - def test_non_message_repeated_field_element_differs(self): - self.assertEqual( - (False, "object_path differs: found {aa!r}, expected {ab!r}".format( - aa=[u"a", u"a"], - ab=[u"a", u"b"])), - test_util.proto_is_superset( - ir_data.CanonicalName(object_path=[u"a", u"a"]), - ir_data.CanonicalName(object_path=[u"a", u"b"]))) - - def test_non_message_repeated_field_extra_element(self): - # For repeated fields of int/bool/str values, the entire list is treated as - # an atomic unit, and should be equal. - self.assertEqual( - (False, - "object_path differs: found {!r}, expected {!r}".format( - [u"a", u"a"], [u"a"])), - test_util.proto_is_superset( - ir_data.CanonicalName(object_path=["a", "a"]), - ir_data.CanonicalName(object_path=["a"]))) - - def test_non_message_repeated_field_no_expected_value(self): - # When a repeated field is empty, it is the same as if it were entirely - # missing -- there is no way to differentiate those two conditions. - self.assertEqual((True, ""), - test_util.proto_is_superset( - ir_data.CanonicalName(object_path=["a", "a"]), - ir_data.CanonicalName(object_path=[]))) + """Tests for test_util.proto_is_superset.""" + + def test_superset_extra_optional_field(self): + self.assertEqual( + (True, ""), + test_util.proto_is_superset( + ir_data.Structure( + field=[ir_data.Field()], + source_location=parser_types.parse_location("1:2-3:4"), + ), + ir_data.Structure(field=[ir_data.Field()]), + ), + ) + + def test_superset_extra_repeated_field(self): + self.assertEqual( + (True, ""), + test_util.proto_is_superset( + ir_data.Structure( + field=[ir_data.Field(), ir_data.Field()], + source_location=parser_types.parse_location("1:2-3:4"), + ), + ir_data.Structure(field=[ir_data.Field()]), + ), + ) + + def test_superset_missing_empty_repeated_field(self): + self.assertEqual( + (False, "field[0] missing"), + test_util.proto_is_superset( + ir_data.Structure( + field=[], source_location=parser_types.parse_location("1:2-3:4") + ), + ir_data.Structure(field=[ir_data.Field(), ir_data.Field()]), + ), + ) + + def test_superset_missing_empty_optional_field(self): + self.assertEqual( + (False, "source_location missing"), + test_util.proto_is_superset( + ir_data.Structure(field=[]), + ir_data.Structure(source_location=ir_data.Location()), + ), + ) + + def test_array_element_differs(self): + self.assertEqual( + (False, "field[0].source_location.start.line differs: found 1, expected 2"), + test_util.proto_is_superset( + ir_data.Structure( + field=[ + ir_data.Field( + source_location=parser_types.parse_location("1:2-3:4") + ) + ] + ), + ir_data.Structure( + field=[ + ir_data.Field( + source_location=parser_types.parse_location("2:2-3:4") + ) + ] + ), + ), + ) + + def test_equal(self): + self.assertEqual( + (True, ""), + test_util.proto_is_superset( + parser_types.parse_location("1:2-3:4"), + parser_types.parse_location("1:2-3:4"), + ), + ) + + def test_superset_missing_optional_field(self): + self.assertEqual( + (False, "source_location missing"), + test_util.proto_is_superset( + ir_data.Structure(field=[ir_data.Field()]), + ir_data.Structure( + field=[ir_data.Field()], + source_location=parser_types.parse_location("1:2-3:4"), + ), + ), + ) + + def test_optional_field_differs(self): + self.assertEqual( + (False, "end.line differs: found 4, expected 3"), + test_util.proto_is_superset( + parser_types.parse_location("1:2-4:4"), + parser_types.parse_location("1:2-3:4"), + ), + ) + + def test_non_message_repeated_field_equal(self): + self.assertEqual( + (True, ""), + test_util.proto_is_superset( + ir_data.CanonicalName(object_path=[]), + ir_data.CanonicalName(object_path=[]), + ), + ) + + def test_non_message_repeated_field_missing_element(self): + self.assertEqual( + ( + False, + "object_path differs: found {none!r}, expected {a!r}".format( + none=[], a=["a"] + ), + ), + test_util.proto_is_superset( + ir_data.CanonicalName(object_path=[]), + ir_data.CanonicalName(object_path=["a"]), + ), + ) + + def test_non_message_repeated_field_element_differs(self): + self.assertEqual( + ( + False, + "object_path differs: found {aa!r}, expected {ab!r}".format( + aa=["a", "a"], ab=["a", "b"] + ), + ), + test_util.proto_is_superset( + ir_data.CanonicalName(object_path=["a", "a"]), + ir_data.CanonicalName(object_path=["a", "b"]), + ), + ) + + def test_non_message_repeated_field_extra_element(self): + # For repeated fields of int/bool/str values, the entire list is treated as + # an atomic unit, and should be equal. + self.assertEqual( + ( + False, + "object_path differs: found {!r}, expected {!r}".format( + ["a", "a"], ["a"] + ), + ), + test_util.proto_is_superset( + ir_data.CanonicalName(object_path=["a", "a"]), + ir_data.CanonicalName(object_path=["a"]), + ), + ) + + def test_non_message_repeated_field_no_expected_value(self): + # When a repeated field is empty, it is the same as if it were entirely + # missing -- there is no way to differentiate those two conditions. + self.assertEqual( + (True, ""), + test_util.proto_is_superset( + ir_data.CanonicalName(object_path=["a", "a"]), + ir_data.CanonicalName(object_path=[]), + ), + ) class DictFileReaderTest(unittest.TestCase): - """Tests for dict_file_reader.""" - - def test_empty_dict(self): - reader = test_util.dict_file_reader({}) - self.assertEqual((None, ["File 'anything' not found."]), reader("anything")) - self.assertEqual((None, ["File '' not found."]), reader("")) - - def test_one_element_dict(self): - reader = test_util.dict_file_reader({"m": "abc"}) - self.assertEqual((None, ["File 'not_there' not found."]), - reader("not_there")) - self.assertEqual((None, ["File '' not found."]), reader("")) - self.assertEqual(("abc", None), reader("m")) - - def test_two_element_dict(self): - reader = test_util.dict_file_reader({"m": "abc", "n": "def"}) - self.assertEqual((None, ["File 'not_there' not found."]), - reader("not_there")) - self.assertEqual((None, ["File '' not found."]), reader("")) - self.assertEqual(("abc", None), reader("m")) - self.assertEqual(("def", None), reader("n")) - - def test_dict_with_empty_key(self): - reader = test_util.dict_file_reader({"m": "abc", "": "def"}) - self.assertEqual((None, ["File 'not_there' not found."]), - reader("not_there")) - self.assertEqual((None, ["File 'None' not found."]), reader(None)) - self.assertEqual(("abc", None), reader("m")) - self.assertEqual(("def", None), reader("")) + """Tests for dict_file_reader.""" + + def test_empty_dict(self): + reader = test_util.dict_file_reader({}) + self.assertEqual((None, ["File 'anything' not found."]), reader("anything")) + self.assertEqual((None, ["File '' not found."]), reader("")) + + def test_one_element_dict(self): + reader = test_util.dict_file_reader({"m": "abc"}) + self.assertEqual((None, ["File 'not_there' not found."]), reader("not_there")) + self.assertEqual((None, ["File '' not found."]), reader("")) + self.assertEqual(("abc", None), reader("m")) + + def test_two_element_dict(self): + reader = test_util.dict_file_reader({"m": "abc", "n": "def"}) + self.assertEqual((None, ["File 'not_there' not found."]), reader("not_there")) + self.assertEqual((None, ["File '' not found."]), reader("")) + self.assertEqual(("abc", None), reader("m")) + self.assertEqual(("def", None), reader("n")) + + def test_dict_with_empty_key(self): + reader = test_util.dict_file_reader({"m": "abc", "": "def"}) + self.assertEqual((None, ["File 'not_there' not found."]), reader("not_there")) + self.assertEqual((None, ["File 'None' not found."]), reader(None)) + self.assertEqual(("abc", None), reader("m")) + self.assertEqual(("def", None), reader("")) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/compiler/util/traverse_ir.py b/compiler/util/traverse_ir.py index 93ffdc8..a6c3dda 100644 --- a/compiler/util/traverse_ir.py +++ b/compiler/util/traverse_ir.py @@ -23,373 +23,418 @@ class _FunctionCaller: - """Provides a template for setting up a generic call to a function. - - The function parameters are inspected at run-time to build up a set of valid - and required arguments. When invoking the function unneccessary parameters - will be trimmed out. If arguments are missing an assertion will be triggered. - - This is currently limited to functions that have at least one positional - parameter. - - Example usage: - ``` - def func_1(a, b, c=2): pass - def func_2(a, d): pass - caller_1 = _FunctionCaller(func_1) - caller_2 = _FunctionCaller(func_2) - generic_params = {"b": 2, "c": 3, "d": 4} - - # Equivalent of: func_1(a, b=2, c=3) - caller_1.invoke(a, generic_params) - - # Equivalent of: func_2(a, d=4) - caller_2.invoke(a, generic_params) - """ - - def __init__(self, function): - self.function = function - self.needs_filtering = True - self.valid_arg_names = set() - self.required_arg_names = set() - - argspec = inspect.getfullargspec(function) - if argspec.varkw: - # If the function accepts a kwargs parameter, then it will accept all - # arguments. - # Note: this isn't technically true if one of the keyword arguments has the - # same name as one of the positional arguments. - self.needs_filtering = False - else: - # argspec.args is a list of all parameter names excluding keyword only - # args. The first element is our required positional_arg and should be - # ignored. - args = argspec.args[1:] - self.valid_arg_names.update(args) - - # args.kwonlyargs gives us the list of keyword only args which are - # also valid. - self.valid_arg_names.update(argspec.kwonlyargs) - - # Required args are positional arguments that don't have defaults. - # Keyword only args are always optional and can be ignored. Args with - # defaults are the last elements of the argsepec.args list and should - # be ignored. - if argspec.defaults: - # Trim the arguments with defaults. - args = args[: -len(argspec.defaults)] - self.required_arg_names.update(args) - - def invoke(self, positional_arg, keyword_args): - """Invokes the function with the given args.""" - if self.needs_filtering: - # Trim to just recognized args. - matched_args = { - k: v for k, v in keyword_args.items() if k in self.valid_arg_names - } - # Check if any required args are missing. - missing_args = self.required_arg_names.difference(matched_args.keys()) - assert not missing_args, ( - f"Attempting to call '{self.function.__name__}'; " - f"missing {missing_args} (have {set(keyword_args.keys())})" - ) - keyword_args = matched_args - - return self.function(positional_arg, **keyword_args) + """Provides a template for setting up a generic call to a function. + + The function parameters are inspected at run-time to build up a set of valid + and required arguments. When invoking the function unneccessary parameters + will be trimmed out. If arguments are missing an assertion will be triggered. + + This is currently limited to functions that have at least one positional + parameter. + + Example usage: + ``` + def func_1(a, b, c=2): pass + def func_2(a, d): pass + caller_1 = _FunctionCaller(func_1) + caller_2 = _FunctionCaller(func_2) + generic_params = {"b": 2, "c": 3, "d": 4} + + # Equivalent of: func_1(a, b=2, c=3) + caller_1.invoke(a, generic_params) + + # Equivalent of: func_2(a, d=4) + caller_2.invoke(a, generic_params) + """ + + def __init__(self, function): + self.function = function + self.needs_filtering = True + self.valid_arg_names = set() + self.required_arg_names = set() + + argspec = inspect.getfullargspec(function) + if argspec.varkw: + # If the function accepts a kwargs parameter, then it will accept all + # arguments. + # Note: this isn't technically true if one of the keyword arguments has the + # same name as one of the positional arguments. + self.needs_filtering = False + else: + # argspec.args is a list of all parameter names excluding keyword only + # args. The first element is our required positional_arg and should be + # ignored. + args = argspec.args[1:] + self.valid_arg_names.update(args) + + # args.kwonlyargs gives us the list of keyword only args which are + # also valid. + self.valid_arg_names.update(argspec.kwonlyargs) + + # Required args are positional arguments that don't have defaults. + # Keyword only args are always optional and can be ignored. Args with + # defaults are the last elements of the argsepec.args list and should + # be ignored. + if argspec.defaults: + # Trim the arguments with defaults. + args = args[: -len(argspec.defaults)] + self.required_arg_names.update(args) + + def invoke(self, positional_arg, keyword_args): + """Invokes the function with the given args.""" + if self.needs_filtering: + # Trim to just recognized args. + matched_args = { + k: v for k, v in keyword_args.items() if k in self.valid_arg_names + } + # Check if any required args are missing. + missing_args = self.required_arg_names.difference(matched_args.keys()) + assert not missing_args, ( + f"Attempting to call '{self.function.__name__}'; " + f"missing {missing_args} (have {set(keyword_args.keys())})" + ) + keyword_args = matched_args + + return self.function(positional_arg, **keyword_args) @simple_memoizer.memoize def _memoized_caller(function): - default_lambda_name = (lambda: None).__name__ - assert ( - callable(function) and not function.__name__ == default_lambda_name - ), "For performance reasons actions must be defined as static functions" - return _FunctionCaller(function) + default_lambda_name = (lambda: None).__name__ + assert ( + callable(function) and not function.__name__ == default_lambda_name + ), "For performance reasons actions must be defined as static functions" + return _FunctionCaller(function) def _call_with_optional_args(function, positional_arg, keyword_args): - """Calls function with whatever keyword_args it will accept.""" - caller = _memoized_caller(function) - return caller.invoke(positional_arg, keyword_args) - - -def _fast_traverse_proto_top_down(proto, incidental_actions, pattern, - skip_descendants_of, action, parameters): - """Traverses an IR, calling `action` on some nodes.""" - - # Parameters are scoped to the branch of the tree, so make a copy here, before - # any action or incidental_action can update them. - parameters = parameters.copy() - - # If there is an incidental action for this node type, run it. - if type(proto) in incidental_actions: # pylint: disable=unidiomatic-typecheck - for incidental_action in incidental_actions[type(proto)]: - parameters.update(_call_with_optional_args( - incidental_action, proto, parameters) or {}) - - # If we are at the end of pattern, check to see if we should call action. - if len(pattern) == 1: - new_pattern = pattern - if pattern[0] == type(proto): - parameters.update( - _call_with_optional_args(action, proto, parameters) or {}) - else: - # Otherwise, if this node's type matches the head of pattern, recurse with - # the tail of the pattern. - if pattern[0] == type(proto): - new_pattern = pattern[1:] + """Calls function with whatever keyword_args it will accept.""" + caller = _memoized_caller(function) + return caller.invoke(positional_arg, keyword_args) + + +def _fast_traverse_proto_top_down( + proto, incidental_actions, pattern, skip_descendants_of, action, parameters +): + """Traverses an IR, calling `action` on some nodes.""" + + # Parameters are scoped to the branch of the tree, so make a copy here, before + # any action or incidental_action can update them. + parameters = parameters.copy() + + # If there is an incidental action for this node type, run it. + if type(proto) in incidental_actions: # pylint: disable=unidiomatic-typecheck + for incidental_action in incidental_actions[type(proto)]: + parameters.update( + _call_with_optional_args(incidental_action, proto, parameters) or {} + ) + + # If we are at the end of pattern, check to see if we should call action. + if len(pattern) == 1: + new_pattern = pattern + if pattern[0] == type(proto): + parameters.update(_call_with_optional_args(action, proto, parameters) or {}) else: - new_pattern = pattern - - # If the current node's type is one of the types whose branch should be - # skipped, then bail. This has to happen after `action` is called, because - # clients rely on being able to, e.g., get a callback for the "root" - # Expression without getting callbacks for every sub-Expression. - # pylint: disable=unidiomatic-typecheck - if type(proto) in skip_descendants_of: - return - - # Otherwise, recurse. _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET tells us, given - # the current node's type and the current target type, which fields to check. - singular_fields, repeated_fields = _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET[ - type(proto), new_pattern[0]] - for member_name in singular_fields: - if proto.HasField(member_name): - _fast_traverse_proto_top_down(getattr(proto, member_name), - incidental_actions, new_pattern, - skip_descendants_of, action, parameters) - for member_name in repeated_fields: - for array_element in getattr(proto, member_name) or []: - _fast_traverse_proto_top_down(array_element, incidental_actions, - new_pattern, skip_descendants_of, action, - parameters) + # Otherwise, if this node's type matches the head of pattern, recurse with + # the tail of the pattern. + if pattern[0] == type(proto): + new_pattern = pattern[1:] + else: + new_pattern = pattern + + # If the current node's type is one of the types whose branch should be + # skipped, then bail. This has to happen after `action` is called, because + # clients rely on being able to, e.g., get a callback for the "root" + # Expression without getting callbacks for every sub-Expression. + # pylint: disable=unidiomatic-typecheck + if type(proto) in skip_descendants_of: + return + + # Otherwise, recurse. _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET tells us, given + # the current node's type and the current target type, which fields to check. + singular_fields, repeated_fields = _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET[ + type(proto), new_pattern[0] + ] + for member_name in singular_fields: + if proto.HasField(member_name): + _fast_traverse_proto_top_down( + getattr(proto, member_name), + incidental_actions, + new_pattern, + skip_descendants_of, + action, + parameters, + ) + for member_name in repeated_fields: + for array_element in getattr(proto, member_name) or []: + _fast_traverse_proto_top_down( + array_element, + incidental_actions, + new_pattern, + skip_descendants_of, + action, + parameters, + ) def _fields_to_scan_by_current_and_target(): - """Generates _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET.""" - # In order to avoid spending a *lot* of time just walking the IR, this - # function sets up a dict that allows `_fast_traverse_proto_top_down()` to - # skip traversing large portions of the IR, depending on what node types it is - # targeting. - # - # Without this branch culling scheme, the Emboss front end (at time of - # writing) spends roughly 70% (19s out of 31s) of its time just walking the - # IR. With branch culling, that goes down to 6% (0.7s out of 12.2s). - - # type_to_fields is a map of types to maps of field names to field types. - # That is, type_to_fields[ir_data.Module]["type"] == ir_data.AddressableUnit. - type_to_fields = {} - - # Later, we need to know which fields are singular and which are repeated, - # because the access methods are not uniform. This maps (type, field_name) - # tuples to descriptor labels: type_fields_to_cardinality[ir_data.Module, - # "type"] == ir_data.Repeated. - type_fields_to_cardinality = {} - - # Fill out the above maps by recursively walking the IR type tree, starting - # from the root. - types_to_check = [ir_data.EmbossIr] - while types_to_check: - type_to_check = types_to_check.pop() - if type_to_check in type_to_fields: - continue - fields = {} - for field_name, field_type in ir_data_utils.field_specs(type_to_check).items(): - if field_type.is_dataclass: - fields[field_name] = field_type.data_type - types_to_check.append(field_type.data_type) - type_fields_to_cardinality[type_to_check, field_name] = ( - field_type.container) - type_to_fields[type_to_check] = fields - - # type_to_descendant_types is a map of all types that can be reached from a - # particular type. After the setup, type_to_descendant_types[ir_data.EmbossIr] - # == set() and type_to_descendant_types[ir_data.Reference] == - # {ir_data.CanonicalName, ir_data.Word, ir_data.Location} and - # type_to_descendant_types[ir_data.Word] == set(). - # - # The while loop basically ors in the known descendants of each known - # descendant of each type until the dict stops changing, which is a bit - # brute-force, but in practice only iterates a few times. - type_to_descendant_types = {} - for parent_type, field_map in type_to_fields.items(): - type_to_descendant_types[parent_type] = set(field_map.values()) - previous_map = {} - while type_to_descendant_types != previous_map: - # In order to check the previous iteration against the current iteration, it - # is necessary to make a two-level copy. Otherwise, the updates to the - # values will also update previous_map's values, which causes the loop to - # exit prematurely. - previous_map = {k: set(v) for k, v in type_to_descendant_types.items()} - for ancestor_type, descendents in previous_map.items(): - for descendent in descendents: - type_to_descendant_types[ancestor_type] |= previous_map[descendent] - - # Finally, we have all of the information we need to make the map we really - # want: given a current node type and a target node type, which fields should - # be checked? (This implicitly skips fields that *can't* contain the target - # type.) - fields_to_scan_by_current_and_target = {} - for current_node_type in type_to_fields: - for target_node_type in type_to_fields: - singular_fields_to_scan = [] - repeated_fields_to_scan = [] - for field_name, field_type in type_to_fields[current_node_type].items(): - # If the target node type cannot contain another instance of itself, it - # is still necessary to scan fields that have the actual target type. - if (target_node_type == field_type or - target_node_type in type_to_descendant_types[field_type]): - # Singular and repeated fields go to different lists, so that they can - # be handled separately. - if (type_fields_to_cardinality[current_node_type, field_name] is not - ir_data_fields.FieldContainer.LIST): - singular_fields_to_scan.append(field_name) - else: - repeated_fields_to_scan.append(field_name) - fields_to_scan_by_current_and_target[ - current_node_type, target_node_type] = ( - singular_fields_to_scan, repeated_fields_to_scan) - return fields_to_scan_by_current_and_target + """Generates _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET.""" + # In order to avoid spending a *lot* of time just walking the IR, this + # function sets up a dict that allows `_fast_traverse_proto_top_down()` to + # skip traversing large portions of the IR, depending on what node types it is + # targeting. + # + # Without this branch culling scheme, the Emboss front end (at time of + # writing) spends roughly 70% (19s out of 31s) of its time just walking the + # IR. With branch culling, that goes down to 6% (0.7s out of 12.2s). + + # type_to_fields is a map of types to maps of field names to field types. + # That is, type_to_fields[ir_data.Module]["type"] == ir_data.AddressableUnit. + type_to_fields = {} + + # Later, we need to know which fields are singular and which are repeated, + # because the access methods are not uniform. This maps (type, field_name) + # tuples to descriptor labels: type_fields_to_cardinality[ir_data.Module, + # "type"] == ir_data.Repeated. + type_fields_to_cardinality = {} + + # Fill out the above maps by recursively walking the IR type tree, starting + # from the root. + types_to_check = [ir_data.EmbossIr] + while types_to_check: + type_to_check = types_to_check.pop() + if type_to_check in type_to_fields: + continue + fields = {} + for field_name, field_type in ir_data_utils.field_specs(type_to_check).items(): + if field_type.is_dataclass: + fields[field_name] = field_type.data_type + types_to_check.append(field_type.data_type) + type_fields_to_cardinality[type_to_check, field_name] = ( + field_type.container + ) + type_to_fields[type_to_check] = fields + + # type_to_descendant_types is a map of all types that can be reached from a + # particular type. After the setup, type_to_descendant_types[ir_data.EmbossIr] + # == set() and type_to_descendant_types[ir_data.Reference] == + # {ir_data.CanonicalName, ir_data.Word, ir_data.Location} and + # type_to_descendant_types[ir_data.Word] == set(). + # + # The while loop basically ors in the known descendants of each known + # descendant of each type until the dict stops changing, which is a bit + # brute-force, but in practice only iterates a few times. + type_to_descendant_types = {} + for parent_type, field_map in type_to_fields.items(): + type_to_descendant_types[parent_type] = set(field_map.values()) + previous_map = {} + while type_to_descendant_types != previous_map: + # In order to check the previous iteration against the current iteration, it + # is necessary to make a two-level copy. Otherwise, the updates to the + # values will also update previous_map's values, which causes the loop to + # exit prematurely. + previous_map = {k: set(v) for k, v in type_to_descendant_types.items()} + for ancestor_type, descendents in previous_map.items(): + for descendent in descendents: + type_to_descendant_types[ancestor_type] |= previous_map[descendent] + + # Finally, we have all of the information we need to make the map we really + # want: given a current node type and a target node type, which fields should + # be checked? (This implicitly skips fields that *can't* contain the target + # type.) + fields_to_scan_by_current_and_target = {} + for current_node_type in type_to_fields: + for target_node_type in type_to_fields: + singular_fields_to_scan = [] + repeated_fields_to_scan = [] + for field_name, field_type in type_to_fields[current_node_type].items(): + # If the target node type cannot contain another instance of itself, it + # is still necessary to scan fields that have the actual target type. + if ( + target_node_type == field_type + or target_node_type in type_to_descendant_types[field_type] + ): + # Singular and repeated fields go to different lists, so that they can + # be handled separately. + if ( + type_fields_to_cardinality[current_node_type, field_name] + is not ir_data_fields.FieldContainer.LIST + ): + singular_fields_to_scan.append(field_name) + else: + repeated_fields_to_scan.append(field_name) + fields_to_scan_by_current_and_target[ + current_node_type, target_node_type + ] = (singular_fields_to_scan, repeated_fields_to_scan) + return fields_to_scan_by_current_and_target _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET = _fields_to_scan_by_current_and_target() -def _emboss_ir_action(ir): - return {"ir": ir} - -def _module_action(m): - return {"source_file_name": m.source_file_name} - -def _type_definition_action(t): - return {"type_definition": t} - -def _field_action(f): - return {"field": f} - -def fast_traverse_ir_top_down(ir, pattern, action, incidental_actions=None, - skip_descendants_of=(), parameters=None): - """Traverses an IR from the top down, executing the given actions. - `fast_traverse_ir_top_down` walks the given IR in preorder traversal, - specifically looking for nodes whose path from the root of the tree matches - `pattern`. For every node which matches `pattern`, `action` will be called. - - `pattern` is just a list of node types. For example, to execute `print` on - every `ir_data.Word` in the IR: - - fast_traverse_ir_top_down(ir, [ir_data.Word], print) - - If more than one type is specified, then each one must be found inside the - previous. For example, to print only the Words inside of import statements: - - fast_traverse_ir_top_down(ir, [ir_data.Import, ir_data.Word], print) - - The optional arguments provide additional control. - - `skip_descendants_of` is a list of types that should be treated as if they are - leaf nodes when they are encountered. That is, traversal will skip any - nodes with any ancestor node whose type is in `skip_descendants_of`. For - example, to `do_something` only on outermost `Expression`s: +def _emboss_ir_action(ir): + return {"ir": ir} - fast_traverse_ir_top_down(ir, [ir_data.Expression], do_something, - skip_descendants_of={ir_data.Expression}) - `parameters` specifies a dictionary of initial parameters which can be passed - as arguments to `action` and `incidental_actions`. Note that the parameters - can be overridden for parts of the tree by `action` and `incidental_actions`. - Parameters can be used to set an object which may be updated by `action`, such - as a list of errors generated by some check in `action`: +def _module_action(m): + return {"source_file_name": m.source_file_name} - def check_structure(structure, errors): - if structure_is_bad(structure): - errors.append(error_for_structure(structure)) - errors = [] - fast_traverse_ir_top_down(ir, [ir_data.Structure], check_structure, - parameters={"errors": errors}) - if errors: - print("Errors: {}".format(errors)) - sys.exit(1) +def _type_definition_action(t): + return {"type_definition": t} - `incidental_actions` is a map from node types to functions (or tuples of - functions or lists of functions) which should be called on those nodes. - Because `fast_traverse_ir_top_down` may skip branches that can't contain - `pattern`, functions in `incidental_actions` should generally not have any - side effects: instead, they may return a dictionary, which will be used to - override `parameters` for any children of the node they were called on. For - example: - def do_something(expression, field_name=None): - if field_name: - print("Found {} inside {}".format(expression, field_name)) - else: - print("Found {} not in any field".format(expression)) - - fast_traverse_ir_top_down( - ir, [ir_data.Expression], do_something, - incidental_actions={ir_data.Field: lambda f: {"field_name": f.name}}) - - (The `action` may also return a dict in the same way.) - - A few `incidental_actions` are built into `fast_traverse_ir_top_down`, so - that certain parameters are contextually available with well-known names: - - ir: The complete IR (the root ir_data.EmbossIr node). - source_file_name: The file name from which the current node was sourced. - type_definition: The most-immediate ancestor type definition. - field: The field containing the current node, if any. - - Arguments: - ir: An ir_data.Ir object to walk. - pattern: A list of node types to match. - action: A callable, which will be called on nodes matching `pattern`. - incidental_actions: A dict of node types to callables, which can be used to - set new parameters for `action` for part of the IR tree. - skip_descendants_of: A list of types whose children should be skipped when - traversing `ir`. - parameters: A list of top-level parameters. - - Returns: - None - """ - all_incidental_actions = { - ir_data.EmbossIr: [_emboss_ir_action], - ir_data.Module: [_module_action], - ir_data.TypeDefinition: [_type_definition_action], - ir_data.Field: [_field_action], - } - if incidental_actions: - for key, incidental_action in incidental_actions.items(): - if not isinstance(incidental_action, (list, tuple)): - incidental_action = [incidental_action] - all_incidental_actions.setdefault(key, []).extend(incidental_action) - _fast_traverse_proto_top_down(ir, all_incidental_actions, pattern, - skip_descendants_of, action, parameters or {}) - - -def fast_traverse_node_top_down(node, pattern, action, incidental_actions=None, - skip_descendants_of=(), parameters=None): - """Traverse a subtree of an IR, executing the given actions. - - fast_traverse_node_top_down is like fast_traverse_ir_top_down, except that: - - It may be called on a subtree, instead of the top of the IR. - - It does not have any built-in incidental actions. - - Arguments: - node: An ir_data.Ir object to walk. - pattern: A list of node types to match. - action: A callable, which will be called on nodes matching `pattern`. - incidental_actions: A dict of node types to callables, which can be used to - set new parameters for `action` for part of the IR tree. - skip_descendants_of: A list of types whose children should be skipped when - traversing `node`. - parameters: A list of top-level parameters. - - Returns: - None - """ - _fast_traverse_proto_top_down(node, incidental_actions or {}, pattern, - skip_descendants_of or {}, action, - parameters or {}) +def _field_action(f): + return {"field": f} + + +def fast_traverse_ir_top_down( + ir, + pattern, + action, + incidental_actions=None, + skip_descendants_of=(), + parameters=None, +): + """Traverses an IR from the top down, executing the given actions. + + `fast_traverse_ir_top_down` walks the given IR in preorder traversal, + specifically looking for nodes whose path from the root of the tree matches + `pattern`. For every node which matches `pattern`, `action` will be called. + + `pattern` is just a list of node types. For example, to execute `print` on + every `ir_data.Word` in the IR: + + fast_traverse_ir_top_down(ir, [ir_data.Word], print) + + If more than one type is specified, then each one must be found inside the + previous. For example, to print only the Words inside of import statements: + + fast_traverse_ir_top_down(ir, [ir_data.Import, ir_data.Word], print) + + The optional arguments provide additional control. + + `skip_descendants_of` is a list of types that should be treated as if they are + leaf nodes when they are encountered. That is, traversal will skip any + nodes with any ancestor node whose type is in `skip_descendants_of`. For + example, to `do_something` only on outermost `Expression`s: + + fast_traverse_ir_top_down(ir, [ir_data.Expression], do_something, + skip_descendants_of={ir_data.Expression}) + + `parameters` specifies a dictionary of initial parameters which can be passed + as arguments to `action` and `incidental_actions`. Note that the parameters + can be overridden for parts of the tree by `action` and `incidental_actions`. + Parameters can be used to set an object which may be updated by `action`, such + as a list of errors generated by some check in `action`: + + def check_structure(structure, errors): + if structure_is_bad(structure): + errors.append(error_for_structure(structure)) + + errors = [] + fast_traverse_ir_top_down(ir, [ir_data.Structure], check_structure, + parameters={"errors": errors}) + if errors: + print("Errors: {}".format(errors)) + sys.exit(1) + + `incidental_actions` is a map from node types to functions (or tuples of + functions or lists of functions) which should be called on those nodes. + Because `fast_traverse_ir_top_down` may skip branches that can't contain + `pattern`, functions in `incidental_actions` should generally not have any + side effects: instead, they may return a dictionary, which will be used to + override `parameters` for any children of the node they were called on. For + example: + + def do_something(expression, field_name=None): + if field_name: + print("Found {} inside {}".format(expression, field_name)) + else: + print("Found {} not in any field".format(expression)) + + fast_traverse_ir_top_down( + ir, [ir_data.Expression], do_something, + incidental_actions={ir_data.Field: lambda f: {"field_name": f.name}}) + + (The `action` may also return a dict in the same way.) + + A few `incidental_actions` are built into `fast_traverse_ir_top_down`, so + that certain parameters are contextually available with well-known names: + + ir: The complete IR (the root ir_data.EmbossIr node). + source_file_name: The file name from which the current node was sourced. + type_definition: The most-immediate ancestor type definition. + field: The field containing the current node, if any. + + Arguments: + ir: An ir_data.Ir object to walk. + pattern: A list of node types to match. + action: A callable, which will be called on nodes matching `pattern`. + incidental_actions: A dict of node types to callables, which can be used to + set new parameters for `action` for part of the IR tree. + skip_descendants_of: A list of types whose children should be skipped when + traversing `ir`. + parameters: A list of top-level parameters. + + Returns: + None + """ + all_incidental_actions = { + ir_data.EmbossIr: [_emboss_ir_action], + ir_data.Module: [_module_action], + ir_data.TypeDefinition: [_type_definition_action], + ir_data.Field: [_field_action], + } + if incidental_actions: + for key, incidental_action in incidental_actions.items(): + if not isinstance(incidental_action, (list, tuple)): + incidental_action = [incidental_action] + all_incidental_actions.setdefault(key, []).extend(incidental_action) + _fast_traverse_proto_top_down( + ir, + all_incidental_actions, + pattern, + skip_descendants_of, + action, + parameters or {}, + ) + + +def fast_traverse_node_top_down( + node, + pattern, + action, + incidental_actions=None, + skip_descendants_of=(), + parameters=None, +): + """Traverse a subtree of an IR, executing the given actions. + + fast_traverse_node_top_down is like fast_traverse_ir_top_down, except that: + + It may be called on a subtree, instead of the top of the IR. + + It does not have any built-in incidental actions. + + Arguments: + node: An ir_data.Ir object to walk. + pattern: A list of node types to match. + action: A callable, which will be called on nodes matching `pattern`. + incidental_actions: A dict of node types to callables, which can be used to + set new parameters for `action` for part of the IR tree. + skip_descendants_of: A list of types whose children should be skipped when + traversing `node`. + parameters: A list of top-level parameters. + + Returns: + None + """ + _fast_traverse_proto_top_down( + node, + incidental_actions or {}, + pattern, + skip_descendants_of or {}, + action, + parameters or {}, + ) diff --git a/compiler/util/traverse_ir_test.py b/compiler/util/traverse_ir_test.py index ff54d63..504eb88 100644 --- a/compiler/util/traverse_ir_test.py +++ b/compiler/util/traverse_ir_test.py @@ -22,7 +22,9 @@ from compiler.util import ir_data_utils from compiler.util import traverse_ir -_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, """{ +_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json( + ir_data.EmbossIr, + """{ "module": [ { "type": [ @@ -175,165 +177,252 @@ "source_file_name": "" } ] -}""") +}""", +) def _count_entries(sequence): - counts = collections.Counter() - for entry in sequence: - counts[entry] += 1 - return counts + counts = collections.Counter() + for entry in sequence: + counts[entry] += 1 + return counts def _record_constant(constant, constant_list): - constant_list.append(int(constant.value)) + constant_list.append(int(constant.value)) def _record_field_name_and_constant(constant, constant_list, field): - constant_list.append((field.name.name.text, int(constant.value))) + constant_list.append((field.name.name.text, int(constant.value))) def _record_file_name_and_constant(constant, constant_list, source_file_name): - constant_list.append((source_file_name, int(constant.value))) + constant_list.append((source_file_name, int(constant.value))) -def _record_location_parameter_and_constant(constant, constant_list, - location=None): - constant_list.append((location, int(constant.value))) +def _record_location_parameter_and_constant(constant, constant_list, location=None): + constant_list.append((location, int(constant.value))) def _record_kind_and_constant(constant, constant_list, type_definition): - if type_definition.HasField("enumeration"): - constant_list.append(("enumeration", int(constant.value))) - elif type_definition.HasField("structure"): - constant_list.append(("structure", int(constant.value))) - elif type_definition.HasField("external"): - constant_list.append(("external", int(constant.value))) - else: - assert False, "Shouldn't be here." + if type_definition.HasField("enumeration"): + constant_list.append(("enumeration", int(constant.value))) + elif type_definition.HasField("structure"): + constant_list.append(("structure", int(constant.value))) + elif type_definition.HasField("external"): + constant_list.append(("external", int(constant.value))) + else: + assert False, "Shouldn't be here." class TraverseIrTest(unittest.TestCase): - def test_filter_on_type(self): - constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant, - parameters={"constant_list": constants}) - self.assertEqual( - _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]), - _count_entries(constants)) - - def test_filter_on_type_in_type(self): - constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, - [ir_data.Function, ir_data.Expression, ir_data.NumericConstant], - _record_constant, - parameters={"constant_list": constants}) - self.assertEqual([1, 1], constants) - - def test_filter_on_type_star_type(self): - struct_constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.Structure, ir_data.NumericConstant], - _record_constant, - parameters={"constant_list": struct_constants}) - self.assertEqual(_count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]), - _count_entries(struct_constants)) - enum_constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.Enum, ir_data.NumericConstant], _record_constant, - parameters={"constant_list": enum_constants}) - self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants)) - - def test_filter_on_not_type(self): - notstruct_constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant, - skip_descendants_of=(ir_data.Structure,), - parameters={"constant_list": notstruct_constants}) - self.assertEqual(_count_entries([1, 1, 1, 64]), - _count_entries(notstruct_constants)) - - def test_field_is_populated(self): - constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.Field, ir_data.NumericConstant], - _record_field_name_and_constant, - parameters={"constant_list": constants}) - self.assertEqual(_count_entries([ - ("field1", 0), ("field1", 8), ("field2", 8), ("field2", 8), - ("field2", 16), ("bar_field1", 24), ("bar_field1", 32), - ("bar_field2", 16), ("bar_field2", 32), ("bar_field2", 320) - ]), _count_entries(constants)) - - def test_file_name_is_populated(self): - constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.NumericConstant], _record_file_name_and_constant, - parameters={"constant_list": constants}) - self.assertEqual(_count_entries([ - ("t.emb", 0), ("t.emb", 8), ("t.emb", 8), ("t.emb", 8), ("t.emb", 16), - ("t.emb", 24), ("t.emb", 32), ("t.emb", 16), ("t.emb", 32), - ("t.emb", 320), ("t.emb", 1), ("t.emb", 1), ("t.emb", 1), ("", 64) - ]), _count_entries(constants)) - - def test_type_definition_is_populated(self): - constants = [] - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.NumericConstant], _record_kind_and_constant, - parameters={"constant_list": constants}) - self.assertEqual(_count_entries([ - ("structure", 0), ("structure", 8), ("structure", 8), ("structure", 8), - ("structure", 16), ("structure", 24), ("structure", 32), - ("structure", 16), ("structure", 32), ("structure", 320), - ("enumeration", 1), ("enumeration", 1), ("enumeration", 1), - ("external", 64) - ]), _count_entries(constants)) - - def test_keyword_args_dict_in_action(self): - call_counts = {"populated": 0, "not": 0} - - def check_field_is_populated(node, **kwargs): - del node # Unused. - self.assertTrue(kwargs["field"]) - call_counts["populated"] += 1 - - def check_field_is_not_populated(node, **kwargs): - del node # Unused. - self.assertFalse("field" in kwargs) - call_counts["not"] += 1 - - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated) - self.assertEqual(7, call_counts["populated"]) - - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue], - check_field_is_not_populated) - self.assertEqual(2, call_counts["not"]) - - def test_pass_only_to_sub_nodes(self): - constants = [] - - def pass_location_down(field): - return { - "location": (int(field.location.start.constant.value), - int(field.location.size.constant.value)) - } + def test_filter_on_type(self): + constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.NumericConstant], + _record_constant, + parameters={"constant_list": constants}, + ) + self.assertEqual( + _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]), + _count_entries(constants), + ) + + def test_filter_on_type_in_type(self): + constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.Function, ir_data.Expression, ir_data.NumericConstant], + _record_constant, + parameters={"constant_list": constants}, + ) + self.assertEqual([1, 1], constants) + + def test_filter_on_type_star_type(self): + struct_constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.Structure, ir_data.NumericConstant], + _record_constant, + parameters={"constant_list": struct_constants}, + ) + self.assertEqual( + _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]), + _count_entries(struct_constants), + ) + enum_constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.Enum, ir_data.NumericConstant], + _record_constant, + parameters={"constant_list": enum_constants}, + ) + self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants)) + + def test_filter_on_not_type(self): + notstruct_constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.NumericConstant], + _record_constant, + skip_descendants_of=(ir_data.Structure,), + parameters={"constant_list": notstruct_constants}, + ) + self.assertEqual( + _count_entries([1, 1, 1, 64]), _count_entries(notstruct_constants) + ) + + def test_field_is_populated(self): + constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.Field, ir_data.NumericConstant], + _record_field_name_and_constant, + parameters={"constant_list": constants}, + ) + self.assertEqual( + _count_entries( + [ + ("field1", 0), + ("field1", 8), + ("field2", 8), + ("field2", 8), + ("field2", 16), + ("bar_field1", 24), + ("bar_field1", 32), + ("bar_field2", 16), + ("bar_field2", 32), + ("bar_field2", 320), + ] + ), + _count_entries(constants), + ) + + def test_file_name_is_populated(self): + constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.NumericConstant], + _record_file_name_and_constant, + parameters={"constant_list": constants}, + ) + self.assertEqual( + _count_entries( + [ + ("t.emb", 0), + ("t.emb", 8), + ("t.emb", 8), + ("t.emb", 8), + ("t.emb", 16), + ("t.emb", 24), + ("t.emb", 32), + ("t.emb", 16), + ("t.emb", 32), + ("t.emb", 320), + ("t.emb", 1), + ("t.emb", 1), + ("t.emb", 1), + ("", 64), + ] + ), + _count_entries(constants), + ) + + def test_type_definition_is_populated(self): + constants = [] + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.NumericConstant], + _record_kind_and_constant, + parameters={"constant_list": constants}, + ) + self.assertEqual( + _count_entries( + [ + ("structure", 0), + ("structure", 8), + ("structure", 8), + ("structure", 8), + ("structure", 16), + ("structure", 24), + ("structure", 32), + ("structure", 16), + ("structure", 32), + ("structure", 320), + ("enumeration", 1), + ("enumeration", 1), + ("enumeration", 1), + ("external", 64), + ] + ), + _count_entries(constants), + ) + + def test_keyword_args_dict_in_action(self): + call_counts = {"populated": 0, "not": 0} + + def check_field_is_populated(node, **kwargs): + del node # Unused. + self.assertTrue(kwargs["field"]) + call_counts["populated"] += 1 + + def check_field_is_not_populated(node, **kwargs): + del node # Unused. + self.assertFalse("field" in kwargs) + call_counts["not"] += 1 + + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated + ) + self.assertEqual(7, call_counts["populated"]) + + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue], check_field_is_not_populated + ) + self.assertEqual(2, call_counts["not"]) + + def test_pass_only_to_sub_nodes(self): + constants = [] + + def pass_location_down(field): + return { + "location": ( + int(field.location.start.constant.value), + int(field.location.size.constant.value), + ) + } - traverse_ir.fast_traverse_ir_top_down( - _EXAMPLE_IR, [ir_data.NumericConstant], - _record_location_parameter_and_constant, - incidental_actions={ir_data.Field: pass_location_down}, - parameters={"constant_list": constants, "location": None}) - self.assertEqual(_count_entries([ - ((0, 8), 0), ((0, 8), 8), ((8, 16), 8), ((8, 16), 8), ((8, 16), 16), - ((24, 32), 24), ((24, 32), 32), ((32, 320), 16), ((32, 320), 32), - ((32, 320), 320), (None, 1), (None, 1), (None, 1), (None, 64) - ]), _count_entries(constants)) + traverse_ir.fast_traverse_ir_top_down( + _EXAMPLE_IR, + [ir_data.NumericConstant], + _record_location_parameter_and_constant, + incidental_actions={ir_data.Field: pass_location_down}, + parameters={"constant_list": constants, "location": None}, + ) + self.assertEqual( + _count_entries( + [ + ((0, 8), 0), + ((0, 8), 8), + ((8, 16), 8), + ((8, 16), 8), + ((8, 16), 16), + ((24, 32), 24), + ((24, 32), 32), + ((32, 320), 16), + ((32, 320), 32), + ((32, 320), 320), + (None, 1), + (None, 1), + (None, 1), + (None, 64), + ] + ), + _count_entries(constants), + ) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/doc/__init__.py b/doc/__init__.py index 2c31d84..086a24e 100644 --- a/doc/__init__.py +++ b/doc/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/runtime/cpp/generators/all_known.py b/runtime/cpp/generators/all_known.py index 5efa03f..c3c1108 100644 --- a/runtime/cpp/generators/all_known.py +++ b/runtime/cpp/generators/all_known.py @@ -22,7 +22,8 @@ OVERLOADS = 64 # Copyright header in the generated code complies with Google policies. -print("""// Copyright 2020 Google LLC +print( + """// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,18 +38,27 @@ // limitations under the License. // GENERATED CODE. DO NOT EDIT. REGENERATE WITH -// runtime/cpp/generators/all_known.py""") +// runtime/cpp/generators/all_known.py""" +) for i in range(1, OVERLOADS + 1): - print(""" + print( + """ template <{}> inline constexpr bool AllKnown({}) {{ return {}; }}""".format( - ", ".join(["typename T{}".format(n) for n in range(i)] + - (["typename... RestT"] if i == OVERLOADS else [])), - ", ".join(["T{} v{}".format(n, n) for n in range(i)] + - (["RestT... rest"] if i == OVERLOADS else [])), - " && ".join(["v{}.Known()".format(n) for n in range(i)] + - (["AllKnown(rest...)"] if i == OVERLOADS else [])))) - + ", ".join( + ["typename T{}".format(n) for n in range(i)] + + (["typename... RestT"] if i == OVERLOADS else []) + ), + ", ".join( + ["T{} v{}".format(n, n) for n in range(i)] + + (["RestT... rest"] if i == OVERLOADS else []) + ), + " && ".join( + ["v{}.Known()".format(n) for n in range(i)] + + (["AllKnown(rest...)"] if i == OVERLOADS else []) + ), + ) + ) diff --git a/runtime/cpp/generators/maximum_operation_do.py b/runtime/cpp/generators/maximum_operation_do.py index 9609873..b0087a3 100644 --- a/runtime/cpp/generators/maximum_operation_do.py +++ b/runtime/cpp/generators/maximum_operation_do.py @@ -24,7 +24,8 @@ OVERLOADS = 64 # Copyright header in the generated code complies with Google policies. -print("""// Copyright 2020 Google LLC +print( + """// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,17 +40,21 @@ // limitations under the License. // GENERATED CODE. DO NOT EDIT. REGENERATE WITH -// runtime/cpp/generators/maximum_operation_do.py""") +// runtime/cpp/generators/maximum_operation_do.py""" +) for i in range(5, OVERLOADS + 1): - print(""" + print( + """ template static inline constexpr T Do({0}) {{ return Do(Do({1}), Do({2})); }}""".strip().format( - ", ".join(["T v{}".format(n) for n in range(i)]), - ", ".join(["v{}".format(n) for n in range(i // 2)]), - ", ".join(["v{}".format(n) for n in range(i // 2, i)]))) + ", ".join(["T v{}".format(n) for n in range(i)]), + ", ".join(["v{}".format(n) for n in range(i // 2)]), + ", ".join(["v{}".format(n) for n in range(i // 2, i)]), + ) + ) # The "more than OVERLOADS arguments" overload uses a variadic template to # handle the remaining arguments, even though all arguments should have the @@ -59,11 +64,14 @@ # # This also uses one explicit argument, rest0, to ensure that it does not get # confused with the last non-variadic overload. -print(""" +print( + """ template static inline constexpr T Do({0}, T rest0, RestT... rest) {{ return Do(Do({1}), Do(rest0, rest...)); }}""".format( - ", ".join(["T v{}".format(n) for n in range(OVERLOADS)]), - ", ".join(["v{}".format(n) for n in range(OVERLOADS)]), - OVERLOADS)) + ", ".join(["T v{}".format(n) for n in range(OVERLOADS)]), + ", ".join(["v{}".format(n) for n in range(OVERLOADS)]), + OVERLOADS, + ) +) diff --git a/testdata/__init__.py b/testdata/__init__.py index 2c31d84..086a24e 100644 --- a/testdata/__init__.py +++ b/testdata/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/testdata/format/__init__.py b/testdata/format/__init__.py index 2c31d84..086a24e 100644 --- a/testdata/format/__init__.py +++ b/testdata/format/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/testdata/golden/__init__.py b/testdata/golden/__init__.py index 2c31d84..086a24e 100644 --- a/testdata/golden/__init__.py +++ b/testdata/golden/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -