diff --git a/et_converter/text2chakra_converter.py b/et_converter/text2chakra_converter.py index c5fb826c..f6cf05ba 100644 --- a/et_converter/text2chakra_converter.py +++ b/et_converter/text2chakra_converter.py @@ -14,6 +14,7 @@ ALL_TO_ALL, ALL_GATHER, REDUCE_SCATTER, + GlobalMetadata ) class Layer: @@ -66,6 +67,17 @@ def __init__( self.num_passes = num_passes self.logger = logger self.next_node_id = 0 + + def get_global_metadata(self): + input_text = "" + with open(self.input_filename, "r") as input_file: + input_text = input_file.read() + attr = [ + ChakraAttr(name="schema", string_val="text2chakra_converter"), + ChakraAttr(name="input_file", string_val=input_text) + ] + metadata = GlobalMetadata(attr=attr) + return metadata def get_layers( self, @@ -126,6 +138,10 @@ def get_comm_coll_node( node.attr.append( ChakraAttr(name="comm_type", int64_val=self.get_comm_type(comm_type))) + node.attr.append( + ChakraAttr(name="comm_size", + uint64_val = comm_size) + ) return node def add_parent( @@ -133,7 +149,7 @@ def add_parent( child_node: Any, parent_node: Any ) -> None: - child_node.parent.append(parent_node.id) + child_node.data_deps.append(parent_node.id) def convert(self) -> None: with open(self.input_filename, "r") as f: @@ -167,6 +183,8 @@ def convert_microbenchmark( for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) with open(output_filename, "wb") as g: + global_metadata = self.get_global_metadata() + encode_message(g, global_metadata) for i in range(self.num_passes): for layer in layers: bwd_wg_comm_node = self.get_comm_coll_node( @@ -190,6 +208,8 @@ def convert_data_parallel( for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) with open(output_filename, "wb") as g: + global_metadata = self.get_global_metadata() + encode_message(g, global_metadata) for i in range(self.num_passes): fwd_comp_node = None @@ -252,6 +272,8 @@ def convert_model_parallel( for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) with open(output_filename, "wb") as g: + global_metadata = self.get_global_metadata() + encode_message(g, global_metadata) for i in range(self.num_passes): fwd_comm_node = None @@ -327,6 +349,8 @@ def convert_hybrid_data_model( for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) with open(output_filename, "wb") as g: + global_metadata = self.get_global_metadata() + encode_message(g, global_metadata) for i in range(self.num_passes): fwd_comm_node = None @@ -416,6 +440,8 @@ def convert_hybrid_model_data( for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) with open(output_filename, "wb") as g: + global_metadata = self.get_global_metadata() + encode_message(g, global_metadata) for i in range(self.num_passes): fwd_comm_node = None @@ -504,6 +530,8 @@ def convert_hybrid_dlrm( for npu_id in range(self.num_npus): output_filename = "%s.%d.et" % (self.output_filename, npu_id) with open(output_filename, "wb") as g: + global_metadata = self.get_global_metadata() + encode_message(g, global_metadata) for i in range(self.num_passes): fwd_comp_node = None