Skip to content

Commit

Permalink
Change decomposition of zeros_like op
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 committed Jan 20, 2025
1 parent 4304c2f commit 392f255
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,8 @@ def zeros(self, inputs, input_types):
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype

data = data[0] if isinstance(data, list) else data
return self.full_impl(data, 0, dtype)

def zero_(self, inputs, input_types):
Expand All @@ -960,16 +962,10 @@ def zero_(self, inputs, input_types):

def zeros_like(self, inputs, input_types):
data = inputs[0]
out = _op.zeros_like(data)

# If the input and the output datatype is different, do a cast
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
if input_types[0] not in dtype:
out = _op.cast(out, dtype)


inputs = [[_infer_shape(data)], inputs[1]]
out = self.zeros(inputs, input_types[0])

return out


Expand Down Expand Up @@ -6244,4 +6240,4 @@ def from_pytorch(
if export_renamed_c_graph_path:
export_c_graph(export_renamed_c_graph_path, graph)

return transform.RemoveUnusedFunctions()(mod), tvm_params
return transform.RemoveUnusedFunctions()(mod), tvm_params

0 comments on commit 392f255

Please sign in to comment.