diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 37f17057b..2213fe121 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -952,6 +952,8 @@ def zeros(self, inputs, input_types): dtype = _convert_dtype_value(inputs[1]) else: dtype = self.default_dtype + + data = data[0] if isinstance(data, list) else data return self.full_impl(data, 0, dtype) def zero_(self, inputs, input_types): @@ -960,16 +962,10 @@ def zero_(self, inputs, input_types): def zeros_like(self, inputs, input_types): data = inputs[0] - out = _op.zeros_like(data) - - # If the input and the output datatype is different, do a cast - if inputs[1] is not None: - dtype = _convert_dtype_value(inputs[1]) - else: - dtype = self.default_dtype - if input_types[0] not in dtype: - out = _op.cast(out, dtype) - + + inputs = [[_infer_shape(data)], inputs[1]] + out = self.zeros(inputs, input_types[0]) + return out @@ -6244,4 +6240,4 @@ def from_pytorch( if export_renamed_c_graph_path: export_c_graph(export_renamed_c_graph_path, graph) - return transform.RemoveUnusedFunctions()(mod), tvm_params + return transform.RemoveUnusedFunctions()(mod), tvm_params \ No newline at end of file