Skip to content

Commit

Permalink
Add Greater and others for v9 (#270)
Browse files Browse the repository at this point in the history
Add v9 support for Greater, Constant, Flatten, Gemm,
MatMul, and PRelu. No logic changes seem needed, just to
declare support of v9.
  • Loading branch information
chinhuang007 authored and tjingrant committed Oct 4, 2018
1 parent 026bd29 commit 6013e3a
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 15 deletions.
12 changes: 6 additions & 6 deletions doc/support_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ______
|Ceil|1, 6|
|Clip|1, 6|
|Concat|1, 4|
|Constant|1|
|Constant|1, 9|
|ConstantFill|1|
|ConstantLike|9|
|Conv|1|
Expand All @@ -38,17 +38,17 @@ ______
|Exp|1, 6|
|Expand|8|
|EyeLike|N/A|
|Flatten|1|
|Flatten|1, 9|
|Floor|1, 6|
|GRU|1, 3, 7|
|GRUUnit|N/A|
|Gather|1|
|Gemm|1, 6, 7|
|Gemm|1, 6, 7, 9|
|GivenTensorFill|N/A|
|GlobalAveragePool|1|
|GlobalLpPool|1, 2|
|GlobalMaxPool|1|
|Greater|1, 7|
|Greater|1, 7, 9|
|HardSigmoid|1, 6|
|Hardmax|1|
|Identity|1|
Expand All @@ -64,7 +64,7 @@ ______
|Loop|N/A|
|LpNormalization|1|
|LpPool|N/A|
|MatMul|1|
|MatMul|1, 9|
|Max|1, 6, 8|
|MaxPool|1, 8|
|MaxRoiPool|N/A|
Expand All @@ -76,7 +76,7 @@ ______
|Neg|1, 6|
|Not|1|
|Or|1, 7|
|PRelu|1, 6, 7|
|PRelu|1, 6, 7, 9|
|Pad|1, 2|
|ParametricSoftplus|N/A|
|Pow|1, 7|
Expand Down
10 changes: 9 additions & 1 deletion onnx_tf/handlers/backend/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@
class Constant(BackendHandler):

@classmethod
def version_1(cls, node, **kwargs):
def _common(cls, node, **kwargs):
attr_value = node.attrs["value"]
dtype = data_type.onnx2tf(attr_value.data_type)
value = numpy_helper.to_array(attr_value)
return [
cls.make_tensor_from_onnx_node(
node, inputs=[value], attrs={"dtype": dtype})
]

@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)
11 changes: 10 additions & 1 deletion onnx_tf/handlers/backend/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Flatten(BackendHandler):

@classmethod
def version_1(cls, node, **kwargs):
def _common(cls, node, **kwargs):
x = kwargs["tensor_dict"][node.inputs[0]]
shape = tf.shape(x)
x_rank = len(x.shape)
Expand All @@ -25,3 +25,12 @@ def version_1(cls, node, **kwargs):
cal_shape = (tf.reduce_prod(shape[0:axis]),
tf.reduce_prod(shape[axis:tf.size(shape)]))
return [tf.reshape(x, cal_shape)]

@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)

4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ def version_6(cls, node, **kwargs):
@classmethod
def version_7(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/greater.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ def version_1(cls, node, **kwargs):
@classmethod
def version_7(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

@classmethod
def version_9(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/mat_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ class MatMul(BackendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

@classmethod
def version_9(cls, node, **kwargs):
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
4 changes: 4 additions & 0 deletions onnx_tf/handlers/backend/p_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ def version_6(cls, node, **kwargs):
@classmethod
def version_7(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/greater.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def version_1(cls, node, **kwargs):
@classmethod
def version_7(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)

@classmethod
def version_9(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)
10 changes: 9 additions & 1 deletion onnx_tf/handlers/frontend/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class Matmul(FrontendHandler):

@classmethod
def version_1(cls, node, **kwargs):
def _common(cls, node, **kwargs):
transpose_a = node.attr.get("transpose_a", False)
transpose_b = node.attr.get("transpose_b", False)
input_a = node.inputs[0]
Expand All @@ -33,3 +33,11 @@ def version_1(cls, node, **kwargs):
nodes.append(transposed_b)
nodes.append(cls.make_node_from_tf_node(node, [input_a, input_b]))
return nodes

@classmethod
def version_1(cls, node, **kwargs):
return cls._common(node, **kwargs)

@classmethod
def version_9(cls, node, **kwargs):
return cls._common(node, **kwargs)
12 changes: 6 additions & 6 deletions onnx_tf/opset_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
'Ceil': [1, 6],
'Clip': [1, 6],
'Concat': [1, 4],
'Constant': [1],
'Constant': [1, 9],
'ConstantFill': [1],
'ConstantLike': [9],
'Conv': [1],
Expand All @@ -31,17 +31,17 @@
'Exp': [1, 6],
'Expand': [8],
'EyeLike': [],
'Flatten': [1],
'Flatten': [1, 9],
'Floor': [1, 6],
'GRU': [1, 3, 7],
'GRUUnit': [],
'Gather': [1],
'Gemm': [1, 6, 7],
'Gemm': [1, 6, 7, 9],
'GivenTensorFill': [],
'GlobalAveragePool': [1],
'GlobalLpPool': [1, 2],
'GlobalMaxPool': [1],
'Greater': [1, 7],
'Greater': [1, 7, 9],
'HardSigmoid': [1, 6],
'Hardmax': [1],
'Identity': [1],
Expand All @@ -57,7 +57,7 @@
'Loop': [],
'LpNormalization': [1],
'LpPool': [],
'MatMul': [1],
'MatMul': [1, 9],
'Max': [1, 6, 8],
'MaxPool': [1, 8],
'MaxRoiPool': [],
Expand All @@ -69,7 +69,7 @@
'Neg': [1, 6],
'Not': [1],
'Or': [1, 7],
'PRelu': [1, 6, 7],
'PRelu': [1, 6, 7, 9],
'Pad': [1, 2],
'ParametricSoftplus': [],
'Pow': [1, 7],
Expand Down

0 comments on commit 6013e3a

Please sign in to comment.